// Package worker provides the ML task worker implementation package worker import ( "context" "encoding/json" "fmt" "net/http" "os" "path/filepath" "time" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/worker/execution" "github.com/jfraeys/fetch_ml/internal/worker/executor" "github.com/jfraeys/fetch_ml/internal/worker/integrity" "github.com/jfraeys/fetch_ml/internal/worker/interfaces" "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" ) // JupyterManager interface for jupyter service management type JupyterManager interface { StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) StopService(ctx context.Context, serviceID string) error RemoveService(ctx context.Context, serviceID string, purge bool) error RestoreWorkspace(ctx context.Context, name string) (string, error) ListServices() []*jupyter.JupyterService ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) } // MLServer is an alias for network.MLServer for backward compatibility. type MLServer = network.MLServer // NewMLServer creates a new ML server connection. func NewMLServer(cfg *Config) (*MLServer, error) { return network.NewMLServer("", "", "", 0, "") } // Worker represents an ML task worker with composed dependencies. type Worker struct { ID string Config *Config Logger *logging.Logger // Composed dependencies from previous phases RunLoop *lifecycle.RunLoop Runner *executor.JobRunner Metrics *metrics.Metrics metricsSrv *http.Server Health *lifecycle.HealthMonitor Resources *resources.Manager // GPU detection metadata for status output gpuDetectionInfo GPUDetectionInfo // Legacy fields for backward compatibility during migration Jupyter JupyterManager QueueClient queue.Backend // Stored for prewarming access } // Start begins the worker's main processing loop. func (w *Worker) Start() { w.Logger.Info("worker starting", "worker_id", w.ID, "max_concurrent", w.Config.MaxWorkers) w.Health.RecordHeartbeat() w.RunLoop.Start() } // Stop gracefully shuts down the worker immediately. func (w *Worker) Stop() { w.Logger.Info("worker stopping", "worker_id", w.ID) w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.Logger.Warn("metrics server shutdown error", "error", err) } } w.Logger.Info("worker stopped", "worker_id", w.ID) } // Shutdown performs a graceful shutdown with timeout. func (w *Worker) Shutdown() error { w.Logger.Info("starting graceful shutdown", "worker_id", w.ID) w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.Logger.Warn("metrics server shutdown error", "error", err) } } w.Logger.Info("worker shut down gracefully", "worker_id", w.ID) return nil } // IsHealthy returns true if the worker is healthy. func (w *Worker) IsHealthy() bool { return w.Health.IsHealthy(5 * time.Minute) } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { stats := w.Metrics.GetStats() stats["worker_id"] = w.ID stats["max_workers"] = w.Config.MaxWorkers stats["healthy"] = w.IsHealthy() return stats } // GetID returns the worker ID. func (w *Worker) GetID() string { return w.ID } // SelectDependencyManifest re-exports the executor function for API helpers. // It detects the dependency manifest file in the given directory. func SelectDependencyManifest(filesPath string) (string, error) { return executor.SelectDependencyManifest(filesPath) } // DirOverallSHA256Hex re-exports the integrity function for test compatibility. func DirOverallSHA256Hex(root string) (string, error) { return integrity.DirOverallSHA256Hex(root) } // NormalizeSHA256ChecksumHex re-exports the integrity function for test compatibility. func NormalizeSHA256ChecksumHex(checksum string) (string, error) { return integrity.NormalizeSHA256ChecksumHex(checksum) } // StageSnapshot re-exports the execution function for test compatibility. func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error { return execution.StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir) } // StageSnapshotFromPath re-exports the execution function for test compatibility. func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error { return execution.StageSnapshotFromPath(basePath, taskID, srcPath, jobDir) } // ComputeTaskProvenance computes provenance information for a task. // This re-exports the integrity function for test compatibility. func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { pc := integrity.NewProvenanceCalculator(basePath) return pc.ComputeProvenance(task) } // VerifyDatasetSpecs verifies dataset specifications for this task. // This is a test compatibility method that wraps the integrity package. func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { dataDir := w.Config.DataDir if dataDir == "" { dataDir = "/tmp/data" } verifier := integrity.NewDatasetVerifier(dataDir) return verifier.VerifyDatasetSpecs(task) } // EnforceTaskProvenance enforces provenance requirements for a task. // It validates and/or populates provenance metadata based on the ProvenanceBestEffort config. // In strict mode (ProvenanceBestEffort=false), it returns an error if metadata doesn't match computed values. // In best-effort mode (ProvenanceBestEffort=true), it populates missing metadata fields. func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } bestEffort := w.Config.ProvenanceBestEffort // Get commit_id from metadata commitID := task.Metadata["commit_id"] if commitID == "" { return fmt.Errorf("missing commit_id in task metadata") } // Compute and verify experiment manifest SHA expPath := filepath.Join(basePath, commitID) manifestSHA, err := integrity.DirOverallSHA256Hex(expPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute experiment manifest SHA: %w", err) } // In best-effort mode, we'll use whatever is provided or skip manifestSHA = "" } // Handle experiment_manifest_overall_sha expectedManifestSHA := task.Metadata["experiment_manifest_overall_sha"] if expectedManifestSHA == "" { if !bestEffort { return fmt.Errorf("missing experiment_manifest_overall_sha in task metadata") } // Populate in best-effort mode if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["experiment_manifest_overall_sha"] = manifestSHA } else if !bestEffort && expectedManifestSHA != manifestSHA { return fmt.Errorf("experiment manifest SHA mismatch: expected %s, got %s", expectedManifestSHA, manifestSHA) } // Handle deps_manifest_sha256 - auto-detect if not provided filesPath := filepath.Join(expPath, "files") depsManifestName := task.Metadata["deps_manifest_name"] if depsManifestName == "" { // Auto-detect manifest file depsManifestName, _ = executor.SelectDependencyManifest(filesPath) } if depsManifestName != "" { if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["deps_manifest_name"] = depsManifestName depsPath := filepath.Join(filesPath, depsManifestName) depsSHA, err := integrity.FileSHA256Hex(depsPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute deps manifest SHA: %w", err) } depsSHA = "" } expectedDepsSHA := task.Metadata["deps_manifest_sha256"] if expectedDepsSHA == "" { if !bestEffort { return fmt.Errorf("missing deps_manifest_sha256 in task metadata") } task.Metadata["deps_manifest_sha256"] = depsSHA } else if !bestEffort && expectedDepsSHA != depsSHA { return fmt.Errorf("deps manifest SHA mismatch: expected %s, got %s", expectedDepsSHA, depsSHA) } } // Handle snapshot_sha256 if SnapshotID is set if task.SnapshotID != "" { snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID) snapSHA, err := integrity.DirOverallSHA256Hex(snapPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute snapshot SHA: %w", err) } snapSHA = "" } expectedSnapSHA, _ := integrity.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) if expectedSnapSHA == "" { if !bestEffort { return fmt.Errorf("missing snapshot_sha256 in task metadata") } if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["snapshot_sha256"] = snapSHA } else if !bestEffort && expectedSnapSHA != snapSHA { return fmt.Errorf("snapshot SHA mismatch: expected %s, got %s", expectedSnapSHA, snapSHA) } } return nil } // VerifySnapshot verifies snapshot integrity for this task. // It computes the SHA256 of the snapshot directory and compares with task metadata. func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { if task.SnapshotID == "" { return nil // No snapshot to verify } dataDir := w.Config.DataDir if dataDir == "" { dataDir = "/tmp/data" } // Get expected checksum from metadata expectedChecksum, ok := task.Metadata["snapshot_sha256"] if !ok || expectedChecksum == "" { return fmt.Errorf("missing snapshot_sha256 in task metadata") } // Normalize the checksum (remove sha256: prefix if present) expectedChecksum, err := integrity.NormalizeSHA256ChecksumHex(expectedChecksum) if err != nil { return fmt.Errorf("invalid snapshot_sha256 format: %w", err) } // Compute actual checksum of snapshot directory snapshotDir := filepath.Join(dataDir, "snapshots", task.SnapshotID) actualChecksum, err := integrity.DirOverallSHA256Hex(snapshotDir) if err != nil { return fmt.Errorf("failed to compute snapshot hash: %w", err) } // Compare checksums if actualChecksum != expectedChecksum { return fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum) } return nil } // RunJupyterTask runs a Jupyter-related task. // It handles start, stop, remove, restore, and list_packages actions. func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { if w.Jupyter == nil { return nil, fmt.Errorf("jupyter manager not configured") } action := task.Metadata["jupyter_action"] if action == "" { action = "start" // Default action } switch action { case "start": name := task.Metadata["jupyter_name"] if name == "" { name = task.Metadata["jupyter_workspace"] } if name == "" { // Extract from jobName if format is "jupyter-" if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" { name = task.JobName[8:] } } if name == "" { return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") } req := &jupyter.StartRequest{Name: name} service, err := w.Jupyter.StartService(ctx, req) if err != nil { return nil, fmt.Errorf("failed to start jupyter service: %w", err) } output := map[string]interface{}{ "type": "start", "service": service, } return json.Marshal(output) case "stop": serviceID := task.Metadata["jupyter_service_id"] if serviceID == "" { return nil, fmt.Errorf("missing jupyter_service_id in task metadata") } if err := w.Jupyter.StopService(ctx, serviceID); err != nil { return nil, fmt.Errorf("failed to stop jupyter service: %w", err) } return json.Marshal(map[string]string{"type": "stop", "status": "stopped"}) case "remove": serviceID := task.Metadata["jupyter_service_id"] if serviceID == "" { return nil, fmt.Errorf("missing jupyter_service_id in task metadata") } purge := task.Metadata["jupyter_purge"] == "true" if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil { return nil, fmt.Errorf("failed to remove jupyter service: %w", err) } return json.Marshal(map[string]string{"type": "remove", "status": "removed"}) case "restore": name := task.Metadata["jupyter_name"] if name == "" { name = task.Metadata["jupyter_workspace"] } if name == "" { return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") } serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name) if err != nil { return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err) } return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID}) case "list_packages": serviceName := task.Metadata["jupyter_name"] if serviceName == "" { // Extract from jobName if format is "jupyter-packages-" if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" { serviceName = task.JobName[16:] } } if serviceName == "" { return nil, fmt.Errorf("missing jupyter_name in task metadata") } packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName) if err != nil { return nil, fmt.Errorf("failed to list installed packages: %w", err) } output := map[string]interface{}{ "type": "list_packages", "packages": packages, } return json.Marshal(output) default: return nil, fmt.Errorf("unknown jupyter action: %s", action) } } // PrewarmNextOnce prewarms the next task in queue. // It fetches the next task, verifies its snapshot, and stages it to the prewarm directory. // Returns true if prewarming was performed, false if disabled or queue empty. func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { // Check if prewarming is enabled if !w.Config.PrewarmEnabled { return false, nil } // Get base path and data directory basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } // Create prewarm directory prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots") if err := os.MkdirAll(prewarmDir, 0750); err != nil { return false, fmt.Errorf("failed to create prewarm directory: %w", err) } // Try to get next task from queue client if available (peek, don't lease) if w.QueueClient != nil { task, err := w.QueueClient.PeekNextTask() if err != nil { // Queue empty - check if we have existing prewarm state // Return false but preserve any existing state (don't delete) state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID) if state != nil { // We have existing state, return true to indicate prewarm is active return true, nil } return false, nil } if task != nil && task.SnapshotID != "" { // Resolve snapshot path using SHA from metadata if available snapshotSHA := task.Metadata["snapshot_sha256"] if snapshotSHA != "" { snapshotSHA, _ = integrity.NormalizeSHA256ChecksumHex(snapshotSHA) } var srcDir string if snapshotSHA != "" { // Check if snapshot exists in SHA cache directory shaDir := filepath.Join(dataDir, "snapshots", "sha256", snapshotSHA) if info, err := os.Stat(shaDir); err == nil && info.IsDir() { srcDir = shaDir } } // Fall back to direct snapshot path if SHA directory doesn't exist if srcDir == "" { srcDir = filepath.Join(dataDir, "snapshots", task.SnapshotID) } dstDir := filepath.Join(prewarmDir, task.ID) if err := execution.CopyDir(srcDir, dstDir); err != nil { return false, fmt.Errorf("failed to stage snapshot: %w", err) } // Store prewarm state in queue backend if w.QueueClient != nil { now := time.Now().UTC().Format(time.RFC3339) state := queue.PrewarmState{ WorkerID: w.ID, TaskID: task.ID, SnapshotID: task.SnapshotID, StartedAt: now, UpdatedAt: now, Phase: "staged", } _ = w.QueueClient.SetWorkerPrewarmState(state) } return true, nil } } // If we have a runLoop but no queue client, use runLoop (for backward compatibility) if w.RunLoop != nil { return true, nil } return false, nil } // RunJob runs a job task. // It uses the JobRunner to execute the job and write the run manifest. func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error { if w.Runner == nil { return fmt.Errorf("job runner not configured") } basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } // Determine execution mode mode := executor.ModeAuto if w.Config.LocalMode { mode = executor.ModeLocal } // Create minimal GPU environment (empty for now) gpuEnv := interfaces.ExecutionEnv{} // Run the job return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv) }