//nolint:revive // Package name 'tests' is appropriate for this integration test package package tests import ( "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "math" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/worker" ) func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) ( *httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis, *storage.DB, ) { s, err := miniredis.Run() require.NoError(t, err) queueCfg := queue.Config{ RedisAddr: s.Addr(), MetricsFlushInterval: 10 * time.Millisecond, } tq, err := queue.NewTaskQueue(queueCfg) require.NoError(t, err) logger := logging.NewLogger(0, false) expManager := experiment.NewManager(t.TempDir()) authConfig := &auth.Config{Enabled: false} dbPath := filepath.Join(t.TempDir(), "test.db") db, err := storage.NewDBFromPath(dbPath) require.NoError(t, err) schema, err := storage.SchemaForDBType(storage.DBTypeSQLite) require.NoError(t, err) require.NoError(t, db.Initialize(schema)) handler := api.NewWSHandler( authConfig, logger, expManager, dataDir, tq, db, nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger ) server := httptest.NewServer(handler) return server, tq, expManager, s, db } func decodeDataPacket(t *testing.T, resp []byte) (string, []byte) { t.Helper() require.GreaterOrEqual(t, len(resp), 1+8) if resp[0] != byte(api.PacketTypeData) { t.Fatalf("expected PacketTypeData=%d, got %d", api.PacketTypeData, resp[0]) } idx := 1 + 8 dataTypeLen, n := binary.Uvarint(resp[idx:]) require.Greater(t, n, 0) idx += n require.GreaterOrEqual(t, len(resp), idx+int(dataTypeLen)) dataType := string(resp[idx : idx+int(dataTypeLen)]) idx += int(dataTypeLen) payloadLen, n := binary.Uvarint(resp[idx:]) require.Greater(t, n, 0) idx += n require.GreaterOrEqual(t, len(resp), idx+int(payloadLen)) return dataType, resp[idx : idx+int(payloadLen)] } func TestWSHandler_ValidateRequest_TaskID_RunManifestMissingForRunning_Fails(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) taskID := "task-run-manifest-missing" task := &queue.Task{ ID: taskID, JobName: "job", Status: "running", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, }, } require.NoError(t, tq.AddTask(task)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) rm := checks["run_manifest"].(map[string]any) require.Equal(t, false, rm["ok"].(bool)) } func TestWSHandler_ValidateRequest_TaskID_RunManifestCommitMismatch_Fails(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) taskID := "task-run-manifest-commit-mismatch" task := &queue.Task{ ID: taskID, JobName: "job", Status: "completed", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, }, } require.NoError(t, tq.AddTask(task)) jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName) require.NoError(t, os.MkdirAll(jobDir, 0750)) rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt) rm.CommitID = strings.Repeat("62", 20) rm.DepsManifestName = "requirements.txt" rm.DepsManifestSHA = depSha rm.MarkStarted(time.Now().UTC().Add(-2 * time.Second)) exitCode := 0 rm.MarkFinished(time.Now().UTC(), &exitCode, nil) require.NoError(t, rm.WriteToDir(jobDir)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) commitCheck := checks["run_manifest_commit_id"].(map[string]any) require.Equal(t, false, commitCheck["ok"].(bool)) require.Equal(t, commitIDStr, commitCheck["expected"].(string)) } func TestWSHandler_ValidateRequest_TaskID_RunManifestLocationMismatch_Fails(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) taskID := "task-run-manifest-location-mismatch" task := &queue.Task{ ID: taskID, JobName: "job", Status: "running", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, }, } require.NoError(t, tq.AddTask(task)) // Intentionally write manifest to the wrong bucket. jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName) require.NoError(t, os.MkdirAll(jobDir, 0750)) rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt) rm.CommitID = commitIDStr rm.DepsManifestName = "requirements.txt" rm.DepsManifestSHA = depSha rm.MarkStarted(time.Now().UTC().Add(-2 * time.Second)) require.NoError(t, rm.WriteToDir(jobDir)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) loc := checks["run_manifest_location"].(map[string]any) require.Equal(t, false, loc["ok"].(bool)) require.Equal(t, "running", loc["expected"].(string)) require.Equal(t, "finished", loc["actual"].(string)) } func TestWSHandler_ValidateRequest_TaskID_RunManifestLifecycleOrdering_Fails(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) taskID := "task-run-manifest-lifecycle-ordering" task := &queue.Task{ ID: taskID, JobName: "job", Status: "completed", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, }, } require.NoError(t, tq.AddTask(task)) jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName) require.NoError(t, os.MkdirAll(jobDir, 0750)) rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt) rm.CommitID = commitIDStr rm.DepsManifestName = "requirements.txt" rm.DepsManifestSHA = depSha start := time.Now().UTC() end := start.Add(-1 * time.Second) rm.MarkStarted(start) exitCode := 0 rm.MarkFinished(end, &exitCode, nil) require.NoError(t, rm.WriteToDir(jobDir)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) lifecycle := checks["run_manifest_lifecycle"].(map[string]any) require.Equal(t, false, lifecycle["ok"].(bool)) } func TestWSHandler_ValidateRequest_TaskID_InvalidResources(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) taskID := "task-invalid-resources" task := &queue.Task{ ID: taskID, JobName: "job", Status: "queued", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, }, CPU: -1, } require.NoError(t, tq.AddTask(task)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) res := checks["resources"].(map[string]any) require.Equal(t, false, res["ok"].(bool)) } func TestWSHandler_ValidateRequest_TaskID_SnapshotMismatch(t *testing.T) { dataDir := t.TempDir() server, tq, expMgr, s, db := setupWSIntegrationServerWithDataDir(t, dataDir) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) reqBytes := []byte("numpy==1.0.0\n") reqSum := sha256.Sum256(reqBytes) depSha := hex.EncodeToString(reqSum[:]) snapshotID := "snap-1" snapPath := filepath.Join(dataDir, "snapshots", snapshotID) require.NoError(t, os.MkdirAll(snapPath, 0750)) require.NoError(t, os.WriteFile(filepath.Join(snapPath, "hello.txt"), []byte("hello"), 0600)) actualSnap, err := worker.DirOverallSHA256Hex(snapPath) require.NoError(t, err) taskID := "task-snap-mismatch" task := &queue.Task{ ID: taskID, JobName: "job", Status: "queued", Priority: 1, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", SnapshotID: snapshotID, Metadata: map[string]string{ "commit_id": commitIDStr, "experiment_manifest_overall_sha": man.OverallSHA, "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": depSha, "snapshot_sha256": strings.Repeat("0", 64), }, } require.NoError(t, tq.AddTask(task)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+len(taskID)) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(1)) msg = append(msg, byte(len(taskID))) msg = append(msg, []byte(taskID)...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report map[string]any require.NoError(t, json.Unmarshal(payload, &report)) require.Equal(t, false, report["ok"].(bool)) checks := report["checks"].(map[string]any) snap := checks["snapshot"].(map[string]any) require.Equal(t, false, snap["ok"].(bool)) require.Equal(t, actualSnap, snap["actual"].(string)) } func setupWSIntegrationServer(t *testing.T) ( *httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis, *storage.DB, ) { // Setup miniredis s, err := miniredis.Run() require.NoError(t, err) // Setup TaskQueue queueCfg := queue.Config{ RedisAddr: s.Addr(), MetricsFlushInterval: 10 * time.Millisecond, } tq, err := queue.NewTaskQueue(queueCfg) require.NoError(t, err) // Setup dependencies logger := logging.NewLogger(0, false) expManager := experiment.NewManager(t.TempDir()) authConfig := &auth.Config{Enabled: false} // Renamed from authCfg dbPath := filepath.Join(t.TempDir(), "test.db") db, err := storage.NewDBFromPath(dbPath) require.NoError(t, err) schema, err := storage.SchemaForDBType(storage.DBTypeSQLite) require.NoError(t, err) require.NoError(t, db.Initialize(schema)) // Create handler handler := api.NewWSHandler( authConfig, logger, expManager, "", tq, // Renamed from taskQueue db, // db nil, // jupyterServiceMgr nil, // securityConfig nil, // auditLogger ) // Setup test server server := httptest.NewServer(handler) return server, tq, expManager, s, db } func connectWSIntegration(t *testing.T, serverURL string) *websocket.Conn { wsURL := "ws" + strings.TrimPrefix(serverURL, "http") ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) if resp != nil && resp.Body != nil { _ = resp.Body.Close() } require.NoError(t, err) return ws } func TestWSHandler_QueueJob_Integration(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare queue_job message // Protocol: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] opcode := byte(api.OpcodeQueueJob) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := make([]byte, 20) copy(commitID, []byte(strings.Repeat("a", 20))) commitIDStr := strings.Repeat("61", 20) // Pre-create experiment files so enqueue can compute expected provenance (deps manifest + manifest overall sha). require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) priority := byte(5) jobName := "test-job" jobNameLen := byte(len(jobName)) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) msg = append(msg, priority) msg = append(msg, jobNameLen) msg = append(msg, []byte(jobName)...) // Optional resource request tail: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] msg = append(msg, byte(4)) // cpu msg = append(msg, byte(16)) // memory_gb msg = append(msg, byte(1)) // gpu gpuMem := "8GB" msg = append(msg, byte(len(gpuMem))) msg = append(msg, []byte(gpuMem)...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response (PacketTypeSuccess = 0x00) assert.Equal(t, byte(api.PacketTypeSuccess), resp[0]) // Verify task in Redis time.Sleep(100 * time.Millisecond) task, err := tq.GetNextTask() require.NoError(t, err) require.NotNil(t, task) assert.Equal(t, jobName, task.JobName) assert.Equal(t, 4, task.CPU) assert.Equal(t, 16, task.MemoryGB) assert.Equal(t, 1, task.GPU) assert.Equal(t, gpuMem, task.GPUMemory) } func TestWSHandler_StatusRequest_Integration(t *testing.T) { server, tq, _, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() // Add a task to queue task := &queue.Task{ ID: "task-1", JobName: "job-1", Status: "queued", Priority: 10, CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", } err := tq.AddTask(task) require.NoError(t, err) ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare status_request message // Protocol: [opcode:1][api_key_hash:16] opcode := byte(api.OpcodeStatusRequest) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response (PacketTypeData = 0x04 for status with payload) assert.Equal(t, byte(api.PacketTypeData), resp[0]) } func TestWSHandler_ValidateRequest_Integration(t *testing.T) { server, tq, expMgr, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() commitIDBytes := make([]byte, 20) copy(commitIDBytes, []byte(strings.Repeat("a", 20))) commitIDStr := strings.Repeat("61", 20) require.NoError(t, expMgr.CreateExperiment(commitIDStr)) filesPath := expMgr.GetFilesPath(commitIDStr) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) man, err := expMgr.GenerateManifest(commitIDStr) require.NoError(t, err) require.NoError(t, expMgr.WriteManifest(man)) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) msg := make([]byte, 0, 1+16+1+1+20) msg = append(msg, byte(api.OpcodeValidateRequest)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(0)) msg = append(msg, byte(20)) msg = append(msg, commitIDBytes...) err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) dataType, payload := decodeDataPacket(t, resp) require.Equal(t, "validate", dataType) var report struct { OK bool `json:"ok"` CommitID string `json:"commit_id"` } require.NoError(t, json.Unmarshal(payload, &report)) require.True(t, report.OK) require.Equal(t, commitIDStr, report.CommitID) } func TestWSHandler_CancelJob_Integration(t *testing.T) { server, tq, _, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() // Add a task to queue task := &queue.Task{ ID: "task-1", JobName: "job-to-cancel", Status: "queued", Priority: 10, CreatedAt: time.Now(), UserID: "user", // Auth disabled so this matches any user CreatedBy: "user", } err := tq.AddTask(task) require.NoError(t, err) ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare cancel_job message // Protocol: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var] opcode := byte(api.OpcodeCancelJob) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) jobName := "job-to-cancel" jobNameLen := byte(len(jobName)) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, jobNameLen) msg = append(msg, []byte(jobName)...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(api.PacketTypeSuccess), resp[0]) // Verify task cancelled updatedTask, err := tq.GetTask("task-1") require.NoError(t, err) assert.Equal(t, "cancelled", updatedTask.Status) } func TestWSHandler_Prune_Integration(t *testing.T) { server, tq, expManager, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() // Create some experiments _ = expManager.CreateExperiment("commit-1") _ = expManager.CreateExperiment("commit-2") ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare prune message // Protocol: [opcode:1][api_key_hash:16][prune_type:1][value:4] opcode := byte(api.OpcodePrune) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) pruneType := byte(0) // Keep N value := uint32(1) // Keep 1 valueBytes := make([]byte, 4) binary.BigEndian.PutUint32(valueBytes, value) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, pruneType) msg = append(msg, valueBytes...) // Send message err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(api.PacketTypeSuccess), resp[0]) } func TestWSHandler_LogMetric_Integration(t *testing.T) { server, tq, expManager, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() // Create experiment commitIDStr := strings.Repeat("a", 20) err := expManager.CreateExperiment(commitIDStr) require.NoError(t, err) // Write metadata to ensure proper initialization meta := &experiment.Metadata{ CommitID: commitIDStr, JobName: "test-job", } err = expManager.WriteMetadata(meta) require.NoError(t, err) ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare log_metric message // Protocol: [opcode:1][api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var] opcode := byte(api.OpcodeLogMetric) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := []byte(commitIDStr) step := uint32(100) value := 0.95 valueBits := math.Float64bits(value) metricName := "accuracy" nameLen := byte(len(metricName)) stepBytes := make([]byte, 4) binary.BigEndian.PutUint32(stepBytes, step) valueBytes := make([]byte, 8) binary.BigEndian.PutUint64(valueBytes, valueBits) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) msg = append(msg, stepBytes...) msg = append(msg, valueBytes...) msg = append(msg, nameLen) msg = append(msg, []byte(metricName)...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify success response assert.Equal(t, byte(api.PacketTypeSuccess), resp[0]) } func TestWSHandler_GetExperiment_Integration(t *testing.T) { server, tq, expManager, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() // Create experiment and metadata commitIDStr := strings.Repeat("a", 20) err := expManager.CreateExperiment(commitIDStr) require.NoError(t, err) meta := &experiment.Metadata{ CommitID: commitIDStr, JobName: "test-job", } err = expManager.WriteMetadata(meta) require.NoError(t, err) ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() // Prepare get_experiment message // Protocol: [opcode:1][api_key_hash:16][commit_id:20] opcode := byte(api.OpcodeGetExperiment) apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) commitID := []byte(commitIDStr) var msg []byte msg = append(msg, opcode) msg = append(msg, apiKeyHash...) msg = append(msg, commitID...) // Send message err = ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) // Read response _, resp, err := ws.ReadMessage() require.NoError(t, err) // Verify error response (PacketTypeError) assert.Equal(t, byte(api.PacketTypeError), resp[0]) } func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) { server, tq, _, s, db := setupWSIntegrationServer(t) defer server.Close() defer func() { _ = tq.Close() }() defer s.Close() defer func() { _ = db.Close() }() ws := connectWSIntegration(t, server.URL) defer func() { _ = ws.Close() }() apiKeyHash := make([]byte, 16) copy(apiKeyHash, []byte(strings.Repeat("0", 16))) // 1) List should return empty array { msg := make([]byte, 0, 1+16) msg = append(msg, byte(api.OpcodeDatasetList)) msg = append(msg, apiKeyHash...) err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) assert.Equal(t, byte(api.PacketTypeData), resp[0]) } // 2) Register dataset name := "mnist" urlStr := "https://example.com/mnist.tar.gz" { msg := make([]byte, 0, 1+16+1+len(name)+2+len(urlStr)) msg = append(msg, byte(api.OpcodeDatasetRegister)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(name))) msg = append(msg, []byte(name)...) urlLen := make([]byte, 2) binary.BigEndian.PutUint16(urlLen, uint16(len(urlStr))) msg = append(msg, urlLen...) msg = append(msg, []byte(urlStr)...) err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) assert.Equal(t, byte(api.PacketTypeSuccess), resp[0]) } // 3) Info should return PacketTypeData { msg := make([]byte, 0, 1+16+1+len(name)) msg = append(msg, byte(api.OpcodeDatasetInfo)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(name))) msg = append(msg, []byte(name)...) err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) assert.Equal(t, byte(api.PacketTypeData), resp[0]) } // 4) Search should return PacketTypeData term := "mn" { msg := make([]byte, 0, 1+16+1+len(term)) msg = append(msg, byte(api.OpcodeDatasetSearch)) msg = append(msg, apiKeyHash...) msg = append(msg, byte(len(term))) msg = append(msg, []byte(term)...) err := ws.WriteMessage(websocket.BinaryMessage, msg) require.NoError(t, err) _, resp, err := ws.ReadMessage() require.NoError(t, err) assert.Equal(t, byte(api.PacketTypeData), resp[0]) } }