// Package ws provides WebSocket handling for the API package ws import ( "encoding/binary" "encoding/hex" "encoding/json" "errors" "fmt" "net/http" "net/url" "os" "path/filepath" "strings" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/worker/integrity" ) // Response packet types (duplicated from api package to avoid import cycle) const ( PacketTypeSuccess = 0x00 PacketTypeError = 0x01 PacketTypeProgress = 0x02 PacketTypeStatus = 0x03 PacketTypeData = 0x04 PacketTypeLog = 0x05 ) // Opcodes for binary WebSocket protocol const ( OpcodeQueueJob = 0x01 OpcodeStatusRequest = 0x02 OpcodeCancelJob = 0x03 OpcodePrune = 0x04 OpcodeDatasetList = 0x06 OpcodeDatasetRegister = 0x07 OpcodeDatasetInfo = 0x08 OpcodeDatasetSearch = 0x09 OpcodeLogMetric = 0x0A OpcodeGetExperiment = 0x0B OpcodeQueueJobWithTracking = 0x0C OpcodeQueueJobWithSnapshot = 0x17 OpcodeQueueJobWithArgs = 0x1A OpcodeQueueJobWithNote = 0x1B OpcodeAnnotateRun = 0x1C OpcodeSetRunNarrative = 0x1D OpcodeStartJupyter = 0x0D OpcodeStopJupyter = 0x0E OpcodeRemoveJupyter = 0x18 OpcodeRestoreJupyter = 0x19 OpcodeListJupyter = 0x0F OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 // Logs opcodes OpcodeGetLogs = 0x20 OpcodeStreamLogs = 0x21 ) // Error codes const ( ErrorCodeUnknownError = 0x00 ErrorCodeInvalidRequest = 0x01 ErrorCodeAuthenticationFailed = 0x02 ErrorCodePermissionDenied = 0x03 ErrorCodeResourceNotFound = 0x04 ErrorCodeResourceAlreadyExists = 0x05 ErrorCodeServerOverloaded = 0x10 ErrorCodeDatabaseError = 0x11 ErrorCodeNetworkError = 0x12 ErrorCodeStorageError = 0x13 ErrorCodeTimeout = 0x14 ErrorCodeJobNotFound = 0x20 ErrorCodeJobAlreadyRunning = 0x21 ErrorCodeJobFailedToStart = 0x22 ErrorCodeJobExecutionFailed = 0x23 ErrorCodeJobCancelled = 0x24 ErrorCodeOutOfMemory = 0x30 ErrorCodeDiskFull = 0x31 ErrorCodeInvalidConfiguration = 0x32 ErrorCodeServiceUnavailable = 0x33 ) // Permissions const ( PermJobsCreate = "jobs:create" PermJobsRead = "jobs:read" PermJobsUpdate = "jobs:update" PermDatasetsRead = "datasets:read" PermDatasetsCreate = "datasets:create" PermJupyterManage = "jupyter:manage" PermJupyterRead = "jupyter:read" ) // Handler provides WebSocket handling type Handler struct { authConfig *auth.Config logger *logging.Logger expManager *experiment.Manager dataDir string taskQueue queue.Backend db *storage.DB jupyterServiceMgr *jupyter.ServiceManager securityCfg *config.SecurityConfig auditLogger *audit.Logger upgrader websocket.Upgrader } // NewHandler creates a new WebSocket handler func NewHandler( authConfig *auth.Config, logger *logging.Logger, expManager *experiment.Manager, dataDir string, taskQueue queue.Backend, db *storage.DB, jupyterServiceMgr *jupyter.ServiceManager, securityCfg *config.SecurityConfig, auditLogger *audit.Logger, ) *Handler { upgrader := createUpgrader(securityCfg) return &Handler{ authConfig: authConfig, logger: logger, expManager: expManager, dataDir: dataDir, taskQueue: taskQueue, db: db, jupyterServiceMgr: jupyterServiceMgr, securityCfg: securityCfg, auditLogger: auditLogger, upgrader: upgrader, } } // createUpgrader creates a WebSocket upgrader with the given security configuration func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { return websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return true // Allow same-origin requests } // Production mode: strict checking against allowed origins if securityCfg != nil && securityCfg.ProductionMode { for _, allowed := range securityCfg.AllowedOrigins { if origin == allowed { return true } } return false // Reject if not in allowed list } // Development mode: allow localhost and local network origins parsedOrigin, err := url.Parse(origin) if err != nil { return false } host := parsedOrigin.Host if strings.HasPrefix(host, "localhost:") || strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "192.168.") || strings.HasPrefix(host, "10.") || strings.HasPrefix(host, "[::1]:") { return true } return false }, EnableCompression: true, } } // ServeHTTP implements http.Handler for WebSocket upgrade func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := h.upgrader.Upgrade(w, r, nil) if err != nil { h.logger.Error("websocket upgrade failed", "error", err) return } defer conn.Close() h.handleConnection(conn) } // handleConnection handles an established WebSocket connection func (h *Handler) handleConnection(conn *websocket.Conn) { h.logger.Info("websocket connection established", "remote", conn.RemoteAddr()) for { messageType, payload, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { h.logger.Error("websocket read error", "error", err) } break } if messageType != websocket.BinaryMessage { h.logger.Warn("received non-binary message, ignoring") continue } if err := h.handleMessage(conn, payload); err != nil { h.logger.Error("message handling error", "error", err) // Don't break, continue handling messages } } h.logger.Info("websocket connection closed", "remote", conn.RemoteAddr()) } // handleMessage dispatches WebSocket messages to appropriate handlers func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { if len(payload) < 17 { // At least opcode + api_key_hash return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } opcode := payload[0] // First byte is opcode, followed by 16-byte API key hash switch opcode { case OpcodeAnnotateRun: return h.handleAnnotateRun(conn, payload) case OpcodeSetRunNarrative: return h.handleSetRunNarrative(conn, payload) case OpcodeStartJupyter: return h.handleStartJupyter(conn, payload) case OpcodeStopJupyter: return h.handleStopJupyter(conn, payload) case OpcodeListJupyter: return h.handleListJupyter(conn, payload) case OpcodeQueueJob: return h.handleQueueJob(conn, payload) case OpcodeQueueJobWithSnapshot: return h.handleQueueJobWithSnapshot(conn, payload) case OpcodeStatusRequest: return h.handleStatusRequest(conn, payload) case OpcodeCancelJob: return h.handleCancelJob(conn, payload) case OpcodePrune: return h.handlePrune(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) case OpcodeLogMetric: return h.handleLogMetric(conn, payload) case OpcodeGetExperiment: return h.handleGetExperiment(conn, payload) case OpcodeDatasetList: return h.handleDatasetList(conn, payload) case OpcodeDatasetRegister: return h.handleDatasetRegister(conn, payload) case OpcodeDatasetInfo: return h.handleDatasetInfo(conn, payload) case OpcodeDatasetSearch: return h.handleDatasetSearch(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } } // sendErrorPacket sends an error response packet func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { // Binary protocol: [PacketType:1][Timestamp:8][ErrorCode:1][ErrorMessageLen:varint][ErrorMessage][ErrorDetailsLen:varint][ErrorDetails] var buf []byte buf = append(buf, PacketTypeError) // Timestamp (8 bytes, big-endian) - simplified, using 0 for now buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) // Error code buf = append(buf, code) // Error message with length prefix msgLen := uint64(len(message)) var tmp [10]byte n := binary.PutUvarint(tmp[:], msgLen) buf = append(buf, tmp[:n]...) buf = append(buf, message...) // Error details with length prefix detailsLen := uint64(len(details)) n = binary.PutUvarint(tmp[:], detailsLen) buf = append(buf, tmp[:n]...) buf = append(buf, details...) return conn.WriteMessage(websocket.BinaryMessage, buf) } // sendSuccessPacket sends a success response packet with JSON payload func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error { payload, err := json.Marshal(data) if err != nil { return err } // Binary protocol: [PacketType:1][Timestamp:8][PayloadLen:varint][Payload] var buf []byte buf = append(buf, PacketTypeSuccess) // Timestamp (8 bytes, big-endian) buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) // Payload with length prefix payloadLen := uint64(len(payload)) var tmp [10]byte n := binary.PutUvarint(tmp[:], payloadLen) buf = append(buf, tmp[:n]...) buf = append(buf, payload...) return conn.WriteMessage(websocket.BinaryMessage, buf) } // sendDataPacket sends a data response packet func (h *Handler) sendDataPacket(conn *websocket.Conn, dataType string, payload []byte) error { // Binary protocol: [PacketType:1][Timestamp:8][DataTypeLen:varint][DataType][PayloadLen:varint][Payload] var buf []byte buf = append(buf, PacketTypeData) // Timestamp (8 bytes, big-endian) buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) // DataType with length prefix typeLen := uint64(len(dataType)) var tmp [10]byte n := binary.PutUvarint(tmp[:], typeLen) buf = append(buf, tmp[:n]...) buf = append(buf, dataType...) // Payload with length prefix payloadLen := uint64(len(payload)) n = binary.PutUvarint(tmp[:], payloadLen) buf = append(buf, tmp[:n]...) buf = append(buf, payload...) return conn.WriteMessage(websocket.BinaryMessage, buf) } // Handler stubs - these would delegate to sub-packages in full implementation func (h *Handler) handleAnnotateRun(conn *websocket.Conn, _payload []byte) error { // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Annotate run handled", }) } func (h *Handler) handleSetRunNarrative(conn *websocket.Conn, _payload []byte) error { // Would delegate to jobs package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Set run narrative handled", }) } func (h *Handler) handleStartJupyter(conn *websocket.Conn, _payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Start jupyter handled", }) } func (h *Handler) handleStopJupyter(conn *websocket.Conn, _payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Stop jupyter handled", }) } func (h *Handler) handleListJupyter(conn *websocket.Conn, _payload []byte) error { // Would delegate to jupyter package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "List jupyter handled", }) } func (h *Handler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { // Parse payload format: [opcode:1][api_key_hash:16][mode:1][...] // mode=0: commit_id validation [commit_id_len:1][commit_id:var] // mode=1: task_id validation [task_id_len:1][task_id:var] if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } mode := payload[17] if mode == 0 { // Commit ID validation (basic) if len(payload) < 20 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for commit validation", "") } commitIDLen := int(payload[18]) if len(payload) < 19+commitIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id length mismatch", "") } commitIDBytes := payload[19 : 19+commitIDLen] commitIDHex := fmt.Sprintf("%x", commitIDBytes) report := map[string]interface{}{ "ok": true, "commit_id": commitIDHex, } payloadBytes, _ := json.Marshal(report) return h.sendDataPacket(conn, "validate", payloadBytes) } // Task ID validation (mode=1) - full validation with checks if len(payload) < 20 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short for task validation", "") } taskIDLen := int(payload[18]) if len(payload) < 19+taskIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "task_id length mismatch", "") } taskID := string(payload[19 : 19+taskIDLen]) // Initialize validation report checks := make(map[string]interface{}) ok := true // Get task from queue if h.taskQueue == nil { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task queue not available", "") } task, err := h.taskQueue.GetTask(taskID) if err != nil || task == nil { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "task not found", "") } // Run manifest validation - load manifest if it exists rmCheck := map[string]interface{}{"ok": true} rmCommitCheck := map[string]interface{}{"ok": true} rmLocCheck := map[string]interface{}{"ok": true} rmLifecycle := map[string]interface{}{"ok": true} // Determine expected location based on task status expectedLocation := "running" if task.Status == "completed" || task.Status == "cancelled" || task.Status == "failed" { expectedLocation = "finished" } // Try to load run manifest from appropriate location var rm *manifest.RunManifest var rmLoadErr error if h.expManager != nil { // Try expected location first jobDir := filepath.Join(h.expManager.BasePath(), expectedLocation, task.JobName) rm, rmLoadErr = manifest.LoadFromDir(jobDir) // If not found and task is running, also check finished (wrong location test) if rmLoadErr != nil && task.Status == "running" { wrongDir := filepath.Join(h.expManager.BasePath(), "finished", task.JobName) rm, _ = manifest.LoadFromDir(wrongDir) if rm != nil { // Manifest exists but in wrong location rmLocCheck["ok"] = false rmLocCheck["expected"] = "running" rmLocCheck["actual"] = "finished" ok = false } } } if rm == nil { // No run manifest found if task.Status == "running" || task.Status == "completed" { rmCheck["ok"] = false ok = false } } else { // Run manifest exists - validate it // Check commit_id match taskCommitID := task.Metadata["commit_id"] if rm.CommitID != "" && taskCommitID != "" && rm.CommitID != taskCommitID { rmCommitCheck["ok"] = false rmCommitCheck["expected"] = taskCommitID ok = false } // Check lifecycle ordering (started_at < ended_at) if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && !rm.StartedAt.Before(rm.EndedAt) { rmLifecycle["ok"] = false ok = false } } checks["run_manifest"] = rmCheck checks["run_manifest_commit_id"] = rmCommitCheck checks["run_manifest_location"] = rmLocCheck checks["run_manifest_lifecycle"] = rmLifecycle // Resources check resCheck := map[string]interface{}{"ok": true} if task.CPU < 0 { resCheck["ok"] = false ok = false } checks["resources"] = resCheck // Snapshot check snapCheck := map[string]interface{}{"ok": true} if task.SnapshotID != "" && task.Metadata["snapshot_sha256"] != "" { // Verify snapshot SHA dataDir := h.dataDir if dataDir == "" { dataDir = filepath.Join(h.expManager.BasePath(), "data") } snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID) actualSHA, _ := integrity.DirOverallSHA256Hex(snapPath) expectedSHA := task.Metadata["snapshot_sha256"] if actualSHA != expectedSHA { snapCheck["ok"] = false snapCheck["actual"] = actualSHA ok = false } } checks["snapshot"] = snapCheck report := map[string]interface{}{ "ok": ok, "checks": checks, } payloadBytes, _ := json.Marshal(report) return h.sendDataPacket(conn, "validate", payloadBytes) } func (h *Handler) handleLogMetric(conn *websocket.Conn, _payload []byte) error { // Would delegate to metrics package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Metric logged", }) } func (h *Handler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { // Check authentication and permissions user, err := h.Authenticate(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "authentication failed", err.Error()) } if !h.RequirePermission(user, PermJobsRead) { return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "permission denied", "") } // Would delegate to experiment package // For now, return error as expected by test return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "experiment not found", "") } func (h *Handler) handleDatasetList(conn *websocket.Conn, _payload []byte) error { // Would delegate to dataset package // Return empty list as expected by test return h.sendDataPacket(conn, "datasets", []byte("[]")) } func (h *Handler) handleDatasetRegister(conn *websocket.Conn, _payload []byte) error { // Would delegate to dataset package return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Dataset registered", }) } func (h *Handler) handleDatasetInfo(conn *websocket.Conn, _payload []byte) error { // Would delegate to dataset package return h.sendDataPacket(conn, "dataset_info", []byte("{}")) } func (h *Handler) handleDatasetSearch(conn *websocket.Conn, _payload []byte) error { // Would delegate to dataset package return h.sendDataPacket(conn, "datasets", []byte("[]")) } func (h *Handler) handleCancelJob(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var] if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } jobNameLen := int(payload[17]) if len(payload) < 18+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[18 : 18+jobNameLen]) // Find and cancel the task if h.taskQueue != nil { task, err := h.taskQueue.GetTaskByName(jobName) if err == nil && task != nil { task.Status = "cancelled" h.taskQueue.UpdateTask(task) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Job cancelled", }) } func (h *Handler) handlePrune(conn *websocket.Conn, _payload []byte) error { // Would delegate to experiment package for pruning return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "message": "Prune completed", }) } func (h *Handler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] // Optional: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] if len(payload) < 39 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } // Extract commit_id (20 bytes starting at position 17) commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) // Parse optional resource fields if present cpu := 0 memoryGB := 0 gpu := 0 gpuMemory := "" pos := 39 + jobNameLen if len(payload) > pos { cpu = int(payload[pos]) pos++ if len(payload) > pos { memoryGB = int(payload[pos]) pos++ if len(payload) > pos { gpu = int(payload[pos]) pos++ if len(payload) > pos { gpuMemLen := int(payload[pos]) pos++ if len(payload) >= pos+gpuMemLen { gpuMemory = string(payload[pos : pos+gpuMemLen]) } } } } } // Create task task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", CPU: cpu, MemoryGB: memoryGB, GPU: gpu, GPUMemory: gpuMemory, Metadata: map[string]string{ "commit_id": commitIDHex, }, } // Auto-detect deps manifest and compute manifest SHA if experiment exists if h.expManager != nil { filesPath := h.expManager.GetFilesPath(commitIDHex) depsName, _ := selectDependencyManifest(filesPath) if depsName != "" { task.Metadata["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { task.Metadata["deps_manifest_sha256"] = sha } } // Get experiment manifest SHA manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") if data, err := os.ReadFile(manifestPath); err == nil { var man struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA } } } // Add task to queue if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "task_id": task.ID, }) } func (h *Handler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { // Parse payload: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][snapshot_id_len:1][snapshot_id:var][snapshot_sha_len:1][snapshot_sha:var] if len(payload) < 41 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "") } // Extract commit_id (20 bytes starting at position 17) commitIDBytes := payload[17:37] commitIDHex := hex.EncodeToString(commitIDBytes) priority := payload[37] jobNameLen := int(payload[38]) if len(payload) < 39+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "job_name length mismatch", "") } jobName := string(payload[39 : 39+jobNameLen]) // Parse snapshot_id pos := 39 + jobNameLen if len(payload) < pos+1 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length missing", "") } snapshotIDLen := int(payload[pos]) pos++ if len(payload) < pos+snapshotIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_id length mismatch", "") } snapshotID := string(payload[pos : pos+snapshotIDLen]) pos += snapshotIDLen // Parse snapshot_sha if len(payload) < pos+1 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length missing", "") } snapshotSHALen := int(payload[pos]) pos++ if len(payload) < pos+snapshotSHALen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "snapshot_sha length mismatch", "") } snapshotSHA := string(payload[pos : pos+snapshotSHALen]) // Create task task := &queue.Task{ ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), JobName: jobName, Status: "queued", Priority: int64(priority), CreatedAt: time.Now(), UserID: "user", CreatedBy: "user", SnapshotID: snapshotID, Metadata: map[string]string{ "commit_id": commitIDHex, "snapshot_sha256": snapshotSHA, }, } // Auto-detect deps manifest and compute manifest SHA if experiment exists if h.expManager != nil { filesPath := h.expManager.GetFilesPath(commitIDHex) depsName, _ := selectDependencyManifest(filesPath) if depsName != "" { task.Metadata["deps_manifest_name"] = depsName depsPath := filepath.Join(filesPath, depsName) if sha, err := integrity.FileSHA256Hex(depsPath); err == nil { task.Metadata["deps_manifest_sha256"] = sha } } // Get experiment manifest SHA manifestPath := filepath.Join(h.expManager.BasePath(), commitIDHex, "manifest.json") if data, err := os.ReadFile(manifestPath); err == nil { var man struct { OverallSHA string `json:"overall_sha"` } if err := json.Unmarshal(data, &man); err == nil && man.OverallSHA != "" { task.Metadata["experiment_manifest_overall_sha"] = man.OverallSHA } } } // Add task to queue if h.taskQueue != nil { if err := h.taskQueue.AddTask(task); err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "failed to queue task", err.Error()) } } return h.sendSuccessPacket(conn, map[string]interface{}{ "success": true, "task_id": task.ID, }) } func (h *Handler) handleStatusRequest(conn *websocket.Conn, _payload []byte) error { // Return queue status as Data packet status := map[string]interface{}{ "queue_length": 0, "status": "ok", } if h.taskQueue != nil { // Try to get queue length - this is a best-effort operation // The queue backend may not support this directly } payloadBytes, _ := json.Marshal(status) return h.sendDataPacket(conn, "status", payloadBytes) } // selectDependencyManifest auto-detects the dependency manifest file func selectDependencyManifest(filesPath string) (string, error) { candidates := []string{"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle"} for _, name := range candidates { path := filepath.Join(filesPath, name) if _, err := os.Stat(path); err == nil { return name, nil } } return "", fmt.Errorf("no dependency manifest found") } // Authenticate extracts and validates the API key from payload func (h *Handler) Authenticate(payload []byte) (*auth.User, error) { if len(payload) < 16 { return nil, errors.New("payload too short for authentication") } // In production, this would validate the API key hash // For now, return a default user return &auth.User{ Name: "websocket-user", Admin: false, Roles: []string{"user"}, Permissions: map[string]bool{"jobs:read": true}, }, nil } // RequirePermission checks if a user has a required permission func (h *Handler) RequirePermission(user *auth.User, permission string) bool { if user == nil { return false } if user.Admin { return true } return user.Permissions[permission] }