// Package main implements the ML task worker package main import ( "context" "fmt" "log" "log/slog" "net/http" "os" "os/exec" "os/signal" "path/filepath" "strings" "sync" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/errtypes" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/telemetry" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" ) // MLServer wraps network.SSHClient for backward compatibility. type MLServer struct { *network.SSHClient } // isValidName validates that input strings contain only safe characters. // isValidName checks if the input string is a valid name. func isValidName(input string) bool { return len(input) > 0 && len(input) < 256 } // NewMLServer creates a new ML server connection. // NewMLServer returns a new MLServer instance. func NewMLServer(cfg *Config) (*MLServer, error) { if cfg.LocalMode { return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil } client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) if err != nil { return nil, err } return &MLServer{SSHClient: client}, nil } // Worker represents an ML task worker. type Worker struct { id string config *Config server *MLServer queue *queue.TaskQueue running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown runningMu sync.RWMutex ctx context.Context cancel context.CancelFunc logger *logging.Logger metrics *metrics.Metrics metricsSrv *http.Server datasetCache map[string]time.Time datasetCacheMu sync.RWMutex datasetCacheTTL time.Duration // Graceful shutdown fields shutdownCh chan struct{} activeTasks sync.Map // map[string]*queue.Task - track active tasks gracefulWait sync.WaitGroup } func (w *Worker) setupMetricsExporter() { if !w.config.Metrics.Enabled { return } reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), collectors.NewGoCollector(), ) labels := prometheus.Labels{"worker_id": w.id} reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_processed_total", Help: "Total tasks processed successfully by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksProcessed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_failed_total", Help: "Total tasks failed by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksFailed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_active", Help: "Number of tasks currently running on this worker.", ConstLabels: labels, }, func() float64 { return float64(w.runningCount()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_queued", Help: "Latest observed queue depth from Redis.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.QueuedTasks.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_transferred_bytes_total", Help: "Total bytes transferred while fetching datasets.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataTransferred.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_fetch_time_seconds_total", Help: "Total time spent fetching datasets (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_execution_time_seconds_total", Help: "Total execution time for completed tasks (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_worker_max_concurrency", Help: "Configured maximum concurrent tasks for this worker.", ConstLabels: labels, }, func() float64 { return float64(w.config.MaxWorkers) })) mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) srv := &http.Server{ Addr: w.config.Metrics.ListenAddr, Handler: mux, ReadHeaderTimeout: 5 * time.Second, } w.metricsSrv = srv go func() { w.logger.Info("metrics exporter listening", "addr", w.config.Metrics.ListenAddr, "enabled", w.config.Metrics.Enabled) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { w.logger.Warn("metrics exporter stopped", "error", err) } }() } // NewWorker creates a new worker instance. func NewWorker(cfg *Config, _ string) (*Worker, error) { srv, err := NewMLServer(cfg) if err != nil { return nil, err } queueCfg := queue.Config{ RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, MetricsFlushInterval: cfg.MetricsFlushInterval, } queue, err := queue.NewTaskQueue(queueCfg) if err != nil { return nil, err } // Create data_dir if it doesn't exist (for production without NAS) if cfg.DataDir != "" { if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) } } ctx, cancel := context.WithCancel(context.Background()) ctx = logging.EnsureTrace(ctx) ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) baseLogger := logging.NewLogger(slog.LevelInfo, false) logger := baseLogger.Component(ctx, "worker") metrics := &metrics.Metrics{} worker := &Worker{ id: cfg.WorkerID, config: cfg, server: srv, queue: queue, running: make(map[string]context.CancelFunc), datasetCache: make(map[string]time.Time), datasetCacheTTL: cfg.DatasetCacheTTL, ctx: ctx, cancel: cancel, logger: logger, metrics: metrics, shutdownCh: make(chan struct{}), } worker.setupMetricsExporter() return worker, nil } // Start starts the worker's main processing loop. func (w *Worker) Start() { w.logger.Info("worker started", "worker_id", w.id, "max_concurrent", w.config.MaxWorkers, "poll_interval", w.config.PollInterval) go w.heartbeat() for { select { case <-w.ctx.Done(): w.logger.Info("shutdown signal received, waiting for tasks") w.waitForTasks() return default: } if w.runningCount() >= w.config.MaxWorkers { time.Sleep(50 * time.Millisecond) continue } queueStart := time.Now() blockTimeout := time.Duration(w.config.PollInterval) * time.Second task, err := w.queue.GetNextTaskWithLeaseBlocking(w.config.WorkerID, w.config.TaskLeaseDuration, blockTimeout) queueLatency := time.Since(queueStart) if err != nil { if err == context.DeadlineExceeded { continue } w.logger.Error("error fetching task", "worker_id", w.id, "error", err) continue } if task == nil { if queueLatency > 200*time.Millisecond { w.logger.Debug("queue poll latency", "latency_ms", queueLatency.Milliseconds()) } continue } if depth, derr := w.queue.QueueDepth(); derr == nil { if queueLatency > 100*time.Millisecond || depth > 0 { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "remaining_depth", depth) } } else if queueLatency > 100*time.Millisecond { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "depth_error", derr) } go w.executeTaskWithLease(task) } } func (w *Worker) heartbeat() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-w.ctx.Done(): return case <-ticker.C: if err := w.queue.Heartbeat(w.id); err != nil { w.logger.Warn("heartbeat failed", "worker_id", w.id, "error", err) } } } } // NEW: Fetch datasets using data_manager. func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { logger := w.logger.Job(ctx, task.JobName, task.ID) logger.Info("fetching datasets", "worker_id", w.id, "dataset_count", len(task.Datasets)) for _, dataset := range task.Datasets { if w.datasetIsFresh(dataset) { logger.Debug("skipping cached dataset", "dataset", dataset) continue } // Check for cancellation before each dataset fetch select { case <-w.ctx.Done(): return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err()) default: } logger.Info("fetching dataset", "worker_id", w.id, "dataset", dataset) // Create command with context for cancellation support cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) // Validate inputs to prevent command injection if !isValidName(task.JobName) || !isValidName(dataset) { cancel() return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters") } //nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated cmd := exec.CommandContext(cmdCtx, w.config.DataManagerPath, "fetch", task.JobName, dataset, ) output, err := cmd.CombinedOutput() cancel() // Clean up context if err != nil { return &errtypes.DataFetchError{ Dataset: dataset, JobName: task.JobName, Err: fmt.Errorf("command failed: %w, output: %s", err, output), } } logger.Info("dataset ready", "worker_id", w.id, "dataset", dataset) w.markDatasetFetched(dataset) } return nil } func (w *Worker) runJob(ctx context.Context, task *queue.Task) error { // Validate job name to prevent path traversal if err := container.ValidateJobName(task.JobName); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } jobDir, outputDir, logFile, err := w.setupJobDirectories(task) if err != nil { return err } return w.executeJob(ctx, task, jobDir, outputDir, logFile) } func (w *Worker) setupJobDirectories(task *queue.Task) (jobDir, outputDir, logFile string, err error) { jobPaths := config.NewJobPaths(w.config.BasePath) pendingDir := jobPaths.PendingPath() jobDir = filepath.Join(pendingDir, task.JobName) outputDir = filepath.Join(jobPaths.RunningPath(), task.JobName) logFile = filepath.Join(outputDir, "output.log") // Create pending directory if err := os.MkdirAll(pendingDir, 0750); err != nil { return "", "", "", &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to create pending dir: %w", err), } } // Create job directory in pending if err := os.MkdirAll(jobDir, 0750); err != nil { return "", "", "", &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to create job dir: %w", err), } } // Sanitize paths jobDir, err = container.SanitizePath(jobDir) if err != nil { return "", "", "", &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } outputDir, err = container.SanitizePath(outputDir) if err != nil { return "", "", "", &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } return jobDir, outputDir, logFile, nil } func (w *Worker) executeJob(ctx context.Context, task *queue.Task, jobDir, outputDir, logFile string) error { // Create output directory if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) { if err := os.MkdirAll(outputDir, 0750); err != nil { return "", fmt.Errorf("mkdir failed: %w", err) } return "", nil }); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to create output dir: %w", err), } } // Move job from pending to running stagingStart := time.Now() if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) { // Remove existing directory if it exists if _, err := os.Stat(outputDir); err == nil { if err := os.RemoveAll(outputDir); err != nil { return "", fmt.Errorf("remove existing failed: %w", err) } } if err := os.Rename(jobDir, outputDir); err != nil { return "", fmt.Errorf("rename failed: %w", err) } return "", nil }); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to move job: %w", err), } } stagingDuration := time.Since(stagingStart) // Execute job if w.config.LocalMode { return w.executeLocalJob(ctx, task, outputDir, logFile) } return w.executeContainerJob(ctx, task, outputDir, logFile, stagingDuration) } func (w *Worker) executeLocalJob(ctx context.Context, task *queue.Task, outputDir, logFile string) error { // Create experiment script scriptContent := `#!/bin/bash set -e echo "Starting experiment: ` + task.JobName + `" echo "Task ID: ` + task.ID + `" echo "Timestamp: $(date)" # Simulate ML experiment echo "Loading data..." sleep 1 echo "Training model..." sleep 2 echo "Evaluating model..." sleep 1 # Generate results ACCURACY=0.95 LOSS=0.05 EPOCHS=10 echo "" echo "=== EXPERIMENT RESULTS ===" echo "Accuracy: $ACCURACY" echo "Loss: $LOSS" echo "Epochs: $EPOCHS" echo "Status: SUCCESS" echo "=========================" echo "Experiment completed successfully!" ` scriptPath := filepath.Join(outputDir, "run.sh") if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "execution", Err: fmt.Errorf("failed to write script: %w", err), } } logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err) return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "execution", Err: fmt.Errorf("failed to open log file: %w", err), } } defer func() { if err := logFileHandle.Close(); err != nil { log.Printf("Warning: failed to close log file: %v", err) } }() // Execute the script directly localCmd := exec.CommandContext(ctx, "bash", scriptPath) localCmd.Stdout = logFileHandle localCmd.Stderr = logFileHandle w.logger.Info("executing local job", "job", task.JobName, "task_id", task.ID, "script", scriptPath) if err := localCmd.Run(); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "execution", Err: fmt.Errorf("execution failed: %w", err), } } return nil } func (w *Worker) executeContainerJob( ctx context.Context, task *queue.Task, outputDir, logFile string, stagingDuration time.Duration, ) error { containerResults := w.config.ContainerResults if containerResults == "" { containerResults = config.DefaultContainerResults } containerWorkspace := w.config.ContainerWorkspace if containerWorkspace == "" { containerWorkspace = config.DefaultContainerWorkspace } jobPaths := config.NewJobPaths(w.config.BasePath) stagingStart := time.Now() podmanCfg := container.PodmanConfig{ Image: w.config.PodmanImage, Workspace: filepath.Join(outputDir, "code"), Results: filepath.Join(outputDir, "results"), ContainerWorkspace: containerWorkspace, ContainerResults: containerResults, GPUAccess: w.config.GPUAccess, } scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript) requirementsPath := filepath.Join(containerWorkspace, "requirements.txt") var extraArgs []string if task.Args != "" { extraArgs = strings.Fields(task.Args) } ioBefore, ioErr := telemetry.ReadProcessIO() podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, requirementsPath, extraArgs) logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err == nil { podmanCmd.Stdout = logFileHandle podmanCmd.Stderr = logFileHandle } else { w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err) } w.logger.Info("executing podman job", "job", task.JobName, "image", w.config.PodmanImage, "workspace", podmanCfg.Workspace, "results", podmanCfg.Results) containerStart := time.Now() if err := podmanCmd.Run(); err != nil { containerDuration := time.Since(containerStart) // Move job to failed directory failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(outputDir, failedDir); err != nil { return "", fmt.Errorf("rename to failed failed: %w", err) } return "", nil }); moveErr != nil { w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr) } if ioErr == nil { if after, err := telemetry.ReadProcessIO(); err == nil { delta := telemetry.DiffIO(ioBefore, after) w.logger.Debug("worker io stats", "job", task.JobName, "read_bytes", delta.ReadBytes, "write_bytes", delta.WriteBytes) } } w.logger.Info("job timing (failure)", "job", task.JobName, "staging_ms", stagingDuration.Milliseconds(), "container_ms", containerDuration.Milliseconds(), "finalize_ms", 0, "total_ms", time.Since(stagingStart).Milliseconds(), ) return fmt.Errorf("execution failed: %w", err) } containerDuration := time.Since(containerStart) finalizeStart := time.Now() // Move job to finished directory finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(outputDir, finishedDir); err != nil { return "", fmt.Errorf("rename to finished failed: %w", err) } return "", nil }); moveErr != nil { w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr) } finalizeDuration := time.Since(finalizeStart) totalDuration := time.Since(stagingStart) var ioDelta telemetry.IOStats if ioErr == nil { if after, err := telemetry.ReadProcessIO(); err == nil { ioDelta = telemetry.DiffIO(ioBefore, after) } } w.logger.Info("job timing", "job", task.JobName, "staging_ms", stagingDuration.Milliseconds(), "container_ms", containerDuration.Milliseconds(), "finalize_ms", finalizeDuration.Milliseconds(), "total_ms", totalDuration.Milliseconds(), "io_read_bytes", ioDelta.ReadBytes, "io_write_bytes", ioDelta.WriteBytes, ) return nil } func parseDatasets(args string) []string { if !strings.Contains(args, "--datasets") { return nil } parts := strings.Fields(args) for i, part := range parts { if part == "--datasets" && i+1 < len(parts) { return strings.Split(parts[i+1], ",") } } return nil } func (w *Worker) waitForTasks() { timeout := time.After(5 * time.Minute) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-timeout: w.logger.Warn("shutdown timeout, force stopping", "running_tasks", len(w.running)) return case <-ticker.C: count := w.runningCount() if count == 0 { w.logger.Info("all tasks completed, shutting down") return } w.logger.Debug("waiting for tasks to complete", "remaining", count) } } } func (w *Worker) runningCount() int { w.runningMu.RLock() defer w.runningMu.RUnlock() return len(w.running) } func (w *Worker) datasetIsFresh(dataset string) bool { w.datasetCacheMu.RLock() defer w.datasetCacheMu.RUnlock() expires, ok := w.datasetCache[dataset] return ok && time.Now().Before(expires) } func (w *Worker) markDatasetFetched(dataset string) { expires := time.Now().Add(w.datasetCacheTTL) w.datasetCacheMu.Lock() w.datasetCache[dataset] = expires w.datasetCacheMu.Unlock() } // GetMetrics returns current worker metrics. func (w *Worker) GetMetrics() map[string]any { stats := w.metrics.GetStats() stats["worker_id"] = w.id stats["max_workers"] = w.config.MaxWorkers return stats } // Stop gracefully shuts down the worker. func (w *Worker) Stop() { w.cancel() w.waitForTasks() // FIXED: Check error return values if err := w.server.Close(); err != nil { w.logger.Warn("error closing server connection", "error", err) } if err := w.queue.Close(); err != nil { w.logger.Warn("error closing queue connection", "error", err) } if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.logger.Warn("metrics exporter shutdown error", "error", err) } } w.logger.Info("worker stopped", "worker_id", w.id) } // Execute task with lease management and retry. func (w *Worker) executeTaskWithLease(task *queue.Task) { // Track task for graceful shutdown w.gracefulWait.Add(1) w.activeTasks.Store(task.ID, task) defer w.gracefulWait.Done() defer w.activeTasks.Delete(task.ID) // Create task-specific context with timeout taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) defer taskCancel() logger := w.logger.Job(taskCtx, task.JobName, task.ID) logger.Info("starting task", "worker_id", w.id, "datasets", task.Datasets, "priority", task.Priority) // Record task start w.metrics.RecordTaskStart() defer w.metrics.RecordTaskCompletion() // Check for context cancellation select { case <-taskCtx.Done(): logger.Info("task cancelled before execution") return default: } // Parse datasets from task arguments if task.Datasets == nil { task.Datasets = parseDatasets(task.Args) } // Start heartbeat goroutine heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) defer cancelHeartbeat() go w.heartbeatLoop(heartbeatCtx, task.ID) // Update task status task.Status = "running" now := time.Now() task.StartedAt = &now task.WorkerID = w.id if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { logger.Error("failed to update task status", "error", err) w.metrics.RecordTaskFailure() return } if w.config.AutoFetchData && len(task.Datasets) > 0 { if err := w.fetchDatasets(taskCtx, task); err != nil { logger.Error("data fetch failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Data fetch failed: %v", err) endTime := time.Now() task.EndedAt = &endTime err := w.queue.UpdateTask(task) if err != nil { logger.Error("failed to update task status after data fetch failure", "error", err) } w.metrics.RecordTaskFailure() return } } // Execute job with panic recovery var execErr error func() { defer func() { if r := recover(); r != nil { execErr = fmt.Errorf("panic during execution: %v", r) } }() execErr = w.runJob(taskCtx, task) }() // Finalize task endTime := time.Now() task.EndedAt = &endTime if execErr != nil { task.Error = execErr.Error() // Check if transient error (network, timeout, etc) if isTransientError(execErr) && task.RetryCount < task.MaxRetries { w.logger.Warn("task failed with transient error, will retry", "task_id", task.ID, "error", execErr, "retry_count", task.RetryCount) _ = w.queue.RetryTask(task) } else { task.Status = "failed" _ = w.queue.UpdateTaskWithMetrics(task, "final") } } else { task.Status = "completed" // Read output file for completed tasks jobPaths := config.NewJobPaths(w.config.BasePath) outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName) logFile := filepath.Join(outputDir, "output.log") if outputBytes, err := os.ReadFile(logFile); err == nil { task.Output = string(outputBytes) } _ = w.queue.UpdateTaskWithMetrics(task, "final") } // Release lease _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) } // Heartbeat loop to renew lease. func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { ticker := time.NewTicker(w.config.HeartbeatInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) return } // Also update worker heartbeat _ = w.queue.Heartbeat(w.config.WorkerID) } } } // Shutdown gracefully shuts down the worker. func (w *Worker) Shutdown() error { w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) // Wait for active tasks with timeout done := make(chan struct{}) go func() { w.gracefulWait.Wait() close(done) }() timeout := time.After(w.config.GracefulTimeout) select { case <-done: w.logger.Info("all tasks completed, shutdown successful") case <-timeout: w.logger.Warn("graceful shutdown timeout, releasing active leases") w.releaseAllLeases() } return w.queue.Close() } // Release all active leases. func (w *Worker) releaseAllLeases() { w.activeTasks.Range(func(key, _ interface{}) bool { taskID := key.(string) if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { w.logger.Error("failed to release lease", "task_id", taskID, "error", err) } return true }) } // Helper functions. func (w *Worker) countActiveTasks() int { count := 0 w.activeTasks.Range(func(_, _ interface{}) bool { count++ return true }) return count } func isTransientError(err error) bool { if err == nil { return false } // Check if error is transient (network, timeout, resource unavailable, etc) errStr := err.Error() transientIndicators := []string{ "connection refused", "timeout", "temporary failure", "resource temporarily unavailable", "no such host", "network unreachable", } for _, indicator := range transientIndicators { if strings.Contains(strings.ToLower(errStr), indicator) { return true } } return false } func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) // Parse authentication flags authFlags := auth.ParseAuthFlags() if err := auth.ValidateFlags(authFlags); err != nil { log.Fatalf("Authentication flag error: %v", err) } // Get API key from various sources apiKey := auth.GetAPIKeyFromSources(authFlags) // Load configuration configPath := "config-local.yaml" if authFlags.ConfigFile != "" { configPath = authFlags.ConfigFile } resolvedConfig, err := config.ResolveConfigPath(configPath) if err != nil { log.Fatalf("%v", err) } cfg, err := LoadConfig(resolvedConfig) if err != nil { log.Fatalf("Failed to load config: %v", err) } // Validate authentication configuration if err := cfg.Auth.ValidateAuthConfig(); err != nil { log.Fatalf("Invalid authentication configuration: %v", err) } // Validate configuration if err := cfg.Validate(); err != nil { log.Fatalf("Invalid configuration: %v", err) } // Test authentication if enabled if cfg.Auth.Enabled && apiKey != "" { user, err := cfg.Auth.ValidateAPIKey(apiKey) if err != nil { log.Fatalf("Authentication failed: %v", err) } log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin) } else if cfg.Auth.Enabled { log.Fatal("Authentication required but no API key provided") } worker, err := NewWorker(cfg, apiKey) if err != nil { log.Fatalf("Failed to create worker: %v", err) } sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go worker.Start() sig := <-sigChan log.Printf("Received signal: %v", sig) // Use graceful shutdown if err := worker.Shutdown(); err != nil { log.Printf("Graceful shutdown error: %v", err) worker.Stop() // Fallback to force stop } else { log.Println("Worker shut down gracefully") } }