package api import ( "encoding/binary" "math" "net/http/httptest" "strings" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func setupTestServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) { // Setup miniredis s, err := miniredis.Run() require.NoError(t, err) // Setup TaskQueue queueCfg := queue.Config{ RedisAddr: s.Addr(), MetricsFlushInterval: 10 * time.Millisecond, } tq, err := queue.NewTaskQueue(queueCfg) require.NoError(t, err) // Setup dependencies logger := logging.NewLogger(0, false) expManager := experiment.NewManager(t.TempDir()) authCfg := &auth.AuthConfig{Enabled: false} // Create handler handler := NewWSHandler(authCfg, logger, expManager, tq) // Setup test server server := httptest.NewServer(handler) return server, tq, expManager, s } func connectWS(t *testing.T, serverURL string) *websocket.Conn { wsURL := "ws" + strings.TrimPrefix(serverURL, "http") ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) require.NoError(t, err) return ws } func TestWSHandler_QueueJob(t *testing.T) { server, tq, _, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() ws := connectWS(t, server.URL) defer ws.Close() // Prepare queue_job message // Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var] opcode := byte(OpcodeQueueJob) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) commitID := make([]byte, 64) copy(commitID, []byte(strings.Repeat("a", 64))) priority := byte(5) jobName := "test-job" jobNameLen := byte(len(jobName)) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) msg = append(msg, priority) msg = append(msg, jobNameLen) msg = append(msg, []byte(jobName)...) // Send message err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response (PacketTypeSuccess = 0x00) assert.Equal(t, byte(PacketTypeSuccess), resp[0]) // Verify task in Redis time.Sleep(100 * time.Millisecond) task, err := tq.GetNextTask() require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, jobName, task.JobName) } func TestWSHandler_StatusRequest(t *testing.T) { server, tq, _, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() // Add a task to queue task := &queue.Task{ ID: "task-1", JobName: "job-1", Status: "queued", Priority: 10, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", } err := tq.AddTask(task) require.NoError(t, err) ws := connectWS(t, server.URL) defer ws.Close() // Prepare status_request message // Protocol: [opcode:1][api_key_hash:64] opcode := byte(OpcodeStatusRequest) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response (PacketTypeData = 0x04 for status with payload) assert.Equal(t, byte(PacketTypeData), resp[0]) } func TestWSHandler_CancelJob(t *testing.T) { server, tq, _, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() // Add a task to queue task := &queue.Task{ ID: "task-1", JobName: "job-to-cancel", Status: "queued", Priority: 10, CreatedAt: time.Now(), UserID: "user", // Auth disabled so this matches any user CreatedBy: "user", } err := tq.AddTask(task) require.NoError(t, err) ws := connectWS(t, server.URL) defer ws.Close() // Prepare cancel_job message // Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var] opcode := byte(OpcodeCancelJob) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) jobName := "job-to-cancel" jobNameLen := byte(len(jobName)) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, jobNameLen) msg = append(msg, []byte(jobName)...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(PacketTypeSuccess), resp[0]) // Verify task cancelled updatedTask, err := tq.GetTask("task-1") require.NoError(t, err) assert.Equal(t, "cancelled", updatedTask.Status) } func TestWSHandler_Prune(t *testing.T) { server, tq, expManager, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() // Create some experiments _ = expManager.CreateExperiment("commit-1") _ = expManager.CreateExperiment("commit-2") ws := connectWS(t, server.URL) defer ws.Close() // Prepare prune message // Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4] opcode := byte(OpcodePrune) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) pruneType := byte(0) // Keep N value := uint32(1) // Keep 1 valueBytes := make([]byte, 4) binary.BigEndian.PutUint32(valueBytes, value) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, pruneType) msg = append(msg, valueBytes...) // Send message err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(PacketTypeSuccess), resp[0]) } func TestWSHandler_LogMetric(t *testing.T) { server, tq, expManager, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() // Create experiment commitIDStr := strings.Repeat("a", 64) err := expManager.CreateExperiment(commitIDStr) require.NoError(t, err) ws := connectWS(t, server.URL) defer ws.Close() // Prepare log_metric message // Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var] opcode := byte(OpcodeLogMetric) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) commitID := []byte(commitIDStr) step := uint32(100) value := 0.95 valueBits := math.Float64bits(value) metricName := "accuracy" nameLen := byte(len(metricName)) stepBytes := make([]byte, 4) binary.BigEndian.PutUint32(stepBytes, step) valueBytes := make([]byte, 8) binary.BigEndian.PutUint64(valueBytes, valueBits) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) msg = append(msg, stepBytes...) msg = append(msg, valueBytes...) msg = append(msg, nameLen) msg = append(msg, []byte(metricName)...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(PacketTypeSuccess), resp[0]) } func TestWSHandler_GetExperiment(t *testing.T) { server, tq, expManager, s := setupTestServer(t) defer server.Close() defer tq.Close() defer s.Close() // Create experiment and metadata commitIDStr := strings.Repeat("a", 64) err := expManager.CreateExperiment(commitIDStr) require.NoError(t, err) meta := &experiment.Metadata{ CommitID: commitIDStr, JobName: "test-job", } err = expManager.WriteMetadata(meta) require.NoError(t, err) ws := connectWS(t, server.URL) defer ws.Close() // Prepare get_experiment message // Protocol: [opcode:1][api_key_hash:64][commit_id:64] opcode := byte(OpcodeGetExperiment) apiKeyHash := make([]byte, 64) copy(apiKeyHash, []byte(strings.Repeat("0", 64))) commitID := []byte(commitIDStr) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response (PacketTypeData) assert.Equal(t, byte(PacketTypeData), resp[0]) }