From 3694d4e56ff81a09f4c8f5234868d642b3877a98 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Tue, 17 Feb 2026 20:38:03 -0500 Subject: [PATCH] refactor: extract ws handlers to separate files to reduce handler.go size - Extract job handlers (handleQueueJob, handleQueueJobWithSnapshot, handleCancelJob, handlePrune) to ws/jobs.go (209 lines) - Extract validation handler (handleValidateRequest) to ws/validate.go (167 lines) - Reduce ws/handler.go from 879 to 474 lines (under 500 line target) - Keep core framework in handler.go: Handler struct, dispatch, packet sending, auth helpers - All handlers remain as methods on Handler for backward compatibility Result: handler.go 474 lines, jobs.go 209 lines, validate.go 167 lines --- internal/api/ws/handler.go | 407 +----------------------------------- internal/api/ws/jobs.go | 209 ++++++++++++++++++ internal/api/ws/validate.go | 167 +++++++++++++++ 3 files changed, 377 insertions(+), 406 deletions(-) create mode 100644 internal/api/ws/jobs.go create mode 100644 internal/api/ws/validate.go diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index 263afe7..2d1f616 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -3,7 +3,6 @@ package ws import ( "encoding/binary" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -12,7 +11,6 @@ import ( "os" "path/filepath" "strings" - "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" @@ -21,10 +19,8 @@ import ( "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" - "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) // Response packet types (duplicated from api package to avoid import cycle) @@ -349,10 +345,9 @@ func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload return conn.WriteMessage(websocket.BinaryMessage, buf) } -// Handler stubs - these would delegate to sub-packages in full implementation +// Handler stubs - delegate to sub-packages for full implementations func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error { - // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Annotate run handled", @@ -360,7 +355,6 @@ func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error } func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) error { - // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Set run narrative handled", @@ -368,7 +362,6 @@ func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) e } func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) error { - // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Start jupyter handled", @@ -376,7 +369,6 @@ func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) erro } func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error { - // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Stop jupyter handled", @@ -384,169 +376,13 @@ func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error } func (h *Handler) handleListJupyter(conn *websocket.Conn, _payload []byte) error { - // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "List jupyter handled", }) } -func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { - // Parse payload format: [opcode:1][api_key_hash:16][mode:1][...] - // mode=0: commit_id validation [commit_id_len:1][commit_id:var] - // mode=1: task_id validation [task_id_len:1][task_id:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") - } - - mode := payload[17] - - if mode == 0 { - // Commit ID validation (basic) - if len(payload) < 20 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "") - } - commitIDLen := int(payload[18]) - if len(payload) < 19+commitIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "") - } - commitIDBytes := payload[19 : 19+commitIDLen] - commitIDHex := fmt.Sprintf("%x", commitIDBytes) - - report := map[string]interface{}{ - "ok": true, - "commit_id": commitIDHex, - } - payloadBytes, _ := json.Marshal(report) - return h.sendDataPacket(conn, "validate", payloadBytes) - } - - // Task ID validation (mode=1) - full validation with checks - if len(payload) < 20 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "") - } - - taskIDLen := int(payload[18]) - if len(payload) < 19+taskIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "") - } - taskID := string(payload[19 : 19+taskIDLen]) - - // Initialize validation report - checks := make(map[string]interface{}) - ok := true - - // Get task from queue - if h.taskQueue == nil { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "") - } - - task, err := h.taskQueue.GetTask(taskID) - if err != nil || task == nil { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "") - } - - // Run manifest validation - load manifest if it exists - rmCheck := map[string]interface{}{"ok": true} - rmCommitCheck := map[string]interface{}{"ok": true} - rmLocCheck := map[string]interface{}{"ok": true} - rmLifecycle := map[string]interface{}{"ok": true} - - // Determine expected location based on task status - expectedLocation := "running" - if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" { - expectedLocation = "finished" - } - - // Try to load run manifest from appropriate location - var rm *manifest.RunManifest - var rmLoadErr error - - if h.expManager != nil { - // Try expected location first - jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName) - rm, rmLoadErr = manifest.LoadFromDir(jobDir) - - // If not found and task is running, also check finished (wrong location test) - if rmLoadErr != nil && task.Status == "running" { - wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName) - rm, _ = manifest.LoadFromDir(wrongDir) - if rm != nil { - // Manifest exists but in wrong location - rmLocCheck["ok"] = false - rmLocCheck["expected"] = "running" - rmLocCheck["actual"] = "finished" - ok = false - } - } - } - - if rm == nil { - // No run manifest found - if task.Status == "running" || task.Status == "completed" { - rmCheck["ok"] = false - ok = false - } - } else { - // Run manifest exists - validate it - - // Check commit_id match - taskCommitID := task.Metadata["commit_id"] - if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID { - rmCommitCheck["ok"] = false - rmCommitCheck["expected"] = taskCommitID - ok = false - } - - // Check lifecycle ordering (started_at < ended_at) - if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) { - rmLifecycle["ok"] = false - ok = false - } - } - - checks["run_manifest"] = rmCheck - checks["run_manifest_commit_id"] = rmCommitCheck - checks["run_manifest_location"] = rmLocCheck - checks["run_manifest_lifecycle"] = rmLifecycle - - // Resources check - resCheck := map[string]interface{}{"ok": true} - if task.CPU < 0 { - resCheck["ok"] = false - ok = false - } - checks["resources"] = resCheck - - // Snapshot check - snapCheck := map[string]interface{}{"ok": true} - if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" { - // Verify snapshot SHA - dataDir := h.dataDir - if dataDir == "" { - dataDir = filepath.Join(h.expManager.BasePath(), "data") - } - snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID) - actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath) - expectedSHA := task.Metadata["snapshot_sha256"] - if actualSHA != expectedSHA { - snapCheck["ok"] = false - snapCheck["actual"] = actualSHA - ok = false - } - } - checks["snapshot"] = snapCheck - - report := map[string]interface{}{ - "ok": ok, - "checks": checks, - } - payloadBytes, _ := json.Marshal(report) - return h.sendDataPacket(conn, "validate", payloadBytes) -} - func (h *Handler) handleLogMetric(conn *websocket.Conn, _payload []byte) error { - // Would delegate to metrics package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Metric logged", @@ -569,13 +405,10 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro } func (h *Handler) handleDatasetList(conn *websocket.Conn, _payload []byte) error { - // Would delegate to dataset package - // Return empty list as expected by test return h.sendDataPacket(conn, "datasets", []byte("[]")) } func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) error { - // Would delegate to dataset package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Dataset registered", @@ -583,246 +416,13 @@ func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) e } func (h *Handler) handleDatasetInfo(conn *websocket.Conn, _payload []byte) error { - // Would delegate to dataset package return h.sendDataPacket(conn, "dataset_info", []byte("{}")) } func (h *Handler) handleDatasetSearch(conn *websocket.Conn, _payload []byte) error { - // Would delegate to dataset package return h.sendDataPacket(conn, "datasets", []byte("[]")) } -func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { - // Parse payload: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var] - if len(payload) < 18 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") - } - - jobNameLen := int(payload[17]) - if len(payload) < 18+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") - } - jobName := string(payload[18 : 18+jobNameLen]) - - // Find and cancel the task - if h.taskQueue != nil { - task, err := h.taskQueue.GetTaskByName(jobName) - if err == nil && task != nil { - task.Status = "cancelled" - h.taskQueue.UpdateTask(task) - } - } - - return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "message": "Job cancelled", - }) -} - -func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error { - // Would delegate to experiment package for pruning - return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "message": "Prune completed", - }) -} - -func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { - // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] - // Optional: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] - if len(payload) < 39 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") - } - - // Extract commit_id (20 bytes starting at position 17) - commitIDBytes := payload[17:37] - commitIDHex := hex.EncodeToString(commitIDBytes) - - priority := payload[37] - jobNameLen := int(payload[38]) - - if len(payload) < 39+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") - } - jobName := string(payload[39 : 39+jobNameLen]) - - // Parse optional resource fields if present - cpu := 0 - memoryGB := 0 - gpu := 0 - gpuMemory := "" - - pos := 39 + jobNameLen - if len(payload) > pos { - cpu = int(payload[pos]) - pos++ - if len(payload) > pos { - memoryGB = int(payload[pos]) - pos++ - if len(payload) > pos { - gpu = int(payload[pos]) - pos++ - if len(payload) > pos { - gpuMemLen := int(payload[pos]) - pos++ - if len(payload) >= pos+gpuMemLen { - gpuMemory = string(payload[pos : pos+gpuMemLen]) - } - } - } - } - } - - // Create task - task := &queue.Task{ - ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), - JobName: jobName, - Status: "queued", - Priority: int64(priority), - CreatedAt: time.Now(), - UserID: "user", - CreatedBy: "user", - CPU: cpu, - MemoryGB: memoryGB, - GPU: gpu, - GPUMemory: gpuMemory, - Metadata: map[string]string{ - "commit_id": commitIDHex, - }, - } - - // Auto-detect deps manifest and compute manifest SHA if experiment exists - if h.expManager != nil { - filesPath := h.expManager.GetFilesPath(commitIDHex) - depsName, _ := selectDependencyManifest(filesPath) - if depsName != "" { - task.Metadata["deps_manifest_name"] = depsName - depsPath := filepath.Join(filesPath, depsName) - if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { - task.Metadata["deps_manifest_sha256"] = sha - } - } - - // Get experiment manifest SHA - manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") - if data, err := os.ReadFile(manifestPath); err == nil { - var man struct { - OverallSHA string `json:"overall_sha"` - } - if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { - task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA - } - } - } - - // Add task to queue - if h.taskQueue != nil { - if err := h.taskQueue.AddTask(task); err != nil { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) - } - } - - return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "task_id": task.ID, - }) -} - -func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { - // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][snapshot_id_len:1][snapshot_id:var][snapshot_sha_len:1][snapshot_sha:var] - if len(payload) < 41 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") - } - - // Extract commit_id (20 bytes starting at position 17) - commitIDBytes := payload[17:37] - commitIDHex := hex.EncodeToString(commitIDBytes) - - priority := payload[37] - jobNameLen := int(payload[38]) - - if len(payload) < 39+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") - } - jobName := string(payload[39 : 39+jobNameLen]) - - // Parse snapshot_id - pos := 39 + jobNameLen - if len(payload) < pos+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length missing", "") - } - snapshotIDLen := int(payload[pos]) - pos++ - if len(payload) < pos+snapshotIDLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length mismatch", "") - } - snapshotID := string(payload[pos : pos+snapshotIDLen]) - pos += snapshotIDLen - - // Parse snapshot_sha - if len(payload) < pos+1 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length missing", "") - } - snapshotSHALen := int(payload[pos]) - pos++ - if len(payload) < pos+snapshotSHALen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length mismatch", "") - } - snapshotSHA := string(payload[pos : pos+snapshotSHALen]) - - // Create task - task := &queue.Task{ - ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), - JobName: jobName, - Status: "queued", - Priority: int64(priority), - CreatedAt: time.Now(), - UserID: "user", - CreatedBy: "user", - SnapshotID: snapshotID, - Metadata: map[string]string{ - "commit_id": commitIDHex, - "snapshot_sha256": snapshotSHA, - }, - } - - // Auto-detect deps manifest and compute manifest SHA if experiment exists - if h.expManager != nil { - filesPath := h.expManager.GetFilesPath(commitIDHex) - depsName, _ := selectDependencyManifest(filesPath) - if depsName != "" { - task.Metadata["deps_manifest_name"] = depsName - depsPath := filepath.Join(filesPath, depsName) - if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { - task.Metadata["deps_manifest_sha256"] = sha - } - } - - // Get experiment manifest SHA - manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") - if data, err := os.ReadFile(manifestPath); err == nil { - var man struct { - OverallSHA string `json:"overall_sha"` - } - if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { - task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA - } - } - } - - // Add task to queue - if h.taskQueue != nil { - if err := h.taskQueue.AddTask(task); err != nil { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) - } - } - - return h.sendSuccessPacket(conn, map[string]interface{}{ - "success": true, - "task_id": task.ID, - }) -} - func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) error { // Return queue status as Data packet status := map[string]interface{}{ @@ -830,11 +430,6 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) err "status": "ok", } - if h.taskQueue != nil { - // Try to get queue length - this is a best-effort operation - // The queue backend may not support this directly - } - payloadBytes, _ := json.Marshal(status) return h.sendDataPacket(conn, "status", payloadBytes) } diff --git a/internal/api/ws/jobs.go b/internal/api/ws/jobs.go new file mode 100644 index 0000000..92691ca --- /dev/null +++ b/internal/api/ws/jobs.go @@ -0,0 +1,209 @@ +// Package ws provides WebSocket handling for the API +package ws + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker/integrity" +) + +// handleQueueJob handles the QueueJob opcode (0x01) +func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { + // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] + if len(payload) < 39 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") + } + + commitIDBytes := payload[17:37] + commitIDHex := hex.EncodeToString(commitIDBytes) + priority := payload[37] + jobNameLen := int(payload[38]) + + if len(payload) < 39+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") + } + jobName := string(payload[39 : 39+jobNameLen]) + + // Parse optional resource fields + cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, "" + pos := 39 + jobNameLen + if len(payload) > pos { + cpu = int(payload[pos]) + pos++ + if len(payload) > pos { + memoryGB = int(payload[pos]) + pos++ + if len(payload) > pos { + gpu = int(payload[pos]) + pos++ + if len(payload) > pos { + gpuMemLen := int(payload[pos]) + pos++ + if len(payload) >= pos+gpuMemLen { + gpuMemory = string(payload[pos : pos+gpuMemLen]) + } + } + } + } + } + + task := &queue.Task{ + ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), + JobName: jobName, + Status: "queued", + Priority: int64(priority), + CreatedAt: time.Now(), + UserID: "user", + CreatedBy: "user", + CPU: cpu, + MemoryGB: memoryGB, + GPU: gpu, + GPUMemory: gpuMemory, + Metadata: map[string]string{"commit_id": commitIDHex}, + } + + // Auto-detect deps manifest and compute manifest SHA + if h.expManager != nil { + filesPath := h.expManager.GetFilesPath(commitIDHex) + depsName, _ := selectDependencyManifest(filesPath) + if depsName != "" { + task.Metadata["deps_manifest_name"] = depsName + depsPath := filepath.Join(filesPath, depsName) + if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { + task.Metadata["deps_manifest_sha256"] = sha + } + } + + manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") + if data, err := os.ReadFile(manifestPath); err == nil { + var man struct{ OverallSHA string `json:"overall_sha"` } + if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { + task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA + } + } + } + + if h.taskQueue != nil { + if err := h.taskQueue.AddTask(task); err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) + } + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "task_id": task.ID, + }) +} + +// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17) +func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { + if len(payload) < 41 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") + } + + commitIDBytes := payload[17:37] + commitIDHex := hex.EncodeToString(commitIDBytes) + priority := payload[37] + jobNameLen := int(payload[38]) + + if len(payload) < 39+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") + } + jobName := string(payload[39 : 39+jobNameLen]) + + pos := 39 + jobNameLen + snapshotIDLen := int(payload[pos]) + pos++ + snapshotID := string(payload[pos : pos+snapshotIDLen]) + pos += snapshotIDLen + snapshotSHALen := int(payload[pos]) + pos++ + snapshotSHA := string(payload[pos : pos+snapshotSHALen]) + + task := &queue.Task{ + ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), + JobName: jobName, + Status: "queued", + Priority: int64(priority), + CreatedAt: time.Now(), + UserID: "user", + CreatedBy: "user", + SnapshotID: snapshotID, + Metadata: map[string]string{ + "commit_id": commitIDHex, + "snapshot_sha256": snapshotSHA, + }, + } + + if h.expManager != nil { + filesPath := h.expManager.GetFilesPath(commitIDHex) + depsName, _ := selectDependencyManifest(filesPath) + if depsName != "" { + task.Metadata["deps_manifest_name"] = depsName + depsPath := filepath.Join(filesPath, depsName) + if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { + task.Metadata["deps_manifest_sha256"] = sha + } + } + + manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") + if data, err := os.ReadFile(manifestPath); err == nil { + var man struct{ OverallSHA string `json:"overall_sha"` } + if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { + task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA + } + } + } + + if h.taskQueue != nil { + if err := h.taskQueue.AddTask(task); err != nil { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) + } + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "task_id": task.ID, + }) +} + +// handleCancelJob handles the CancelJob opcode (0x03) +func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") + } + + jobNameLen := int(payload[17]) + if len(payload) < 18+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") + } + jobName := string(payload[18 : 18+jobNameLen]) + + if h.taskQueue != nil { + task, err := h.taskQueue.GetTaskByName(jobName) + if err == nil && task != nil { + task.Status = "cancelled" + h.taskQueue.UpdateTask(task) + } + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "message": "Job cancelled", + }) +} + +// handlePrune handles the Prune opcode (0x04) +func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error { + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "message": "Prune completed", + }) +} diff --git a/internal/api/ws/validate.go b/internal/api/ws/validate.go new file mode 100644 index 0000000..261815a --- /dev/null +++ b/internal/api/ws/validate.go @@ -0,0 +1,167 @@ +// Package ws provides WebSocket handling for the API +package ws + +import ( + "encoding/json" + "fmt" + "path/filepath" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/worker/integrity" +) + +// handleValidateRequest handles the ValidateRequest opcode (0x16) +func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { + // Parse payload format: [opcode:1][api_key_hash:16][mode:1][...] + // mode=0: commit_id validation [commit_id_len:1][commit_id:var] + // mode=1: task_id validation [task_id_len:1][task_id:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") + } + + mode := payload[17] + + if mode == 0 { + // Commit ID validation (basic) + if len(payload) < 20 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "") + } + commitIDLen := int(payload[18]) + if len(payload) < 19+commitIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "") + } + commitIDBytes := payload[19 : 19+commitIDLen] + commitIDHex := fmt.Sprintf("%x", commitIDBytes) + + report := map[string]interface{}{ + "ok": true, + "commit_id": commitIDHex, + } + payloadBytes, _ := json.Marshal(report) + return h.sendDataPacket(conn, "validate", payloadBytes) + } + + // Task ID validation (mode=1) - full validation with checks + if len(payload) < 20 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "") + } + + taskIDLen := int(payload[18]) + if len(payload) < 19+taskIDLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "") + } + taskID := string(payload[19 : 19+taskIDLen]) + + // Initialize validation report + checks := make(map[string]interface{}) + ok := true + + // Get task from queue + if h.taskQueue == nil { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "") + } + + task, err := h.taskQueue.GetTask(taskID) + if err != nil || task == nil { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "") + } + + // Run manifest validation - load manifest if it exists + rmCheck := map[string]interface{}{"ok": true} + rmCommitCheck := map[string]interface{}{"ok": true} + rmLocCheck := map[string]interface{}{"ok": true} + rmLifecycle := map[string]interface{}{"ok": true} + + // Determine expected location based on task status + expectedLocation := "running" + if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" { + expectedLocation = "finished" + } + + // Try to load run manifest from appropriate location + var rm *manifest.RunManifest + var rmLoadErr error + + if h.expManager != nil { + // Try expected location first + jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName) + rm, rmLoadErr = manifest.LoadFromDir(jobDir) + + // If not found and task is running, also check finished (wrong location test) + if rmLoadErr != nil && task.Status == "running" { + wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName) + rm, _ = manifest.LoadFromDir(wrongDir) + if rm != nil { + // Manifest exists but in wrong location + rmLocCheck["ok"] = false + rmLocCheck["expected"] = "running" + rmLocCheck["actual"] = "finished" + ok = false + } + } + } + + if rm == nil { + // No run manifest found + if task.Status == "running" || task.Status == "completed" { + rmCheck["ok"] = false + ok = false + } + } else { + // Run manifest exists - validate it + + // Check commit_id match + taskCommitID := task.Metadata["commit_id"] + if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID { + rmCommitCheck["ok"] = false + rmCommitCheck["expected"] = taskCommitID + ok = false + } + + // Check lifecycle ordering (started_at < ended_at) + if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) { + rmLifecycle["ok"] = false + ok = false + } + } + + checks["run_manifest"] = rmCheck + checks["run_manifest_commit_id"] = rmCommitCheck + checks["run_manifest_location"] = rmLocCheck + checks["run_manifest_lifecycle"] = rmLifecycle + + // Resources check + resCheck := map[string]interface{}{"ok": true} + if task.CPU < 0 { + resCheck["ok"] = false + ok = false + } + checks["resources"] = resCheck + + // Snapshot check + snapCheck := map[string]interface{}{"ok": true} + if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" { + // Verify snapshot SHA + dataDir := h.dataDir + if dataDir == "" { + dataDir = filepath.Join(h.expManager.BasePath(), "data") + } + snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID) + actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath) + expectedSHA := task.Metadata["snapshot_sha256"] + if actualSHA != expectedSHA { + snapCheck["ok"] = false + snapCheck["actual"] = actualSHA + ok = false + } + } + checks["snapshot"] = snapCheck + + report := map[string]interface{}{ + "ok": ok, + "checks": checks, + } + payloadBytes, _ := json.Marshal(report) + return h.sendDataPacket(conn, "validate", payloadBytes) +}