// Package ws provides WebSocket handling for the API package ws import ( "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" "strings" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) func (h *Handler) populateExperimentIntegrityMetadata( task *queue.Task, commitIDHex string, ) (string, error) { if h.expManager == nil { return "", nil } // Validate commit ID (defense-in-depth) if len(commitIDHex) != 40 { return "", fmt.Errorf("invalid commit id length") } if _, err := hex.DecodeString(commitIDHex); err != nil { return "", fmt.Errorf("invalid commit id format") } filesPath := h.expManager.GetFilesPath(commitIDHex) depsName, err := selectDependencyManifest(filesPath) if err != nil { return "", err } if depsName != "" { task.Metadata["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { task.Metadata["deps_manifest_sha256"] = sha } } basePath := filepath.Clean(h.expManager.BasePath()) manifestPath := filepath.Join(basePath, commitIDHex, "manifest.json") manifestPath = filepath.Clean(manifestPath) if !strings.HasPrefix(manifestPath, basePath+string(os.PathSeparator)) { return "", fmt.Errorf("path traversal detected") } if data, err := os.ReadFile(manifestPath); err == nil { var man struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA } } return depsName, nil } // handleQueueJob handles the QueueJob opcode (0x01) func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] if len(payload) < 39 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) // Parse optional resource fields cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, "" pos := 39 + jobNameLen if len(payload) > pos { cpu = int(payload[pos]) pos++ if len(payload) > pos { memoryGB = int(payload[pos]) pos++ if len(payload) > pos { gpu = int(payload[pos]) pos++ if len(payload) > pos { gpuMemLen := int(payload[pos]) pos++ if len(payload) >= pos+gpuMemLen { gpuMemory = string(payload[pos : pos+gpuMemLen]) } } } } } task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", CPU: cpu, MemoryGB: memoryGB, GPU: gpu, GPUMemory: gpuMemory, Metadata: map[string]string{"commit_id": commitIDHex}, } if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(), ) } if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID}) } // handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17) func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { if len(payload) < 41 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) pos := 39 + jobNameLen snapshotIDLen := int(payload[pos]) pos++ snapshotID := string(payload[pos : pos+snapshotIDLen]) pos += snapshotIDLen snapshotSHALen := int(payload[pos]) pos++ snapshotSHA := string(payload[pos : pos+snapshotSHALen]) task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", SnapshotID: snapshotID, Metadata: map[string]string{ "commit_id": commitIDHex, "snapshot_sha256": snapshotSHA, }, } if _, err := h.populateExperimentIntegrityMetadata(task, commitIDHex); err != nil { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "failed to resolve experiment metadata", err.Error(), ) } if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]any{"task_id": task.ID}) } // handleCancelJob handles the CancelJob opcode (0x03) func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } jobNameLen := int(payload[17]) if len(payload) < 18+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[18 : 18+jobNameLen]) if h.taskQueue != nil { task, err := h.taskQueue.GetTaskByName(jobName) if err == nil && task != nil { task.Status = "cancelled" if err := h.taskQueue.UpdateTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to cancel task", err.Error()) } } } return h.sendSuccessPacket(conn, map[string]any{"message": "Job cancelled"}) } // handlePrune handles the Prune opcode (0x04) func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][prune_type:1][value:4] if len(payload) < 16+1+4 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") } // Authenticate user // Skip 16-byte API key hash for now (authentication would use it) // offset := 16 // pruneType := payload[offset] // value := binary.BigEndian.Uint32(payload[offset+1 : offset+5]) return h.sendSuccessPacket(conn, map[string]any{"message": "Prune completed", "pruned": 0}) }