package worker import ( "context" "fmt" "strings" "time" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" ) // Start starts the worker's main processing loop. func (w *Worker) Start() { w.logger.Info("worker started", "worker_id", w.id, "max_concurrent", w.config.MaxWorkers, "poll_interval", w.config.PollInterval) go w.heartbeat() if w.config != nil && w.config.PrewarmEnabled { go w.prewarmNextLoop() go w.prewarmImageGCLoop() } for { switch { case w.ctx.Err() != nil: w.logger.Info("shutdown signal received, waiting for tasks") w.waitForTasks() return default: } if w.runningCount() >= w.config.MaxWorkers { time.Sleep(50 * time.Millisecond) continue } queueStart := time.Now() blockTimeout := time.Duration(w.config.PollInterval) * time.Second task, err := w.queue.GetNextTaskWithLeaseBlocking( w.config.WorkerID, w.config.TaskLeaseDuration, blockTimeout, ) queueLatency := time.Since(queueStart) if err != nil { if err == context.DeadlineExceeded { continue } w.logger.Error("error fetching task", "worker_id", w.id, "error", err) continue } if task == nil { if queueLatency > 200*time.Millisecond { w.logger.Debug("queue poll latency", "latency_ms", queueLatency.Milliseconds()) } continue } if depth, derr := w.queue.QueueDepth(); derr == nil { if queueLatency > 100*time.Millisecond || depth > 0 { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "remaining_depth", depth) } } else if queueLatency > 100*time.Millisecond { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "depth_error", derr) } // Reserve a running slot *before* starting the goroutine so we don't drain // the entire queue while max_workers is 1. w.reserveRunningSlot(task.ID) go w.executeTaskWithLease(task) } } func (w *Worker) reserveRunningSlot(taskID string) { w.runningMu.Lock() defer w.runningMu.Unlock() if w.running == nil { w.running = make(map[string]context.CancelFunc) } // Track a cancel func for future shutdown handling; currently best-effort. _, cancel := context.WithCancel(w.ctx) w.running[taskID] = cancel } func (w *Worker) releaseRunningSlot(taskID string) { w.runningMu.Lock() defer w.runningMu.Unlock() if w.running == nil { return } if cancel, ok := w.running[taskID]; ok { cancel() delete(w.running, taskID) } } func (w *Worker) prewarmImageGCLoop() { if w.config == nil || !w.config.PrewarmEnabled { return } if w.envPool == nil { return } if w.config.LocalMode { return } if strings.TrimSpace(w.config.PodmanImage) == "" { return } lastSeen := "" ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() runGC := func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() _ = w.envPool.PruneImages(ctx, 24*time.Hour) } for { select { case <-w.ctx.Done(): return case <-ticker.C: if w.queue != nil { v, err := w.queue.PrewarmGCRequestValue() if err == nil && v != "" && v != lastSeen { lastSeen = v runGC() continue } } runGC() } } } func (w *Worker) heartbeat() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-w.ctx.Done(): return case <-ticker.C: if err := w.queue.Heartbeat(w.id); err != nil { w.logger.Warn("heartbeat failed", "worker_id", w.id, "error", err) } } } } func (w *Worker) waitForTasks() { timeout := time.After(5 * time.Minute) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-timeout: w.logger.Warn("shutdown timeout, force stopping", "running_tasks", len(w.running)) return case <-ticker.C: count := w.runningCount() if count == 0 { w.logger.Info("all tasks completed, shutting down") return } w.logger.Debug("waiting for tasks to complete", "remaining", count) } } } func (w *Worker) runningCount() int { w.runningMu.RLock() defer w.runningMu.RUnlock() return len(w.running) } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { stats := w.metrics.GetStats() stats["worker_id"] = w.id stats["max_workers"] = w.config.MaxWorkers return stats } // Stop gracefully shuts down the worker. func (w *Worker) Stop() { w.cancel() w.waitForTasks() // FIXED: Check error return values if err := w.server.Close(); err != nil { w.logger.Warn("error closing server connection", "error", err) } if err := w.queue.Close(); err != nil { w.logger.Warn("error closing queue connection", "error", err) } if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.logger.Warn("metrics exporter shutdown error", "error", err) } } w.logger.Info("worker stopped", "worker_id", w.id) } // Execute task with lease management and retry. func (w *Worker) executeTaskWithLease(task *queue.Task) { defer w.releaseRunningSlot(task.ID) // Track task for graceful shutdown w.gracefulWait.Add(1) w.activeTasks.Store(task.ID, task) defer w.gracefulWait.Done() defer w.activeTasks.Delete(task.ID) // Create task-specific context with timeout taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) defer taskCancel() logger := w.logger.Job(taskCtx, task.JobName, task.ID) logger.Info("starting task", "worker_id", w.id, "datasets", task.Datasets, "priority", task.Priority) // Record task start w.metrics.RecordTaskStart() defer w.metrics.RecordTaskCompletion() // Check for context cancellation select { case <-taskCtx.Done(): logger.Info("task cancelled before execution") return default: } // Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline). if isJupyterTask(task) { out, err := w.runJupyterTask(taskCtx, task) endTime := time.Now() task.EndedAt = &endTime if err != nil { logger.Error("jupyter task failed", "error", err) task.Status = "failed" task.Error = err.Error() _ = w.queue.UpdateTaskWithMetrics(task, "final") w.metrics.RecordTaskFailure() _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) return } if len(out) > 0 { task.Output = string(out) } task.Status = "completed" _ = w.queue.UpdateTaskWithMetrics(task, "final") _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) return } // Parse datasets from task arguments task.Datasets = resolveDatasets(task) if err := w.validateTaskForExecution(taskCtx, task); err != nil { logger.Error("task validation failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Validation failed: %v", err) endTime := time.Now() task.EndedAt = &endTime if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { logger.Error("failed to update task status after validation failure", "error", updateErr) } w.metrics.RecordTaskFailure() _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) return } if err := w.enforceTaskProvenance(taskCtx, task); err != nil { logger.Error("provenance validation failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Provenance validation failed: %v", err) endTime := time.Now() task.EndedAt = &endTime if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { logger.Error( "failed to update task status after provenance validation failure", "error", updateErr) } w.metrics.RecordTaskFailure() _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) return } lease, err := w.resources.Acquire(taskCtx, task) if err != nil { logger.Error("resource acquisition failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Resource acquisition failed: %v", err) endTime := time.Now() task.EndedAt = &endTime if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { logger.Error( "failed to update task status after resource acquisition failure", "error", updateErr) } w.metrics.RecordTaskFailure() _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) return } defer lease.Release() // Start heartbeat goroutine heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) defer cancelHeartbeat() go w.heartbeatLoop(heartbeatCtx, task.ID) // Update task status and record attempt start task.Status = "running" now := time.Now() task.StartedAt = &now task.WorkerID = w.id // Record this attempt in the task history attempt := queue.Attempt{ Attempt: task.RetryCount + 1, // 1-indexed attempt number StartedAt: now, WorkerID: w.id, Status: "running", } task.Attempts = append(task.Attempts, attempt) // Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state. w.recordTaskProvenance(taskCtx, task) if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { logger.Error("failed to update task status", "error", err) w.metrics.RecordTaskFailure() return } if w.config.AutoFetchData && len(task.Datasets) > 0 { if err := w.fetchDatasets(taskCtx, task); err != nil { logger.Error("data fetch failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Data fetch failed: %v", err) endTime := time.Now() task.EndedAt = &endTime err := w.queue.UpdateTask(task) if err != nil { logger.Error("failed to update task status after data fetch failure", "error", err) } w.metrics.RecordTaskFailure() return } } if err := w.verifyDatasetSpecs(taskCtx, task); err != nil { logger.Error("dataset checksum verification failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err) endTime := time.Now() task.EndedAt = &endTime if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { logger.Error( "failed to update task after dataset checksum verification failure", "error", updateErr) } w.metrics.RecordTaskFailure() return } if err := w.verifySnapshot(taskCtx, task); err != nil { logger.Error("snapshot checksum verification failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err) endTime := time.Now() task.EndedAt = &endTime if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { logger.Error( "failed to update task after snapshot checksum verification failure", "error", updateErr) } w.metrics.RecordTaskFailure() return } // Execute job with panic recovery var execErr error func() { defer func() { if r := recover(); r != nil { execErr = fmt.Errorf("panic during execution: %v", r) } }() cudaVisible := resources.FormatCUDAVisibleDevices(lease) execErr = w.runJob(taskCtx, task, cudaVisible) }() // Finalize task and record attempt completion endTime := time.Now() task.EndedAt = &endTime // Update the last attempt with completion info if len(task.Attempts) > 0 { lastIdx := len(task.Attempts) - 1 task.Attempts[lastIdx].EndedAt = &endTime } if execErr != nil { task.Error = execErr.Error() // Update last attempt with failure info if len(task.Attempts) > 0 { lastIdx := len(task.Attempts) - 1 task.Attempts[lastIdx].Status = "failed" task.Attempts[lastIdx].Error = execErr.Error() task.Attempts[lastIdx].FailureClass = queue.ClassifyFailure(0, nil, execErr.Error()) // TODO: Capture exit code and signal from actual execution } // Check if transient error (network, timeout, etc) if isTransientError(execErr) && task.RetryCount < task.MaxRetries { w.logger.Warn("task failed with transient error, will retry", "task_id", task.ID, "error", execErr, "retry_count", task.RetryCount) _ = w.queue.RetryTask(task) } else { task.Status = "failed" _ = w.queue.UpdateTaskWithMetrics(task, "final") } } else { task.Status = "completed" // Update last attempt with success if len(task.Attempts) > 0 { lastIdx := len(task.Attempts) - 1 task.Attempts[lastIdx].Status = "completed" } _ = w.queue.UpdateTaskWithMetrics(task, "final") } // Release lease _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) } // Heartbeat loop to renew lease. func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { ticker := time.NewTicker(w.config.HeartbeatInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) return } // Also update worker heartbeat _ = w.queue.Heartbeat(w.config.WorkerID) } } } // Shutdown gracefully shuts down the worker. func (w *Worker) Shutdown() error { w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) // Wait for active tasks with timeout done := make(chan struct{}) go func() { w.gracefulWait.Wait() close(done) }() timeout := time.After(w.config.GracefulTimeout) select { case <-done: w.logger.Info("all tasks completed, shutdown successful") case <-timeout: w.logger.Warn("graceful shutdown timeout, releasing active leases") w.releaseAllLeases() } return w.queue.Close() } // Release all active leases. func (w *Worker) releaseAllLeases() { w.activeTasks.Range(func(key, _ interface{}) bool { taskID := key.(string) if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { w.logger.Error("failed to release lease", "task_id", taskID, "error", err) } return true }) } // Helper functions. func (w *Worker) countActiveTasks() int { count := 0 w.activeTasks.Range(func(_, _ interface{}) bool { count++ return true }) return count } func isTransientError(err error) bool { if err == nil { return false } // Check if error is transient (network, timeout, resource unavailable, etc) errStr := err.Error() transientIndicators := []string{ "connection refused", "timeout", "temporary failure", "resource temporarily unavailable", "no such host", "network unreachable", } for _, indicator := range transientIndicators { if strings.Contains(strings.ToLower(errStr), indicator) { return true } } return false }