// Package worker provides the ML task worker implementation package worker import ( "context" "encoding/json" "fmt" "log/slog" "math/rand" "net/http" "os" "path/filepath" "time" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/scheduler" "github.com/jfraeys/fetch_ml/internal/worker/execution" "github.com/jfraeys/fetch_ml/internal/worker/executor" "github.com/jfraeys/fetch_ml/internal/worker/integrity" "github.com/jfraeys/fetch_ml/internal/worker/interfaces" "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" "github.com/jfraeys/fetch_ml/internal/worker/plugins" ) // Worker represents an ML task worker with composed dependencies. type Worker struct { Jupyter plugins.JupyterManager QueueClient queue.Backend Config *Config Logger *logging.Logger RunLoop *lifecycle.RunLoop Runner *executor.JobRunner Metrics *metrics.Metrics metricsSrv *http.Server Health *lifecycle.HealthMonitor Resources *resources.Manager ID string gpuDetectionInfo GPUDetectionInfo schedulerConn *scheduler.SchedulerConn // For distributed mode ctx context.Context cancel context.CancelFunc } // Start begins the worker's main processing loop. func (w *Worker) Start() { w.Logger.Info("worker starting", "worker_id", w.ID, "max_concurrent", w.Config.MaxWorkers, "mode", w.Config.Mode, ) slog.SetDefault(w.Logger.Logger) w.ctx, w.cancel = context.WithCancel(context.Background()) w.Health.RecordHeartbeat() // Start heartbeat loop for distributed mode if w.Config.Mode == "distributed" && w.schedulerConn != nil { go w.heartbeatLoop() } w.RunLoop.Start() } // heartbeatLoop sends periodic heartbeats with slot status to scheduler func (w *Worker) heartbeatLoop() { // Use configured interval or default to 10s intervalSecs := w.Config.Scheduler.HeartbeatIntervalSecs if intervalSecs == 0 { intervalSecs = 10 } // Add jitter (0-5s) to prevent thundering herd jitter := time.Duration(rand.Intn(5)) * time.Second interval := time.Duration(intervalSecs)*time.Second + jitter ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-w.ctx.Done(): return case <-ticker.C: w.Health.RecordHeartbeat() if w.schedulerConn != nil { slots := scheduler.SlotStatus{ BatchTotal: w.Config.MaxWorkers, BatchInUse: w.RunLoop.RunningCount(), } w.schedulerConn.Send(scheduler.Message{ Type: scheduler.MsgHeartbeat, Payload: mustMarshal(scheduler.HeartbeatPayload{ WorkerID: w.ID, Slots: slots, }), }) } } } } // Stop gracefully shuts down the worker immediately. func (w *Worker) Stop() { w.Logger.Info("worker stopping", "worker_id", w.ID) if w.cancel != nil { w.cancel() } w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.Logger.Warn("metrics server shutdown error", "error", err) } } w.Logger.Info("worker stopped", "worker_id", w.ID) } // Shutdown performs a graceful shutdown with timeout. func (w *Worker) Shutdown() error { w.Logger.Info("starting graceful shutdown", "worker_id", w.ID) w.RunLoop.Stop() if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.Logger.Warn("metrics server shutdown error", "error", err) } } w.Logger.Info("worker shut down gracefully", "worker_id", w.ID) return nil } // IsHealthy returns true if the worker is healthy. func (w *Worker) IsHealthy() bool { return w.Health.IsHealthy(5 * time.Minute) } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { stats := w.Metrics.GetStats() stats["worker_id"] = w.ID stats["max_workers"] = w.Config.MaxWorkers stats["healthy"] = w.IsHealthy() return stats } // GetID returns the worker ID. func (w *Worker) GetID() string { return w.ID } // SelectDependencyManifest re-exports the executor function for API helpers. // It detects the dependency manifest file in the given directory. func SelectDependencyManifest(filesPath string) (string, error) { return executor.SelectDependencyManifest(filesPath) } // DirOverallSHA256Hex re-exports the integrity function for test compatibility. func DirOverallSHA256Hex(root string) (string, error) { return integrity.DirOverallSHA256Hex(root) } // NormalizeSHA256ChecksumHex re-exports the integrity function for test compatibility. func NormalizeSHA256ChecksumHex(checksum string) (string, error) { return integrity.NormalizeSHA256ChecksumHex(checksum) } // StageSnapshot re-exports the execution function for test compatibility. func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error { return execution.StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir) } // StageSnapshotFromPath re-exports the execution function for test compatibility. func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error { return execution.StageSnapshotFromPath(basePath, taskID, srcPath, jobDir) } // ComputeTaskProvenance computes provenance information for a task. // This re-exports the integrity function for test compatibility. func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { pc := integrity.NewProvenanceCalculator(basePath) return pc.ComputeProvenance(task) } // VerifyDatasetSpecs verifies dataset specifications for this task. // This is a test compatibility method that wraps the integrity package. func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { dataDir := w.Config.DataDir if dataDir == "" { dataDir = "/tmp/data" } verifier := integrity.NewDatasetVerifier(dataDir) return verifier.VerifyDatasetSpecs(task) } // EnforceTaskProvenance enforces provenance requirements for a task. // It validates and/or populates provenance metadata based on the ProvenanceBestEffort config. // In strict mode (ProvenanceBestEffort=false), it returns an error if metadata doesn't match computed values. // In best-effort mode (ProvenanceBestEffort=true), it populates missing metadata fields. func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } basePath := w.Config.BasePath if basePath == "" { basePath = os.TempDir() } dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } bestEffort := w.Config.ProvenanceBestEffort // Get commit_id from metadata commitID := task.Metadata["commit_id"] if commitID == "" { return fmt.Errorf("missing commit_id in task metadata") } // Compute and verify experiment manifest SHA expPath := filepath.Join(basePath, commitID) manifestSHA, err := integrity.DirOverallSHA256Hex(expPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute experiment manifest SHA: %w", err) } // In best-effort mode, we'll use whatever is provided or skip manifestSHA = "" } // Handle experiment_manifest_overall_sha expectedManifestSHA := task.Metadata["experiment_manifest_overall_sha"] if expectedManifestSHA == "" { if !bestEffort { return fmt.Errorf("missing experiment_manifest_overall_sha in task metadata") } // Populate in best-effort mode if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["experiment_manifest_overall_sha"] = manifestSHA } else if !bestEffort && expectedManifestSHA != manifestSHA { return fmt.Errorf("experiment manifest SHA mismatch: expected %s, got %s", expectedManifestSHA, manifestSHA) } // Handle deps_manifest_sha256 - auto-detect if not provided filesPath := filepath.Join(expPath, "files") depsManifestName := task.Metadata["deps_manifest_name"] if depsManifestName == "" { // Auto-detect manifest file depsManifestName, _ = executor.SelectDependencyManifest(filesPath) } if depsManifestName != "" { if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["deps_manifest_name"] = depsManifestName depsPath := filepath.Join(filesPath, depsManifestName) depsSHA, err := integrity.FileSHA256Hex(depsPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute deps manifest SHA: %w", err) } depsSHA = "" } expectedDepsSHA := task.Metadata["deps_manifest_sha256"] if expectedDepsSHA == "" { if !bestEffort { return fmt.Errorf("missing deps_manifest_sha256 in task metadata") } task.Metadata["deps_manifest_sha256"] = depsSHA } else if !bestEffort && expectedDepsSHA != depsSHA { return fmt.Errorf("deps manifest SHA mismatch: expected %s, got %s", expectedDepsSHA, depsSHA) } } // Handle snapshot_sha256 if SnapshotID is set if task.SnapshotID != "" { snapPath := filepath.Join(dataDir, "snapshots", task.SnapshotID) snapSHA, err := integrity.DirOverallSHA256Hex(snapPath) if err != nil { if !bestEffort { return fmt.Errorf("failed to compute snapshot SHA: %w", err) } snapSHA = "" } expectedSnapSHA, _ := integrity.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) if expectedSnapSHA == "" { if !bestEffort { return fmt.Errorf("missing snapshot_sha256 in task metadata") } if task.Metadata == nil { task.Metadata = map[string]string{} } task.Metadata["snapshot_sha256"] = snapSHA } else if !bestEffort && expectedSnapSHA != snapSHA { return fmt.Errorf("snapshot SHA mismatch: expected %s, got %s", expectedSnapSHA, snapSHA) } } return nil } // VerifySnapshot verifies snapshot integrity for this task. // It computes the SHA256 of the snapshot directory and compares with task metadata. func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { if task.SnapshotID == "" { return nil // No snapshot to verify } dataDir := w.Config.DataDir if dataDir == "" { dataDir = os.TempDir() + "/data" } // Get expected checksum from metadata expectedChecksum, ok := task.Metadata["snapshot_sha256"] if !ok || expectedChecksum == "" { return fmt.Errorf("missing snapshot_sha256 in task metadata") } // Normalize the checksum (remove sha256: prefix if present) expectedChecksum, err := integrity.NormalizeSHA256ChecksumHex(expectedChecksum) if err != nil { return fmt.Errorf("invalid snapshot_sha256 format: %w", err) } // Compute actual checksum of snapshot directory snapshotDir := filepath.Join(dataDir, "snapshots", task.SnapshotID) actualChecksum, err := integrity.DirOverallSHA256Hex(snapshotDir) if err != nil { return fmt.Errorf("failed to compute snapshot hash: %w", err) } // Compare checksums if actualChecksum != expectedChecksum { return fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum) } return nil } // GetJupyterManager returns the Jupyter manager for plugin use // This implements the plugins.TaskRunner interface func (w *Worker) GetJupyterManager() plugins.JupyterManager { return w.Jupyter } // PrewarmNextOnce prewarms the next task in queue. // It fetches the next task, verifies its snapshot, and stages it to the prewarm directory. // Returns true if prewarming was performed, false if disabled or queue empty. func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { // Check if prewarming is enabled if !w.Config.PrewarmEnabled { return false, nil } // Get base path and data directory basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } dataDir := w.Config.DataDir if dataDir == "" { dataDir = filepath.Join(basePath, "data") } // Create prewarm directory prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots") if err := os.MkdirAll(prewarmDir, 0o750); err != nil { return false, fmt.Errorf("failed to create prewarm directory: %w", err) } // Try to get next task from queue client if available (peek, don't lease) if w.QueueClient != nil { task, err := w.QueueClient.PeekNextTask() if err != nil { // Queue empty - check if we have existing prewarm state // Return false but preserve any existing state (don't delete) state, _ := w.QueueClient.GetWorkerPrewarmState(w.ID) if state != nil { // We have existing state, return true to indicate prewarm is active return true, nil } return false, nil } if task != nil && task.SnapshotID != "" { // Resolve snapshot path using SHA from metadata if available snapshotSHA := task.Metadata["snapshot_sha256"] if snapshotSHA != "" { snapshotSHA, _ = integrity.NormalizeSHA256ChecksumHex(snapshotSHA) } var srcDir string if snapshotSHA != "" { // Check if snapshot exists in SHA cache directory shaDir := filepath.Join(dataDir, "snapshots", "sha256", snapshotSHA) if info, err := os.Stat(shaDir); err == nil && info.IsDir() { srcDir = shaDir } } // Fall back to direct snapshot path if SHA directory doesn't exist if srcDir == "" { srcDir = filepath.Join(dataDir, "snapshots", task.SnapshotID) } dstDir := filepath.Join(prewarmDir, task.ID) if err := execution.CopyDir(srcDir, dstDir); err != nil { return false, fmt.Errorf("failed to stage snapshot: %w", err) } // Store prewarm state in queue backend if w.QueueClient != nil { now := time.Now().UTC().Format(time.RFC3339) state := queue.PrewarmState{ WorkerID: w.ID, TaskID: task.ID, SnapshotID: task.SnapshotID, StartedAt: now, UpdatedAt: now, Phase: "staged", } _ = w.QueueClient.SetWorkerPrewarmState(state) } return true, nil } } // If we have a runLoop but no queue client, use runLoop (for backward compatibility) if w.RunLoop != nil { return true, nil } return false, nil } // RunJob runs a job task. // It uses the JobRunner to execute the job and write the run manifest. func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) error { if w.Runner == nil { return fmt.Errorf("job runner not configured") } basePath := w.Config.BasePath if basePath == "" { basePath = "/tmp" } // Determine execution mode mode := executor.ModeAuto if w.Config.LocalMode { mode = executor.ModeLocal } // Create minimal GPU environment (empty for now) gpuEnv := interfaces.ExecutionEnv{} // Run the job return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv) } func mustMarshal(v any) []byte { b, _ := json.Marshal(v) return b }