// Package ws provides WebSocket handling for the API package ws import ( "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) // handleQueueJob handles the QueueJob opcode (0x01) func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] if len(payload) < 39 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) // Parse optional resource fields cpu, memoryGB, gpu, gpuMemory := 0, 0, 0, "" pos := 39 + jobNameLen if len(payload) > pos { cpu = int(payload[pos]) pos++ if len(payload) > pos { memoryGB = int(payload[pos]) pos++ if len(payload) > pos { gpu = int(payload[pos]) pos++ if len(payload) > pos { gpuMemLen := int(payload[pos]) pos++ if len(payload) >= pos+gpuMemLen { gpuMemory = string(payload[pos : pos+gpuMemLen]) } } } } } task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", CPU: cpu, MemoryGB: memoryGB, GPU: gpu, GPUMemory: gpuMemory, Metadata: map[string]string{"commit_id": commitIDHex}, } // Auto-detect deps manifest and compute manifest SHA if h.expManager != nil { filesPath := h.expManager.GetFilesPath(commitIDHex) depsName, _ := selectDependencyManifest(filesPath) if depsName != "" { task.Metadata["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { task.Metadata["deps_manifest_sha256"] = sha } } manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") if data, err := os.ReadFile(manifestPath); err == nil { var man struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA } } } if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "task_id": task.ID, }) } // handleQueueJobWithSnapshot handles the QueueJobWithSnapshot opcode (0x17) func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { if len(payload) < 41 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) pos := 39 + jobNameLen snapshotIDLen := int(payload[pos]) pos++ snapshotID := string(payload[pos : pos+snapshotIDLen]) pos += snapshotIDLen snapshotSHALen := int(payload[pos]) pos++ snapshotSHA := string(payload[pos : pos+snapshotSHALen]) task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", SnapshotID: snapshotID, Metadata: map[string]string{ "commit_id": commitIDHex, "snapshot_sha256": snapshotSHA, }, } if h.expManager != nil { filesPath := h.expManager.GetFilesPath(commitIDHex) depsName, _ := selectDependencyManifest(filesPath) if depsName != "" { task.Metadata["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { task.Metadata["deps_manifest_sha256"] = sha } } manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") if data, err := os.ReadFile(manifestPath); err == nil { var man struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA } } } if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "task_id": task.ID, }) } // handleCancelJob handles the CancelJob opcode (0x03) func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } jobNameLen := int(payload[17]) if len(payload) < 18+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[18 : 18+jobNameLen]) if h.taskQueue != nil { task, err := h.taskQueue.GetTaskByName(jobName) if err == nil && task != nil { task.Status = "cancelled" h.taskQueue.UpdateTask(task) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Job cancelled", }) } // handlePrune handles the Prune opcode (0x04) func (h *Handler) handlePrune(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][prune_type:1][value:4] if len(payload) < 16+1+4 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") } // Authenticate user // Skip 16-byte API key hash for now (authentication would use it) // offset := 16 // pruneType := payload[offset] // value := binary.BigEndian.Uint32(payload[offset+1 : offset+5]) return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Prune completed", "pruned": 0, }) }