package api import ( "context" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "fmt" "io" "math" "os" "path/filepath" "sort" "strings" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" "github.com/jfraeys/fetch_ml/internal/worker" ) func fileSHA256Hex(path string) (string, error) { f, err := os.Open(filepath.Clean(path)) if err != nil { return "", err } defer func() { _ = f.Close() }() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return "", err } return hex.EncodeToString(h.Sum(nil)), nil } func expectedProvenanceForCommit( expMgr *experiment.Manager, commitID string, ) (map[string]string, error) { out := map[string]string{} manifest, err := expMgr.ReadManifest(commitID) if err != nil { return nil, err } if manifest == nil || manifest.OverallSHA == "" { return nil, fmt.Errorf("missing manifest overall_sha") } out["experiment_manifest_overall_sha"] = manifest.OverallSHA filesPath := expMgr.GetFilesPath(commitID) depName, err := worker.SelectDependencyManifest(filesPath) if err == nil && strings.TrimSpace(depName) != "" { depPath := filepath.Join(filesPath, depName) sha, err := fileSHA256Hex(depPath) if err == nil && strings.TrimSpace(sha) != "" { out["deps_manifest_name"] = depName out["deps_manifest_sha256"] = sha } } return out, nil } func ensureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error { if expMgr == nil { return fmt.Errorf("missing experiment manager") } commitID = strings.TrimSpace(commitID) if commitID == "" { return fmt.Errorf("missing commit id") } filesPath := expMgr.GetFilesPath(commitID) if err := os.MkdirAll(filesPath, 0750); err != nil { return err } trainPath := filepath.Join(filesPath, "train.py") if _, err := os.Stat(trainPath); os.IsNotExist(err) { if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil { return err } } reqPath := filepath.Join(filesPath, "requirements.txt") if _, err := os.Stat(reqPath); os.IsNotExist(err) { if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil { return err } } return nil } func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] if len(payload) < 38 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) if len(payload) < 38+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:]) if resErr != nil { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error(), ) } h.logger.Info("queue job request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID), ) // Validate API key and get user information var user *auth.User var err error if h.authConfig != nil { user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } } else { // Auth disabled - use default admin user user = &auth.User{ Name: "default", Admin: true, Roles: []string{"admin"}, Permissions: map[string]bool{ "*": true, }, } } // Check user permissions if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { h.logger.Info( "job queued", "job", jobName, "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), "user", user.Name, ) } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "", ) } // Create experiment directory and metadata (optimized) if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) { return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) }, ); err != nil { h.logger.Error("failed to create experiment directory", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error(), ) } meta := &experiment.Metadata{ CommitID: fmt.Sprintf("%x", commitID), JobName: jobName, User: user.Name, Timestamp: time.Now().Unix(), } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { return "", h.expManager.WriteMetadata(meta) }); err != nil { h.logger.Error("failed to save experiment metadata", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error(), ) } // Generate and write content integrity manifest commitIDStr := fmt.Sprintf("%x", commitID) if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) }); err != nil { h.logger.Error("failed to ensure minimal experiment files", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error(), ) } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { manifest, err := h.expManager.GenerateManifest(commitIDStr) if err != nil { return "", fmt.Errorf("failed to generate manifest: %w", err) } if err := h.expManager.WriteManifest(manifest); err != nil { return "", fmt.Errorf("failed to write manifest: %w", err) } return "", nil }); err != nil { h.logger.Error("failed to generate/write manifest", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to generate content integrity manifest", err.Error(), ) } // Add user info to experiment metadata (deferred for performance) go func() { if h.db != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() exp := &storage.Experiment{ ID: fmt.Sprintf("%x", commitID), Name: jobName, Status: "pending", UserID: user.Name, } if _, err := telemetry.ExecWithMetrics( h.logger, "db.experiments.upsert", 50*time.Millisecond, func() (string, error) { return "", h.db.UpsertExperiment(ctx, exp) }, ); err != nil { h.logger.Error("failed to upsert experiment row", "error", err) } } }() h.logger.Info( "job queued", "job", jobName, "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), "user", user.Name, ) return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, nil, resources) } func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { if len(payload) < 40 { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "queue job with snapshot payload too short", "", ) } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) if len(payload) < 38+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) offset := 38 + jobNameLen snapIDLen := int(payload[offset]) offset++ if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "") } snapshotID := string(payload[offset : offset+snapIDLen]) offset += snapIDLen snapSHALen := int(payload[offset]) offset++ if snapSHALen < 1 || len(payload) < offset+snapSHALen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "") } snapshotSHA := string(payload[offset : offset+snapSHALen]) offset += snapSHALen resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error(), ) } h.logger.Info("queue job with snapshot request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID), "snapshot_id", snapshotID, ) var user *auth.User var err error if h.authConfig != nil { user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } } else { user = &auth.User{ Name: "default", Admin: true, Roles: []string{"admin"}, Permissions: map[string]bool{ "*": true, }, } } if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { h.logger.Info( "job queued", "job", jobName, "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), "user", user.Name, ) } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "", ) } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) { return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) }, ); err != nil { h.logger.Error("failed to create experiment directory", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error(), ) } meta := &experiment.Metadata{ CommitID: fmt.Sprintf("%x", commitID), JobName: jobName, User: user.Name, Timestamp: time.Now().Unix(), } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { return "", h.expManager.WriteMetadata(meta) }); err != nil { h.logger.Error("failed to save experiment metadata", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error(), ) } commitIDStr := fmt.Sprintf("%x", commitID) if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) }); err != nil { h.logger.Error("failed to ensure minimal experiment files", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error(), ) } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { manifest, err := h.expManager.GenerateManifest(commitIDStr) if err != nil { return "", fmt.Errorf("failed to generate manifest: %w", err) } if err := h.expManager.WriteManifest(manifest); err != nil { return "", fmt.Errorf("failed to write manifest: %w", err) } return "", nil }); err != nil { h.logger.Error("failed to generate/write manifest", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to generate content integrity manifest", err.Error(), ) } go func() { if h.db != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() exp := &storage.Experiment{ ID: fmt.Sprintf("%x", commitID), Name: jobName, Status: "pending", UserID: user.Name, } if _, err := telemetry.ExecWithMetrics( h.logger, "db.experiments.upsert", 50*time.Millisecond, func() (string, error) { return "", h.db.UpsertExperiment(ctx, exp) }, ); err != nil { h.logger.Error("failed to upsert experiment row", "error", err) } } }() h.logger.Info( "job queued", "job", jobName, "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), "user", user.Name, ) return h.enqueueTaskAndRespondWithSnapshot( conn, user, jobName, priority, commitID, nil, resources, snapshotID, snapshotSHA, ) } // handleQueueJobWithTracking queues a job with optional tracking configuration. // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] // [tracking_json_len:2][tracking_json:var] func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error { if len(payload) < 38+2 { // minimum with zero-length tracking JSON return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "queue job with tracking payload too short", "", ) } apiKeyHash := payload[:16] commitID := payload[16:36] priority := int64(payload[36]) jobNameLen := int(payload[37]) // Ensure we have job name and two bytes for tracking length if len(payload) < 38+jobNameLen+2 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[38 : 38+jobNameLen]) offset := 38 + jobNameLen trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) offset += 2 if trackingLen < 0 || len(payload) < offset+trackingLen { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "invalid tracking json length", "", ) } var trackingCfg *queue.TrackingConfig if trackingLen > 0 { var cfg queue.TrackingConfig if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error()) } trackingCfg = &cfg } offset += trackingLen resources, resErr := parseOptionalResourceRequest(payload[offset:]) if resErr != nil { return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, "invalid resource request", resErr.Error(), ) } h.logger.Info("queue job with tracking request", "job", jobName, "priority", priority, "commit_id", fmt.Sprintf("%x", commitID), ) // Validate API key and get user information var user *auth.User var err error if h.authConfig != nil { user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } } else { // Auth disabled - use default admin user user = &auth.User{ Name: "default", Admin: true, Roles: []string{"admin"}, Permissions: map[string]bool{ "*": true, }, } } // Check user permissions if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { h.logger.Info( "job queued (with tracking)", "job", jobName, "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), "user", user.Name, ) } else { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "", ) } // Create experiment directory and metadata (optimized) if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) { return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) }, ); err != nil { h.logger.Error("failed to create experiment directory", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error(), ) } meta := &experiment.Metadata{ CommitID: fmt.Sprintf("%x", commitID), JobName: jobName, User: user.Name, Timestamp: time.Now().Unix(), } if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { return "", h.expManager.WriteMetadata(meta) }); err != nil { h.logger.Error("failed to save experiment metadata", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error(), ) } // Generate and write content integrity manifest commitIDStr := fmt.Sprintf("%x", commitID) if _, err := telemetry.ExecWithMetrics( h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { manifest, err := h.expManager.GenerateManifest(commitIDStr) if err != nil { return "", fmt.Errorf("failed to generate manifest: %w", err) } if err := h.expManager.WriteManifest(manifest); err != nil { return "", fmt.Errorf("failed to write manifest: %w", err) } return "", nil }); err != nil { h.logger.Error("failed to generate/write manifest", "error", err) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to generate content integrity manifest", err.Error(), ) } // Add user info to experiment metadata (deferred for performance) go func() { if h.db != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() exp := &storage.Experiment{ ID: fmt.Sprintf("%x", commitID), Name: jobName, Status: "pending", UserID: user.Name, } if _, err := telemetry.ExecWithMetrics( h.logger, "db.experiments.upsert", 50*time.Millisecond, func() (string, error) { return "", h.db.UpsertExperiment(ctx, exp) }, ); err != nil { h.logger.Error("failed to upsert experiment row", "error", err) } } }() return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources) } // enqueueTaskAndRespond enqueues a task and sends a success response. func (h *WSHandler) enqueueTaskAndRespond( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, "error", provErr) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to compute expected provenance", provErr.Error(), ) } // Enqueue task if queue is available if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: "", Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, Metadata: map[string]string{ "commit_id": commitIDStr, }, Tracking: tracking, } for k, v := range prov { if v != "" { task.Metadata[k] = v } } if resources != nil { task.CPU = resources.CPU task.MemoryGB = resources.MemoryGB task.GPU = resources.GPU task.GPUMemory = resources.GPUMemory } if _, err := telemetry.ExecWithMetrics( h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { return "", h.queue.AddTask(task) }, ); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket( conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error(), ) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( conn *websocket.Conn, user *auth.User, jobName string, priority int64, commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, snapshotID string, snapshotSHA string, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) commitIDStr := fmt.Sprintf("%x", commitID) prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) if provErr != nil { h.logger.Error("failed to compute expected provenance; refusing to enqueue", "commit_id", commitIDStr, "error", provErr) return h.sendErrorPacket( conn, ErrorCodeStorageError, "Failed to compute expected provenance", provErr.Error(), ) } if h.queue != nil { taskID := uuid.New().String() task := &queue.Task{ ID: taskID, JobName: jobName, Args: "", Status: "queued", Priority: priority, CreatedAt: time.Now(), UserID: user.Name, Username: user.Name, CreatedBy: user.Name, SnapshotID: strings.TrimSpace(snapshotID), Metadata: map[string]string{ "commit_id": commitIDStr, "snapshot_sha256": strings.TrimSpace(snapshotSHA), }, Tracking: tracking, } for k, v := range prov { if v != "" { task.Metadata[k] = v } } if resources != nil { task.CPU = resources.CPU task.MemoryGB = resources.MemoryGB task.GPU = resources.GPU task.GPUMemory = resources.GPUMemory } if _, err := telemetry.ExecWithMetrics( h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { return "", h.queue.AddTask(task) }, ); err != nil { h.logger.Error("failed to enqueue task", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) } h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) } else { h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) } packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } type resourceRequest struct { CPU int MemoryGB int GPU int GPUMemory string } // parseOptionalResourceRequest parses an optional tail encoding: // [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] // If payload is empty, returns nil. func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) { if len(payload) == 0 { return nil, nil } if len(payload) < 4 { return nil, fmt.Errorf("resource payload too short") } cpu := int(payload[0]) mem := int(payload[1]) gpu := int(payload[2]) gpuMemLen := int(payload[3]) if gpuMemLen < 0 || len(payload) < 4+gpuMemLen { return nil, fmt.Errorf("invalid gpu memory length") } gpuMem := "" if gpuMemLen > 0 { gpuMem = string(payload[4 : 4+gpuMemLen]) } return &resourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil } func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16] if len(payload) < 16 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") } apiKeyHash := payload[:16] h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", apiKeyHash)) // Validate API key and get user information var user *auth.User var err error if h.authConfig != nil { user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } } else { // Auth disabled - use default admin user user = &auth.User{ Name: "default", Admin: true, Roles: []string{"admin"}, Permissions: map[string]bool{ "*": true, }, } } // Check user permissions for viewing jobs if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "", ) } // Get tasks with user filtering var tasks []*queue.Task if h.queue != nil { allTasks, err := h.queue.GetAllTasks() if err != nil { h.logger.Error("failed to get tasks", "error", err) return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) } // Filter tasks based on user permissions for _, task := range allTasks { // If auth is disabled or admin can see all tasks if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { tasks = append(tasks, task) continue } // Users can only see their own tasks if task.UserID == user.Name || task.CreatedBy == user.Name { tasks = append(tasks, task) } } } // Build status response as raw JSON for CLI compatibility h.logger.Info("building status response") status := map[string]any{ "user": map[string]any{ "name": user.Name, "admin": user.Admin, "roles": user.Roles, }, "tasks": map[string]any{ "total": len(tasks), "queued": countTasksByStatus(tasks, "queued"), "running": countTasksByStatus(tasks, "running"), "failed": countTasksByStatus(tasks, "failed"), "completed": countTasksByStatus(tasks, "completed"), }, "queue": tasks, } if h.queue != nil { if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil { sort.Slice(states, func(i, j int) bool { if states[i].WorkerID != states[j].WorkerID { return states[i].WorkerID < states[j].WorkerID } return states[i].TaskID < states[j].TaskID }) status["prewarm"] = states } } h.logger.Info("serializing JSON response") jsonData, err := json.Marshal(status) if err != nil { h.logger.Error("failed to marshal JSON", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } h.logger.Info("sending websocket JSON response", "len", len(jsonData)) // Send as binary protocol packet packet := NewDataPacket("status", jsonData) return h.sendResponsePacket(conn, packet) } // countTasksByStatus counts tasks by their status func countTasksByStatus(tasks []*queue.Task, status string) int { count := 0 for _, task := range tasks { if task.Status == status { count++ } } return count } func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][job_name_len:1][job_name:var] if len(payload) < 18 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") } // Parse 16-byte binary API key hash apiKeyHash := payload[:16] jobNameLen := int(payload[16]) if len(payload) < 17+jobNameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") } jobName := string(payload[17 : 17+jobNameLen]) h.logger.Info("cancel job request", "job", jobName) // Validate API key and get user information var user *auth.User var err error if h.authConfig != nil { user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) if err != nil { h.logger.Error("invalid api key", "error", err) return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) } } else { // Auth disabled - use default admin user user = &auth.User{ Name: "default", Admin: true, Roles: []string{"admin"}, Permissions: map[string]bool{ "*": true, }, } } // Check user permissions for canceling jobs if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "", ) } // Find the task and verify ownership if h.queue != nil { task, err := h.queue.GetTaskByName(jobName) if err != nil { h.logger.Error("task not found", "job", jobName, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) } // Check if user can cancel this task (admin or owner) if h.authConfig != nil && h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name { h.logger.Error( "unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID, ) return h.sendErrorPacket( conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "", ) } // Cancel the task if err := h.queue.CancelTask(task.ID); err != nil { h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) } h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) } else { h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) } packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) packetData, err := packet.Serialize() if err != nil { h.logger.Error("failed to serialize packet", "error", err) return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response", ) } return conn.WriteMessage(websocket.BinaryMessage, packetData) } func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][prune_type:1][value:4] if len(payload) < 21 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") } // Parse 16-byte binary API key hash apiKeyHash := payload[:16] pruneType := payload[16] value := binary.BigEndian.Uint32(payload[17:21]) h.logger.Info("prune request", "type", pruneType, "value", value) // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { h.logger.Error("api key verification failed", "error", err) return h.sendErrorPacket( conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error(), ) } } // Convert prune parameters var keepCount int var olderThanDays int switch pruneType { case 0: // keep N keepCount = int(value) olderThanDays = 0 case 1: // older than days keepCount = 0 olderThanDays = int(value) default: return h.sendErrorPacket( conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "", ) } // Perform pruning pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) if err != nil { h.logger.Error("prune failed", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) } if h.queue != nil { _ = h.queue.SignalPrewarmGC() } h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) // Send structured success response packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) return h.sendResponsePacket(conn, packet) } func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var] if len(payload) < 51 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") } apiKeyHash := payload[:16] commitID := payload[16:36] step := int(binary.BigEndian.Uint32(payload[36:40])) valueBits := binary.BigEndian.Uint64(payload[40:48]) value := math.Float64frombits(valueBits) nameLen := int(payload[48]) if len(payload) < 49+nameLen { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") } name := string(payload[49 : 49+nameLen]) // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { h.logger.Error("api key verification failed", "error", err) return h.sendErrorPacket( conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error(), ) } } if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil { h.logger.Error("failed to log metric", "error", err) return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) } return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) } func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { // Protocol: [api_key_hash:16][commit_id:20] if len(payload) < 36 { return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") } apiKeyHash := payload[:16] commitID := payload[16:36] // Verify API key if h.authConfig != nil && h.authConfig.Enabled { if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { return h.sendErrorPacket( conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error(), ) } } meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID)) if err != nil { return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) } metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID)) if err != nil { return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) } var dbMeta *storage.ExperimentWithMetadata if h.db != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID)) if err == nil { dbMeta = m } } response := map[string]interface{}{ "metadata": meta, "metrics": metrics, } if dbMeta != nil { response["reproducibility"] = dbMeta } responseData, err := json.Marshal(response) if err != nil { return h.sendErrorPacket( conn, ErrorCodeServerOverloaded, "Failed to serialize response", err.Error(), ) } return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData)) }