fetch_ml/tests/integration/websocket_queue_integration_test.go
Jeremie Fraeys ec9e845bb6
fix(test): Fix WebSocketQueue test timeout and race conditions
Reduce worker polling interval from 5ms to 1ms for faster task pickup

Add 100ms buffer after job submission to allow queue to settle

Increase timeout from 30s to 60s to prevent flaky failures

Fixes intermittent timeout issues in integration tests
2026-02-23 14:38:18 -05:00

468 lines
13 KiB
Go

package tests
import (
"context"
"fmt"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/api/datasets"
"github.com/jfraeys/fetch_ml/internal/api/jobs"
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/stretchr/testify/require"
)
func TestWebSocketQueueEndToEnd(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
// Miniredis provides an in-memory Redis compatible server for realistic queue tests.
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
taskQueue, err := queue.NewTaskQueue(queue.Config{
RedisAddr: redisServer.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
})
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
ctx, cancelWorkers := context.WithCancel(context.Background())
defer cancelWorkers()
const (
jobCount = 20
workerCount = 4
clientConcurrency = 8
)
doneCh := make(chan string, jobCount)
var workerWG sync.WaitGroup
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
// Submit jobs concurrently through the real WebSocket protocol.
var submitWG sync.WaitGroup
sem := make(chan struct{}, clientConcurrency)
for i := 0; i < jobCount; i++ {
submitWG.Add(1)
go func(idx int) {
sem <- struct{}{}
defer submitWG.Done()
defer func() { <-sem }()
jobName := fmt.Sprintf("ws-load-job-%02d", idx)
commitBytes := make([]byte, 20)
for j := range commitBytes {
commitBytes[j] = byte(idx + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
}(i)
}
submitWG.Wait()
// Give workers time to pick up any remaining tasks
time.Sleep(100 * time.Millisecond)
completed := 0
timeout := time.After(60 * time.Second)
for completed < jobCount {
select {
case <-timeout:
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
case <-doneCh:
completed++
}
}
// Stop workers and ensure they exit cleanly.
cancelWorkers()
workerWG.Wait()
nextTask, err := taskQueue.GetNextTask()
require.NoError(t, err)
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
}
func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
queuePath := filepath.Join(t.TempDir(), "queue.db")
taskQueue, err := queue.NewSQLiteQueue(queuePath)
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
ctx, cancelWorkers := context.WithCancel(context.Background())
defer cancelWorkers()
const (
jobCount = 10
workerCount = 2
clientConcurrency = 4
)
doneCh := make(chan string, jobCount)
var workerWG sync.WaitGroup
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
var submitWG sync.WaitGroup
sem := make(chan struct{}, clientConcurrency)
for i := 0; i < jobCount; i++ {
submitWG.Add(1)
go func(idx int) {
sem <- struct{}{}
defer submitWG.Done()
defer func() { <-sem }()
jobName := fmt.Sprintf("ws-sqlite-job-%02d", idx)
commitBytes := make([]byte, 20)
for j := range commitBytes {
commitBytes[j] = byte(idx + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
}(i)
}
submitWG.Wait()
// Give workers time to pick up any remaining tasks
time.Sleep(100 * time.Millisecond)
completed := 0
timeout := time.After(60 * time.Second)
for completed < jobCount {
select {
case <-timeout:
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
case <-doneCh:
completed++
}
}
cancelWorkers()
workerWG.Wait()
nextTask, err := taskQueue.GetNextTask()
require.NoError(t, err)
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
}
func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
if testing.Short() {
t.Skip("skipping websocket queue integration in short mode")
}
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
taskQueue, err := queue.NewTaskQueue(queue.Config{
RedisAddr: redisServer.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
})
require.NoError(t, err)
defer func() { _ = taskQueue.Close() }()
expMgr := experiment.NewManager(t.TempDir())
require.NoError(t, expMgr.Initialize())
logger := logging.NewLogger(0, false)
authCfg := &auth.Config{Enabled: false}
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
datasetsHandler := datasets.NewHandler(logger, nil, "")
wsHandler := wspkg.NewHandler(
authCfg,
logger,
expMgr,
"",
taskQueue,
nil, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
jobsHandler,
jupyterHandler,
datasetsHandler,
)
server := httptest.NewServer(wsHandler)
defer server.Close()
commitBytes := make([]byte, 20)
for i := range commitBytes {
commitBytes[i] = byte(i + 1)
}
commitIDStr := fmt.Sprintf("%x", commitBytes)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
queueJobViaWebSocketWithSnapshot(t, server.URL, "ws-snap-job", commitBytes, 1, "snap-1", strings.Repeat("a", 64))
tasks, err := taskQueue.GetAllTasks()
require.NoError(t, err)
require.Len(t, tasks, 1)
require.Equal(t, "snap-1", tasks[0].SnapshotID)
require.Equal(t, strings.Repeat("a", 64), tasks[0].Metadata["snapshot_sha256"])
}
func startFakeWorkers(
ctx context.Context,
t *testing.T,
wg *sync.WaitGroup,
taskQueue queue.Backend,
workerCount int,
doneCh chan<- string,
) {
for w := 0; w < workerCount; w++ {
wg.Add(1)
go func(workerID string) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second)
if err != nil {
t.Logf("worker %s queue error: %v", workerID, err)
time.Sleep(time.Millisecond)
continue
}
if task == nil {
time.Sleep(time.Millisecond)
continue
}
started := time.Now()
completed := started.Add(10 * time.Millisecond)
task.Status = statusCompleted
task.StartedAt = &started
task.EndedAt = &completed
task.LeaseExpiry = nil
task.LeasedBy = ""
if err := taskQueue.UpdateTask(task); err != nil {
t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err)
continue
}
doneCh <- task.JobName
}
}(fmt.Sprintf("worker-%d", w))
}
}
func queueJobViaWebSocket(t *testing.T, baseURL, jobName string, commitID []byte, priority byte) {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
defer func() {
if err := resp.Body.Close(); err != nil {
t.Logf("Warning: failed to close response body: %v", err)
}
}()
}
require.NoError(t, err)
defer func() { _ = conn.Close() }()
msg := buildQueueJobMessage(jobName, commitID, priority)
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
_, payload, err := conn.ReadMessage()
require.NoError(t, err)
require.NotEmpty(t, payload, "expected response payload")
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
}
func queueJobViaWebSocketWithSnapshot(
t *testing.T,
baseURL, jobName string,
commitID []byte,
priority byte,
snapshotID string,
snapshotSHA string,
) {
t.Helper()
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
defer func() {
if err := resp.Body.Close(); err != nil {
t.Logf("Warning: failed to close response body: %v", err)
}
}()
}
require.NoError(t, err)
defer func() { _ = conn.Close() }()
msg := buildQueueJobWithSnapshotMessage(jobName, commitID, priority, snapshotID, snapshotSHA)
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
_, payload, err := conn.ReadMessage()
require.NoError(t, err)
require.NotEmpty(t, payload, "expected response payload")
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
}
func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte {
jobBytes := []byte(jobName)
if len(jobBytes) > 255 {
jobBytes = jobBytes[:255]
}
if len(commitID) != 20 {
// In tests we always use 20 bytes per protocol.
padded := make([]byte, 20)
copy(padded, commitID)
commitID = padded
}
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes))
buf = append(buf, wspkg.OpcodeQueueJob)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)
buf = append(buf, byte(len(jobBytes)))
buf = append(buf, jobBytes...)
return buf
}
func buildQueueJobWithSnapshotMessage(
jobName string,
commitID []byte,
priority byte,
snapshotID string,
snapshotSHA string,
) []byte {
jobBytes := []byte(jobName)
if len(jobBytes) > 255 {
jobBytes = jobBytes[:255]
}
if len(commitID) != 20 {
padded := make([]byte, 20)
copy(padded, commitID)
commitID = padded
}
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
snapIDBytes := []byte(snapshotID)
if len(snapIDBytes) > 255 {
snapIDBytes = snapIDBytes[:255]
}
snapSHAB := []byte(snapshotSHA)
if len(snapSHAB) > 255 {
snapSHAB = snapSHAB[:255]
}
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB))
buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot)
buf = append(buf, apiKeyHash...)
buf = append(buf, commitID...)
buf = append(buf, priority)
buf = append(buf, byte(len(jobBytes)))
buf = append(buf, jobBytes...)
buf = append(buf, byte(len(snapIDBytes)))
buf = append(buf, snapIDBytes...)
buf = append(buf, byte(len(snapSHAB)))
buf = append(buf, snapSHAB...)
return buf
}