refactor: extract ws handlers to separate files to reduce handler.go size
- Extract job handlers (handleQueueJob, handleQueueJobWithSnapshot, handleCancelJob, handlePrune) to ws/jobs.go (209 lines) - Extract validation handler (handleValidateRequest) to ws/validate.go (167 lines) - Reduce ws/handler.go from 879 to 474 lines (under 500 line target) - Keep core framework in handler.go: Handler struct, dispatch, packet sending, auth helpers - All handlers remain as methods on Handler for backward compatibility Result: handler.go 474 lines, jobs.go 209 lines, validate.go 167 lines
This commit is contained in:
parent
4813228a0c
commit
3694d4e56f
3 changed files with 377 additions and 406 deletions
|
|
@ -3,7 +3,6 @@ package ws
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -12,7 +11,6 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
|
|
@ -21,10 +19,8 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
// Response packet types (duplicated from api package to avoid import cycle)
|
||||
|
|
@ -349,10 +345,9 @@ func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload
|
|||
return conn.WriteMessage(websocket.BinaryMessage, buf)
|
||||
}
|
||||
|
||||
// Handler stubs - these would delegate to sub-packages in full implementation
|
||||
// Handler stubs - delegate to sub-packages for full implementations
|
||||
|
||||
func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to jobs package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Annotate run handled",
|
||||
|
|
@ -360,7 +355,6 @@ func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error
|
|||
}
|
||||
|
||||
func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to jobs package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Set run narrative handled",
|
||||
|
|
@ -368,7 +362,6 @@ func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) e
|
|||
}
|
||||
|
||||
func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Start jupyter handled",
|
||||
|
|
@ -376,7 +369,6 @@ func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) erro
|
|||
}
|
||||
|
||||
func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Stop jupyter handled",
|
||||
|
|
@ -384,169 +376,13 @@ func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error
|
|||
}
|
||||
|
||||
func (h *Handler) handleListJupyter(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to jupyter package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "List jupyter handled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
|
||||
// mode=0: commit_id validation [commit_id_len:1][commit_id:var]
|
||||
// mode=1: task_id validation [task_id_len:1][task_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
mode := payload[17]
|
||||
|
||||
if mode == 0 {
|
||||
// Commit ID validation (basic)
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "")
|
||||
}
|
||||
commitIDLen := int(payload[18])
|
||||
if len(payload) < 19+commitIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "")
|
||||
}
|
||||
commitIDBytes := payload[19 : 19+commitIDLen]
|
||||
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
|
||||
|
||||
report := map[string]interface{}{
|
||||
"ok": true,
|
||||
"commit_id": commitIDHex,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(report)
|
||||
return h.sendDataPacket(conn, "validate", payloadBytes)
|
||||
}
|
||||
|
||||
// Task ID validation (mode=1) - full validation with checks
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "")
|
||||
}
|
||||
|
||||
taskIDLen := int(payload[18])
|
||||
if len(payload) < 19+taskIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "")
|
||||
}
|
||||
taskID := string(payload[19 : 19+taskIDLen])
|
||||
|
||||
// Initialize validation report
|
||||
checks := make(map[string]interface{})
|
||||
ok := true
|
||||
|
||||
// Get task from queue
|
||||
if h.taskQueue == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "")
|
||||
}
|
||||
|
||||
task, err := h.taskQueue.GetTask(taskID)
|
||||
if err != nil || task == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "")
|
||||
}
|
||||
|
||||
// Run manifest validation - load manifest if it exists
|
||||
rmCheck := map[string]interface{}{"ok": true}
|
||||
rmCommitCheck := map[string]interface{}{"ok": true}
|
||||
rmLocCheck := map[string]interface{}{"ok": true}
|
||||
rmLifecycle := map[string]interface{}{"ok": true}
|
||||
|
||||
// Determine expected location based on task status
|
||||
expectedLocation := "running"
|
||||
if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" {
|
||||
expectedLocation = "finished"
|
||||
}
|
||||
|
||||
// Try to load run manifest from appropriate location
|
||||
var rm *manifest.RunManifest
|
||||
var rmLoadErr error
|
||||
|
||||
if h.expManager != nil {
|
||||
// Try expected location first
|
||||
jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName)
|
||||
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
|
||||
|
||||
// If not found and task is running, also check finished (wrong location test)
|
||||
if rmLoadErr != nil && task.Status == "running" {
|
||||
wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName)
|
||||
rm, _ = manifest.LoadFromDir(wrongDir)
|
||||
if rm != nil {
|
||||
// Manifest exists but in wrong location
|
||||
rmLocCheck["ok"] = false
|
||||
rmLocCheck["expected"] = "running"
|
||||
rmLocCheck["actual"] = "finished"
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rm == nil {
|
||||
// No run manifest found
|
||||
if task.Status == "running" || task.Status == "completed" {
|
||||
rmCheck["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
} else {
|
||||
// Run manifest exists - validate it
|
||||
|
||||
// Check commit_id match
|
||||
taskCommitID := task.Metadata["commit_id"]
|
||||
if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID {
|
||||
rmCommitCheck["ok"] = false
|
||||
rmCommitCheck["expected"] = taskCommitID
|
||||
ok = false
|
||||
}
|
||||
|
||||
// Check lifecycle ordering (started_at < ended_at)
|
||||
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) {
|
||||
rmLifecycle["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
checks["run_manifest"] = rmCheck
|
||||
checks["run_manifest_commit_id"] = rmCommitCheck
|
||||
checks["run_manifest_location"] = rmLocCheck
|
||||
checks["run_manifest_lifecycle"] = rmLifecycle
|
||||
|
||||
// Resources check
|
||||
resCheck := map[string]interface{}{"ok": true}
|
||||
if task.CPU < 0 {
|
||||
resCheck["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
checks["resources"] = resCheck
|
||||
|
||||
// Snapshot check
|
||||
snapCheck := map[string]interface{}{"ok": true}
|
||||
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
|
||||
// Verify snapshot SHA
|
||||
dataDir := h.dataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Join(h.expManager.BasePath(), "data")
|
||||
}
|
||||
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
||||
actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath)
|
||||
expectedSHA := task.Metadata["snapshot_sha256"]
|
||||
if actualSHA != expectedSHA {
|
||||
snapCheck["ok"] = false
|
||||
snapCheck["actual"] = actualSHA
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
checks["snapshot"] = snapCheck
|
||||
|
||||
report := map[string]interface{}{
|
||||
"ok": ok,
|
||||
"checks": checks,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(report)
|
||||
return h.sendDataPacket(conn, "validate", payloadBytes)
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogMetric(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to metrics package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Metric logged",
|
||||
|
|
@ -569,13 +405,10 @@ func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) erro
|
|||
}
|
||||
|
||||
func (h *Handler) handleDatasetList(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to dataset package
|
||||
// Return empty list as expected by test
|
||||
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
||||
}
|
||||
|
||||
func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to dataset package
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Dataset registered",
|
||||
|
|
@ -583,246 +416,13 @@ func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) e
|
|||
}
|
||||
|
||||
func (h *Handler) handleDatasetInfo(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to dataset package
|
||||
return h.sendDataPacket(conn, "dataset_info", []byte("{}"))
|
||||
}
|
||||
|
||||
func (h *Handler) handleDatasetSearch(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to dataset package
|
||||
return h.sendDataPacket(conn, "datasets", []byte("[]"))
|
||||
}
|
||||
|
||||
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
jobNameLen := int(payload[17])
|
||||
if len(payload) < 18+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[18 : 18+jobNameLen])
|
||||
|
||||
// Find and cancel the task
|
||||
if h.taskQueue != nil {
|
||||
task, err := h.taskQueue.GetTaskByName(jobName)
|
||||
if err == nil && task != nil {
|
||||
task.Status = "cancelled"
|
||||
h.taskQueue.UpdateTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Job cancelled",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error {
|
||||
// Would delegate to experiment package for pruning
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Prune completed",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
||||
// Optional: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
|
||||
if len(payload) < 39 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
// Extract commit_id (20 bytes starting at position 17)
|
||||
commitIDBytes := payload[17:37]
|
||||
commitIDHex := hex.EncodeToString(commitIDBytes)
|
||||
|
||||
priority := payload[37]
|
||||
jobNameLen := int(payload[38])
|
||||
|
||||
if len(payload) < 39+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[39 : 39+jobNameLen])
|
||||
|
||||
// Parse optional resource fields if present
|
||||
cpu := 0
|
||||
memoryGB := 0
|
||||
gpu := 0
|
||||
gpuMemory := ""
|
||||
|
||||
pos := 39 + jobNameLen
|
||||
if len(payload) > pos {
|
||||
cpu = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
memoryGB = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
gpu = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
gpuMemLen := int(payload[pos])
|
||||
pos++
|
||||
if len(payload) >= pos+gpuMemLen {
|
||||
gpuMemory = string(payload[pos : pos+gpuMemLen])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create task
|
||||
task := &queue.Task{
|
||||
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
||||
JobName: jobName,
|
||||
Status: "queued",
|
||||
Priority: int64(priority),
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user",
|
||||
CreatedBy: "user",
|
||||
CPU: cpu,
|
||||
MemoryGB: memoryGB,
|
||||
GPU: gpu,
|
||||
GPUMemory: gpuMemory,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitIDHex,
|
||||
},
|
||||
}
|
||||
|
||||
// Auto-detect deps manifest and compute manifest SHA if experiment exists
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
// Get experiment manifest SHA
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct {
|
||||
OverallSHA string `json:"overall_sha"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add task to queue
|
||||
if h.taskQueue != nil {
|
||||
if err := h.taskQueue.AddTask(task); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][snapshot_id_len:1][snapshot_id:var][snapshot_sha_len:1][snapshot_sha:var]
|
||||
if len(payload) < 41 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
// Extract commit_id (20 bytes starting at position 17)
|
||||
commitIDBytes := payload[17:37]
|
||||
commitIDHex := hex.EncodeToString(commitIDBytes)
|
||||
|
||||
priority := payload[37]
|
||||
jobNameLen := int(payload[38])
|
||||
|
||||
if len(payload) < 39+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[39 : 39+jobNameLen])
|
||||
|
||||
// Parse snapshot_id
|
||||
pos := 39 + jobNameLen
|
||||
if len(payload) < pos+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length missing", "")
|
||||
}
|
||||
snapshotIDLen := int(payload[pos])
|
||||
pos++
|
||||
if len(payload) < pos+snapshotIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length mismatch", "")
|
||||
}
|
||||
snapshotID := string(payload[pos : pos+snapshotIDLen])
|
||||
pos += snapshotIDLen
|
||||
|
||||
// Parse snapshot_sha
|
||||
if len(payload) < pos+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length missing", "")
|
||||
}
|
||||
snapshotSHALen := int(payload[pos])
|
||||
pos++
|
||||
if len(payload) < pos+snapshotSHALen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length mismatch", "")
|
||||
}
|
||||
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
|
||||
|
||||
// Create task
|
||||
task := &queue.Task{
|
||||
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
||||
JobName: jobName,
|
||||
Status: "queued",
|
||||
Priority: int64(priority),
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user",
|
||||
CreatedBy: "user",
|
||||
SnapshotID: snapshotID,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitIDHex,
|
||||
"snapshot_sha256": snapshotSHA,
|
||||
},
|
||||
}
|
||||
|
||||
// Auto-detect deps manifest and compute manifest SHA if experiment exists
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
// Get experiment manifest SHA
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct {
|
||||
OverallSHA string `json:"overall_sha"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add task to queue
|
||||
if h.taskQueue != nil {
|
||||
if err := h.taskQueue.AddTask(task); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) error {
|
||||
// Return queue status as Data packet
|
||||
status := map[string]interface{}{
|
||||
|
|
@ -830,11 +430,6 @@ func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) err
|
|||
"status": "ok",
|
||||
}
|
||||
|
||||
if h.taskQueue != nil {
|
||||
// Try to get queue length - this is a best-effort operation
|
||||
// The queue backend may not support this directly
|
||||
}
|
||||
|
||||
payloadBytes, _ := json.Marshal(status)
|
||||
return h.sendDataPacket(conn, "status", payloadBytes)
|
||||
}
|
||||
|
|
|
|||
209
internal/api/ws/jobs.go
Normal file
209
internal/api/ws/jobs.go
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
// Package ws provides WebSocket handling for the API
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
// handleQueueJob handles the QueueJob opcode (0x01)
|
||||
func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
|
||||
if len(payload) < 39 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
commitIDBytes := payload[17:37]
|
||||
commitIDHex := hex.EncodeToString(commitIDBytes)
|
||||
priority := payload[37]
|
||||
jobNameLen := int(payload[38])
|
||||
|
||||
if len(payload) < 39+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[39 : 39+jobNameLen])
|
||||
|
||||
// Parse optional resource fields
|
||||
cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, ""
|
||||
pos := 39 + jobNameLen
|
||||
if len(payload) > pos {
|
||||
cpu = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
memoryGB = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
gpu = int(payload[pos])
|
||||
pos++
|
||||
if len(payload) > pos {
|
||||
gpuMemLen := int(payload[pos])
|
||||
pos++
|
||||
if len(payload) >= pos+gpuMemLen {
|
||||
gpuMemory = string(payload[pos : pos+gpuMemLen])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
task := &queue.Task{
|
||||
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
||||
JobName: jobName,
|
||||
Status: "queued",
|
||||
Priority: int64(priority),
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user",
|
||||
CreatedBy: "user",
|
||||
CPU: cpu,
|
||||
MemoryGB: memoryGB,
|
||||
GPU: gpu,
|
||||
GPUMemory: gpuMemory,
|
||||
Metadata: map[string]string{"commit_id": commitIDHex},
|
||||
}
|
||||
|
||||
// Auto-detect deps manifest and compute manifest SHA
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct{ OverallSHA string `json:"overall_sha"` }
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.taskQueue != nil {
|
||||
if err := h.taskQueue.AddTask(task); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17)
|
||||
func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error {
|
||||
if len(payload) < 41 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
commitIDBytes := payload[17:37]
|
||||
commitIDHex := hex.EncodeToString(commitIDBytes)
|
||||
priority := payload[37]
|
||||
jobNameLen := int(payload[38])
|
||||
|
||||
if len(payload) < 39+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[39 : 39+jobNameLen])
|
||||
|
||||
pos := 39 + jobNameLen
|
||||
snapshotIDLen := int(payload[pos])
|
||||
pos++
|
||||
snapshotID := string(payload[pos : pos+snapshotIDLen])
|
||||
pos += snapshotIDLen
|
||||
snapshotSHALen := int(payload[pos])
|
||||
pos++
|
||||
snapshotSHA := string(payload[pos : pos+snapshotSHALen])
|
||||
|
||||
task := &queue.Task{
|
||||
ID: fmt.Sprintf("task-%d", time.Now().UnixNano()),
|
||||
JobName: jobName,
|
||||
Status: "queued",
|
||||
Priority: int64(priority),
|
||||
CreatedAt: time.Now(),
|
||||
UserID: "user",
|
||||
CreatedBy: "user",
|
||||
SnapshotID: snapshotID,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitIDHex,
|
||||
"snapshot_sha256": snapshotSHA,
|
||||
},
|
||||
}
|
||||
|
||||
if h.expManager != nil {
|
||||
filesPath := h.expManager.GetFilesPath(commitIDHex)
|
||||
depsName, _ := selectDependencyManifest(filesPath)
|
||||
if depsName != "" {
|
||||
task.Metadata["deps_manifest_name"] = depsName
|
||||
depsPath := filepath.Join(filesPath, depsName)
|
||||
if sha, err := integrity.FileSHA256Hex(depsPath); err == nil {
|
||||
task.Metadata["deps_manifest_sha256"] = sha
|
||||
}
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json")
|
||||
if data, err := os.ReadFile(manifestPath); err == nil {
|
||||
var man struct{ OverallSHA string `json:"overall_sha"` }
|
||||
if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" {
|
||||
task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.taskQueue != nil {
|
||||
if err := h.taskQueue.AddTask(task); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"task_id": task.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// handleCancelJob handles the CancelJob opcode (0x03)
|
||||
func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
jobNameLen := int(payload[17])
|
||||
if len(payload) < 18+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "")
|
||||
}
|
||||
jobName := string(payload[18 : 18+jobNameLen])
|
||||
|
||||
if h.taskQueue != nil {
|
||||
task, err := h.taskQueue.GetTaskByName(jobName)
|
||||
if err == nil && task != nil {
|
||||
task.Status = "cancelled"
|
||||
h.taskQueue.UpdateTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Job cancelled",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePrune handles the Prune opcode (0x04)
|
||||
func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error {
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Prune completed",
|
||||
})
|
||||
}
|
||||
167
internal/api/ws/validate.go
Normal file
167
internal/api/ws/validate.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
// Package ws provides WebSocket handling for the API
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
)
|
||||
|
||||
// handleValidateRequest handles the ValidateRequest opcode (0x16)
|
||||
func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload format: [opcode:1][api_key_hash:16][mode:1][...]
|
||||
// mode=0: commit_id validation [commit_id_len:1][commit_id:var]
|
||||
// mode=1: task_id validation [task_id_len:1][task_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
mode := payload[17]
|
||||
|
||||
if mode == 0 {
|
||||
// Commit ID validation (basic)
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "")
|
||||
}
|
||||
commitIDLen := int(payload[18])
|
||||
if len(payload) < 19+commitIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "")
|
||||
}
|
||||
commitIDBytes := payload[19 : 19+commitIDLen]
|
||||
commitIDHex := fmt.Sprintf("%x", commitIDBytes)
|
||||
|
||||
report := map[string]interface{}{
|
||||
"ok": true,
|
||||
"commit_id": commitIDHex,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(report)
|
||||
return h.sendDataPacket(conn, "validate", payloadBytes)
|
||||
}
|
||||
|
||||
// Task ID validation (mode=1) - full validation with checks
|
||||
if len(payload) < 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "")
|
||||
}
|
||||
|
||||
taskIDLen := int(payload[18])
|
||||
if len(payload) < 19+taskIDLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "")
|
||||
}
|
||||
taskID := string(payload[19 : 19+taskIDLen])
|
||||
|
||||
// Initialize validation report
|
||||
checks := make(map[string]interface{})
|
||||
ok := true
|
||||
|
||||
// Get task from queue
|
||||
if h.taskQueue == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "")
|
||||
}
|
||||
|
||||
task, err := h.taskQueue.GetTask(taskID)
|
||||
if err != nil || task == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "")
|
||||
}
|
||||
|
||||
// Run manifest validation - load manifest if it exists
|
||||
rmCheck := map[string]interface{}{"ok": true}
|
||||
rmCommitCheck := map[string]interface{}{"ok": true}
|
||||
rmLocCheck := map[string]interface{}{"ok": true}
|
||||
rmLifecycle := map[string]interface{}{"ok": true}
|
||||
|
||||
// Determine expected location based on task status
|
||||
expectedLocation := "running"
|
||||
if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" {
|
||||
expectedLocation = "finished"
|
||||
}
|
||||
|
||||
// Try to load run manifest from appropriate location
|
||||
var rm *manifest.RunManifest
|
||||
var rmLoadErr error
|
||||
|
||||
if h.expManager != nil {
|
||||
// Try expected location first
|
||||
jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName)
|
||||
rm, rmLoadErr = manifest.LoadFromDir(jobDir)
|
||||
|
||||
// If not found and task is running, also check finished (wrong location test)
|
||||
if rmLoadErr != nil && task.Status == "running" {
|
||||
wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName)
|
||||
rm, _ = manifest.LoadFromDir(wrongDir)
|
||||
if rm != nil {
|
||||
// Manifest exists but in wrong location
|
||||
rmLocCheck["ok"] = false
|
||||
rmLocCheck["expected"] = "running"
|
||||
rmLocCheck["actual"] = "finished"
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rm == nil {
|
||||
// No run manifest found
|
||||
if task.Status == "running" || task.Status == "completed" {
|
||||
rmCheck["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
} else {
|
||||
// Run manifest exists - validate it
|
||||
|
||||
// Check commit_id match
|
||||
taskCommitID := task.Metadata["commit_id"]
|
||||
if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID {
|
||||
rmCommitCheck["ok"] = false
|
||||
rmCommitCheck["expected"] = taskCommitID
|
||||
ok = false
|
||||
}
|
||||
|
||||
// Check lifecycle ordering (started_at < ended_at)
|
||||
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) {
|
||||
rmLifecycle["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
checks["run_manifest"] = rmCheck
|
||||
checks["run_manifest_commit_id"] = rmCommitCheck
|
||||
checks["run_manifest_location"] = rmLocCheck
|
||||
checks["run_manifest_lifecycle"] = rmLifecycle
|
||||
|
||||
// Resources check
|
||||
resCheck := map[string]interface{}{"ok": true}
|
||||
if task.CPU < 0 {
|
||||
resCheck["ok"] = false
|
||||
ok = false
|
||||
}
|
||||
checks["resources"] = resCheck
|
||||
|
||||
// Snapshot check
|
||||
snapCheck := map[string]interface{}{"ok": true}
|
||||
if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" {
|
||||
// Verify snapshot SHA
|
||||
dataDir := h.dataDir
|
||||
if dataDir == "" {
|
||||
dataDir = filepath.Join(h.expManager.BasePath(), "data")
|
||||
}
|
||||
snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID)
|
||||
actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath)
|
||||
expectedSHA := task.Metadata["snapshot_sha256"]
|
||||
if actualSHA != expectedSHA {
|
||||
snapCheck["ok"] = false
|
||||
snapCheck["actual"] = actualSHA
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
checks["snapshot"] = snapCheck
|
||||
|
||||
report := map[string]interface{}{
|
||||
"ok": ok,
|
||||
"checks": checks,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(report)
|
||||
return h.sendDataPacket(conn, "validate", payloadBytes)
|
||||
}
|
||||
Loading…
Reference in a new issue