// Package main implements the ML task worker package main import ( "context" "fmt" "log" "log/slog" "net/http" "os" "os/exec" "os/signal" "path/filepath" "strings" "sync" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/errors" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/telemetry" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" ) // MLServer wraps network.SSHClient for backward compatibility type MLServer struct { *network.SSHClient } func NewMLServer(cfg *Config) (*MLServer, error) { client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) if err != nil { return nil, err } return &MLServer{SSHClient: client}, nil } type Worker struct { id string config *Config server *MLServer queue *queue.TaskQueue running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown runningMu sync.RWMutex ctx context.Context cancel context.CancelFunc logger *logging.Logger metrics *metrics.Metrics metricsSrv *http.Server datasetCache map[string]time.Time datasetCacheMu sync.RWMutex datasetCacheTTL time.Duration // Graceful shutdown fields shutdownCh chan struct{} activeTasks sync.Map // map[string]*queue.Task - track active tasks gracefulWait sync.WaitGroup } func (w *Worker) setupMetricsExporter() error { if !w.config.Metrics.Enabled { return nil } reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), collectors.NewGoCollector(), ) labels := prometheus.Labels{"worker_id": w.id} reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_processed_total", Help: "Total tasks processed successfully by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksProcessed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_failed_total", Help: "Total tasks failed by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksFailed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_active", Help: "Number of tasks currently running on this worker.", ConstLabels: labels, }, func() float64 { return float64(w.runningCount()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_queued", Help: "Latest observed queue depth from Redis.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.QueuedTasks.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_transferred_bytes_total", Help: "Total bytes transferred while fetching datasets.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataTransferred.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_fetch_time_seconds_total", Help: "Total time spent fetching datasets (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_execution_time_seconds_total", Help: "Total execution time for completed tasks (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_worker_max_concurrency", Help: "Configured maximum concurrent tasks for this worker.", ConstLabels: labels, }, func() float64 { return float64(w.config.MaxWorkers) })) mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) srv := &http.Server{ Addr: w.config.Metrics.ListenAddr, Handler: mux, ReadHeaderTimeout: 5 * time.Second, } w.metricsSrv = srv go func() { w.logger.Info("metrics exporter listening", "addr", w.config.Metrics.ListenAddr, "enabled", w.config.Metrics.Enabled) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { w.logger.Warn("metrics exporter stopped", "error", err) } }() return nil } func NewWorker(cfg *Config, apiKey string) (*Worker, error) { srv, err := NewMLServer(cfg) if err != nil { return nil, err } queueCfg := queue.Config{ RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, MetricsFlushInterval: cfg.MetricsFlushInterval, } queue, err := queue.NewTaskQueue(queueCfg) if err != nil { return nil, err } // Create data_dir if it doesn't exist (for production without NAS) if cfg.DataDir != "" { if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) } } ctx, cancel := context.WithCancel(context.Background()) ctx = logging.EnsureTrace(ctx) ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) baseLogger := logging.NewLogger(slog.LevelInfo, false) logger := baseLogger.Component(ctx, "worker") metrics := &metrics.Metrics{} worker := &Worker{ id: cfg.WorkerID, config: cfg, server: srv, queue: queue, running: make(map[string]context.CancelFunc), datasetCache: make(map[string]time.Time), datasetCacheTTL: cfg.DatasetCacheTTL, ctx: ctx, cancel: cancel, logger: logger, metrics: metrics, shutdownCh: make(chan struct{}), } if err := worker.setupMetricsExporter(); err != nil { return nil, err } return worker, nil } func (w *Worker) Start() { w.logger.Info("worker started", "worker_id", w.id, "max_concurrent", w.config.MaxWorkers, "poll_interval", w.config.PollInterval) go w.heartbeat() for { select { case <-w.ctx.Done(): w.logger.Info("shutdown signal received, waiting for tasks") w.waitForTasks() return default: } if w.runningCount() >= w.config.MaxWorkers { time.Sleep(50 * time.Millisecond) continue } queueStart := time.Now() task, err := w.queue.GetNextTaskWithLease(w.config.WorkerID, w.config.TaskLeaseDuration) queueLatency := time.Since(queueStart) if err != nil { if err == context.DeadlineExceeded { continue } w.logger.Error("error fetching task", "worker_id", w.id, "error", err) continue } if task == nil { if queueLatency > 200*time.Millisecond { w.logger.Debug("queue poll latency", "latency_ms", queueLatency.Milliseconds()) } continue } if depth, derr := w.queue.QueueDepth(); derr == nil { if queueLatency > 100*time.Millisecond || depth > 0 { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "remaining_depth", depth) } } else if queueLatency > 100*time.Millisecond { w.logger.Debug("queue fetch metrics", "latency_ms", queueLatency.Milliseconds(), "depth_error", derr) } go w.executeTaskWithLease(task) } } func (w *Worker) heartbeat() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-w.ctx.Done(): return case <-ticker.C: if err := w.queue.Heartbeat(w.id); err != nil { w.logger.Warn("heartbeat failed", "worker_id", w.id, "error", err) } } } } // NEW: Fetch datasets using data_manager func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { logger := w.logger.Job(ctx, task.JobName, task.ID) logger.Info("fetching datasets", "worker_id", w.id, "dataset_count", len(task.Datasets)) for _, dataset := range task.Datasets { if w.datasetIsFresh(dataset) { logger.Debug("skipping cached dataset", "dataset", dataset) continue } // Check for cancellation before each dataset fetch select { case <-w.ctx.Done(): return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err()) default: } logger.Info("fetching dataset", "worker_id", w.id, "dataset", dataset) // Create command with context for cancellation support cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) cmd := exec.CommandContext(cmdCtx, w.config.DataManagerPath, "fetch", task.JobName, dataset, ) output, err := cmd.CombinedOutput() cancel() // Clean up context if err != nil { return &errors.DataFetchError{ Dataset: dataset, JobName: task.JobName, Err: fmt.Errorf("command failed: %w, output: %s", err, output), } } logger.Info("dataset ready", "worker_id", w.id, "dataset", dataset) w.markDatasetFetched(dataset) } return nil } func (w *Worker) runJob(task *queue.Task) error { // Validate job name to prevent path traversal if err := container.ValidateJobName(task.JobName); err != nil { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } jobPaths := config.NewJobPaths(w.config.BasePath) jobDir := filepath.Join(jobPaths.PendingPath(), task.JobName) outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName) logFile := filepath.Join(outputDir, "output.log") // Sanitize paths jobDir, err := container.SanitizePath(jobDir) if err != nil { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } outputDir, err = container.SanitizePath(outputDir) if err != nil { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: err, } } // Create output directory if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) { if err := os.MkdirAll(outputDir, 0755); err != nil { return "", fmt.Errorf("mkdir failed: %w", err) } return "", nil }); err != nil { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to create output dir: %w", err), } } // Move job from pending to running stagingStart := time.Now() if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(jobDir, outputDir); err != nil { return "", fmt.Errorf("rename failed: %w", err) } return "", nil }); err != nil { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "setup", Err: fmt.Errorf("failed to move job: %w", err), } } stagingDuration := time.Since(stagingStart) if w.config.PodmanImage == "" { return &errors.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Err: fmt.Errorf("podman_image must be configured"), } } containerWorkspace := w.config.ContainerWorkspace if containerWorkspace == "" { containerWorkspace = config.DefaultContainerWorkspace } containerResults := w.config.ContainerResults if containerResults == "" { containerResults = config.DefaultContainerResults } podmanCfg := container.PodmanConfig{ Image: w.config.PodmanImage, Workspace: filepath.Join(outputDir, "code"), Results: filepath.Join(outputDir, "results"), ContainerWorkspace: containerWorkspace, ContainerResults: containerResults, GPUAccess: w.config.GPUAccess, } scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript) requirementsPath := filepath.Join(containerWorkspace, "requirements.txt") var extraArgs []string if task.Args != "" { extraArgs = strings.Fields(task.Args) } ioBefore, ioErr := telemetry.ReadProcessIO() podmanCmd := container.BuildPodmanCommand(podmanCfg, scriptPath, requirementsPath, extraArgs) logFileHandle, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err == nil { podmanCmd.Stdout = logFileHandle podmanCmd.Stderr = logFileHandle } else { w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err) } w.logger.Info("executing podman job", "job", task.JobName, "image", w.config.PodmanImage, "workspace", podmanCfg.Workspace, "results", podmanCfg.Results) containerStart := time.Now() if err := podmanCmd.Run(); err != nil { containerDuration := time.Since(containerStart) // Move job to failed directory failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(outputDir, failedDir); err != nil { return "", fmt.Errorf("rename to failed failed: %w", err) } return "", nil }); moveErr != nil { w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr) } if ioErr == nil { if after, err := telemetry.ReadProcessIO(); err == nil { delta := telemetry.DiffIO(ioBefore, after) w.logger.Debug("worker io stats", "job", task.JobName, "read_bytes", delta.ReadBytes, "write_bytes", delta.WriteBytes) } } w.logger.Info("job timing (failure)", "job", task.JobName, "staging_ms", stagingDuration.Milliseconds(), "container_ms", containerDuration.Milliseconds(), "finalize_ms", 0, "total_ms", time.Since(stagingStart).Milliseconds(), ) return fmt.Errorf("execution failed: %w", err) } containerDuration := time.Since(containerStart) finalizeStart := time.Now() // Move job to finished directory finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(outputDir, finishedDir); err != nil { return "", fmt.Errorf("rename to finished failed: %w", err) } return "", nil }); moveErr != nil { w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr) } finalizeDuration := time.Since(finalizeStart) totalDuration := time.Since(stagingStart) var ioDelta telemetry.IOStats if ioErr == nil { if after, err := telemetry.ReadProcessIO(); err == nil { ioDelta = telemetry.DiffIO(ioBefore, after) } } w.logger.Info("job timing", "job", task.JobName, "staging_ms", stagingDuration.Milliseconds(), "container_ms", containerDuration.Milliseconds(), "finalize_ms", finalizeDuration.Milliseconds(), "total_ms", totalDuration.Milliseconds(), "io_read_bytes", ioDelta.ReadBytes, "io_write_bytes", ioDelta.WriteBytes, ) return nil } func parseDatasets(args string) []string { if !strings.Contains(args, "--datasets") { return nil } parts := strings.Fields(args) for i, part := range parts { if part == "--datasets" && i+1 < len(parts) { return strings.Split(parts[i+1], ",") } } return nil } func (w *Worker) waitForTasks() { timeout := time.After(5 * time.Minute) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-timeout: w.logger.Warn("shutdown timeout, force stopping", "running_tasks", len(w.running)) return case <-ticker.C: count := w.runningCount() if count == 0 { w.logger.Info("all tasks completed, shutting down") return } w.logger.Debug("waiting for tasks to complete", "remaining", count) } } } func (w *Worker) runningCount() int { w.runningMu.RLock() defer w.runningMu.RUnlock() return len(w.running) } func (w *Worker) datasetIsFresh(dataset string) bool { w.datasetCacheMu.RLock() defer w.datasetCacheMu.RUnlock() expires, ok := w.datasetCache[dataset] return ok && time.Now().Before(expires) } func (w *Worker) markDatasetFetched(dataset string) { expires := time.Now().Add(w.datasetCacheTTL) w.datasetCacheMu.Lock() w.datasetCache[dataset] = expires w.datasetCacheMu.Unlock() } func (w *Worker) GetMetrics() map[string]any { stats := w.metrics.GetStats() stats["worker_id"] = w.id stats["max_workers"] = w.config.MaxWorkers return stats } func (w *Worker) Stop() { w.cancel() w.waitForTasks() // FIXED: Check error return values if err := w.server.Close(); err != nil { w.logger.Warn("error closing server connection", "error", err) } if err := w.queue.Close(); err != nil { w.logger.Warn("error closing queue connection", "error", err) } if w.metricsSrv != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := w.metricsSrv.Shutdown(ctx); err != nil { w.logger.Warn("metrics exporter shutdown error", "error", err) } } w.logger.Info("worker stopped", "worker_id", w.id) } // Execute task with lease management and retry: func (w *Worker) executeTaskWithLease(task *queue.Task) { // Track task for graceful shutdown w.gracefulWait.Add(1) w.activeTasks.Store(task.ID, task) defer w.gracefulWait.Done() defer w.activeTasks.Delete(task.ID) // Create task-specific context with timeout taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) defer taskCancel() logger := w.logger.Job(taskCtx, task.JobName, task.ID) logger.Info("starting task", "worker_id", w.id, "datasets", task.Datasets, "priority", task.Priority) // Record task start w.metrics.RecordTaskStart() defer w.metrics.RecordTaskCompletion() // Check for context cancellation select { case <-taskCtx.Done(): logger.Info("task cancelled before execution") return default: } // Parse datasets from task arguments if task.Datasets == nil { task.Datasets = parseDatasets(task.Args) } // Start heartbeat goroutine heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) defer cancelHeartbeat() go w.heartbeatLoop(heartbeatCtx, task.ID) // Update task status task.Status = "running" now := time.Now() task.StartedAt = &now task.WorkerID = w.id if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { logger.Error("failed to update task status", "error", err) w.metrics.RecordTaskFailure() return } if w.config.AutoFetchData && len(task.Datasets) > 0 { if err := w.fetchDatasets(taskCtx, task); err != nil { logger.Error("data fetch failed", "error", err) task.Status = "failed" task.Error = fmt.Sprintf("Data fetch failed: %v", err) endTime := time.Now() task.EndedAt = &endTime err := w.queue.UpdateTask(task) if err != nil { logger.Error("failed to update task status after data fetch failure", "error", err) } w.metrics.RecordTaskFailure() return } } // Execute job with panic recovery var execErr error func() { defer func() { if r := recover(); r != nil { execErr = fmt.Errorf("panic during execution: %v", r) } }() execErr = w.runJob(task) }() // Finalize task endTime := time.Now() task.EndedAt = &endTime if execErr != nil { task.Error = execErr.Error() // Check if transient error (network, timeout, etc) if isTransientError(execErr) && task.RetryCount < task.MaxRetries { w.logger.Warn("task failed with transient error, will retry", "task_id", task.ID, "error", execErr, "retry_count", task.RetryCount) w.queue.RetryTask(task) } else { task.Status = "failed" w.queue.UpdateTaskWithMetrics(task, "final") } } else { task.Status = "completed" w.queue.UpdateTaskWithMetrics(task, "final") } // Release lease w.queue.ReleaseLease(task.ID, w.config.WorkerID) } // Heartbeat loop to renew lease: func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { ticker := time.NewTicker(w.config.HeartbeatInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) return } // Also update worker heartbeat w.queue.Heartbeat(w.config.WorkerID) } } } // Graceful shutdown: func (w *Worker) Shutdown() error { w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) // Wait for active tasks with timeout done := make(chan struct{}) go func() { w.gracefulWait.Wait() close(done) }() timeout := time.After(w.config.GracefulTimeout) select { case <-done: w.logger.Info("all tasks completed, shutdown successful") case <-timeout: w.logger.Warn("graceful shutdown timeout, releasing active leases") w.releaseAllLeases() } return w.queue.Close() } // Release all active leases: func (w *Worker) releaseAllLeases() { w.activeTasks.Range(func(key, value interface{}) bool { taskID := key.(string) if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { w.logger.Error("failed to release lease", "task_id", taskID, "error", err) } return true }) } // Helper functions: func (w *Worker) countActiveTasks() int { count := 0 w.activeTasks.Range(func(_, _ interface{}) bool { count++ return true }) return count } func isTransientError(err error) bool { if err == nil { return false } // Check if error is transient (network, timeout, resource unavailable, etc) errStr := err.Error() transientIndicators := []string{ "connection refused", "timeout", "temporary failure", "resource temporarily unavailable", "no such host", "network unreachable", } for _, indicator := range transientIndicators { if strings.Contains(strings.ToLower(errStr), indicator) { return true } } return false } func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) // Parse authentication flags authFlags := auth.ParseAuthFlags() if err := auth.ValidateAuthFlags(authFlags); err != nil { log.Fatalf("Authentication flag error: %v", err) } // Get API key from various sources apiKey := auth.GetAPIKeyFromSources(authFlags) // Load configuration configPath := "config-local.yaml" if authFlags.ConfigFile != "" { configPath = authFlags.ConfigFile } resolvedConfig, err := config.ResolveConfigPath(configPath) if err != nil { log.Fatalf("%v", err) } cfg, err := LoadConfig(resolvedConfig) if err != nil { log.Fatalf("Failed to load config: %v", err) } // Validate authentication configuration if err := cfg.Auth.ValidateAuthConfig(); err != nil { log.Fatalf("Invalid authentication configuration: %v", err) } // Validate configuration if err := cfg.Validate(); err != nil { log.Fatalf("Invalid configuration: %v", err) } // Test authentication if enabled if cfg.Auth.Enabled && apiKey != "" { user, err := cfg.Auth.ValidateAPIKey(apiKey) if err != nil { log.Fatalf("Authentication failed: %v", err) } log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin) } else if cfg.Auth.Enabled { log.Fatal("Authentication required but no API key provided") } worker, err := NewWorker(cfg, apiKey) if err != nil { log.Fatalf("Failed to create worker: %v", err) } sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go worker.Start() sig := <-sigChan log.Printf("Received signal: %v", sig) // Use graceful shutdown if err := worker.Shutdown(); err != nil { log.Printf("Graceful shutdown error: %v", err) worker.Stop() // Fallback to force stop } else { log.Println("Worker shut down gracefully") } }