package tests import ( "context" "fmt" "net/http/httptest" "os" "path/filepath" "strings" "sync" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api" wspkg "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/stretchr/testify/require" ) func TestWebSocketQueueEndToEnd(t *testing.T) { if testing.Short() { t.Skip("skipping websocket queue integration in short mode") } // Miniredis provides an in-memory Redis compatible server for realistic queue tests. redisServer, err := miniredis.Run() require.NoError(t, err) defer redisServer.Close() taskQueue, err := queue.NewTaskQueue(queue.Config{ RedisAddr: redisServer.Addr(), MetricsFlushInterval: 10 * time.Millisecond, }) require.NoError(t, err) defer func() { _ = taskQueue.Close() }() expMgr := experiment.NewManager(t.TempDir()) require.NoError(t, expMgr.Initialize()) logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, "", taskQueue, nil, // db nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger ) server := httptest.NewServer(wsHandler) defer server.Close() ctx, cancelWorkers := context.WithCancel(context.Background()) defer cancelWorkers() const ( jobCount = 20 workerCount = 4 clientConcurrency = 8 ) doneCh := make(chan string, jobCount) var workerWG sync.WaitGroup startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh) // Submit jobs concurrently through the real WebSocket protocol. var submitWG sync.WaitGroup sem := make(chan struct{}, clientConcurrency) for i := 0; i < jobCount; i++ { submitWG.Add(1) go func(idx int) { sem <- struct{}{} defer submitWG.Done() defer func() { <-sem }() jobName := fmt.Sprintf("ws-load-job-%02d", idx) commitBytes := make([]byte, 20) for j := range commitBytes { commitBytes[j] = byte(idx + 1) } commitIDStr := fmt.Sprintf("%x", commitBytes) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5)) }(i) } submitWG.Wait() completed := 0 timeout := time.After(30 * time.Second) for completed < jobCount { select { case <-timeout: t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed) case <-doneCh: completed++ } } // Stop workers and ensure they exit cleanly. cancelWorkers() workerWG.Wait() nextTask, err := taskQueue.GetNextTask() require.NoError(t, err) require.Nil(t, nextTask, "queue should be empty after all jobs complete") } func TestWebSocketQueueEndToEndSQLite(t *testing.T) { if testing.Short() { t.Skip("skipping websocket queue integration in short mode") } queuePath := filepath.Join(t.TempDir(), "queue.db") taskQueue, err := queue.NewSQLiteQueue(queuePath) require.NoError(t, err) defer func() { _ = taskQueue.Close() }() expMgr := experiment.NewManager(t.TempDir()) require.NoError(t, expMgr.Initialize()) logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, "", taskQueue, nil, // db nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger ) server := httptest.NewServer(wsHandler) defer server.Close() ctx, cancelWorkers := context.WithCancel(context.Background()) defer cancelWorkers() const ( jobCount = 10 workerCount = 2 clientConcurrency = 4 ) doneCh := make(chan string, jobCount) var workerWG sync.WaitGroup startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh) var submitWG sync.WaitGroup sem := make(chan struct{}, clientConcurrency) for i := 0; i < jobCount; i++ { submitWG.Add(1) go func(idx int) { sem <- struct{}{} defer submitWG.Done() defer func() { <-sem }() jobName := fmt.Sprintf("ws-sqlite-job-%02d", idx) commitBytes := make([]byte, 20) for j := range commitBytes { commitBytes[j] = byte(idx + 1) } commitIDStr := fmt.Sprintf("%x", commitBytes) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5)) }(i) } submitWG.Wait() completed := 0 timeout := time.After(20 * time.Second) for completed < jobCount { select { case <-timeout: t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed) case <-doneCh: completed++ } } cancelWorkers() workerWG.Wait() nextTask, err := taskQueue.GetNextTask() require.NoError(t, err) require.Nil(t, nextTask, "queue should be empty after all jobs complete") } func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) { if testing.Short() { t.Skip("skipping websocket queue integration in short mode") } redisServer, err := miniredis.Run() require.NoError(t, err) defer redisServer.Close() taskQueue, err := queue.NewTaskQueue(queue.Config{ RedisAddr: redisServer.Addr(), MetricsFlushInterval: 10 * time.Millisecond, }) require.NoError(t, err) defer func() { _ = taskQueue.Close() }() expMgr := experiment.NewManager(t.TempDir()) require.NoError(t, expMgr.Initialize()) logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} wsHandler := wspkg.NewHandler( authCfg, logger, expMgr, "", taskQueue, nil, // db nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger ) server := httptest.NewServer(wsHandler) defer server.Close() commitBytes := make([]byte, 20) for i := range commitBytes { commitBytes[i] = byte(i + 1) } commitIDStr := fmt.Sprintf("%x", commitBytes) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) queueJobViaWebSocketWithSnapshot(t, server.URL, "ws-snap-job", commitBytes, 1, "snap-1", strings.Repeat("a", 64)) tasks, err := taskQueue.GetAllTasks() require.NoError(t, err) require.Len(t, tasks, 1) require.Equal(t, "snap-1", tasks[0].SnapshotID) require.Equal(t, strings.Repeat("a", 64), tasks[0].Metadata["snapshot_sha256"]) } func startFakeWorkers( ctx context.Context, t *testing.T, wg *sync.WaitGroup, taskQueue queue.Backend, workerCount int, doneCh chan<- string, ) { for w := 0; w < workerCount; w++ { wg.Add(1) go func(workerID string) { defer wg.Done() for { select { case <-ctx.Done(): return default: } task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second) if err != nil { t.Logf("worker %s queue error: %v", workerID, err) time.Sleep(5 * time.Millisecond) continue } if task == nil { time.Sleep(5 * time.Millisecond) continue } started := time.Now() completed := started.Add(10 * time.Millisecond) task.Status = statusCompleted task.StartedAt = &started task.EndedAt = &completed task.LeaseExpiry = nil task.LeasedBy = "" if err := taskQueue.UpdateTask(task); err != nil { t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err) continue } doneCh <- task.JobName } }(fmt.Sprintf("worker-%d", w)) } } func queueJobViaWebSocket(t *testing.T, baseURL, jobName string, commitID []byte, priority byte) { t.Helper() wsURL := "ws" + strings.TrimPrefix(baseURL, "http") conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) if resp != nil && resp.Body != nil { defer func() { if err := resp.Body.Close(); err != nil { t.Logf("Warning: failed to close response body: %v", err) } }() } require.NoError(t, err) defer func() { _ = conn.Close() }() msg := buildQueueJobMessage(jobName, commitID, priority) require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg)) _, payload, err := conn.ReadMessage() require.NoError(t, err) require.NotEmpty(t, payload, "expected response payload") require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet") } func queueJobViaWebSocketWithSnapshot( t *testing.T, baseURL, jobName string, commitID []byte, priority byte, snapshotID string, snapshotSHA string, ) { t.Helper() wsURL := "ws" + strings.TrimPrefix(baseURL, "http") conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) if resp != nil && resp.Body != nil { defer func() { if err := resp.Body.Close(); err != nil { t.Logf("Warning: failed to close response body: %v", err) } }() } require.NoError(t, err) defer func() { _ = conn.Close() }() msg := buildQueueJobWithSnapshotMessage(jobName, commitID, priority, snapshotID, snapshotSHA) require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg)) _, payload, err := conn.ReadMessage() require.NoError(t, err) require.NotEmpty(t, payload, "expected response payload") require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet") } func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte { jobBytes := []byte(jobName) if len(jobBytes) > 255 { jobBytes = jobBytes[:255] } if len(commitID) != 20 { // In tests we always use 20 bytes per protocol. padded := make([]byte, 20) copy(padded, commitID) commitID = padded } apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)) buf = append(buf, wspkg.OpcodeQueueJob) buf = append(buf, apiKeyHash...) buf = append(buf, commitID...) buf = append(buf, priority) buf = append(buf, byte(len(jobBytes))) buf = append(buf, jobBytes...) return buf } func buildQueueJobWithSnapshotMessage( jobName string, commitID []byte, priority byte, snapshotID string, snapshotSHA string, ) []byte { jobBytes := []byte(jobName) if len(jobBytes) > 255 { jobBytes = jobBytes[:255] } if len(commitID) != 20 { padded := make([]byte, 20) copy(padded, commitID) commitID = padded } apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) snapIDBytes := []byte(snapshotID) if len(snapIDBytes) > 255 { snapIDBytes = snapIDBytes[:255] } snapSHAB := []byte(snapshotSHA) if len(snapSHAB) > 255 { snapSHAB = snapSHAB[:255] } buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB)) buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot) buf = append(buf, apiKeyHash...) buf = append(buf, commitID...) buf = append(buf, priority) buf = append(buf, byte(len(jobBytes))) buf = append(buf, jobBytes...) buf = append(buf, byte(len(snapIDBytes))) buf = append(buf, snapIDBytes...) buf = append(buf, byte(len(snapSHAB))) buf = append(buf, snapSHAB...) return buf }