fetch_ml/tests/integration/websocket_queue_integration_test.go
Jeremie Fraeys 8ecdd36155
Some checks failed
Checkout test / test (push) Successful in 7s
CI with Native Libraries / Check Build Environment (push) Successful in 13s
CI/CD Pipeline / Test (push) Failing after 5m8s
CI/CD Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI/CD Pipeline / Build (push) Has been skipped
CI/CD Pipeline / Test Scripts (push) Has been skipped
CI/CD Pipeline / Security Scan (push) Failing after 4m51s
Documentation / build-and-publish (push) Failing after 37s
CI with Native Libraries / Build and Test Native Libraries (push) Failing after 14m38s
CI with Native Libraries / Build Release Libraries (push) Has been skipped
CI/CD Pipeline / Docker Build (push) Has been skipped
test(integration): add websocket queue and hash benchmarks
- Add websocket queue integration test
- Add worker hash benchmark test
- Add native detection script
2026-02-18 12:46:06 -05:00

462 lines
13 KiB
Go

package tests
import (
"context"
"fmt"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/stretchr/testify/require"
)
func TestWebSocketQueueEndToEnd(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
// Miniredis provides an in-memory Redis compatible server for realistic queue tests.
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
taskQueue, err := queue.NewTaskQueue(queue.Config{
RedisAddr: redisServer.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
})
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
ctx, cancelWorkers := context.WithCancel(context.Background())
defer cancelWorkers()
const (
jobCount = 20
workerCount = 4
clientConcurrency = 8
)
doneCh := make(chan string, jobCount)
var workerWG sync.WaitGroup
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
// Submit jobs concurrently through the real WebSocket protocol.
var submitWG sync.WaitGroup
sem := make(chan struct{}, clientConcurrency)
for i := 0; i < jobCount; i++ {
submitWG.Add(1)
go func(idx int) {
sem <- struct{}{}
defer submitWG.Done()
defer func() { <-sem }()
jobName := fmt.Sprintf("ws-load-job-%02d", idx)
commitBytes := make([]byte, 20)
for j := range commitBytes {
commitBytes[j] = byte(idx + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
}(i)
}
submitWG.Wait()
completed := 0
timeout := time.After(30 * time.Second)
for completed < jobCount {
select {
case <-timeout:
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
case <-doneCh:
completed++
}
}
// Stop workers and ensure they exit cleanly.
cancelWorkers()
workerWG.Wait()
nextTask, err := taskQueue.GetNextTask()
require.NoError(t, err)
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
}
func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
queuePath := filepath.Join(t.TempDir(), "queue.db")
taskQueue, err := queue.NewSQLiteQueue(queuePath)
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
ctx, cancelWorkers := context.WithCancel(context.Background())
defer cancelWorkers()
const (
jobCount = 10
workerCount = 2
clientConcurrency = 4
)
doneCh := make(chan string, jobCount)
var workerWG sync.WaitGroup
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
var submitWG sync.WaitGroup
sem := make(chan struct{}, clientConcurrency)
for i := 0; i < jobCount; i++ {
submitWG.Add(1)
go func(idx int) {
sem <- struct{}{}
defer submitWG.Done()
defer func() { <-sem }()
jobName := fmt.Sprintf("ws-sqlite-job-%02d", idx)
commitBytes := make([]byte, 20)
for j := range commitBytes {
commitBytes[j] = byte(idx + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
}(i)
}
submitWG.Wait()
completed := 0
timeout := time.After(30 * time.Second)
for completed < jobCount {
select {
case <-timeout:
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
case <-doneCh:
completed++
}
}
cancelWorkers()
workerWG.Wait()
nextTask, err := taskQueue.GetNextTask()
require.NoError(t, err)
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
}
func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
taskQueue, err := queue.NewTaskQueue(queue.Config{
RedisAddr: redisServer.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
})
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
commitBytes := make([]byte, 20)
for i := range commitBytes {
commitBytes[i] = byte(i + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocketWithSnapshot(t, server.URL, "ws-snap-job", commitBytes, 1, "snap-1", strings.Repeat("a", 64))
tasks, err := taskQueue.GetAllTasks()
require.NoError(t, err)
require.Len(t, tasks, 1)
require.Equal(t, "snap-1", tasks[0].SnapshotID)
require.Equal(t, strings.Repeat("a", 64), tasks[0].Metadata["snapshot_sha256"])
}
func startFakeWorkers(
ctx context.Context,
t *testing.T,
wg *sync.WaitGroup,
taskQueue queue.Backend,
workerCount int,
doneCh chan<- string,
) {
for w := 0; w < workerCount; w++ {
wg.Add(1)
go func(workerID string) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second)
if err != nil {
t.Logf("worker %s queue error: %v", workerID, err)
time.Sleep(5 * time.Millisecond)
continue
}
if task == nil {
time.Sleep(5 * time.Millisecond)
continue
}
started := time.Now()
completed := started.Add(10 * time.Millisecond)
task.Status = statusCompleted
task.StartedAt = &started
task.EndedAt = &completed
task.LeaseExpiry = nil
task.LeasedBy = ""
if err := taskQueue.UpdateTask(task); err != nil {
t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err)
continue
}
doneCh <- task.JobName
}
}(fmt.Sprintf("worker-%d", w))
}
}
func queueJobViaWebSocket(t *testing.T, baseURL, jobName string, commitID []byte, priority byte) {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
defer func() {
if err := resp.Body.Close(); err != nil {
t.Logf("Warning: failed to close response body: %v", err)
}
}()
}
require.NoError(t, err)
defer func() { _ = conn.Close() }()
msg := buildQueueJobMessage(jobName, commitID, priority)
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
_, payload, err := conn.ReadMessage()
require.NoError(t, err)
require.NotEmpty(t, payload, "expected response payload")
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
}
func queueJobViaWebSocketWithSnapshot(
t *testing.T,
baseURL, jobName string,
commitID []byte,
priority byte,
snapshotID string,
snapshotSHA string,
) {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
defer func() {
if err := resp.Body.Close(); err != nil {
t.Logf("Warning: failed to close response body: %v", err)
}
}()
}
require.NoError(t, err)
defer func() { _ = conn.Close() }()
msg := buildQueueJobWithSnapshotMessage(jobName, commitID, priority, snapshotID, snapshotSHA)
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
_, payload, err := conn.ReadMessage()
require.NoError(t, err)
require.NotEmpty(t, payload, "expected response payload")
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
}
func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte {
jobBytes := []byte(jobName)
if len(jobBytes) > 255 {
jobBytes = jobBytes[:255]
}
if len(commitID) != 20 {
// In tests we always use 20 bytes per protocol.
padded := make([]byte, 20)
copy(padded, commitID)
commitID = padded
}
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes))
buf = append(buf, wspkg.OpcodeQueueJob)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)
buf = append(buf, byte(len(jobBytes)))
buf = append(buf, jobBytes...)
return buf
}
func buildQueueJobWithSnapshotMessage(
jobName string,
commitID []byte,
priority byte,
snapshotID string,
snapshotSHA string,
) []byte {
jobBytes := []byte(jobName)
if len(jobBytes) > 255 {
jobBytes = jobBytes[:255]
}
if len(commitID) != 20 {
padded := make([]byte, 20)
copy(padded, commitID)
commitID = padded
}
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
snapIDBytes := []byte(snapshotID)
if len(snapIDBytes) > 255 {
snapIDBytes = snapIDBytes[:255]
}
snapSHAB := []byte(snapshotSHA)
if len(snapSHAB) > 255 {
snapSHAB = snapSHAB[:255]
}
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB))
buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)
buf = append(buf, byte(len(jobBytes)))
buf = append(buf, jobBytes...)
buf = append(buf, byte(len(snapIDBytes)))
buf = append(buf, snapIDBytes...)
buf = append(buf, byte(len(snapSHAB)))
buf = append(buf, snapSHAB...)
return buf
}