525 lines
14 KiB
Go
525 lines
14 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/resources"
|
|
)
|
|
|
|
// Start starts the worker's main processing loop.
|
|
func (w *Worker) Start() {
|
|
w.logger.Info("worker started",
|
|
"worker_id", w.id,
|
|
"max_concurrent", w.config.MaxWorkers,
|
|
"poll_interval", w.config.PollInterval)
|
|
|
|
go w.heartbeat()
|
|
if w.config != nil && w.config.PrewarmEnabled {
|
|
go w.prewarmNextLoop()
|
|
go w.prewarmImageGCLoop()
|
|
}
|
|
|
|
for {
|
|
switch {
|
|
case w.ctx.Err() != nil:
|
|
w.logger.Info("shutdown signal received, waiting for tasks")
|
|
w.waitForTasks()
|
|
return
|
|
default:
|
|
}
|
|
if w.runningCount() >= w.config.MaxWorkers {
|
|
time.Sleep(50 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
queueStart := time.Now()
|
|
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
|
|
task, err := w.queue.GetNextTaskWithLeaseBlocking(
|
|
w.config.WorkerID,
|
|
w.config.TaskLeaseDuration,
|
|
blockTimeout,
|
|
)
|
|
queueLatency := time.Since(queueStart)
|
|
if err != nil {
|
|
if err == context.DeadlineExceeded {
|
|
continue
|
|
}
|
|
w.logger.Error("error fetching task",
|
|
"worker_id", w.id,
|
|
"error", err)
|
|
continue
|
|
}
|
|
|
|
if task == nil {
|
|
if queueLatency > 200*time.Millisecond {
|
|
w.logger.Debug("queue poll latency",
|
|
"latency_ms", queueLatency.Milliseconds())
|
|
}
|
|
continue
|
|
}
|
|
|
|
if depth, derr := w.queue.QueueDepth(); derr == nil {
|
|
if queueLatency > 100*time.Millisecond || depth > 0 {
|
|
w.logger.Debug("queue fetch metrics",
|
|
"latency_ms", queueLatency.Milliseconds(),
|
|
"remaining_depth", depth)
|
|
}
|
|
} else if queueLatency > 100*time.Millisecond {
|
|
w.logger.Debug("queue fetch metrics",
|
|
"latency_ms", queueLatency.Milliseconds(),
|
|
"depth_error", derr)
|
|
}
|
|
|
|
// Reserve a running slot *before* starting the goroutine so we don't drain
|
|
// the entire queue while max_workers is 1.
|
|
w.reserveRunningSlot(task.ID)
|
|
go w.executeTaskWithLease(task)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) reserveRunningSlot(taskID string) {
|
|
w.runningMu.Lock()
|
|
defer w.runningMu.Unlock()
|
|
if w.running == nil {
|
|
w.running = make(map[string]context.CancelFunc)
|
|
}
|
|
// Track a cancel func for future shutdown handling; currently best-effort.
|
|
_, cancel := context.WithCancel(w.ctx)
|
|
w.running[taskID] = cancel
|
|
}
|
|
|
|
func (w *Worker) releaseRunningSlot(taskID string) {
|
|
w.runningMu.Lock()
|
|
defer w.runningMu.Unlock()
|
|
if w.running == nil {
|
|
return
|
|
}
|
|
if cancel, ok := w.running[taskID]; ok {
|
|
cancel()
|
|
delete(w.running, taskID)
|
|
}
|
|
}
|
|
|
|
func (w *Worker) prewarmImageGCLoop() {
|
|
if w.config == nil || !w.config.PrewarmEnabled {
|
|
return
|
|
}
|
|
if w.envPool == nil {
|
|
return
|
|
}
|
|
if w.config.LocalMode {
|
|
return
|
|
}
|
|
if strings.TrimSpace(w.config.PodmanImage) == "" {
|
|
return
|
|
}
|
|
|
|
lastSeen := ""
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
runGC := func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
defer cancel()
|
|
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if w.queue != nil {
|
|
v, err := w.queue.PrewarmGCRequestValue()
|
|
if err == nil && v != "" && v != lastSeen {
|
|
lastSeen = v
|
|
runGC()
|
|
continue
|
|
}
|
|
}
|
|
runGC()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) heartbeat() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-w.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := w.queue.Heartbeat(w.id); err != nil {
|
|
w.logger.Warn("heartbeat failed",
|
|
"worker_id", w.id,
|
|
"error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) waitForTasks() {
|
|
timeout := time.After(5 * time.Minute)
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-timeout:
|
|
w.logger.Warn("shutdown timeout, force stopping",
|
|
"running_tasks", len(w.running))
|
|
return
|
|
case <-ticker.C:
|
|
count := w.runningCount()
|
|
if count == 0 {
|
|
w.logger.Info("all tasks completed, shutting down")
|
|
return
|
|
}
|
|
w.logger.Debug("waiting for tasks to complete",
|
|
"remaining", count)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) runningCount() int {
|
|
w.runningMu.RLock()
|
|
defer w.runningMu.RUnlock()
|
|
return len(w.running)
|
|
}
|
|
|
|
// GetMetrics returns current worker metrics.
|
|
func (w *Worker) GetMetrics() map[string]any {
|
|
stats := w.metrics.GetStats()
|
|
stats["worker_id"] = w.id
|
|
stats["max_workers"] = w.config.MaxWorkers
|
|
return stats
|
|
}
|
|
|
|
// Stop gracefully shuts down the worker.
|
|
func (w *Worker) Stop() {
|
|
w.cancel()
|
|
w.waitForTasks()
|
|
|
|
// FIXED: Check error return values
|
|
if err := w.server.Close(); err != nil {
|
|
w.logger.Warn("error closing server connection", "error", err)
|
|
}
|
|
if err := w.queue.Close(); err != nil {
|
|
w.logger.Warn("error closing queue connection", "error", err)
|
|
}
|
|
if w.metricsSrv != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := w.metricsSrv.Shutdown(ctx); err != nil {
|
|
w.logger.Warn("metrics exporter shutdown error", "error", err)
|
|
}
|
|
}
|
|
w.logger.Info("worker stopped", "worker_id", w.id)
|
|
}
|
|
|
|
// Execute task with lease management and retry.
|
|
func (w *Worker) executeTaskWithLease(task *queue.Task) {
|
|
defer w.releaseRunningSlot(task.ID)
|
|
|
|
// Track task for graceful shutdown
|
|
w.gracefulWait.Add(1)
|
|
w.activeTasks.Store(task.ID, task)
|
|
defer w.gracefulWait.Done()
|
|
defer w.activeTasks.Delete(task.ID)
|
|
|
|
// Create task-specific context with timeout
|
|
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
|
|
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
|
|
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
|
|
|
|
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
|
|
defer taskCancel()
|
|
|
|
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
|
|
logger.Info("starting task",
|
|
"worker_id", w.id,
|
|
"datasets", task.Datasets,
|
|
"priority", task.Priority)
|
|
|
|
// Record task start
|
|
w.metrics.RecordTaskStart()
|
|
defer w.metrics.RecordTaskCompletion()
|
|
|
|
// Check for context cancellation
|
|
select {
|
|
case <-taskCtx.Done():
|
|
logger.Info("task cancelled before execution")
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
|
|
if isJupyterTask(task) {
|
|
out, err := w.runJupyterTask(taskCtx, task)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if err != nil {
|
|
logger.Error("jupyter task failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = err.Error()
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
w.metrics.RecordTaskFailure()
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
return
|
|
}
|
|
if len(out) > 0 {
|
|
task.Output = string(out)
|
|
}
|
|
task.Status = "completed"
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
return
|
|
}
|
|
|
|
// Parse datasets from task arguments
|
|
task.Datasets = resolveDatasets(task)
|
|
|
|
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
|
|
logger.Error("task validation failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Validation failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
|
logger.Error("failed to update task status after validation failure", "error", updateErr)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
return
|
|
}
|
|
|
|
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
|
|
logger.Error("provenance validation failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
|
logger.Error(
|
|
"failed to update task status after provenance validation failure",
|
|
"error", updateErr)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
return
|
|
}
|
|
|
|
lease, err := w.resources.Acquire(taskCtx, task)
|
|
if err != nil {
|
|
logger.Error("resource acquisition failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
|
logger.Error(
|
|
"failed to update task status after resource acquisition failure",
|
|
"error", updateErr)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
return
|
|
}
|
|
defer lease.Release()
|
|
|
|
// Start heartbeat goroutine
|
|
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
|
defer cancelHeartbeat()
|
|
|
|
go w.heartbeatLoop(heartbeatCtx, task.ID)
|
|
|
|
// Update task status
|
|
task.Status = "running"
|
|
now := time.Now()
|
|
task.StartedAt = &now
|
|
task.WorkerID = w.id
|
|
|
|
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
|
|
w.recordTaskProvenance(taskCtx, task)
|
|
|
|
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
|
|
logger.Error("failed to update task status", "error", err)
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
|
|
if w.config.AutoFetchData && len(task.Datasets) > 0 {
|
|
if err := w.fetchDatasets(taskCtx, task); err != nil {
|
|
logger.Error("data fetch failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
err := w.queue.UpdateTask(task)
|
|
if err != nil {
|
|
logger.Error("failed to update task status after data fetch failure", "error", err)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
}
|
|
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
|
|
logger.Error("dataset checksum verification failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
|
logger.Error(
|
|
"failed to update task after dataset checksum verification failure",
|
|
"error", updateErr)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
if err := w.verifySnapshot(taskCtx, task); err != nil {
|
|
logger.Error("snapshot checksum verification failed", "error", err)
|
|
task.Status = "failed"
|
|
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
|
|
logger.Error(
|
|
"failed to update task after snapshot checksum verification failure",
|
|
"error", updateErr)
|
|
}
|
|
w.metrics.RecordTaskFailure()
|
|
return
|
|
}
|
|
|
|
// Execute job with panic recovery
|
|
var execErr error
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
execErr = fmt.Errorf("panic during execution: %v", r)
|
|
}
|
|
}()
|
|
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
|
|
execErr = w.runJob(taskCtx, task, cudaVisible)
|
|
}()
|
|
|
|
// Finalize task
|
|
endTime := time.Now()
|
|
task.EndedAt = &endTime
|
|
|
|
if execErr != nil {
|
|
task.Error = execErr.Error()
|
|
|
|
// Check if transient error (network, timeout, etc)
|
|
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
|
|
w.logger.Warn("task failed with transient error, will retry",
|
|
"task_id", task.ID,
|
|
"error", execErr,
|
|
"retry_count", task.RetryCount)
|
|
_ = w.queue.RetryTask(task)
|
|
} else {
|
|
task.Status = "failed"
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
}
|
|
} else {
|
|
task.Status = "completed"
|
|
_ = w.queue.UpdateTaskWithMetrics(task, "final")
|
|
}
|
|
|
|
// Release lease
|
|
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
|
|
}
|
|
|
|
// Heartbeat loop to renew lease.
|
|
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
|
|
ticker := time.NewTicker(w.config.HeartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
|
|
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
|
|
return
|
|
}
|
|
// Also update worker heartbeat
|
|
_ = w.queue.Heartbeat(w.config.WorkerID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the worker.
|
|
func (w *Worker) Shutdown() error {
|
|
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
|
|
|
|
// Wait for active tasks with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
w.gracefulWait.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
timeout := time.After(w.config.GracefulTimeout)
|
|
select {
|
|
case <-done:
|
|
w.logger.Info("all tasks completed, shutdown successful")
|
|
case <-timeout:
|
|
w.logger.Warn("graceful shutdown timeout, releasing active leases")
|
|
w.releaseAllLeases()
|
|
}
|
|
|
|
return w.queue.Close()
|
|
}
|
|
|
|
// Release all active leases.
|
|
func (w *Worker) releaseAllLeases() {
|
|
w.activeTasks.Range(func(key, _ interface{}) bool {
|
|
taskID := key.(string)
|
|
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
|
|
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// Helper functions.
|
|
func (w *Worker) countActiveTasks() int {
|
|
count := 0
|
|
w.activeTasks.Range(func(_, _ interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
return count
|
|
}
|
|
|
|
func isTransientError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
// Check if error is transient (network, timeout, resource unavailable, etc)
|
|
errStr := err.Error()
|
|
transientIndicators := []string{
|
|
"connection refused",
|
|
"timeout",
|
|
"temporary failure",
|
|
"resource temporarily unavailable",
|
|
"no such host",
|
|
"network unreachable",
|
|
}
|
|
for _, indicator := range transientIndicators {
|
|
if strings.Contains(strings.ToLower(errStr), indicator) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|