package tests import ( "context" "fmt" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/stretchr/testify/require" ) func TestWebSocketQueueEndToEnd(t *testing.T) { if testing.Short() { t.Skip("skipping websocket queue integration in short mode") } // Miniredis provides an in-memory Redis compatible server for realistic queue tests. redisServer, err := miniredis.Run() require.NoError(t, err) defer redisServer.Close() taskQueue, err := queue.NewTaskQueue(queue.Config{ RedisAddr: redisServer.Addr(), MetricsFlushInterval: 10 * time.Millisecond, }) require.NoError(t, err) defer func() { _ = taskQueue.Close() }() expMgr := experiment.NewManager(t.TempDir()) require.NoError(t, expMgr.Initialize()) logger := logging.NewLogger(0, false) authCfg := &auth.Config{Enabled: false} wsHandler := api.NewWSHandler(authCfg, logger, expMgr, taskQueue) server := httptest.NewServer(wsHandler) defer server.Close() ctx, cancelWorkers := context.WithCancel(context.Background()) defer cancelWorkers() const ( jobCount = 20 workerCount = 4 clientConcurrency = 8 ) doneCh := make(chan string, jobCount) var workerWG sync.WaitGroup startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh) // Submit jobs concurrently through the real WebSocket protocol. var submitWG sync.WaitGroup sem := make(chan struct{}, clientConcurrency) for i := 0; i < jobCount; i++ { submitWG.Add(1) go func(idx int) { sem <- struct{}{} defer submitWG.Done() defer func() { <-sem }() jobName := fmt.Sprintf("ws-load-job-%02d", idx) commitID := fmt.Sprintf("%064x", idx+1) queueJobViaWebSocket(t, server.URL, jobName, commitID, byte(idx%5)) }(i) } submitWG.Wait() completed := 0 timeout := time.After(30 * time.Second) for completed < jobCount { select { case <-timeout: t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed) case <-doneCh: completed++ } } // Stop workers and ensure they exit cleanly. cancelWorkers() workerWG.Wait() nextTask, err := taskQueue.GetNextTask() require.NoError(t, err) require.Nil(t, nextTask, "queue should be empty after all jobs complete") } func startFakeWorkers( ctx context.Context, t *testing.T, wg *sync.WaitGroup, taskQueue *queue.TaskQueue, workerCount int, doneCh chan<- string, ) { for w := 0; w < workerCount; w++ { wg.Add(1) go func(workerID string) { defer wg.Done() for { select { case <-ctx.Done(): return default: } task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second) if err != nil { t.Logf("worker %s queue error: %v", workerID, err) time.Sleep(5 * time.Millisecond) continue } if task == nil { time.Sleep(5 * time.Millisecond) continue } started := time.Now() completed := started.Add(10 * time.Millisecond) task.Status = statusCompleted task.StartedAt = &started task.EndedAt = &completed task.LeaseExpiry = nil task.LeasedBy = "" if err := taskQueue.UpdateTask(task); err != nil { t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err) continue } doneCh <- task.JobName } }(fmt.Sprintf("worker-%d", w)) } } func queueJobViaWebSocket(t *testing.T, baseURL, jobName, commitID string, priority byte) { t.Helper() wsURL := "ws" + strings.TrimPrefix(baseURL, "http") conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) if resp != nil && resp.Body != nil { defer func() { if err := resp.Body.Close(); err != nil { t.Logf("Warning: failed to close response body: %v", err) } }() } require.NoError(t, err) defer func() { _ = conn.Close() }() msg := buildQueueJobMessage(jobName, commitID, priority) require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg)) _, payload, err := conn.ReadMessage() require.NoError(t, err) require.NotEmpty(t, payload, "expected response payload") require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet") } func buildQueueJobMessage(jobName, commitID string, priority byte) []byte { jobBytes := []byte(jobName) if len(jobBytes) > 255 { jobBytes = jobBytes[:255] } if len(commitID) < 64 { commitID += strings.Repeat("a", 64-len(commitID)) } buf := make([]byte, 0, 1+64+64+1+1+len(jobBytes)) buf = append(buf, api.OpcodeQueueJob) buf = append(buf, []byte(strings.Repeat("0", 64))...) buf = append(buf, []byte(commitID[:64])...) buf = append(buf, priority) buf = append(buf, byte(len(jobBytes))) buf = append(buf, jobBytes...) return buf }