Phase 1: Fix Redis Schema Leak - Create internal/storage/dataset.go with DatasetStore abstraction - Remove all direct Redis calls from cmd/data_manager/data_sync.go - data_manager now uses DatasetStore for transfer tracking and metadata Phase 2: Simplify TUI Services - Embed *queue.TaskQueue directly in services.TaskQueue - Eliminate 60% of wrapper boilerplate (203 -> ~100 lines) - Keep only TUI-specific methods (EnqueueTask, GetJobStatus, experiment methods) Phase 5: Clean go.mod Dependencies - Remove duplicate go-redis/redis/v8 dependency - Migrate internal/storage/migrate.go to redis/go-redis/v9 - Separate test-only deps (miniredis, testify) into own block Results: - Zero direct Redis calls in cmd/ - 60% fewer lines in TUI services - Cleaner dependency structure
462 lines
13 KiB
Go
462 lines
13 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
|
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
|
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
|
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestWebSocketQueueEndToEnd(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
// Miniredis provides an in-memory Redis compatible server for realistic queue tests.
|
|
redisServer, err := miniredis.Run()
|
|
require.NoError(t, err)
|
|
defer redisServer.Close()
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queue.Config{
|
|
RedisAddr: redisServer.Addr(),
|
|
MetricsFlushInterval: 10 * time.Millisecond,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
ctx, cancelWorkers := context.WithCancel(context.Background())
|
|
defer cancelWorkers()
|
|
|
|
const (
|
|
jobCount = 20
|
|
workerCount = 4
|
|
clientConcurrency = 8
|
|
)
|
|
|
|
doneCh := make(chan string, jobCount)
|
|
var workerWG sync.WaitGroup
|
|
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
|
|
|
|
// Submit jobs concurrently through the real WebSocket protocol.
|
|
var submitWG sync.WaitGroup
|
|
sem := make(chan struct{}, clientConcurrency)
|
|
for i := 0; i < jobCount; i++ {
|
|
submitWG.Add(1)
|
|
go func(idx int) {
|
|
sem <- struct{}{}
|
|
defer submitWG.Done()
|
|
defer func() { <-sem }()
|
|
jobName := fmt.Sprintf("ws-load-job-%02d", idx)
|
|
commitBytes := make([]byte, 20)
|
|
for j := range commitBytes {
|
|
commitBytes[j] = byte(idx + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
|
|
}(i)
|
|
}
|
|
submitWG.Wait()
|
|
|
|
completed := 0
|
|
timeout := time.After(30 * time.Second)
|
|
for completed < jobCount {
|
|
select {
|
|
case <-timeout:
|
|
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
|
|
case <-doneCh:
|
|
completed++
|
|
}
|
|
}
|
|
|
|
// Stop workers and ensure they exit cleanly.
|
|
cancelWorkers()
|
|
workerWG.Wait()
|
|
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
require.NoError(t, err)
|
|
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
|
|
}
|
|
|
|
func TestWebSocketQueueEndToEndSQLite(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
queuePath := filepath.Join(t.TempDir(), "queue.db")
|
|
taskQueue, err := queue.NewSQLiteQueue(queuePath)
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
ctx, cancelWorkers := context.WithCancel(context.Background())
|
|
defer cancelWorkers()
|
|
|
|
const (
|
|
jobCount = 10
|
|
workerCount = 2
|
|
clientConcurrency = 4
|
|
)
|
|
|
|
doneCh := make(chan string, jobCount)
|
|
var workerWG sync.WaitGroup
|
|
startFakeWorkers(ctx, t, &workerWG, taskQueue, workerCount, doneCh)
|
|
|
|
var submitWG sync.WaitGroup
|
|
sem := make(chan struct{}, clientConcurrency)
|
|
for i := 0; i < jobCount; i++ {
|
|
submitWG.Add(1)
|
|
go func(idx int) {
|
|
sem <- struct{}{}
|
|
defer submitWG.Done()
|
|
defer func() { <-sem }()
|
|
|
|
jobName := fmt.Sprintf("ws-sqlite-job-%02d", idx)
|
|
commitBytes := make([]byte, 20)
|
|
for j := range commitBytes {
|
|
commitBytes[j] = byte(idx + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocket(t, server.URL, jobName, commitBytes, byte(idx%5))
|
|
}(i)
|
|
}
|
|
submitWG.Wait()
|
|
|
|
completed := 0
|
|
timeout := time.After(20 * time.Second)
|
|
for completed < jobCount {
|
|
select {
|
|
case <-timeout:
|
|
t.Fatalf("timed out waiting for %d completions, only saw %d", jobCount, completed)
|
|
case <-doneCh:
|
|
completed++
|
|
}
|
|
}
|
|
|
|
cancelWorkers()
|
|
workerWG.Wait()
|
|
|
|
nextTask, err := taskQueue.GetNextTask()
|
|
require.NoError(t, err)
|
|
require.Nil(t, nextTask, "queue should be empty after all jobs complete")
|
|
}
|
|
|
|
func TestWebSocketQueueWithSnapshotOpcode(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping websocket queue integration in short mode")
|
|
}
|
|
|
|
redisServer, err := miniredis.Run()
|
|
require.NoError(t, err)
|
|
defer redisServer.Close()
|
|
|
|
taskQueue, err := queue.NewTaskQueue(queue.Config{
|
|
RedisAddr: redisServer.Addr(),
|
|
MetricsFlushInterval: 10 * time.Millisecond,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = taskQueue.Close() }()
|
|
|
|
expMgr := experiment.NewManager(t.TempDir())
|
|
require.NoError(t, expMgr.Initialize())
|
|
|
|
logger := logging.NewLogger(0, false)
|
|
authCfg := &auth.Config{Enabled: false}
|
|
jobsHandler := jobs.NewHandler(expMgr, logger, taskQueue, nil, authCfg)
|
|
jupyterHandler := jupyterj.NewHandler(logger, nil, authCfg)
|
|
datasetsHandler := datasets.NewHandler(logger, nil, "")
|
|
wsHandler := wspkg.NewHandler(
|
|
authCfg,
|
|
logger,
|
|
expMgr,
|
|
"",
|
|
taskQueue,
|
|
nil, // db
|
|
nil, // jupyterServiceMgr
|
|
nil, // securityConfig
|
|
nil, // auditLogger
|
|
jobsHandler,
|
|
jupyterHandler,
|
|
datasetsHandler,
|
|
)
|
|
server := httptest.NewServer(wsHandler)
|
|
defer server.Close()
|
|
|
|
commitBytes := make([]byte, 20)
|
|
for i := range commitBytes {
|
|
commitBytes[i] = byte(i + 1)
|
|
}
|
|
commitIDStr := fmt.Sprintf("%x", commitBytes)
|
|
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
|
|
filesPath := expMgr.GetFilesPath(commitIDStr)
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
man, err := expMgr.GenerateManifest(commitIDStr)
|
|
require.NoError(t, err)
|
|
require.NoError(t, expMgr.WriteManifest(man))
|
|
|
|
queueJobViaWebSocketWithSnapshot(t, server.URL, "ws-snap-job", commitBytes, 1, "snap-1", strings.Repeat("a", 64))
|
|
|
|
tasks, err := taskQueue.GetAllTasks()
|
|
require.NoError(t, err)
|
|
require.Len(t, tasks, 1)
|
|
require.Equal(t, "snap-1", tasks[0].SnapshotID)
|
|
require.Equal(t, strings.Repeat("a", 64), tasks[0].Metadata["snapshot_sha256"])
|
|
}
|
|
|
|
func startFakeWorkers(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
wg *sync.WaitGroup,
|
|
taskQueue queue.Backend,
|
|
workerCount int,
|
|
doneCh chan<- string,
|
|
) {
|
|
for w := 0; w < workerCount; w++ {
|
|
wg.Add(1)
|
|
go func(workerID string) {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
task, err := taskQueue.GetNextTaskWithLease(workerID, 30*time.Second)
|
|
if err != nil {
|
|
t.Logf("worker %s queue error: %v", workerID, err)
|
|
time.Sleep(5 * time.Millisecond)
|
|
continue
|
|
}
|
|
if task == nil {
|
|
time.Sleep(5 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
started := time.Now()
|
|
completed := started.Add(10 * time.Millisecond)
|
|
|
|
task.Status = statusCompleted
|
|
task.StartedAt = &started
|
|
task.EndedAt = &completed
|
|
task.LeaseExpiry = nil
|
|
task.LeasedBy = ""
|
|
|
|
if err := taskQueue.UpdateTask(task); err != nil {
|
|
t.Logf("worker %s failed to update task %s: %v", workerID, task.ID, err)
|
|
continue
|
|
}
|
|
|
|
doneCh <- task.JobName
|
|
}
|
|
}(fmt.Sprintf("worker-%d", w))
|
|
}
|
|
}
|
|
|
|
func queueJobViaWebSocket(t *testing.T, baseURL, jobName string, commitID []byte, priority byte) {
|
|
t.Helper()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
t.Logf("Warning: failed to close response body: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
require.NoError(t, err)
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
msg := buildQueueJobMessage(jobName, commitID, priority)
|
|
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
|
|
|
|
_, payload, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, payload, "expected response payload")
|
|
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
|
|
}
|
|
|
|
func queueJobViaWebSocketWithSnapshot(
|
|
t *testing.T,
|
|
baseURL, jobName string,
|
|
commitID []byte,
|
|
priority byte,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) {
|
|
t.Helper()
|
|
|
|
wsURL := "ws" + strings.TrimPrefix(baseURL, "http")
|
|
conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() {
|
|
if err := resp.Body.Close(); err != nil {
|
|
t.Logf("Warning: failed to close response body: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
require.NoError(t, err)
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
msg := buildQueueJobWithSnapshotMessage(jobName, commitID, priority, snapshotID, snapshotSHA)
|
|
require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, msg))
|
|
|
|
_, payload, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, payload, "expected response payload")
|
|
require.EqualValues(t, api.PacketTypeSuccess, payload[0], "queue job should return success packet")
|
|
}
|
|
|
|
func buildQueueJobMessage(jobName string, commitID []byte, priority byte) []byte {
|
|
jobBytes := []byte(jobName)
|
|
if len(jobBytes) > 255 {
|
|
jobBytes = jobBytes[:255]
|
|
}
|
|
if len(commitID) != 20 {
|
|
// In tests we always use 20 bytes per protocol.
|
|
padded := make([]byte, 20)
|
|
copy(padded, commitID)
|
|
commitID = padded
|
|
}
|
|
|
|
apiKeyHash := make([]byte, 16)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
|
|
|
|
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes))
|
|
buf = append(buf, wspkg.OpcodeQueueJob)
|
|
buf = append(buf, apiKeyHash...)
|
|
buf = append(buf, commitID...)
|
|
buf = append(buf, priority)
|
|
buf = append(buf, byte(len(jobBytes)))
|
|
buf = append(buf, jobBytes...)
|
|
return buf
|
|
}
|
|
|
|
func buildQueueJobWithSnapshotMessage(
|
|
jobName string,
|
|
commitID []byte,
|
|
priority byte,
|
|
snapshotID string,
|
|
snapshotSHA string,
|
|
) []byte {
|
|
jobBytes := []byte(jobName)
|
|
if len(jobBytes) > 255 {
|
|
jobBytes = jobBytes[:255]
|
|
}
|
|
if len(commitID) != 20 {
|
|
padded := make([]byte, 20)
|
|
copy(padded, commitID)
|
|
commitID = padded
|
|
}
|
|
|
|
apiKeyHash := make([]byte, 16)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
|
|
|
|
snapIDBytes := []byte(snapshotID)
|
|
if len(snapIDBytes) > 255 {
|
|
snapIDBytes = snapIDBytes[:255]
|
|
}
|
|
snapSHAB := []byte(snapshotSHA)
|
|
if len(snapSHAB) > 255 {
|
|
snapSHAB = snapSHAB[:255]
|
|
}
|
|
|
|
buf := make([]byte, 0, 1+16+20+1+1+len(jobBytes)+1+len(snapIDBytes)+1+len(snapSHAB))
|
|
buf = append(buf, wspkg.OpcodeQueueJobWithSnapshot)
|
|
buf = append(buf, apiKeyHash...)
|
|
buf = append(buf, commitID...)
|
|
buf = append(buf, priority)
|
|
buf = append(buf, byte(len(jobBytes)))
|
|
buf = append(buf, jobBytes...)
|
|
buf = append(buf, byte(len(snapIDBytes)))
|
|
buf = append(buf, snapIDBytes...)
|
|
buf = append(buf, byte(len(snapSHAB)))
|
|
buf = append(buf, snapSHAB...)
|
|
return buf
|
|
}
|