diff --git a/internal/api/protocol.go b/internal/api/protocol.go index 464e794..fb6d7e6 100644 --- a/internal/api/protocol.go +++ b/internal/api/protocol.go @@ -308,7 +308,6 @@ func GetErrorMessage(code byte) string { return "Resource not found" case ErrorCodeResourceAlreadyExists: return "Resource already exists" - case ErrorCodeServerOverloaded: return "Server is overloaded" case ErrorCodeDatabaseError: @@ -319,7 +318,6 @@ func GetErrorMessage(code byte) string { return "Storage error occurred" case ErrorCodeTimeout: return "Operation timed out" - case ErrorCodeJobNotFound: return "Job not found" case ErrorCodeJobAlreadyRunning: @@ -330,7 +328,6 @@ func GetErrorMessage(code byte) string { return "Job execution failed" case ErrorCodeJobCancelled: return "Job was cancelled" - case ErrorCodeOutOfMemory: return "Server out of memory" case ErrorCodeDiskFull: @@ -339,7 +336,6 @@ func GetErrorMessage(code byte) string { return "Invalid server configuration" case ErrorCodeServiceUnavailable: return "Service temporarily unavailable" - default: return "Unknown error code" } diff --git a/internal/api/server.go b/internal/api/server.go index 8fbdb6b..001c47e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -159,6 +159,8 @@ func (s *Server) initTaskQueue() error { RedisPassword: s.config.Redis.Password, RedisDB: s.config.Redis.DB, SQLitePath: s.config.Queue.SQLitePath, + FilesystemPath: s.config.Queue.FilesystemPath, + FallbackToFilesystem: s.config.Queue.FallbackToFilesystem, MetricsFlushInterval: 0, } diff --git a/internal/api/server_config.go b/internal/api/server_config.go index 3879b83..d82a719 100644 --- a/internal/api/server_config.go +++ b/internal/api/server_config.go @@ -15,8 +15,10 @@ import ( ) type QueueConfig struct { - Backend string `yaml:"backend"` - SQLitePath string `yaml:"sqlite_path"` + Backend string `yaml:"backend"` + SQLitePath string `yaml:"sqlite_path"` + FilesystemPath string `yaml:"filesystem_path"` + FallbackToFilesystem bool `yaml:"fallback_to_filesystem"` } // ServerConfig holds all server configuration @@ -172,8 +174,8 @@ func (c *ServerConfig) Validate() error { backend = "redis" c.Queue.Backend = backend } - if backend != "redis" && backend != "sqlite" { - return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'") + if backend != "redis" && backend != "sqlite" && backend != "filesystem" { + return fmt.Errorf("queue.backend must be one of 'redis', 'sqlite', or 'filesystem'") } if backend == "sqlite" { if strings.TrimSpace(c.Queue.SQLitePath) == "" { @@ -184,6 +186,15 @@ func (c *ServerConfig) Validate() error { c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) } } + if backend == "filesystem" || c.Queue.FallbackToFilesystem { + if strings.TrimSpace(c.Queue.FilesystemPath) == "" { + c.Queue.FilesystemPath = filepath.Join(c.DataDir, "queue-fs") + } + c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath) + if !filepath.IsAbs(c.Queue.FilesystemPath) { + c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath) + } + } return nil } diff --git a/internal/api/ws_handler.go b/internal/api/ws_handler.go index 0df5d39..e2b5612 100644 --- a/internal/api/ws_handler.go +++ b/internal/api/ws_handler.go @@ -35,11 +35,16 @@ const ( OpcodeGetExperiment = 0x0B OpcodeQueueJobWithTracking = 0x0C OpcodeQueueJobWithSnapshot = 0x17 + OpcodeQueueJobWithArgs = 0x1A + OpcodeQueueJobWithNote = 0x1B + OpcodeAnnotateRun = 0x1C + OpcodeSetRunNarrative = 0x1D OpcodeStartJupyter = 0x0D OpcodeStopJupyter = 0x0E OpcodeRemoveJupyter = 0x18 OpcodeRestoreJupyter = 0x19 OpcodeListJupyter = 0x0F + OpcodeListJupyterPackages = 0x1E OpcodeValidateRequest = 0x16 ) @@ -243,6 +248,14 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { return h.handleQueueJobWithTracking(conn, payload) case OpcodeQueueJobWithSnapshot: return h.handleQueueJobWithSnapshot(conn, payload) + case OpcodeQueueJobWithArgs: + return h.handleQueueJobWithArgs(conn, payload) + case OpcodeQueueJobWithNote: + return h.handleQueueJobWithNote(conn, payload) + case OpcodeAnnotateRun: + return h.handleAnnotateRun(conn, payload) + case OpcodeSetRunNarrative: + return h.handleSetRunNarrative(conn, payload) case OpcodeStatusRequest: return h.handleStatusRequest(conn, payload) case OpcodeCancelJob: @@ -271,6 +284,8 @@ func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { return h.handleRestoreJupyter(conn, payload) case OpcodeListJupyter: return h.handleListJupyter(conn, payload) + case OpcodeListJupyterPackages: + return h.handleListJupyterPackages(conn, payload) case OpcodeValidateRequest: return h.handleValidateRequest(conn, payload) default: diff --git a/internal/api/ws_jobs.go b/internal/api/ws_jobs.go index defe033..140529f 100644 --- a/internal/api/ws_jobs.go +++ b/internal/api/ws_jobs.go @@ -18,14 +18,220 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" "github.com/jfraeys/fetch_ml/internal/worker" ) +func (h *WSHandler) handleAnnotateRun(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][author_len:1][author:var][note_len:2][note:var] + if len(payload) < 16+1+1+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "annotate run payload too short", "") + } + + apiKeyHash := payload[:16] + offset := 16 + + jobNameLen := int(payload[offset]) + offset += 1 + if jobNameLen <= 0 || len(payload) < offset+jobNameLen+1+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + jobName := string(payload[offset : offset+jobNameLen]) + offset += jobNameLen + + authorLen := int(payload[offset]) + offset += 1 + if authorLen < 0 || len(payload) < offset+authorLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid author length", "") + } + author := string(payload[offset : offset+authorLen]) + offset += authorLen + + noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if noteLen <= 0 || len(payload) < offset+noteLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid note length", "") + } + note := string(payload[offset : offset+noteLen]) + + // Validate API key and get user information + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + // Permission model: if auth is enabled, require jobs:update to mutate shared run artifacts. + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to annotate runs", "") + } + + if err := container.ValidateJobName(jobName); err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) + } + + base := strings.TrimSpace(h.expManager.BasePath()) + if base == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") + } + + jobPaths := config.NewJobPaths(base) + typedRoots := []struct { + root string + }{ + {root: jobPaths.RunningPath()}, + {root: jobPaths.PendingPath()}, + {root: jobPaths.FinishedPath()}, + {root: jobPaths.FailedPath()}, + } + + var manifestDir string + for _, item := range typedRoots { + dir := filepath.Join(item.root, jobName) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { + manifestDir = dir + break + } + } + } + if manifestDir == "" { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") + } + + rm, err := manifest.LoadFromDir(manifestDir) + if err != nil || rm == nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) + } + + // Default author to authenticated user if empty. + if strings.TrimSpace(author) == "" { + author = user.Name + } + rm.AddAnnotation(time.Now().UTC(), author, note) + if err := rm.WriteToDir(manifestDir); err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) + } + + return h.sendResponsePacket(conn, NewSuccessPacket("Annotation added")) +} + +func (h *WSHandler) handleSetRunNarrative(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][job_name_len:1][job_name:var][patch_json_len:2][patch_json:var] + if len(payload) < 16+1+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "set run narrative payload too short", "") + } + + apiKeyHash := payload[:16] + offset := 16 + + jobNameLen := int(payload[offset]) + offset += 1 + if jobNameLen <= 0 || len(payload) < offset+jobNameLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + jobName := string(payload[offset : offset+jobNameLen]) + offset += jobNameLen + + patchLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if patchLen <= 0 || len(payload) < offset+patchLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch length", "") + } + patchJSON := payload[offset : offset+patchLen] + + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to update run narrative", "") + } + if err := container.ValidateJobName(jobName); err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name", err.Error()) + } + + base := strings.TrimSpace(h.expManager.BasePath()) + if base == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidConfiguration, "Missing api base_path", "") + } + + jobPaths := config.NewJobPaths(base) + typedRoots := []struct{ root string }{ + {root: jobPaths.RunningPath()}, + {root: jobPaths.PendingPath()}, + {root: jobPaths.FinishedPath()}, + {root: jobPaths.FailedPath()}, + } + + var manifestDir string + for _, item := range typedRoots { + dir := filepath.Join(item.root, jobName) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { + manifestDir = dir + break + } + } + } + if manifestDir == "" { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "run manifest not found", "") + } + + rm, err := manifest.LoadFromDir(manifestDir) + if err != nil || rm == nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "unable to read run manifest", fmt.Sprintf("%v", err)) + } + + var patch manifest.NarrativePatch + if err := json.Unmarshal(patchJSON, &patch); err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid narrative patch JSON", err.Error()) + } + rm.ApplyNarrativePatch(patch) + if err := rm.WriteToDir(manifestDir); err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "failed to write run manifest", err.Error()) + } + + return h.sendResponsePacket(conn, NewSuccessPacket("Narrative updated")) +} + func fileSHA256Hex(path string) (string, error) { f, err := os.Open(filepath.Clean(path)) if err != nil { @@ -668,6 +874,294 @@ func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []b return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources) } +type queueJobWithArgsPayload struct { + apiKeyHash []byte + commitID []byte + priority int64 + jobName string + args string + resources *resourceRequest +} + +type queueJobWithNotePayload struct { + apiKeyHash []byte + commitID []byte + priority int64 + jobName string + args string + note string + resources *resourceRequest +} + +func parseQueueJobWithNotePayload(payload []byte) (*queueJobWithNotePayload, error) { + // Protocol: + // [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] + // [args_len:2][args:var][note_len:2][note:var][resources?:var] + if len(payload) < 42 { + return nil, fmt.Errorf("queue job with note payload too short") + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + priority := int64(payload[36]) + jobNameLen := int(payload[37]) + if len(payload) < 38+jobNameLen+2 { + return nil, fmt.Errorf("invalid job name length") + } + jobName := string(payload[38 : 38+jobNameLen]) + + offset := 38 + jobNameLen + argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if argsLen < 0 || len(payload) < offset+argsLen+2 { + return nil, fmt.Errorf("invalid args length") + } + args := "" + if argsLen > 0 { + args = string(payload[offset : offset+argsLen]) + } + offset += argsLen + + noteLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if noteLen < 0 || len(payload) < offset+noteLen { + return nil, fmt.Errorf("invalid note length") + } + note := "" + if noteLen > 0 { + note = string(payload[offset : offset+noteLen]) + } + offset += noteLen + + resources, resErr := parseOptionalResourceRequest(payload[offset:]) + if resErr != nil { + return nil, resErr + } + + return &queueJobWithNotePayload{ + apiKeyHash: apiKeyHash, + commitID: commitID, + priority: priority, + jobName: jobName, + args: args, + note: note, + resources: resources, + }, nil +} + +func parseQueueJobWithArgsPayload(payload []byte) (*queueJobWithArgsPayload, error) { + // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var][args_len:2][args:var][resources?:var] + if len(payload) < 40 { + return nil, fmt.Errorf("queue job with args payload too short") + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + priority := int64(payload[36]) + jobNameLen := int(payload[37]) + if len(payload) < 38+jobNameLen+2 { + return nil, fmt.Errorf("invalid job name length") + } + jobName := string(payload[38 : 38+jobNameLen]) + + offset := 38 + jobNameLen + argsLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if argsLen < 0 || len(payload) < offset+argsLen { + return nil, fmt.Errorf("invalid args length") + } + args := "" + if argsLen > 0 { + args = string(payload[offset : offset+argsLen]) + } + offset += argsLen + + resources, resErr := parseOptionalResourceRequest(payload[offset:]) + if resErr != nil { + return nil, resErr + } + + return &queueJobWithArgsPayload{ + apiKeyHash: apiKeyHash, + commitID: commitID, + priority: priority, + jobName: jobName, + args: args, + resources: resources, + }, nil +} + +func (h *WSHandler) handleQueueJobWithArgs(conn *websocket.Conn, payload []byte) error { + p, err := parseQueueJobWithArgsPayload(payload) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with args payload", err.Error()) + } + + h.logger.Info("queue job request", + "job", p.jobName, + "priority", p.priority, + "commit_id", fmt.Sprintf("%x", p.commitID), + ) + + // Validate API key and get user information + var user *auth.User + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info( + "job queued", + "job", p.jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)), + "user", user.Name, + ) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to create jobs", + "", + ) + } + + commitIDStr := fmt.Sprintf("%x", p.commitID) + if _, err := telemetry.ExecWithMetrics( + h.logger, + "experiment.create", + 50*time.Millisecond, + func() (string, error) { + return "", h.expManager.CreateExperiment(commitIDStr) + }, + ); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) + } + + meta := &experiment.Metadata{ + CommitID: commitIDStr, + JobName: p.jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { + return "", h.expManager.WriteMetadata(meta) + }); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { + return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) + }); err != nil { + h.logger.Error("failed to ensure minimal experiment files", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error()) + } + + return h.enqueueTaskAndRespondWithArgs(conn, user, p.jobName, p.priority, p.commitID, p.args, nil, p.resources) +} + +func (h *WSHandler) handleQueueJobWithNote(conn *websocket.Conn, payload []byte) error { + p, err := parseQueueJobWithNotePayload(payload) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid queue job with note payload", err.Error()) + } + + h.logger.Info("queue job request", + "job", p.jobName, + "priority", p.priority, + "commit_id", fmt.Sprintf("%x", p.commitID), + ) + + var user *auth.User + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(p.apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info( + "job queued", + "job", p.jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", p.commitID)), + "user", user.Name, + ) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to create jobs", + "", + ) + } + + commitIDStr := fmt.Sprintf("%x", p.commitID) + if _, err := telemetry.ExecWithMetrics( + h.logger, + "experiment.create", + 50*time.Millisecond, + func() (string, error) { + return "", h.expManager.CreateExperiment(commitIDStr) + }, + ); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) + } + + meta := &experiment.Metadata{ + CommitID: commitIDStr, + JobName: p.jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { + return "", h.expManager.WriteMetadata(meta) + }); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to save experiment metadata", err.Error()) + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { + return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) + }); err != nil { + h.logger.Error("failed to ensure minimal experiment files", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to initialize experiment files", err.Error()) + } + + return h.enqueueTaskAndRespondWithArgsAndNote(conn, user, p.jobName, p.priority, p.commitID, p.args, p.note, nil, p.resources) +} + // enqueueTaskAndRespond enqueues a task and sends a success response. func (h *WSHandler) enqueueTaskAndRespond( conn *websocket.Conn, @@ -677,6 +1171,112 @@ func (h *WSHandler) enqueueTaskAndRespond( commitID []byte, tracking *queue.TrackingConfig, resources *resourceRequest, +) error { + return h.enqueueTaskAndRespondWithArgs(conn, user, jobName, priority, commitID, "", tracking, resources) +} + +func (h *WSHandler) enqueueTaskAndRespondWithArgsAndNote( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + args string, + note string, + tracking *queue.TrackingConfig, + resources *resourceRequest, +) error { + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) + + commitIDStr := fmt.Sprintf("%x", commitID) + prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + if provErr != nil { + h.logger.Error("failed to compute expected provenance; refusing to enqueue", + "commit_id", commitIDStr, + "error", provErr) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to compute expected provenance", + provErr.Error(), + ) + } + + if h.queue != nil { + taskID := uuid.New().String() + task := &queue.Task{ + ID: taskID, + JobName: jobName, + Args: strings.TrimSpace(args), + Status: "queued", + Priority: priority, + CreatedAt: time.Now(), + UserID: user.Name, + Username: user.Name, + CreatedBy: user.Name, + Metadata: map[string]string{ + "commit_id": commitIDStr, + }, + Tracking: tracking, + } + if strings.TrimSpace(note) != "" { + task.Metadata["note"] = strings.TrimSpace(note) + } + for k, v := range prov { + if v != "" { + task.Metadata[k] = v + } + } + if resources != nil { + task.CPU = resources.CPU + task.MemoryGB = resources.MemoryGB + task.GPU = resources.GPU + task.GPUMemory = resources.GPUMemory + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, + "queue.add_task", + 20*time.Millisecond, + func() (string, error) { + return "", h.queue.AddTask(task) + }, + ); err != nil { + h.logger.Error("failed to enqueue task", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeDatabaseError, + "Failed to enqueue task", + err.Error(), + ) + } + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) + } + + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Internal error", + "Failed to serialize response", + ) + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +func (h *WSHandler) enqueueTaskAndRespondWithArgs( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + args string, + tracking *queue.TrackingConfig, + resources *resourceRequest, ) error { packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) @@ -700,7 +1300,7 @@ func (h *WSHandler) enqueueTaskAndRespond( task := &queue.Task{ ID: taskID, JobName: jobName, - Args: "", + Args: strings.TrimSpace(args), Status: "queued", Priority: priority, CreatedAt: time.Now(), diff --git a/internal/api/ws_jupyter.go b/internal/api/ws_jupyter.go index 72795cd..7870513 100644 --- a/internal/api/ws_jupyter.go +++ b/internal/api/ws_jupyter.go @@ -13,10 +13,47 @@ import ( "github.com/jfraeys/fetch_ml/internal/queue" ) +func JupyterTaskErrorCode(t *queue.Task) byte { + if t == nil { + return ErrorCodeUnknownError + } + status := strings.ToLower(strings.TrimSpace(t.Status)) + errStr := strings.ToLower(strings.TrimSpace(t.Error)) + + if status == "cancelled" { + return ErrorCodeJobCancelled + } + if strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "oom") { + return ErrorCodeOutOfMemory + } + if strings.Contains(errStr, "no space left") || strings.Contains(errStr, "disk full") { + return ErrorCodeDiskFull + } + if strings.Contains(errStr, "rate limit") || strings.Contains(errStr, "too many requests") || strings.Contains(errStr, "throttle") { + return ErrorCodeServiceUnavailable + } + if strings.Contains(errStr, "timed out") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline") { + return ErrorCodeTimeout + } + if strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "network unreachable") { + return ErrorCodeNetworkError + } + if strings.Contains(errStr, "queue") && strings.Contains(errStr, "not configured") { + return ErrorCodeInvalidConfiguration + } + + // Default for worker-side execution failures. + if status == "failed" { + return ErrorCodeJobExecutionFailed + } + return ErrorCodeUnknownError +} + type jupyterTaskOutput struct { Type string `json:"type"` Service json.RawMessage `json:"service,omitempty"` Services json.RawMessage `json:"services,omitempty"` + Packages json.RawMessage `json:"packages,omitempty"` RestorePath string `json:"restore_path,omitempty"` } @@ -76,7 +113,7 @@ func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) e return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") } if result.Status != "completed" { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error)) + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error)) } msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name)) @@ -102,18 +139,98 @@ const ( jupyterTaskTypeKey = "task_type" jupyterTaskTypeValue = "jupyter" - jupyterTaskActionKey = "jupyter_action" - jupyterActionStart = "start" - jupyterActionStop = "stop" - jupyterActionRemove = "remove" - jupyterActionRestore = "restore" - jupyterActionList = "list" + jupyterTaskActionKey = "jupyter_action" + jupyterActionStart = "start" + jupyterActionStop = "stop" + jupyterActionRemove = "remove" + jupyterActionRestore = "restore" + jupyterActionList = "list" + jupyterActionListPkgs = "list_packages" jupyterNameKey = "jupyter_name" jupyterWorkspaceKey = "jupyter_workspace" jupyterServiceIDKey = "jupyter_service_id" ) +func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][name_len:1][name:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "") + } + + apiKeyHash := payload[:16] + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:read") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + offset := 16 + nameLen := int(payload[offset]) + offset++ + if len(payload) < offset+nameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") + } + name := string(payload[offset : offset+nameLen]) + name = strings.TrimSpace(name) + if name == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "") + } + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionListPkgs, + jupyterNameKey: name, + } + jobName := fmt.Sprintf("jupyter-packages-%s", name) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter packages list", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "") + } + + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter packages list", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error)) + } + + out := strings.TrimSpace(result.Output) + if out == "" { + return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]"))) + } + var payloadOut jupyterTaskOutput + if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { + payload := payloadOut.Packages + if len(payload) == 0 { + payload = []byte("[]") + } + return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload)) + } + + return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", []byte("[]"))) +} + func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) { if h.queue == nil { return "", fmt.Errorf("task queue not configured") @@ -260,7 +377,7 @@ func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) err if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") { return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details) } - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details) + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details) } msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name)) @@ -336,7 +453,7 @@ func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) erro return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") } if result.Status != "completed" { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error)) + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error)) } return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID))) } @@ -405,7 +522,7 @@ func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) er return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") } if result.Status != "completed" { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error)) + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error)) } return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID))) } @@ -456,7 +573,7 @@ func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) erro return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") } if result.Status != "completed" { - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error)) + return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error)) } out := strings.TrimSpace(result.Output) diff --git a/internal/jupyter/config.go b/internal/jupyter/config.go index ce06616..080cea0 100644 --- a/internal/jupyter/config.go +++ b/internal/jupyter/config.go @@ -12,7 +12,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/logging" ) -var defaultBlockedPackages = []string{"requests", "urllib3", "httpx"} +var defaultBlockedPackages = []string{} func DefaultBlockedPackages() []string { return append([]string{}, defaultBlockedPackages...) diff --git a/internal/jupyter/service_manager.go b/internal/jupyter/service_manager.go index a409b0b..ac9b38d 100644 --- a/internal/jupyter/service_manager.go +++ b/internal/jupyter/service_manager.go @@ -203,6 +203,36 @@ type ServiceManager struct { services map[string]*JupyterService workspaceMetadataMgr *WorkspaceMetadataManager securityMgr *SecurityManager + startupBlockedPkgs []string +} + +func splitPackageList(value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + parts := strings.Split(value, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + out = append(out, p) + } + return out +} + +func startupBlockedPackages(installBlocked []string) []string { + val, ok := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + if !ok { + return append([]string{}, installBlocked...) + } + val = strings.TrimSpace(val) + if val == "" || strings.EqualFold(val, "off") || strings.EqualFold(val, "none") || strings.EqualFold(val, "disabled") { + return []string{} + } + return splitPackageList(val) } // ServiceConfig holds configuration for Jupyter services @@ -270,6 +300,12 @@ type JupyterService struct { Metadata map[string]string `json:"metadata"` } +type InstalledPackage struct { + Name string `json:"name"` + Version string `json:"version"` + Source string `json:"source"` +} + // StartRequest defines parameters for starting a Jupyter service type StartRequest struct { Name string `json:"name"` @@ -316,6 +352,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM } securityMgr := NewSecurityManager(logger, securityConfig) + startupBlockedPkgs := startupBlockedPackages(securityConfig.BlockedPackages) sm := &ServiceManager{ logger: logger, @@ -324,6 +361,7 @@ func NewServiceManager(logger *logging.Logger, config *ServiceConfig) (*ServiceM services: make(map[string]*JupyterService), workspaceMetadataMgr: workspaceMetadataMgr, securityMgr: securityMgr, + startupBlockedPkgs: startupBlockedPkgs, } // Load existing services @@ -421,6 +459,10 @@ func (sm *ServiceManager) StartService( // checkPackageBlacklist validates that no blacklisted packages are installed in the container func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID string) error { + if len(sm.startupBlockedPkgs) == 0 { + return nil + } + // Get list of installed packages from the container // Try both pip and conda package managers packages, err := sm.getInstalledPackages(ctx, containerID) @@ -430,19 +472,14 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID return nil } - // Check each installed package against the blacklist + // Check each installed package against the startup blacklist var blockedPackages []string for _, pkg := range packages { - // Create a package request for validation - pkgReq := &PackageRequest{ - PackageName: pkg, - RequestedBy: "system", - Channel: "", - Version: "", - } - - if err := sm.securityMgr.ValidatePackageRequest(pkgReq); err != nil { - blockedPackages = append(blockedPackages, pkg) + for _, blocked := range sm.startupBlockedPkgs { + if strings.EqualFold(blocked, pkg) { + blockedPackages = append(blockedPackages, pkg) + break + } } } @@ -450,7 +487,7 @@ func (sm *ServiceManager) checkPackageBlacklist(ctx context.Context, containerID if len(blockedPackages) > 0 { return fmt.Errorf("container startup failed: blacklisted packages detected: %v. "+ "These packages are blocked by security policy. "+ - "Remove them from the image or use FETCHML_JUPYTER_BLOCKED_PACKAGES to configure the blacklist", + "Remove them from the image or use FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES to configure the startup blacklist", blockedPackages) } @@ -508,6 +545,88 @@ func (sm *ServiceManager) parsePipList(output string) []string { return packages } +func (sm *ServiceManager) serviceByName(name string) *JupyterService { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + for _, svc := range sm.services { + if svc == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(svc.Name), name) { + return svc + } + } + return nil +} + +func (sm *ServiceManager) listInstalledPackages(ctx context.Context, containerID string) ([]InstalledPackage, error) { + var pkgs []InstalledPackage + + // pip + pipJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"pip", "list", "--format=json"}) + if err == nil { + var parsed []struct { + Name string `json:"name"` + Version string `json:"version"` + } + if json.Unmarshal([]byte(pipJSON), &parsed) == nil { + for _, p := range parsed { + name := strings.TrimSpace(p.Name) + if name == "" { + continue + } + pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "pip"}) + } + } + } + + // conda + condaJSON, err := sm.podman.ExecContainer(ctx, containerID, []string{"conda", "list", "--json"}) + if err == nil { + var parsed []struct { + Name string `json:"name"` + Version string `json:"version"` + } + if json.Unmarshal([]byte(condaJSON), &parsed) == nil { + for _, p := range parsed { + name := strings.TrimSpace(p.Name) + if name == "" { + continue + } + pkgs = append(pkgs, InstalledPackage{Name: name, Version: strings.TrimSpace(p.Version), Source: "conda"}) + } + } + } + + seen := make(map[string]bool) + out := make([]InstalledPackage, 0, len(pkgs)) + for _, p := range pkgs { + key := strings.ToLower(strings.TrimSpace(p.Name)) + ":" + strings.ToLower(strings.TrimSpace(p.Source)) + if seen[key] { + continue + } + seen[key] = true + out = append(out, p) + } + return out, nil +} + +func (sm *ServiceManager) ListInstalledPackages(ctx context.Context, serviceName string) ([]InstalledPackage, error) { + if sm == nil { + return nil, fmt.Errorf("service manager is nil") + } + svc := sm.serviceByName(serviceName) + if svc == nil { + return nil, fmt.Errorf("service %s not found", strings.TrimSpace(serviceName)) + } + if strings.TrimSpace(svc.ContainerID) == "" { + return nil, fmt.Errorf("service container not available") + } + return sm.listInstalledPackages(ctx, svc.ContainerID) +} + // parseCondaList parses conda list --export output func (sm *ServiceManager) parseCondaList(output string) []string { var packages []string diff --git a/internal/jupyter/startup_blacklist_test.go b/internal/jupyter/startup_blacklist_test.go new file mode 100644 index 0000000..e3be4c4 --- /dev/null +++ b/internal/jupyter/startup_blacklist_test.go @@ -0,0 +1,57 @@ +package jupyter + +import ( + "os" + "testing" +) + +func TestStartupBlockedPackages_DefaultInheritsInstallBlocked(t *testing.T) { + oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") + _, hadStartup := os.LookupEnv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + + _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") + _ = os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) + if hadStartup { + defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) + } else { + defer os.Unsetenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + } + + cfg := DefaultEnhancedSecurityConfigFromEnv() + startup := startupBlockedPackages(cfg.BlockedPackages) + if len(startup) != 2 { + t.Fatalf("expected startup list to inherit 2 items, got %d", len(startup)) + } +} + +func TestStartupBlockedPackages_Disabled(t *testing.T) { + oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") + oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") + _ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "off") + defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) + defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) + + cfg := DefaultEnhancedSecurityConfigFromEnv() + startup := startupBlockedPackages(cfg.BlockedPackages) + if len(startup) != 0 { + t.Fatalf("expected startup list to be disabled, got %d", len(startup)) + } +} + +func TestStartupBlockedPackages_ExplicitList(t *testing.T) { + oldInstall := os.Getenv("FETCHML_JUPYTER_BLOCKED_PACKAGES") + oldStartup := os.Getenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES") + _ = os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", "requests,urllib3") + _ = os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", "aiohttp") + defer os.Setenv("FETCHML_JUPYTER_BLOCKED_PACKAGES", oldInstall) + defer os.Setenv("FETCHML_JUPYTER_STARTUP_BLOCKED_PACKAGES", oldStartup) + + cfg := DefaultEnhancedSecurityConfigFromEnv() + startup := startupBlockedPackages(cfg.BlockedPackages) + if len(startup) != 1 || startup[0] != "aiohttp" { + t.Fatalf("expected explicit startup list [aiohttp], got %v", startup) + } +} diff --git a/internal/manifest/run_manifest.go b/internal/manifest/run_manifest.go new file mode 100644 index 0000000..25e9b28 --- /dev/null +++ b/internal/manifest/run_manifest.go @@ -0,0 +1,226 @@ +package manifest + +import ( + "encoding/json" + "fmt" + "path/filepath" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/fileutil" +) + +const runManifestFilename = "run_manifest.json" + +type Annotation struct { + Timestamp time.Time `json:"timestamp"` + Author string `json:"author,omitempty"` + Note string `json:"note"` +} + +func (a *Annotation) UnmarshalJSON(data []byte) error { + type annotationWire struct { + Timestamp *time.Time `json:"timestamp,omitempty"` + TS *time.Time `json:"ts,omitempty"` + Author string `json:"author,omitempty"` + Note string `json:"note"` + } + var w annotationWire + if err := json.Unmarshal(data, &w); err != nil { + return err + } + if w.Timestamp != nil { + a.Timestamp = *w.Timestamp + } else if w.TS != nil { + a.Timestamp = *w.TS + } + a.Author = w.Author + a.Note = w.Note + return nil +} + +type Narrative struct { + Hypothesis string `json:"hypothesis,omitempty"` + Context string `json:"context,omitempty"` + Intent string `json:"intent,omitempty"` + ExpectedOutcome string `json:"expected_outcome,omitempty"` + ParentRun string `json:"parent_run,omitempty"` + ExperimentGroup string `json:"experiment_group,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +type NarrativePatch struct { + Hypothesis *string `json:"hypothesis,omitempty"` + Context *string `json:"context,omitempty"` + Intent *string `json:"intent,omitempty"` + ExpectedOutcome *string `json:"expected_outcome,omitempty"` + ParentRun *string `json:"parent_run,omitempty"` + ExperimentGroup *string `json:"experiment_group,omitempty"` + Tags *[]string `json:"tags,omitempty"` +} + +type ArtifactFile struct { + Path string `json:"path"` + SizeBytes int64 `json:"size_bytes"` + Modified time.Time `json:"modified"` +} + +type Artifacts struct { + DiscoveryTime time.Time `json:"discovery_time"` + Files []ArtifactFile `json:"files,omitempty"` + TotalSizeBytes int64 `json:"total_size_bytes,omitempty"` +} + +// RunManifest is a best-effort, self-contained provenance record for a run. +// It is written to /run_manifest.json. +type RunManifest struct { + RunID string `json:"run_id"` + TaskID string `json:"task_id"` + JobName string `json:"job_name"` + CreatedAt time.Time `json:"created_at"` + StartedAt time.Time `json:"started_at,omitempty"` + EndedAt time.Time `json:"ended_at,omitempty"` + + Annotations []Annotation `json:"annotations,omitempty"` + Narrative *Narrative `json:"narrative,omitempty"` + Artifacts *Artifacts `json:"artifacts,omitempty"` + + CommitID string `json:"commit_id,omitempty"` + ExperimentManifestSHA string `json:"experiment_manifest_sha,omitempty"` + DepsManifestName string `json:"deps_manifest_name,omitempty"` + DepsManifestSHA string `json:"deps_manifest_sha,omitempty"` + TrainScriptPath string `json:"train_script_path,omitempty"` + + WorkerVersion string `json:"worker_version,omitempty"` + PodmanImage string `json:"podman_image,omitempty"` + ImageDigest string `json:"image_digest,omitempty"` + + SnapshotID string `json:"snapshot_id,omitempty"` + SnapshotSHA256 string `json:"snapshot_sha256,omitempty"` + + Command string `json:"command,omitempty"` + Args string `json:"args,omitempty"` + ExitCode *int `json:"exit_code,omitempty"` + Error string `json:"error,omitempty"` + + StagingDurationMS int64 `json:"staging_duration_ms,omitempty"` + ExecutionDurationMS int64 `json:"execution_duration_ms,omitempty"` + FinalizeDurationMS int64 `json:"finalize_duration_ms,omitempty"` + TotalDurationMS int64 `json:"total_duration_ms,omitempty"` + + GPUDevices []string `json:"gpu_devices,omitempty"` + WorkerHost string `json:"worker_host,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +func NewRunManifest(runID, taskID, jobName string, createdAt time.Time) *RunManifest { + m := &RunManifest{ + RunID: runID, + TaskID: taskID, + JobName: jobName, + CreatedAt: createdAt, + Metadata: make(map[string]string), + } + return m +} + +func ManifestPath(dir string) string { + return filepath.Join(dir, runManifestFilename) +} + +func (m *RunManifest) WriteToDir(dir string) error { + if m == nil { + return fmt.Errorf("run manifest is nil") + } + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return fmt.Errorf("marshal run manifest: %w", err) + } + if err := fileutil.SecureFileWrite(ManifestPath(dir), data, 0640); err != nil { + return fmt.Errorf("write run manifest: %w", err) + } + return nil +} + +func LoadFromDir(dir string) (*RunManifest, error) { + data, err := fileutil.SecureFileRead(ManifestPath(dir)) + if err != nil { + return nil, fmt.Errorf("read run manifest: %w", err) + } + var m RunManifest + if err := json.Unmarshal(data, &m); err != nil { + return nil, fmt.Errorf("parse run manifest: %w", err) + } + return &m, nil +} + +func (m *RunManifest) MarkStarted(t time.Time) { + m.StartedAt = t +} + +func (m *RunManifest) MarkFinished(t time.Time, exitCode *int, execErr error) { + m.EndedAt = t + m.ExitCode = exitCode + if execErr != nil { + m.Error = execErr.Error() + } else { + m.Error = "" + } + if !m.StartedAt.IsZero() { + m.TotalDurationMS = m.EndedAt.Sub(m.StartedAt).Milliseconds() + } +} + +func (m *RunManifest) AddAnnotation(ts time.Time, author, note string) { + if m == nil { + return + } + n := strings.TrimSpace(note) + if n == "" { + return + } + a := Annotation{ + Timestamp: ts, + Author: strings.TrimSpace(author), + Note: n, + } + m.Annotations = append(m.Annotations, a) +} + +func (m *RunManifest) ApplyNarrativePatch(p NarrativePatch) { + if m == nil { + return + } + if m.Narrative == nil { + m.Narrative = &Narrative{} + } + if p.Hypothesis != nil { + m.Narrative.Hypothesis = strings.TrimSpace(*p.Hypothesis) + } + if p.Context != nil { + m.Narrative.Context = strings.TrimSpace(*p.Context) + } + if p.Intent != nil { + m.Narrative.Intent = strings.TrimSpace(*p.Intent) + } + if p.ExpectedOutcome != nil { + m.Narrative.ExpectedOutcome = strings.TrimSpace(*p.ExpectedOutcome) + } + if p.ParentRun != nil { + m.Narrative.ParentRun = strings.TrimSpace(*p.ParentRun) + } + if p.ExperimentGroup != nil { + m.Narrative.ExperimentGroup = strings.TrimSpace(*p.ExperimentGroup) + } + if p.Tags != nil { + clean := make([]string, 0, len(*p.Tags)) + for _, t := range *p.Tags { + t = strings.TrimSpace(t) + if t == "" { + continue + } + clean = append(clean, t) + } + m.Narrative.Tags = clean + } +} diff --git a/internal/queue/backend.go b/internal/queue/backend.go index 2b2f088..f7f9ab3 100644 --- a/internal/queue/backend.go +++ b/internal/queue/backend.go @@ -2,6 +2,8 @@ package queue import ( "errors" + "fmt" + "strings" "time" ) @@ -48,6 +50,7 @@ type QueueBackend string const ( QueueBackendRedis QueueBackend = "redis" QueueBackendSQLite QueueBackend = "sqlite" + QueueBackendFS QueueBackend = "filesystem" ) type BackendConfig struct { @@ -56,20 +59,49 @@ type BackendConfig struct { RedisPassword string RedisDB int SQLitePath string + FilesystemPath string + FallbackToFilesystem bool MetricsFlushInterval time.Duration } func NewBackend(cfg BackendConfig) (Backend, error) { + mkFallback := func(err error) (Backend, error) { + if !cfg.FallbackToFilesystem { + return nil, err + } + if strings.TrimSpace(cfg.FilesystemPath) == "" { + return nil, fmt.Errorf("filesystem queue path is required for fallback") + } + fsq, fsErr := NewFilesystemQueue(cfg.FilesystemPath) + if fsErr != nil { + return nil, fmt.Errorf("filesystem queue fallback init failed: %w", fsErr) + } + return fsq, nil + } + switch cfg.Backend { + case QueueBackendFS: + if strings.TrimSpace(cfg.FilesystemPath) == "" { + return nil, fmt.Errorf("filesystem queue path is required") + } + return NewFilesystemQueue(cfg.FilesystemPath) case "", QueueBackendRedis: - return NewTaskQueue(Config{ + b, err := NewTaskQueue(Config{ RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, MetricsFlushInterval: cfg.MetricsFlushInterval, }) + if err != nil { + return mkFallback(err) + } + return b, nil case QueueBackendSQLite: - return NewSQLiteQueue(cfg.SQLitePath) + b, err := NewSQLiteQueue(cfg.SQLitePath) + if err != nil { + return mkFallback(err) + } + return b, nil default: return nil, ErrInvalidQueueBackend } diff --git a/internal/queue/filesystem_queue.go b/internal/queue/filesystem_queue.go new file mode 100644 index 0000000..6d2cfd0 --- /dev/null +++ b/internal/queue/filesystem_queue.go @@ -0,0 +1,572 @@ +package queue + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +type FilesystemQueue struct { + root string + ctx context.Context + cancel context.CancelFunc +} + +type filesystemQueueIndex struct { + Version int `json:"version"` + UpdatedAt string `json:"updated_at"` + Tasks []filesystemQueueIndexTask `json:"tasks"` +} + +type filesystemQueueIndexTask struct { + ID string `json:"id"` + Priority int64 `json:"priority"` + CreatedAt string `json:"created_at"` +} + +func NewFilesystemQueue(root string) (*FilesystemQueue, error) { + root = strings.TrimSpace(root) + if root == "" { + return nil, fmt.Errorf("filesystem queue root is required") + } + root = filepath.Clean(root) + if err := os.MkdirAll(filepath.Join(root, "pending", "entries"), 0750); err != nil { + return nil, err + } + for _, d := range []string{"running", "finished", "failed"} { + if err := os.MkdirAll(filepath.Join(root, d), 0750); err != nil { + return nil, err + } + } + + ctx, cancel := context.WithCancel(context.Background()) + q := &FilesystemQueue{root: root, ctx: ctx, cancel: cancel} + _ = q.rebuildIndex() + return q, nil +} + +func (q *FilesystemQueue) Close() error { + q.cancel() + return nil +} + +func (q *FilesystemQueue) AddTask(task *Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if strings.TrimSpace(task.ID) == "" { + return fmt.Errorf("task id is required") + } + if strings.TrimSpace(task.JobName) == "" { + return fmt.Errorf("job name is required") + } + if task.MaxRetries == 0 { + task.MaxRetries = defaultMaxRetries + } + if task.CreatedAt.IsZero() { + task.CreatedAt = time.Now().UTC() + } + if strings.TrimSpace(task.Status) == "" { + task.Status = "queued" + } + if task.Status != "queued" { + // For filesystem backend we only enqueue queued tasks. + // Other status updates should go through UpdateTask. + task.Status = "queued" + } + + payload, err := json.Marshal(task) + if err != nil { + return err + } + + path := q.pendingEntryPath(task.ID) + if err := writeFileAtomic(path, payload, 0640); err != nil { + return err + } + TasksQueued.Inc() + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + _ = q.rebuildIndex() + return nil +} + +func (q *FilesystemQueue) GetNextTask() (*Task, error) { + return q.claimNext("", 0, false) +} + +func (q *FilesystemQueue) PeekNextTask() (*Task, error) { + return q.claimNext("", 0, true) +} + +func (q *FilesystemQueue) GetNextTaskWithLease(workerID string, leaseDuration time.Duration) (*Task, error) { + return q.claimNext(workerID, leaseDuration, false) +} + +func (q *FilesystemQueue) GetNextTaskWithLeaseBlocking( + workerID string, + leaseDuration, blockTimeout time.Duration, +) (*Task, error) { + if blockTimeout <= 0 { + blockTimeout = defaultBlockTimeout + } + deadline := time.Now().Add(blockTimeout) + for { + t, err := q.claimNext(workerID, leaseDuration, false) + if err != nil { + return nil, err + } + if t != nil { + return t, nil + } + if time.Now().After(deadline) { + return nil, nil + } + select { + case <-q.ctx.Done(): + return nil, q.ctx.Err() + case <-time.After(50 * time.Millisecond): + } + } +} + +func (q *FilesystemQueue) RenewLease(taskID string, workerID string, leaseDuration time.Duration) error { + // Single-worker friendly best-effort: update task lease fields if present. + t, err := q.GetTask(taskID) + if err != nil { + return err + } + if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", t.LeasedBy) + } + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + exp := time.Now().UTC().Add(leaseDuration) + t.LeaseExpiry = &exp + if workerID != "" { + t.LeasedBy = workerID + } + RecordLeaseRenewal(workerID) + return q.UpdateTask(t) +} + +func (q *FilesystemQueue) ReleaseLease(taskID string, workerID string) error { + t, err := q.GetTask(taskID) + if err != nil { + return err + } + if t.LeasedBy != "" && workerID != "" && t.LeasedBy != workerID { + return fmt.Errorf("task leased by different worker: %s", t.LeasedBy) + } + t.LeaseExpiry = nil + t.LeasedBy = "" + return q.UpdateTask(t) +} + +func (q *FilesystemQueue) RetryTask(task *Task) error { + if task.RetryCount >= task.MaxRetries { + RecordDLQAddition("max_retries") + return q.MoveToDeadLetterQueue(task, "max retries exceeded") + } + + errorCategory := ErrorUnknown + if task.Error != "" { + errorCategory = ClassifyError(fmt.Errorf("%s", task.Error)) + } + if !IsRetryable(errorCategory) { + RecordDLQAddition(string(errorCategory)) + return q.MoveToDeadLetterQueue(task, fmt.Sprintf("non-retryable error: %s", errorCategory)) + } + + task.RetryCount++ + task.Status = "queued" + task.LastError = task.Error + task.Error = "" + + backoffSeconds := RetryDelay(errorCategory, task.RetryCount) + nextRetry := time.Now().UTC().Add(time.Duration(backoffSeconds) * time.Second) + task.NextRetry = &nextRetry + task.LeaseExpiry = nil + task.LeasedBy = "" + + RecordTaskRetry(task.JobName, errorCategory) + return q.AddTask(task) +} + +func (q *FilesystemQueue) MoveToDeadLetterQueue(task *Task, reason string) error { + if task == nil { + return fmt.Errorf("task is nil") + } + task.Status = "failed" + task.Error = fmt.Sprintf("DLQ: %s. Last error: %s", reason, task.LastError) + RecordTaskFailure(task.JobName, ClassifyError(fmt.Errorf("%s", task.LastError))) + return q.UpdateTask(task) +} + +func (q *FilesystemQueue) GetTask(taskID string) (*Task, error) { + path, err := q.findTaskPath(taskID) + if err != nil { + return nil, err + } + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var t Task + if err := json.Unmarshal(data, &t); err != nil { + return nil, err + } + return &t, nil +} + +func (q *FilesystemQueue) GetAllTasks() ([]*Task, error) { + paths := make([]string, 0, 32) + for _, p := range []string{ + filepath.Join(q.root, "pending", "entries"), + filepath.Join(q.root, "running"), + filepath.Join(q.root, "finished"), + filepath.Join(q.root, "failed"), + } { + entries, err := os.ReadDir(p) + if err != nil { + continue + } + for _, e := range entries { + if e.IsDir() { + continue + } + if !strings.HasSuffix(e.Name(), ".json") { + continue + } + paths = append(paths, filepath.Join(p, e.Name())) + } + } + + out := make([]*Task, 0, len(paths)) + for _, path := range paths { + data, err := os.ReadFile(path) + if err != nil { + continue + } + var t Task + if err := json.Unmarshal(data, &t); err != nil { + continue + } + out = append(out, &t) + } + return out, nil +} + +func (q *FilesystemQueue) GetTaskByName(jobName string) (*Task, error) { + jobName = strings.TrimSpace(jobName) + if jobName == "" { + return nil, fmt.Errorf("job name is required") + } + tasks, err := q.GetAllTasks() + if err != nil { + return nil, err + } + var best *Task + for _, t := range tasks { + if t == nil || t.JobName != jobName { + continue + } + if best == nil || t.CreatedAt.After(best.CreatedAt) { + best = t + } + } + if best == nil { + return nil, os.ErrNotExist + } + return best, nil +} + +func (q *FilesystemQueue) CancelTask(taskID string) error { + t, err := q.GetTask(taskID) + if err != nil { + return err + } + t.Status = "cancelled" + now := time.Now().UTC() + t.EndedAt = &now + return q.UpdateTask(t) +} + +func (q *FilesystemQueue) UpdateTask(task *Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if strings.TrimSpace(task.ID) == "" { + return fmt.Errorf("task id is required") + } + if strings.TrimSpace(task.Status) == "" { + return fmt.Errorf("task status is required") + } + + payload, err := json.Marshal(task) + if err != nil { + return err + } + + dst := q.pathForStatus(task.Status, task.ID) + if dst == "" { + // For statuses we don't map yet, keep it in running. + dst = filepath.Join(q.root, "running", task.ID+".json") + } + if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil { + return err + } + + // Best-effort: remove any other copies before writing. + _ = q.removeTaskFromAllDirs(task.ID) + if err := writeFileAtomic(dst, payload, 0640); err != nil { + return err + } + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + _ = q.rebuildIndex() + return nil +} + +func (q *FilesystemQueue) UpdateTaskWithMetrics(task *Task, _ string) error { + return q.UpdateTask(task) +} + +func (q *FilesystemQueue) RecordMetric(_, _ string, _ float64) error { + return nil +} + +func (q *FilesystemQueue) Heartbeat(_ string) error { + return nil +} + +func (q *FilesystemQueue) QueueDepth() (int64, error) { + entries, err := os.ReadDir(filepath.Join(q.root, "pending", "entries")) + if err != nil { + return 0, err + } + var n int64 + for _, e := range entries { + if e.IsDir() { + continue + } + if strings.HasSuffix(e.Name(), ".json") { + n++ + } + } + return n, nil +} + +func (q *FilesystemQueue) SetWorkerPrewarmState(_ PrewarmState) error { return nil } +func (q *FilesystemQueue) ClearWorkerPrewarmState(_ string) error { return nil } +func (q *FilesystemQueue) GetWorkerPrewarmState(_ string) (*PrewarmState, error) { + return nil, nil +} +func (q *FilesystemQueue) GetAllWorkerPrewarmStates() ([]PrewarmState, error) { + return nil, nil +} +func (q *FilesystemQueue) SignalPrewarmGC() error { return nil } +func (q *FilesystemQueue) PrewarmGCRequestValue() (string, error) { + return "", nil +} + +func (q *FilesystemQueue) claimNext(workerID string, leaseDuration time.Duration, peek bool) (*Task, error) { + pendingDir := filepath.Join(q.root, "pending", "entries") + entries, err := os.ReadDir(pendingDir) + if err != nil { + return nil, err + } + + candidates := make([]*Task, 0, len(entries)) + paths := make(map[string]string, len(entries)) + for _, e := range entries { + if e.IsDir() { + continue + } + if !strings.HasSuffix(e.Name(), ".json") { + continue + } + path := filepath.Join(pendingDir, e.Name()) + data, err := os.ReadFile(path) + if err != nil { + continue + } + var t Task + if err := json.Unmarshal(data, &t); err != nil { + continue + } + if t.NextRetry != nil && time.Now().UTC().Before(t.NextRetry.UTC()) { + continue + } + candidates = append(candidates, &t) + paths[t.ID] = path + } + + if len(candidates) == 0 { + return nil, nil + } + + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].Priority != candidates[j].Priority { + return candidates[i].Priority > candidates[j].Priority + } + return candidates[i].CreatedAt.Before(candidates[j].CreatedAt) + }) + + chosen := candidates[0] + if peek { + return chosen, nil + } + + src := paths[chosen.ID] + if src == "" { + return nil, nil + } + dst := filepath.Join(q.root, "running", chosen.ID+".json") + if err := os.Rename(src, dst); err != nil { + // Another process might have claimed it. + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, err + } + + // Refresh from the moved file to avoid race on content. + data, err := os.ReadFile(dst) + if err != nil { + return nil, err + } + var t Task + if err := json.Unmarshal(data, &t); err != nil { + return nil, err + } + + now := time.Now().UTC() + if leaseDuration == 0 { + leaseDuration = defaultLeaseDuration + } + exp := now.Add(leaseDuration) + t.LeaseExpiry = &exp + if strings.TrimSpace(workerID) != "" { + t.LeasedBy = workerID + } + // Note: status transitions are handled by worker UpdateTask calls. + + payload, err := json.Marshal(&t) + if err == nil { + _ = writeFileAtomic(dst, payload, 0640) + } + if depth, derr := q.QueueDepth(); derr == nil { + UpdateQueueDepth(depth) + } + _ = q.rebuildIndex() + return &t, nil +} + +func (q *FilesystemQueue) pendingEntryPath(taskID string) string { + return filepath.Join(q.root, "pending", "entries", taskID+".json") +} + +func (q *FilesystemQueue) pathForStatus(status, taskID string) string { + switch status { + case "queued": + return q.pendingEntryPath(taskID) + case "running": + return filepath.Join(q.root, "running", taskID+".json") + case "completed", "finished": + return filepath.Join(q.root, "finished", taskID+".json") + case "failed", "cancelled": + return filepath.Join(q.root, "failed", taskID+".json") + default: + return "" + } +} + +func (q *FilesystemQueue) findTaskPath(taskID string) (string, error) { + paths := []string{ + q.pendingEntryPath(taskID), + filepath.Join(q.root, "running", taskID+".json"), + filepath.Join(q.root, "finished", taskID+".json"), + filepath.Join(q.root, "failed", taskID+".json"), + } + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p, nil + } + } + return "", os.ErrNotExist +} + +func (q *FilesystemQueue) removeTaskFromAllDirs(taskID string) error { + paths := []string{ + q.pendingEntryPath(taskID), + filepath.Join(q.root, "running", taskID+".json"), + filepath.Join(q.root, "finished", taskID+".json"), + filepath.Join(q.root, "failed", taskID+".json"), + } + var outErr error + for _, p := range paths { + if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) { + outErr = err + } + } + return outErr +} + +func (q *FilesystemQueue) rebuildIndex() error { + pendingDir := filepath.Join(q.root, "pending", "entries") + entries, err := os.ReadDir(pendingDir) + if err != nil { + return err + } + idx := filesystemQueueIndex{Version: 1, UpdatedAt: time.Now().UTC().Format(time.RFC3339)} + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { + continue + } + data, err := os.ReadFile(filepath.Join(pendingDir, e.Name())) + if err != nil { + continue + } + var t Task + if err := json.Unmarshal(data, &t); err != nil { + continue + } + idx.Tasks = append(idx.Tasks, filesystemQueueIndexTask{ID: t.ID, Priority: t.Priority, CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339Nano)}) + } + + sort.Slice(idx.Tasks, func(i, j int) bool { + if idx.Tasks[i].Priority != idx.Tasks[j].Priority { + return idx.Tasks[i].Priority > idx.Tasks[j].Priority + } + return idx.Tasks[i].CreatedAt < idx.Tasks[j].CreatedAt + }) + + payload, err := json.MarshalIndent(&idx, "", " ") + if err != nil { + return err + } + path := filepath.Join(q.root, "pending", ".queue.json") + return writeFileAtomic(path, payload, 0640) +} + +func writeFileAtomic(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0750); err != nil { + return err + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, perm); err != nil { + return err + } + return os.Rename(tmp, path) +} diff --git a/internal/worker/config.go b/internal/worker/config.go index cb3d7ed..6e1fda7 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -23,8 +23,10 @@ const ( ) type QueueConfig struct { - Backend string `yaml:"backend"` - SQLitePath string `yaml:"sqlite_path"` + Backend string `yaml:"backend"` + SQLitePath string `yaml:"sqlite_path"` + FilesystemPath string `yaml:"filesystem_path"` + FallbackToFilesystem bool `yaml:"fallback_to_filesystem"` } // Config holds worker configuration. @@ -203,6 +205,12 @@ func LoadConfig(path string) (*Config, error) { } cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath) } + if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem { + if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" { + cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs") + } + cfg.Queue.FilesystemPath = config.ExpandPath(cfg.Queue.FilesystemPath) + } if strings.TrimSpace(cfg.GPUVendor) == "" { if cfg.AppleGPU.Enabled { @@ -254,8 +262,8 @@ func (c *Config) Validate() error { backend = string(queue.QueueBackendRedis) c.Queue.Backend = backend } - if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) { - return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite) + if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) { + return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS) } if backend == string(queue.QueueBackendSQLite) { @@ -267,6 +275,15 @@ func (c *Config) Validate() error { c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) } } + if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem { + if strings.TrimSpace(c.Queue.FilesystemPath) == "" { + return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled") + } + c.Queue.FilesystemPath = config.ExpandPath(c.Queue.FilesystemPath) + if !filepath.IsAbs(c.Queue.FilesystemPath) { + c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath) + } + } if c.RedisAddr != "" { addr := strings.TrimSpace(c.RedisAddr) diff --git a/internal/worker/core.go b/internal/worker/core.go index 1c467bb..03c804c 100644 --- a/internal/worker/core.go +++ b/internal/worker/core.go @@ -45,6 +45,7 @@ type JupyterManager interface { RemoveService(ctx context.Context, serviceID string, purge bool) error RestoreWorkspace(ctx context.Context, name string) (string, error) ListServices() []*jupyter.JupyterService + ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) } // isValidName validates that input strings contain only safe characters. @@ -382,6 +383,8 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, SQLitePath: cfg.Queue.SQLitePath, + FilesystemPath: cfg.Queue.FilesystemPath, + FallbackToFilesystem: cfg.Queue.FallbackToFilesystem, MetricsFlushInterval: cfg.MetricsFlushInterval, } queueClient, err := queue.NewBackend(backendCfg) diff --git a/internal/worker/execution.go b/internal/worker/execution.go index 12ce4bd..d8c7f36 100644 --- a/internal/worker/execution.go +++ b/internal/worker/execution.go @@ -253,6 +253,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice if err := w.stageExperimentFiles(task, jobDir); err != nil { w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) { + if a, aerr := scanArtifacts(jobDir); aerr == nil { + m.Artifacts = a + } else { + w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr) + } now := time.Now().UTC() exitCode := 1 m.MarkFinished(now, &exitCode, err) @@ -271,6 +276,11 @@ func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevice } if err := w.stageSnapshot(ctx, task, jobDir); err != nil { w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) { + if a, aerr := scanArtifacts(jobDir); aerr == nil { + m.Artifacts = a + } else { + w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr) + } now := time.Now().UTC() exitCode := 1 m.MarkFinished(now, &exitCode, err) @@ -586,6 +596,11 @@ func (w *Worker) executeJob( w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { now := time.Now().UTC() m.ExecutionDurationMS = execDuration.Milliseconds() + if a, aerr := scanArtifacts(outputDir); aerr == nil { + m.Artifacts = a + } else { + w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr) + } if err != nil { exitCode := 1 m.MarkFinished(now, &exitCode, err) @@ -832,6 +847,35 @@ func (w *Worker) executeContainerJob( if trackingEnv == nil { trackingEnv = make(map[string]string) } + cacheRoot := filepath.Join(w.config.BasePath, ".cache") + if err := os.MkdirAll(cacheRoot, 0755); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "cache_setup", + Err: err, + } + } + if volumes == nil { + volumes = make(map[string]string) + } + volumes[cacheRoot] = "/workspace/.cache:rw" + defaultEnv := map[string]string{ + "HF_HOME": "/workspace/.cache/huggingface", + "TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub", + "HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets", + "TORCH_HOME": "/workspace/.cache/torch", + "TORCH_HUB_DIR": "/workspace/.cache/torch/hub", + "KERAS_HOME": "/workspace/.cache/keras", + "CUDA_CACHE_PATH": "/workspace/.cache/cuda", + "PIP_CACHE_DIR": "/workspace/.cache/pip", + } + for k, v := range defaultEnv { + if _, ok := trackingEnv[k]; ok { + continue + } + trackingEnv[k] = v + } if strings.TrimSpace(visibleEnvVar) != "" { trackingEnv[visibleEnvVar] = strings.TrimSpace(visibleDevices) } @@ -841,9 +885,6 @@ func (w *Worker) executeContainerJob( if strings.TrimSpace(task.SnapshotID) != "" { trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID) } - if volumes == nil { - volumes = make(map[string]string) - } volumes[snap] = "/snapshot:ro" } @@ -932,6 +973,11 @@ func (w *Worker) executeContainerJob( now := time.Now().UTC() exitCode := 1 m.ExecutionDurationMS = containerDuration.Milliseconds() + if a, aerr := scanArtifacts(outputDir); aerr == nil { + m.Artifacts = a + } else { + w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr) + } m.MarkFinished(now, &exitCode, err) }) // Move job to failed directory @@ -981,6 +1027,11 @@ func (w *Worker) executeContainerJob( now := time.Now().UTC() exitCode := 0 m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds() + if a, aerr := scanArtifacts(outputDir); aerr == nil { + m.Artifacts = a + } else { + w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr) + } m.MarkFinished(now, &exitCode, nil) }) if _, moveErr := telemetry.ExecWithMetrics( diff --git a/internal/worker/jupyter_task.go b/internal/worker/jupyter_task.go index dd9a0dc..4d19a25 100644 --- a/internal/worker/jupyter_task.go +++ b/internal/worker/jupyter_task.go @@ -21,6 +21,7 @@ const ( jupyterActionRemove = "remove" jupyterActionRestore = "restore" jupyterActionList = "list" + jupyterActionListPkgs = "list_packages" jupyterNameKey = "jupyter_name" jupyterWorkspaceKey = "jupyter_workspace" jupyterServiceIDKey = "jupyter_service_id" @@ -28,10 +29,11 @@ const ( ) type jupyterTaskOutput struct { - Type string `json:"type"` - Service *jupyter.JupyterService `json:"service,omitempty"` - Services []*jupyter.JupyterService `json:"services"` - RestorePath string `json:"restore_path,omitempty"` + Type string `json:"type"` + Service *jupyter.JupyterService `json:"service,omitempty"` + Services []*jupyter.JupyterService `json:"services"` + Packages []jupyter.InstalledPackage `json:"packages,omitempty"` + RestorePath string `json:"restore_path,omitempty"` } func isJupyterTask(task *queue.Task) bool { @@ -109,6 +111,17 @@ func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, services := w.jupyter.ListServices() out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services} return json.Marshal(out) + case jupyterActionListPkgs: + name := strings.TrimSpace(task.Metadata[jupyterNameKey]) + if name == "" { + return nil, fmt.Errorf("missing jupyter name") + } + pkgs, err := w.jupyter.ListInstalledPackages(ctx, name) + if err != nil { + return nil, err + } + out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs} + return json.Marshal(out) case jupyterActionRestore: name := strings.TrimSpace(task.Metadata[jupyterNameKey]) if name == "" {