Reduce worker polling interval from 5ms to 1ms for faster task pickup Add 100ms buffer after job submission to allow queue to settle Increase timeout from 30s to 60s to prevent flaky failures Fixes intermittent timeout issues in integration tests
468 lines
13 KiB
Go
468 lines
13 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
// Miniredis provides an in-memory Redis compatible server for realistic queue tests.
|
|
redisServer, err := miniredis.Run()
|
|
require.NoError(t, err)
|
|
defer redisServer.Close()
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queue.Config{
|
|
RedisAddr: redisServer.Addr(),
|
|
MetricsFlushInterval: 10 * time.Millisecond,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
ctx, cancelWorkers := context.WithCancel(context.Background())
|
|
defer cancelWorkers()
|
|
|
|
const (
|
|
jobCount = 20
|
|
workerCount = 4
|
|
clientConcurrency = 8
|
|
)
|
|
|
|
doneCh := make(chan string, jobCount)
|
|
var workerWG sync.WaitGroup
|
|
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
|
|
|
|
// Submit jobs concurrently through the real WebSocket protocol.
|
|
var submitWG sync.WaitGroup
|
|
sem := make(chan struct{}, clientConcurrency)
|
|
for i := 0; i < jobCount; i++ {
|
|
submitWG.Add(1)
|
|
go func(idx int) {
|
|
sem <- struct{}{}
|
|
defer submitWG.Done()
|
|
defer func() { <-sem }()
|
|
jobName := fmt.Sprintf("ws-load-job-%02d", idx)
|
|
commitBytes := make([]byte, 20)
|
|
for j := range commitBytes {
|
|
commitBytes[j] = byte(idx + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
|
|
}(i)
|
|
}
|
|
submitWG.Wait()
|
|
|
|
// Give workers time to pick up any remaining tasks
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
completed := 0
|
|
timeout := time.After(60 * time.Second)
|
|
for completed < jobCount {
|
|
select {
|
|
case <-timeout:
|
|
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
|
|
case <-doneCh:
|
|
completed++
|
|
}
|
|
}
|
|
|
|
// Stop workers and ensure they exit cleanly.
|
|
cancelWorkers()
|
|
workerWG.Wait()
|
|
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
require.NoError(t, err)
|
|
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
|
|
}
|
|
|
|
func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
queuePath := filepath.Join(t.TempDir(), "queue.db")
|
|
taskQueue, err := queue.NewSQLiteQueue(queuePath)
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
ctx, cancelWorkers := context.WithCancel(context.Background())
|
|
defer cancelWorkers()
|
|
|
|
const (
|
|
jobCount = 10
|
|
workerCount = 2
|
|
clientConcurrency = 4
|
|
)
|
|
|
|
doneCh := make(chan string, jobCount)
|
|
var workerWG sync.WaitGroup
|
|
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
|
|
|
|
var submitWG sync.WaitGroup
|
|
sem := make(chan struct{}, clientConcurrency)
|
|
for i := 0; i < jobCount; i++ {
|
|
submitWG.Add(1)
|
|
go func(idx int) {
|
|
sem <- struct{}{}
|
|
defer submitWG.Done()
|
|
defer func() { <-sem }()
|
|
|
|
jobName := fmt.Sprintf("ws-sqlite-job-%02d", idx)
|
|
commitBytes := make([]byte, 20)
|
|
for j := range commitBytes {
|
|
commitBytes[j] = byte(idx + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
|
|
}(i)
|
|
}
|
|
submitWG.Wait()
|
|
|
|
// Give workers time to pick up any remaining tasks
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
completed := 0
|
|
timeout := time.After(60 * time.Second)
|
|
for completed < jobCount {
|
|
select {
|
|
case <-timeout:
|
|
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
|
|
case <-doneCh:
|
|
completed++
|
|
}
|
|
}
|
|
|
|
cancelWorkers()
|
|
workerWG.Wait()
|
|
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
require.NoError(t, err)
|
|
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
|
|
}
|
|
|
|
func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
redisServer, err := miniredis.Run()
|
|
require.NoError(t, err)
|
|
defer redisServer.Close()
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queue.Config{
|
|
RedisAddr: redisServer.Addr(),
|
|
MetricsFlushInterval: 10 * time.Millisecond,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg, nil)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
commitBytes := make([]byte, 20)
|
|
for i := range commitBytes {
|
|
commitBytes[i] = byte(i + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocketWithSnapshot(t, server.URL, "ws-snap-job", commitBytes, 1, "snap-1", strings.Repeat("a", 64))
|
|
|
|
tasks, err := taskQueue.GetAllTasks()
|
|
require.NoError(t, err)
|
|
require.Len(t, tasks, 1)
|
|
require.Equal(t, "snap-1", tasks[0].SnapshotID)
|
|
require.Equal(t, strings.Repeat("a", 64), tasks[0].Metadata["snapshot_sha256"])
|
|
}
|
|
|
|
func startFakeWorkers(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
wg *sync.WaitGroup,
|
|
taskQueue queue.Backend,
|
|
workerCount int,
|
|
doneCh chan<- string,
|
|
) {
|
|
for w := 0; w < workerCount; w++ {
|
|
wg.Add(1)
|
|
go func(workerID string) {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second)
|
|
if err != nil {
|
|
t.Logf("worker %s queue error: %v", workerID, err)
|
|
time.Sleep(time.Millisecond)
|
|
continue
|
|
}
|
|
if task == nil {
|
|
time.Sleep(time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
started := time.Now()
|
|
completed := started.Add(10 * time.Millisecond)
|
|
|
|
task.Status = statusCompleted
|
|
task.StartedAt = &started
|
|
task.EndedAt = &completed
|
|
task.LeaseExpiry = nil
|
|
task.LeasedBy = ""
|
|
|
|
if err := taskQueue.UpdateTask(task); err != nil {
|
|
t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err)
|
|
continue
|
|
}
|
|
|
|
doneCh <- task.JobName
|
|
}
|
|
}(fmt.Sprintf("worker-%d", w))
|
|
}
|
|
}
|
|
|
|
func queueJobViaWebSocket(t *testing.T, baseURL, jobName string, commitID []byte, priority byte) {
|
|
t.Helper()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
t.Logf("Warning: failed to close response body: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
require.NoError(t, err)
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
msg := buildQueueJobMessage(jobName, commitID, priority)
|
|
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
|
|
|
|
_, payload, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, payload, "expected response payload")
|
|
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
|
|
}
|
|
|
|
func queueJobViaWebSocketWithSnapshot(
|
|
t *testing.T,
|
|
baseURL, jobName string,
|
|
commitID []byte,
|
|
priority byte,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) {
|
|
t.Helper()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
t.Logf("Warning: failed to close response body: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
require.NoError(t, err)
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
msg := buildQueueJobWithSnapshotMessage(jobName, commitID, priority, snapshotID, snapshotSHA)
|
|
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
|
|
|
|
_, payload, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, payload, "expected response payload")
|
|
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
|
|
}
|
|
|
|
func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte {
|
|
jobBytes := []byte(jobName)
|
|
if len(jobBytes) > 255 {
|
|
jobBytes = jobBytes[:255]
|
|
}
|
|
if len(commitID) != 20 {
|
|
// In tests we always use 20 bytes per protocol.
|
|
padded := make([]byte, 20)
|
|
copy(padded, commitID)
|
|
commitID = padded
|
|
}
|
|
|
|
apiKeyHash := make([]byte, 16)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
|
|
|
|
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes))
|
|
buf = append(buf, wspkg.OpcodeQueueJob)
|
|
buf = append(buf, apiKeyHash...)
|
|
buf = append(buf, commitID...)
|
|
buf = append(buf, priority)
|
|
buf = append(buf, byte(len(jobBytes)))
|
|
buf = append(buf, jobBytes...)
|
|
return buf
|
|
}
|
|
|
|
func buildQueueJobWithSnapshotMessage(
|
|
jobName string,
|
|
commitID []byte,
|
|
priority byte,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) []byte {
|
|
jobBytes := []byte(jobName)
|
|
if len(jobBytes) > 255 {
|
|
jobBytes = jobBytes[:255]
|
|
}
|
|
if len(commitID) != 20 {
|
|
padded := make([]byte, 20)
|
|
copy(padded, commitID)
|
|
commitID = padded
|
|
}
|
|
|
|
apiKeyHash := make([]byte, 16)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
|
|
|
|
snapIDBytes := []byte(snapshotID)
|
|
if len(snapIDBytes) > 255 {
|
|
snapIDBytes = snapIDBytes[:255]
|
|
}
|
|
snapSHAB := []byte(snapshotSHA)
|
|
if len(snapSHAB) > 255 {
|
|
snapSHAB = snapSHAB[:255]
|
|
}
|
|
|
|
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB))
|
|
buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot)
|
|
buf = append(buf, apiKeyHash...)
|
|
buf = append(buf, commitID...)
|
|
buf = append(buf, priority)
|
|
buf = append(buf, byte(len(jobBytes)))
|
|
buf = append(buf, jobBytes...)
|
|
buf = append(buf, byte(len(snapIDBytes)))
|
|
buf = append(buf, snapIDBytes...)
|
|
buf = append(buf, byte(len(snapSHAB)))
|
|
buf = append(buf, snapSHAB...)
|
|
return buf
|
|
}
|