fetch_ml/internal/worker/runloop.go

525 lines
14 KiB
Go

package worker
import (
"context"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
)
// Start starts the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker started",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers,
"poll_interval", w.config.PollInterval)
go w.heartbeat()
if w.config != nil && w.config.PrewarmEnabled {
go w.prewarmNextLoop()
go w.prewarmImageGCLoop()
}
for {
switch {
case w.ctx.Err() != nil:
w.logger.Info("shutdown signal received, waiting for tasks")
w.waitForTasks()
return
default:
}
if w.runningCount() >= w.config.MaxWorkers {
time.Sleep(50 * time.Millisecond)
continue
}
queueStart := time.Now()
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
task, err := w.queue.GetNextTaskWithLeaseBlocking(
w.config.WorkerID,
w.config.TaskLeaseDuration,
blockTimeout,
)
queueLatency := time.Since(queueStart)
if err != nil {
if err == context.DeadlineExceeded {
continue
}
w.logger.Error("error fetching task",
"worker_id", w.id,
"error", err)
continue
}
if task == nil {
if queueLatency > 200*time.Millisecond {
w.logger.Debug("queue poll latency",
"latency_ms", queueLatency.Milliseconds())
}
continue
}
if depth, derr := w.queue.QueueDepth(); derr == nil {
if queueLatency > 100*time.Millisecond || depth > 0 {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"remaining_depth", depth)
}
} else if queueLatency > 100*time.Millisecond {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"depth_error", derr)
}
// Reserve a running slot *before* starting the goroutine so we don't drain
// the entire queue while max_workers is 1.
w.reserveRunningSlot(task.ID)
go w.executeTaskWithLease(task)
}
}
func (w *Worker) reserveRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
w.running = make(map[string]context.CancelFunc)
}
// Track a cancel func for future shutdown handling; currently best-effort.
_, cancel := context.WithCancel(w.ctx)
w.running[taskID] = cancel
}
func (w *Worker) releaseRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
return
}
if cancel, ok := w.running[taskID]; ok {
cancel()
delete(w.running, taskID)
}
}
func (w *Worker) prewarmImageGCLoop() {
if w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.envPool == nil {
return
}
if w.config.LocalMode {
return
}
if strings.TrimSpace(w.config.PodmanImage) == "" {
return
}
lastSeen := ""
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
runGC := func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
}
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if w.queue != nil {
v, err := w.queue.PrewarmGCRequestValue()
if err == nil && v != "" && v != lastSeen {
lastSeen = v
runGC()
continue
}
}
runGC()
}
}
}
func (w *Worker) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if err := w.queue.Heartbeat(w.id); err != nil {
w.logger.Warn("heartbeat failed",
"worker_id", w.id,
"error", err)
}
}
}
}
func (w *Worker) waitForTasks() {
timeout := time.After(5 * time.Minute)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-timeout:
w.logger.Warn("shutdown timeout, force stopping",
"running_tasks", len(w.running))
return
case <-ticker.C:
count := w.runningCount()
if count == 0 {
w.logger.Info("all tasks completed, shutting down")
return
}
w.logger.Debug("waiting for tasks to complete",
"remaining", count)
}
}
}
func (w *Worker) runningCount() int {
w.runningMu.RLock()
defer w.runningMu.RUnlock()
return len(w.running)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
return stats
}
// Stop gracefully shuts down the worker.
func (w *Worker) Stop() {
w.cancel()
w.waitForTasks()
// FIXED: Check error return values
if err := w.server.Close(); err != nil {
w.logger.Warn("error closing server connection", "error", err)
}
if err := w.queue.Close(); err != nil {
w.logger.Warn("error closing queue connection", "error", err)
}
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics exporter shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Execute task with lease management and retry.
func (w *Worker) executeTaskWithLease(task *queue.Task) {
defer w.releaseRunningSlot(task.ID)
// Track task for graceful shutdown
w.gracefulWait.Add(1)
w.activeTasks.Store(task.ID, task)
defer w.gracefulWait.Done()
defer w.activeTasks.Delete(task.ID)
// Create task-specific context with timeout
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
defer taskCancel()
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
logger.Info("starting task",
"worker_id", w.id,
"datasets", task.Datasets,
"priority", task.Priority)
// Record task start
w.metrics.RecordTaskStart()
defer w.metrics.RecordTaskCompletion()
// Check for context cancellation
select {
case <-taskCtx.Done():
logger.Info("task cancelled before execution")
return
default:
}
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
if isJupyterTask(task) {
out, err := w.runJupyterTask(taskCtx, task)
endTime := time.Now()
task.EndedAt = &endTime
if err != nil {
logger.Error("jupyter task failed", "error", err)
task.Status = "failed"
task.Error = err.Error()
_ = w.queue.UpdateTaskWithMetrics(task, "final")
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if len(out) > 0 {
task.Output = string(out)
}
task.Status = "completed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
// Parse datasets from task arguments
task.Datasets = resolveDatasets(task)
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
logger.Error("task validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error("failed to update task status after validation failure", "error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
logger.Error("provenance validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after provenance validation failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
lease, err := w.resources.Acquire(taskCtx, task)
if err != nil {
logger.Error("resource acquisition failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after resource acquisition failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
defer lease.Release()
// Start heartbeat goroutine
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer cancelHeartbeat()
go w.heartbeatLoop(heartbeatCtx, task.ID)
// Update task status
task.Status = "running"
now := time.Now()
task.StartedAt = &now
task.WorkerID = w.id
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
w.recordTaskProvenance(taskCtx, task)
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
logger.Error("failed to update task status", "error", err)
w.metrics.RecordTaskFailure()
return
}
if w.config.AutoFetchData && len(task.Datasets) > 0 {
if err := w.fetchDatasets(taskCtx, task); err != nil {
logger.Error("data fetch failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
err := w.queue.UpdateTask(task)
if err != nil {
logger.Error("failed to update task status after data fetch failure", "error", err)
}
w.metrics.RecordTaskFailure()
return
}
}
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
logger.Error("dataset checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after dataset checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
if err := w.verifySnapshot(taskCtx, task); err != nil {
logger.Error("snapshot checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after snapshot checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
// Execute job with panic recovery
var execErr error
func() {
defer func() {
if r := recover(); r != nil {
execErr = fmt.Errorf("panic during execution: %v", r)
}
}()
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
execErr = w.runJob(taskCtx, task, cudaVisible)
}()
// Finalize task
endTime := time.Now()
task.EndedAt = &endTime
if execErr != nil {
task.Error = execErr.Error()
// Check if transient error (network, timeout, etc)
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
w.logger.Warn("task failed with transient error, will retry",
"task_id", task.ID,
"error", execErr,
"retry_count", task.RetryCount)
_ = w.queue.RetryTask(task)
} else {
task.Status = "failed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
} else {
task.Status = "completed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
// Release lease
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
}
// Heartbeat loop to renew lease.
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
ticker := time.NewTicker(w.config.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
return
}
// Also update worker heartbeat
_ = w.queue.Heartbeat(w.config.WorkerID)
}
}
}
// Shutdown gracefully shuts down the worker.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
// Wait for active tasks with timeout
done := make(chan struct{})
go func() {
w.gracefulWait.Wait()
close(done)
}()
timeout := time.After(w.config.GracefulTimeout)
select {
case <-done:
w.logger.Info("all tasks completed, shutdown successful")
case <-timeout:
w.logger.Warn("graceful shutdown timeout, releasing active leases")
w.releaseAllLeases()
}
return w.queue.Close()
}
// Release all active leases.
func (w *Worker) releaseAllLeases() {
w.activeTasks.Range(func(key, _ interface{}) bool {
taskID := key.(string)
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
}
return true
})
}
// Helper functions.
func (w *Worker) countActiveTasks() int {
count := 0
w.activeTasks.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
func isTransientError(err error) bool {
if err == nil {
return false
}
// Check if error is transient (network, timeout, resource unavailable, etc)
errStr := err.Error()
transientIndicators := []string{
"connection refused",
"timeout",
"temporary failure",
"resource temporarily unavailable",
"no such host",
"network unreachable",
}
for _, indicator := range transientIndicators {
if strings.Contains(strings.ToLower(errStr), indicator) {
return true
}
}
return false
}