package api import ( "encoding/binary" "encoding/json" "fmt" "math" "os" "path/filepath" "sort" "strings" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api/helpers" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" ) func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var] if len(payload) < 16+1+1+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "") } offset := 16 jobNameLen := int(payload[offset]) offset += 1 if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[offset : offset+jobNameLen]) offset += jobNameLen authorLen := int(payload[offset]) offset += 1 if authorLen < 0 || len(payload) < offset+authorLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "") } author := string(payload[offset : offset+authorLen]) offset += authorLen noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if noteLen <= 0 || len(payload) < offset+noteLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "") } note := string(payload[offset : offset+noteLen]) user, err := h.authenticate(conn, payload, 16) if err != nil { return err } if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { return err } if err := container.ValidateJobName(jobName); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) } base := strings.TrimSpace(h.expManager.BasePath()) if base == "" { return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") } jobPaths := config.NewJobPaths(base) typedRoots := []struct{ root string }{ {root: jobPaths.RunningPath()}, {root: jobPaths.PendingPath()}, {root: jobPaths.FinishedPath()}, {root: jobPaths.FailedPath()}, } var manifestDir string for _, item := range typedRoots { dir := filepath.Join(item.root, jobName) if info, err := os.Stat(dir); err == nil && info.IsDir() { if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { manifestDir = dir break } } } if manifestDir == "" { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") } rm, err := manifest.LoadFromDir(manifestDir) if err != nil || rm == nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) } if strings.TrimSpace(author) == "" { author = user.Name } rm.AddAnnotation(time.Now().UTC(), author, note) if err := rm.WriteToDir(manifestDir); err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) } return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added")) } func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var] if len(payload) < 16+1+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "") } offset := 16 jobNameLen := int(payload[offset]) offset += 1 if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[offset : offset+jobNameLen]) offset += jobNameLen patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if patchLen <= 0 || len(payload) < offset+patchLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "") } patchJSON := payload[offset : offset+patchLen] user, err := h.authenticate(conn, payload, 16) if err != nil { return err } if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { return err } if err := container.ValidateJobName(jobName); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) } base := strings.TrimSpace(h.expManager.BasePath()) if base == "" { return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") } jobPaths := config.NewJobPaths(base) typedRoots := []struct{ root string }{ {root: jobPaths.RunningPath()}, {root: jobPaths.PendingPath()}, {root: jobPaths.FinishedPath()}, {root: jobPaths.FailedPath()}, } var manifestDir string for _, item := range typedRoots { dir := filepath.Join(item.root, jobName) if info, err := os.Stat(dir); err == nil && info.IsDir() { if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { manifestDir = dir break } } } if manifestDir == "" { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") } rm, err := manifest.LoadFromDir(manifestDir) if err != nil || rm == nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) } var patch manifest.NarrativePatch if err := json.Unmarshal(patchJSON, &patch); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error()) } rm.ApplyNarrativePatch(patch) if err := rm.WriteToDir(manifestDir); err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) } return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated")) } func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Parse payload first if len(payload) < ProtocolMinQueueJob { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") } commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) if len(payload) < 38+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:]) if resErr != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) // Authenticate and authorize user, err := h.authenticate(conn, payload, ProtocolMinQueueJob) if err != nil { return err } if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { return err } return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, nil, resources) } func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { if len(payload) < ProtocolMinQueueJobWithSnapshot { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "") } commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) if len(payload) < 38+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) offset := 38 + jobNameLen snapIDLen := int(payload[offset]) offset++ if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "") } snapshotID := string(payload[offset : offset+snapIDLen]) offset += snapIDLen snapSHALen := int(payload[offset]) offset++ if snapSHALen < 1 || len(payload) < offset+snapSHALen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "") } snapshotSHA := string(payload[offset : offset+snapSHALen]) offset += snapSHALen resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID) user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithSnapshot) if err != nil { return err } if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { return err } return h.processAndEnqueueJobWithSnapshot(conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA) } func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error { if len(payload) < ProtocolMinQueueJobWithTracking { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "") } commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) if len(payload) < 38+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) offset := 38 + jobNameLen trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if trackingLen < 0 || len(payload) < offset+trackingLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json length", "") } var trackingCfg *queue.TrackingConfig if trackingLen > 0 { var cfg queue.TrackingConfig if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error()) } trackingCfg = &cfg } offset += trackingLen resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error()) } h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID)) user, err := h.authenticate(conn, payload, ProtocolMinQueueJobWithTracking) if err != nil { return err } if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { return err } return h.processAndEnqueueJob(conn, user, jobName, priority, commitID, trackingCfg, resources) } type queueJobWithArgsPayload struct { apiKeyHash []byte commitID []byte priority int64 jobName string args string force bool resources *resourceRequest } type queueJobWithNotePayload struct { apiKeyHash []byte commitID []byte priority int64 jobName string args string note string force bool resources *resourceRequest } func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) { // Protocol: // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] // [args_len:2][args:var][note_len:2][note:var][force:1][resources?:var] if len(payload) < 43 { return nil, fmt.Errorf("queue job with note payload too short") } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) p := helpers.NewPayloadParser(payload, 37) jobName, err := p.ParseLengthPrefixedString() if err != nil { return nil, fmt.Errorf("invalid job name: %w", err) } args, err := p.ParseUint16PrefixedString() if err != nil { return nil, fmt.Errorf("invalid args: %w", err) } note, err := p.ParseUint16PrefixedString() if err != nil { return nil, fmt.Errorf("invalid note: %w", err) } force, err := p.ParseBool() if err != nil { return nil, fmt.Errorf("missing force flag: %w", err) } resources, resErr := helpers.ParseResourceRequest(p.Remaining()) if resErr != nil { return nil, resErr } return &queueJobWithNotePayload{ apiKeyHash: apiKeyHash, commitID: commitID, priority: priority, jobName: jobName, args: args, note: note, force: force, resources: resources, }, nil } func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) { // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][force:1][resources?:var] if len(payload) < 41 { return nil, fmt.Errorf("queue job with args payload too short") } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) p := helpers.NewPayloadParser(payload, 37) jobName, err := p.ParseLengthPrefixedString() if err != nil { return nil, fmt.Errorf("invalid job name: %w", err) } args, err := p.ParseUint16PrefixedString() if err != nil { return nil, fmt.Errorf("invalid args: %w", err) } force, err := p.ParseBool() if err != nil { return nil, fmt.Errorf("missing force flag: %w", err) } resources, resErr := helpers.ParseResourceRequest(p.Remaining()) if resErr != nil { return nil, resErr } return &queueJobWithArgsPayload{ apiKeyHash: apiKeyHash, commitID: commitID, priority: priority, jobName: jobName, args: args, force: force, resources: resources, }, nil } func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error { p, err := parseQueueJobWithArgsPayload(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error()) } h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) user, err := h.authenticateWithHash(conn, p.apiKeyHash) if err != nil { return err } if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { return err } return h.processAndEnqueueJobWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, p.force, nil, p.resources) } func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error { p, err := parseQueueJobWithNotePayload(payload) if err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error()) } h.logger.Info("queue job request", "job", p.jobName, "priority", p.priority, "commit_id", fmt.Sprintf("%x", p.commitID)) user, err := h.authenticateWithHash(conn, p.apiKeyHash) if err != nil { return err } if err := h.requirePermission(user, PermJobsCreate, conn); err != nil { return err } return h.processAndEnqueueJobWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, p.force, nil, p.resources) } // findDuplicateTask searches for an existing task with the same composite key // (commit_id + dataset_id + params_hash) to detect truly identical experiments func (h *WSHandler) findDuplicateTask(commitIDStr, datasetID, paramsHash string) *queue.Task { if h.queue == nil { return nil } tasks, err := h.queue.GetAllTasks() if err != nil { return nil } for _, task := range tasks { if task.Metadata == nil { continue } // Check all three components of the composite key if task.Metadata["commit_id"] == commitIDStr && task.Metadata["dataset_id"] == datasetID && task.Metadata["params_hash"] == paramsHash { return task } } return nil } // sendDuplicateResponse sends a data packet response for duplicate jobs func (h *WSHandler) sendDuplicateResponse(conn *websocket.Conn, existingTask *queue.Task) error { response := map[string]interface{}{ "duplicate": true, "existing_id": existingTask.ID, "status": existingTask.Status, "queued_by": existingTask.CreatedBy, "queued_at": existingTask.CreatedAt.Unix(), } // Add duration for completed tasks if existingTask.Status == "completed" && existingTask.EndedAt != nil { duration := existingTask.EndedAt.Sub(existingTask.CreatedAt).Seconds() response["duration_seconds"] = int64(duration) // Try to get metrics for completed tasks if h.expManager != nil { commitID := existingTask.Metadata["commit_id"] if metrics, err := h.expManager.GetMetrics(commitID); err == nil && len(metrics) > 0 { metricsMap := make(map[string]interface{}) for _, m := range metrics { metricsMap[m.Name] = m.Value } response["metrics"] = metricsMap } } } // Add error reason for failed tasks with full failure classification if existingTask.Status == "failed" && existingTask.Error != "" { response["error_reason"] = existingTask.Error // Classify failure using exit codes, signals, and error context failureClass := queue.FailureUnknown exitCode := 0 signalName := "" // Extract exit code from error or metadata if code, ok := existingTask.Metadata["exit_code"]; ok { fmt.Sscanf(code, "%d", &exitCode) } if sig, ok := existingTask.Metadata["signal"]; ok { signalName = sig } // Get log tail for classification if available logTail := existingTask.Error if existingTask.LastError != "" { logTail = existingTask.LastError } // Classify failure directly using signals, exit codes, and log content // Note: failureClass declared above at line 536, just reassign here // Override with signal-based classification if available if signalName == "SIGKILL" || signalName == "9" { failureClass = queue.FailureInfrastructure } else if exitCode != 0 { // Use the new ClassifyFailure with error log content logContent := existingTask.Error if existingTask.LastError != "" { logContent = existingTask.LastError } failureClass = queue.ClassifyFailure(exitCode, nil, logContent) } response["failure_class"] = string(failureClass) response["exit_code"] = exitCode response["signal"] = signalName response["log_tail"] = logTail // Add user-facing suggestion response["suggestion"] = queue.GetFailureSuggestion(failureClass, logTail) // Add retry information with class-specific policy response["retry_count"] = existingTask.RetryCount response["retry_cap"] = 3 response["auto_retryable"] = queue.ShouldAutoRetry(failureClass, existingTask.RetryCount) // Add attempts history if available if len(existingTask.Attempts) > 0 { response["attempts"] = existingTask.Attempts } } responseData, err := json.Marshal(response) if err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize duplicate response", err.Error()) } packet := NewDataPacket("duplicate", responseData) return h.sendResponsePacket(conn, packet) } // enqueueTaskAndRespond enqueues a task and sends a success response. func (h *WSHandler) enqueueTaskAndRespond( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", false, tracking, resources) } func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, args string, note string, force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) // Compute dataset_id and params_hash from existing data paramsHash := helpers.ComputeParamsHash(args) // Note: dataset_id will be empty here since we don't have DatasetSpecs yet // It will be populated when the task is actually created with datasets datasetID := "" // Check for duplicate tasks before proceeding (skip if force=true) if !force { if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) return h.sendDuplicateResponse(conn, existingTask) } } else { h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) } prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, "error", provErr) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to compute expected provenance", provErr.Error(), ) } if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: strings.TrimSpace(args), Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ "commit_id": commitIDStr, "dataset_id": datasetID, "params_hash": paramsHash, }, Tracking: tracking, } if strings.TrimSpace(note) != "" { task.Metadata["note"] = strings.TrimSpace(note) } for k, v := range prov { if v != "" { task.Metadata[k] = v } } if resources != nil { task.CPU = resources.CPU task.MemoryGB = resources.MemoryGB task.GPU = resources.GPU task.GPUMemory = resources.GPUMemory } if _, err := telemetry.ExecWithMetrics( h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { return "", h.queue.AddTask(task) }, ); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket( conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error(), ) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } func (h *WSHandler) enqueueTaskAndRespondWithArgs( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, args string, force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) // Compute dataset_id and params_hash from existing data paramsHash := helpers.ComputeParamsHash(args) // Note: dataset_id will be empty here since we don't have DatasetSpecs yet // It will be populated when the task is actually created with datasets datasetID := "" // Check for duplicate tasks before proceeding (skip if force=true) if !force { if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) return h.sendDuplicateResponse(conn, existingTask) } } else { h.logger.Info("force flag set, skipping duplicate check", "commit_id", commitIDStr) } prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "error", provErr) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to compute expected provenance", provErr.Error(), ) } // Enqueue task if queue is available if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: strings.TrimSpace(args), Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ "commit_id": commitIDStr, "dataset_id": datasetID, "params_hash": paramsHash, }, Tracking: tracking, } for k, v := range prov { if v != "" { task.Metadata[k] = v } } if resources != nil { task.CPU = resources.CPU task.MemoryGB = resources.MemoryGB task.GPU = resources.GPU task.GPUMemory = resources.GPUMemory } if _, err := telemetry.ExecWithMetrics( h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { return "", h.queue.AddTask(task) }, ); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket( conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error(), ) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name, "dataset_id", datasetID, "params_hash", paramsHash) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } // processAndEnqueueJob handles common experiment setup and task enqueueing func (h *WSHandler) processAndEnqueueJob( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") } helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, tracking, resources) } // processAndEnqueueJobWithSnapshot handles experiment setup and task enqueueing for snapshot jobs func (h *WSHandler) processAndEnqueueJobWithSnapshot( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, snapshotID string, snapshotSHA string, ) error { commitIDStr, err := helpers.RunExperimentSetup(h.logger, h.expManager, commitID, jobName, user.Name) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") } helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) return h.enqueueTaskAndRespondWithSnapshot(conn, user, jobName, priority, commitID, tracking, resources, snapshotID, snapshotSHA) } // processAndEnqueueJobWithArgs handles experiment setup and task enqueueing for jobs with args func (h *WSHandler) processAndEnqueueJobWithArgs( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, args string, force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") } helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, args, force, tracking, resources) } // processAndEnqueueJobWithArgsAndNote handles experiment setup for jobs with args and note func (h *WSHandler) processAndEnqueueJobWithArgsAndNote( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, args string, note string, force bool, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { commitIDStr, err := helpers.RunExperimentSetupWithoutManifest(h.logger, h.expManager, commitID, jobName, user.Name) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, err.Error(), "") } helpers.UpsertExperimentDBAsync(h.logger, h.db, commitIDStr, jobName, user.Name) return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, jobName, priority, commitID, args, note, force, tracking, resources) } func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, snapshotID string, snapshotSHA string, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) // Compute dataset_id from snapshot SHA (snapshot acts as dataset) datasetID := "" if strings.TrimSpace(snapshotSHA) != "" { datasetID = snapshotSHA[:16] } // Snapshots don't have args, so params_hash is empty paramsHash := "" // Check for duplicate tasks before proceeding if existingTask := h.findDuplicateTask(commitIDStr, datasetID, paramsHash); existingTask != nil { h.logger.Info("duplicate task found", "commit_id", commitIDStr, "dataset_id", datasetID, "params_hash", paramsHash, "existing_task", existingTask.ID, "status", existingTask.Status) return h.sendDuplicateResponse(conn, existingTask) } prov, provErr := helpers.ExpectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, "error", provErr) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to compute expected provenance", provErr.Error(), ) } if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: "", Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, SnapshotID: strings.TrimSpace(snapshotID), Metadata: map[string]string{ "commit_id": commitIDStr, "snapshot_sha256": strings.TrimSpace(snapshotSHA), }, Tracking: tracking, } for k, v := range prov { if v != "" { task.Metadata[k] = v } } if resources != nil { task.CPU = resources.CPU task.MemoryGB = resources.MemoryGB task.GPU = resources.GPU task.GPUMemory = resources.GPUMemory } if _, err := telemetry.ExecWithMetrics( h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { return "", h.queue.AddTask(task) }, ); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } // resourceRequest is an alias to helpers.ResourceRequest for backward compatibility type resourceRequest = helpers.ResourceRequest // parseOptionalResourceRequest is an alias to helpers.ParseResourceRequest for backward compatibility func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) { r, err := helpers.ParseResourceRequest(payload) if err != nil { return nil, err } // Type conversion is needed because Go doesn't automatically convert named types even with identical underlying structures if r == nil { return nil, nil } return (*resourceRequest)(r), nil } func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinStatusRequest) if err != nil { return err } h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", payload[:16])) if err := h.requirePermission(user, PermJobsRead, conn); err != nil { return err } var tasks []*queue.Task if h.queue != nil { allTasks, err := h.queue.GetAllTasks() if err != nil { h.logger.Error("failed to get tasks", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) } for _, task := range allTasks { if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { tasks = append(tasks, task) continue } if task.UserID == user.Name || task.CreatedBy == user.Name { tasks = append(tasks, task) } } } h.logger.Info("building status response") status := map[string]any{ "user": map[string]any{ "name": user.Name, "admin": user.Admin, "roles": user.Roles, }, "tasks": map[string]any{ "total": len(tasks), "queued": countTasksByStatus(tasks, "queued"), "running": countTasksByStatus(tasks, "running"), "failed": countTasksByStatus(tasks, "failed"), "completed": countTasksByStatus(tasks, "completed"), }, "queue": tasks, } if h.queue != nil { if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil { sort.Slice(states, func(i, j int) bool { if states[i].WorkerID != states[j].WorkerID { return states[i].WorkerID < states[j].WorkerID } return states[i].TaskID < states[j].TaskID }) status["prewarm"] = states } } h.logger.Info("serializing JSON response") jsonData, err := json.Marshal(status) if err != nil { h.logger.Error("failed to marshal JSON", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } h.logger.Info("sending websocket JSON response", "len", len(jsonData)) return h.sendResponsePacket(conn, NewDataPacket("status", jsonData)) } // countTasksByStatus counts tasks by their status func countTasksByStatus(tasks []*queue.Task, status string) int { count := 0 for _, task := range tasks { if task.Status == status { count++ } } return count } func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinCancelJob) if err != nil { return err } if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { return err } jobNameLen := int(payload[ProtocolAPIKeyHashLen]) if len(payload) < ProtocolAPIKeyHashLen+1+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+jobNameLen]) h.logger.Info("cancel job request", "job", jobName) if h.queue == nil { h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) return nil } task, err := h.queue.GetTaskByName(jobName) if err != nil { h.logger.Error("task not found", "job", jobName, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) } if h.authConfig != nil && h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name { h.logger.Error( "unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID, ) return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "", ) } if err := h.queue.CancelTask(task.ID); err != nil { h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) } h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))) } func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinPrune) if err != nil { return err } if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { return err } pruneType := payload[ProtocolAPIKeyHashLen] value := binary.BigEndian.Uint32(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+5]) h.logger.Info("prune request", "type", pruneType, "value", value) var keepCount int var olderThanDays int switch pruneType { case 0: keepCount = int(value) case 1: olderThanDays = int(value) default: return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "", ) } pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) if err != nil { h.logger.Error("prune failed", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) } if h.queue != nil { _ = h.queue.SignalPrewarmGC() } h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))) } func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinLogMetric) if err != nil { return err } if err := h.requirePermission(user, PermJobsUpdate, conn); err != nil { return err } commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] step := int(binary.BigEndian.Uint32(payload[36:40])) valueBits := binary.BigEndian.Uint64(payload[40:48]) value := math.Float64frombits(valueBits) nameLen := int(payload[48]) if len(payload) < 49+nameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") } name := string(payload[49 : 49+nameLen]) if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil { h.logger.Error("failed to log metric", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) } return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) } func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinGetExperiment) if err != nil { return err } if err := h.requirePermission(user, PermJobsRead, conn); err != nil { return err } commitID := payload[ProtocolAPIKeyHashLen : ProtocolAPIKeyHashLen+ProtocolCommitIDLen] meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID)) if err != nil { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) } metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID)) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) } var dbMeta *storage.ExperimentWithMetadata if h.db != nil { ctx, cancel := helpers.DBContextShort() defer cancel() m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID)) if err == nil { dbMeta = m } } response := map[string]interface{}{ "metadata": meta, "metrics": metrics, } if dbMeta != nil { response["reproducibility"] = dbMeta } responseData, err := json.Marshal(response) if err != nil { return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error(), ) } return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData)) } // handleGetLogs handles requests to fetch logs for a task/run func (h *WSHandler) handleGetLogs(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinGetLogs) if err != nil { return err } if err := h.requirePermission(user, PermJobsRead, conn); err != nil { return err } targetIDLen := int(payload[ProtocolAPIKeyHashLen]) if len(payload) < ProtocolMinGetLogs+targetIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", fmt.Sprintf("got %d, need %d", len(payload), ProtocolMinGetLogs+targetIDLen)) } targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) h.logger.Info("get logs request", "target_id", targetID, "user", user.Name) // TODO: Implement actual log fetching from storage // For now, return a stub response response := map[string]interface{}{ "target_id": targetID, "logs": "[Stub] Log content would appear here\nLine 1: Log output\nLine 2: More output\n", "truncated": false, "total_lines": 3, } responseData, err := json.Marshal(response) if err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) } return h.sendResponsePacket(conn, NewDataPacket("logs", responseData)) } // handleStreamLogs handles requests to stream logs in real-time func (h *WSHandler) handleStreamLogs(conn *websocket.Conn, payload []byte) error { user, err := h.authenticate(conn, payload, ProtocolMinStreamLogs) if err != nil { return err } if err := h.requirePermission(user, PermJobsRead, conn); err != nil { return err } targetIDLen := int(payload[ProtocolAPIKeyHashLen]) if len(payload) < ProtocolMinStreamLogs+targetIDLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid target ID length", "") } targetID := string(payload[ProtocolAPIKeyHashLen+1 : ProtocolAPIKeyHashLen+1+targetIDLen]) h.logger.Info("stream logs request", "target_id", targetID, "user", user.Name) // TODO: Implement actual log streaming // For now, return a stub response indicating streaming started response := map[string]interface{}{ "target_id": targetID, "streaming": true, "message": "[Stub] Log streaming would start here. This feature is not yet fully implemented.", } responseData, err := json.Marshal(response) if err != nil { return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error()) } return h.sendResponsePacket(conn, NewDataPacket("logs_stream", responseData)) }