From 82034c68f318df5ec2a101797e5d34697961bcb8 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 5 Jan 2026 12:31:13 -0500 Subject: [PATCH] feat(worker): add integrity checks, snapshot staging, and prewarm support --- cmd/worker/worker_config.go | 179 ----- cmd/worker/worker_server.go | 984 +-------------------------- internal/envpool/envpool.go | 288 ++++++++ internal/resources/manager.go | 323 +++++++++ internal/worker/config.go | 371 +++++++++++ internal/worker/core.go | 547 +++++++++++++++ internal/worker/data_integrity.go | 824 +++++++++++++++++++++++ internal/worker/execution.go | 1029 +++++++++++++++++++++++++++++ internal/worker/gpu_detector.go | 168 +++++ internal/worker/jupyter_task.go | 130 ++++ internal/worker/runloop.go | 525 +++++++++++++++ internal/worker/snapshot_store.go | 270 ++++++++ 12 files changed, 4493 insertions(+), 1145 deletions(-) delete mode 100644 cmd/worker/worker_config.go create mode 100644 internal/envpool/envpool.go create mode 100644 internal/resources/manager.go create mode 100644 internal/worker/config.go create mode 100644 internal/worker/core.go create mode 100644 internal/worker/data_integrity.go create mode 100644 internal/worker/execution.go create mode 100644 internal/worker/gpu_detector.go create mode 100644 internal/worker/jupyter_task.go create mode 100644 internal/worker/runloop.go create mode 100644 internal/worker/snapshot_store.go diff --git a/cmd/worker/worker_config.go b/cmd/worker/worker_config.go deleted file mode 100644 index 9507941..0000000 --- a/cmd/worker/worker_config.go +++ /dev/null @@ -1,179 +0,0 @@ -package main - -import ( - "fmt" - "path/filepath" - "time" - - "github.com/google/uuid" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/config" - "github.com/jfraeys/fetch_ml/internal/fileutil" - "gopkg.in/yaml.v3" -) - -const ( - defaultMetricsFlushInterval = 500 * time.Millisecond - datasetCacheDefaultTTL = 30 * time.Minute -) - -// Config holds worker configuration. -type Config struct { - Host string `yaml:"host"` - User string `yaml:"user"` - SSHKey string `yaml:"ssh_key"` - Port int `yaml:"port"` - BasePath string `yaml:"base_path"` - TrainScript string `yaml:"train_script"` - RedisAddr string `yaml:"redis_addr"` - RedisPassword string `yaml:"redis_password"` - RedisDB int `yaml:"redis_db"` - KnownHosts string `yaml:"known_hosts"` - WorkerID string `yaml:"worker_id"` - MaxWorkers int `yaml:"max_workers"` - PollInterval int `yaml:"poll_interval_seconds"` - Resources config.ResourceConfig `yaml:"resources"` - LocalMode bool `yaml:"local_mode"` - - // Authentication - Auth auth.Config `yaml:"auth"` - - // Metrics exporter - Metrics MetricsConfig `yaml:"metrics"` - // Metrics buffering - MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` - - // Data management - DataManagerPath string `yaml:"data_manager_path"` - AutoFetchData bool `yaml:"auto_fetch_data"` - DataDir string `yaml:"data_dir"` - DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` - - // Podman execution - PodmanImage string `yaml:"podman_image"` - ContainerWorkspace string `yaml:"container_workspace"` - ContainerResults string `yaml:"container_results"` - GPUAccess bool `yaml:"gpu_access"` - - // Task lease and retry settings - TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min) - HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min) - MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) - GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min) -} - -// MetricsConfig controls the Prometheus exporter. -type MetricsConfig struct { - Enabled bool `yaml:"enabled"` - ListenAddr string `yaml:"listen_addr"` -} - -// LoadConfig loads worker configuration from a YAML file. -func LoadConfig(path string) (*Config, error) { - data, err := fileutil.SecureFileRead(path) - if err != nil { - return nil, err - } - - var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, err - } - - // Get smart defaults for current environment - smart := config.GetSmartDefaults() - - if cfg.Port == 0 { - cfg.Port = config.DefaultSSHPort - } - if cfg.Host == "" { - cfg.Host = smart.Host() - } - if cfg.BasePath == "" { - cfg.BasePath = smart.BasePath() - } - if cfg.RedisAddr == "" { - cfg.RedisAddr = smart.RedisAddr() - } - if cfg.KnownHosts == "" { - cfg.KnownHosts = smart.KnownHostsPath() - } - if cfg.WorkerID == "" { - cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) - } - cfg.Resources.ApplyDefaults() - if cfg.MaxWorkers > 0 { - cfg.Resources.MaxWorkers = cfg.MaxWorkers - } else { - cfg.MaxWorkers = cfg.Resources.MaxWorkers - } - if cfg.PollInterval == 0 { - cfg.PollInterval = smart.PollInterval() - } - if cfg.DataManagerPath == "" { - cfg.DataManagerPath = "./data_manager" - } - if cfg.DataDir == "" { - if cfg.Host == "" || !cfg.AutoFetchData { - cfg.DataDir = config.DefaultLocalDataDir - } else { - cfg.DataDir = smart.DataDir() - } - } - if cfg.Metrics.ListenAddr == "" { - cfg.Metrics.ListenAddr = ":9100" - } - if cfg.MetricsFlushInterval == 0 { - cfg.MetricsFlushInterval = defaultMetricsFlushInterval - } - if cfg.DatasetCacheTTL == 0 { - cfg.DatasetCacheTTL = datasetCacheDefaultTTL - } - - // Set lease and retry defaults - if cfg.TaskLeaseDuration == 0 { - cfg.TaskLeaseDuration = 30 * time.Minute - } - if cfg.HeartbeatInterval == 0 { - cfg.HeartbeatInterval = 1 * time.Minute - } - if cfg.MaxRetries == 0 { - cfg.MaxRetries = 3 - } - if cfg.GracefulTimeout == 0 { - cfg.GracefulTimeout = 5 * time.Minute - } - - return &cfg, nil -} - -// Validate implements config.Validator interface. -func (c *Config) Validate() error { - if c.Port != 0 { - if err := config.ValidatePort(c.Port); err != nil { - return fmt.Errorf("invalid SSH port: %w", err) - } - } - - if c.BasePath != "" { - // Convert relative paths to absolute - c.BasePath = config.ExpandPath(c.BasePath) - if !filepath.IsAbs(c.BasePath) { - c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath) - } - } - - if c.RedisAddr != "" { - if err := config.ValidateRedisAddr(c.RedisAddr); err != nil { - return fmt.Errorf("invalid Redis configuration: %w", err) - } - } - - if c.MaxWorkers < 1 { - return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) - } - - return nil -} - -// Task struct and Redis constants moved to internal/queue diff --git a/cmd/worker/worker_server.go b/cmd/worker/worker_server.go index ab196e2..f9034cb 100644 --- a/cmd/worker/worker_server.go +++ b/cmd/worker/worker_server.go @@ -2,975 +2,32 @@ package main import ( - "context" - "fmt" "log" - "log/slog" - "net/http" "os" - "os/exec" "os/signal" - "path/filepath" "strings" - "sync" "syscall" - "time" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" - "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/errtypes" - "github.com/jfraeys/fetch_ml/internal/fileutil" - "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/metrics" - "github.com/jfraeys/fetch_ml/internal/network" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/telemetry" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" - "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/jfraeys/fetch_ml/internal/worker" ) -// MLServer wraps network.SSHClient for backward compatibility. -type MLServer struct { - *network.SSHClient -} +const ( + defaultConfigPath = "config-local.yaml" +) -// isValidName validates that input strings contain only safe characters. -// isValidName checks if the input string is a valid name. -func isValidName(input string) bool { - return len(input) > 0 && len(input) < 256 -} - -// NewMLServer creates a new ML server connection. -// NewMLServer returns a new MLServer instance. -func NewMLServer(cfg *Config) (*MLServer, error) { - if cfg.LocalMode { - return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil - } - - client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) - if err != nil { - return nil, err - } - - return &MLServer{SSHClient: client}, nil -} - -// Worker represents an ML task worker. -type Worker struct { - id string - config *Config - server *MLServer - queue *queue.TaskQueue - running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown - runningMu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - logger *logging.Logger - metrics *metrics.Metrics - metricsSrv *http.Server - - datasetCache map[string]time.Time - datasetCacheMu sync.RWMutex - datasetCacheTTL time.Duration - - // Graceful shutdown fields - shutdownCh chan struct{} - activeTasks sync.Map // map[string]*queue.Task - track active tasks - gracefulWait sync.WaitGroup -} - -func (w *Worker) setupMetricsExporter() { - if !w.config.Metrics.Enabled { - return - } - - reg := prometheus.NewRegistry() - reg.MustRegister( - collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), - collectors.NewGoCollector(), - ) - - labels := prometheus.Labels{"worker_id": w.id} - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_processed_total", - Help: "Total tasks processed successfully by this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.TasksProcessed.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_failed_total", - Help: "Total tasks failed by this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.TasksFailed.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_active", - Help: "Number of tasks currently running on this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.runningCount()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_queued", - Help: "Latest observed queue depth from Redis.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.QueuedTasks.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_data_transferred_bytes_total", - Help: "Total bytes transferred while fetching datasets.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.DataTransferred.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_data_fetch_time_seconds_total", - Help: "Total time spent fetching datasets (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_execution_time_seconds_total", - Help: "Total execution time for completed tasks (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_worker_max_concurrency", - Help: "Configured maximum concurrent tasks for this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.config.MaxWorkers) - })) - - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) - - srv := &http.Server{ - Addr: w.config.Metrics.ListenAddr, - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - } - - w.metricsSrv = srv - go func() { - w.logger.Info("metrics exporter listening", - "addr", w.config.Metrics.ListenAddr, - "enabled", w.config.Metrics.Enabled) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - w.logger.Warn("metrics exporter stopped", - "error", err) - } - }() -} - -// NewWorker creates a new worker instance. -func NewWorker(cfg *Config, _ string) (*Worker, error) { - srv, err := NewMLServer(cfg) - if err != nil { - return nil, err - } - - queueCfg := queue.Config{ - RedisAddr: cfg.RedisAddr, - RedisPassword: cfg.RedisPassword, - RedisDB: cfg.RedisDB, - MetricsFlushInterval: cfg.MetricsFlushInterval, - } - queue, err := queue.NewTaskQueue(queueCfg) - if err != nil { - return nil, err - } - - // Create data_dir if it doesn't exist (for production without NAS) - if cfg.DataDir != "" { - if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { - log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) +func resolveWorkerConfigPath(flags *auth.Flags) string { + if flags != nil { + p := strings.TrimSpace(flags.ConfigFile) + if p != "" { + return p } } - - ctx, cancel := context.WithCancel(context.Background()) - ctx = logging.EnsureTrace(ctx) - ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) - - baseLogger := logging.NewLogger(slog.LevelInfo, false) - logger := baseLogger.Component(ctx, "worker") - metrics := &metrics.Metrics{} - - worker := &Worker{ - id: cfg.WorkerID, - config: cfg, - server: srv, - queue: queue, - running: make(map[string]context.CancelFunc), - datasetCache: make(map[string]time.Time), - datasetCacheTTL: cfg.DatasetCacheTTL, - ctx: ctx, - cancel: cancel, - logger: logger, - metrics: metrics, - shutdownCh: make(chan struct{}), + if _, err := os.Stat("/app/configs/worker.yaml"); err == nil { + return "/app/configs/worker.yaml" } - - worker.setupMetricsExporter() - - return worker, nil -} - -// Start starts the worker's main processing loop. -func (w *Worker) Start() { - w.logger.Info("worker started", - "worker_id", w.id, - "max_concurrent", w.config.MaxWorkers, - "poll_interval", w.config.PollInterval) - - go w.heartbeat() - - for { - select { - case <-w.ctx.Done(): - w.logger.Info("shutdown signal received, waiting for tasks") - w.waitForTasks() - return - default: - } - - if w.runningCount() >= w.config.MaxWorkers { - time.Sleep(50 * time.Millisecond) - continue - } - - queueStart := time.Now() - blockTimeout := time.Duration(w.config.PollInterval) * time.Second - task, err := w.queue.GetNextTaskWithLeaseBlocking(w.config.WorkerID, w.config.TaskLeaseDuration, blockTimeout) - queueLatency := time.Since(queueStart) - if err != nil { - if err == context.DeadlineExceeded { - continue - } - w.logger.Error("error fetching task", - "worker_id", w.id, - "error", err) - continue - } - - if task == nil { - if queueLatency > 200*time.Millisecond { - w.logger.Debug("queue poll latency", - "latency_ms", queueLatency.Milliseconds()) - } - continue - } - - if depth, derr := w.queue.QueueDepth(); derr == nil { - if queueLatency > 100*time.Millisecond || depth > 0 { - w.logger.Debug("queue fetch metrics", - "latency_ms", queueLatency.Milliseconds(), - "remaining_depth", depth) - } - } else if queueLatency > 100*time.Millisecond { - w.logger.Debug("queue fetch metrics", - "latency_ms", queueLatency.Milliseconds(), - "depth_error", derr) - } - - go w.executeTaskWithLease(task) - } -} - -func (w *Worker) heartbeat() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-w.ctx.Done(): - return - case <-ticker.C: - if err := w.queue.Heartbeat(w.id); err != nil { - w.logger.Warn("heartbeat failed", - "worker_id", w.id, - "error", err) - } - } - } -} - -// NEW: Fetch datasets using data_manager. -func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { - logger := w.logger.Job(ctx, task.JobName, task.ID) - logger.Info("fetching datasets", - "worker_id", w.id, - "dataset_count", len(task.Datasets)) - - for _, dataset := range task.Datasets { - if w.datasetIsFresh(dataset) { - logger.Debug("skipping cached dataset", - "dataset", dataset) - continue - } - // Check for cancellation before each dataset fetch - select { - case <-w.ctx.Done(): - return fmt.Errorf("dataset fetch cancelled: %w", w.ctx.Err()) - default: - } - - logger.Info("fetching dataset", - "worker_id", w.id, - "dataset", dataset) - - // Create command with context for cancellation support - cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) - // Validate inputs to prevent command injection - if !isValidName(task.JobName) || !isValidName(dataset) { - cancel() - return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters") - } - //nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated - cmd := exec.CommandContext(cmdCtx, - w.config.DataManagerPath, - "fetch", - task.JobName, - dataset, - ) - - output, err := cmd.CombinedOutput() - cancel() // Clean up context - - if err != nil { - return &errtypes.DataFetchError{ - Dataset: dataset, - JobName: task.JobName, - Err: fmt.Errorf("command failed: %w, output: %s", err, output), - } - } - - logger.Info("dataset ready", - "worker_id", w.id, - "dataset", dataset) - w.markDatasetFetched(dataset) - } - - return nil -} - -func (w *Worker) runJob(ctx context.Context, task *queue.Task) error { - // Validate job name to prevent path traversal - if err := container.ValidateJobName(task.JobName); err != nil { - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "validation", - Err: err, - } - } - - jobDir, outputDir, logFile, err := w.setupJobDirectories(task) - if err != nil { - return err - } - - return w.executeJob(ctx, task, jobDir, outputDir, logFile) -} - -func (w *Worker) setupJobDirectories(task *queue.Task) (jobDir, outputDir, logFile string, err error) { - jobPaths := config.NewJobPaths(w.config.BasePath) - pendingDir := jobPaths.PendingPath() - jobDir = filepath.Join(pendingDir, task.JobName) - outputDir = filepath.Join(jobPaths.RunningPath(), task.JobName) - logFile = filepath.Join(outputDir, "output.log") - - // Create pending directory - if err := os.MkdirAll(pendingDir, 0750); err != nil { - return "", "", "", &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "setup", - Err: fmt.Errorf("failed to create pending dir: %w", err), - } - } - - // Create job directory in pending - if err := os.MkdirAll(jobDir, 0750); err != nil { - return "", "", "", &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "setup", - Err: fmt.Errorf("failed to create job dir: %w", err), - } - } - - // Sanitize paths - jobDir, err = container.SanitizePath(jobDir) - if err != nil { - return "", "", "", &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "validation", - Err: err, - } - } - outputDir, err = container.SanitizePath(outputDir) - if err != nil { - return "", "", "", &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "validation", - Err: err, - } - } - - return jobDir, outputDir, logFile, nil -} - -func (w *Worker) executeJob(ctx context.Context, task *queue.Task, jobDir, outputDir, logFile string) error { - // Create output directory - if _, err := telemetry.ExecWithMetrics(w.logger, "create output dir", 100*time.Millisecond, func() (string, error) { - if err := os.MkdirAll(outputDir, 0750); err != nil { - return "", fmt.Errorf("mkdir failed: %w", err) - } - return "", nil - }); err != nil { - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "setup", - Err: fmt.Errorf("failed to create output dir: %w", err), - } - } - - // Move job from pending to running - stagingStart := time.Now() - if _, err := telemetry.ExecWithMetrics(w.logger, "stage job", 100*time.Millisecond, func() (string, error) { - // Remove existing directory if it exists - if _, err := os.Stat(outputDir); err == nil { - if err := os.RemoveAll(outputDir); err != nil { - return "", fmt.Errorf("remove existing failed: %w", err) - } - } - if err := os.Rename(jobDir, outputDir); err != nil { - return "", fmt.Errorf("rename failed: %w", err) - } - return "", nil - }); err != nil { - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "setup", - Err: fmt.Errorf("failed to move job: %w", err), - } - } - stagingDuration := time.Since(stagingStart) - - // Execute job - if w.config.LocalMode { - return w.executeLocalJob(ctx, task, outputDir, logFile) - } - - return w.executeContainerJob(ctx, task, outputDir, logFile, stagingDuration) -} - -func (w *Worker) executeLocalJob(ctx context.Context, task *queue.Task, outputDir, logFile string) error { - // Create experiment script - scriptContent := `#!/bin/bash -set -e - -echo "Starting experiment: ` + task.JobName + `" -echo "Task ID: ` + task.ID + `" -echo "Timestamp: $(date)" - -# Simulate ML experiment -echo "Loading data..." -sleep 1 - -echo "Training model..." -sleep 2 - -echo "Evaluating model..." -sleep 1 - -# Generate results -ACCURACY=0.95 -LOSS=0.05 -EPOCHS=10 - -echo "" -echo "=== EXPERIMENT RESULTS ===" -echo "Accuracy: $ACCURACY" -echo "Loss: $LOSS" -echo "Epochs: $EPOCHS" -echo "Status: SUCCESS" -echo "=========================" -echo "Experiment completed successfully!" -` - - scriptPath := filepath.Join(outputDir, "run.sh") - if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil { - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "execution", - Err: fmt.Errorf("failed to write script: %w", err), - } - } - - logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err != nil { - w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err) - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "execution", - Err: fmt.Errorf("failed to open log file: %w", err), - } - } - defer func() { - if err := logFileHandle.Close(); err != nil { - log.Printf("Warning: failed to close log file: %v", err) - } - }() - - // Execute the script directly - localCmd := exec.CommandContext(ctx, "bash", scriptPath) - localCmd.Stdout = logFileHandle - localCmd.Stderr = logFileHandle - - w.logger.Info("executing local job", - "job", task.JobName, - "task_id", task.ID, - "script", scriptPath) - - if err := localCmd.Run(); err != nil { - return &errtypes.TaskExecutionError{ - TaskID: task.ID, - JobName: task.JobName, - Phase: "execution", - Err: fmt.Errorf("execution failed: %w", err), - } - } - - return nil -} - -func (w *Worker) executeContainerJob( - ctx context.Context, - task *queue.Task, - outputDir, logFile string, - stagingDuration time.Duration, -) error { - containerResults := w.config.ContainerResults - if containerResults == "" { - containerResults = config.DefaultContainerResults - } - - containerWorkspace := w.config.ContainerWorkspace - if containerWorkspace == "" { - containerWorkspace = config.DefaultContainerWorkspace - } - - jobPaths := config.NewJobPaths(w.config.BasePath) - stagingStart := time.Now() - - podmanCfg := container.PodmanConfig{ - Image: w.config.PodmanImage, - Workspace: filepath.Join(outputDir, "code"), - Results: filepath.Join(outputDir, "results"), - ContainerWorkspace: containerWorkspace, - ContainerResults: containerResults, - GPUAccess: w.config.GPUAccess, - } - - scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript) - requirementsPath := filepath.Join(containerWorkspace, "requirements.txt") - - var extraArgs []string - if task.Args != "" { - extraArgs = strings.Fields(task.Args) - } - - ioBefore, ioErr := telemetry.ReadProcessIO() - podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, requirementsPath, extraArgs) - logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err == nil { - podmanCmd.Stdout = logFileHandle - podmanCmd.Stderr = logFileHandle - } else { - w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err) - } - - w.logger.Info("executing podman job", - "job", task.JobName, - "image", w.config.PodmanImage, - "workspace", podmanCfg.Workspace, - "results", podmanCfg.Results) - - containerStart := time.Now() - if err := podmanCmd.Run(); err != nil { - containerDuration := time.Since(containerStart) - // Move job to failed directory - failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) - if _, moveErr := telemetry.ExecWithMetrics(w.logger, "move failed job", 100*time.Millisecond, func() (string, error) { - if err := os.Rename(outputDir, failedDir); err != nil { - return "", fmt.Errorf("rename to failed failed: %w", err) - } - return "", nil - }); moveErr != nil { - w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr) - } - - if ioErr == nil { - if after, err := telemetry.ReadProcessIO(); err == nil { - delta := telemetry.DiffIO(ioBefore, after) - w.logger.Debug("worker io stats", - "job", task.JobName, - "read_bytes", delta.ReadBytes, - "write_bytes", delta.WriteBytes) - } - } - w.logger.Info("job timing (failure)", - "job", task.JobName, - "staging_ms", stagingDuration.Milliseconds(), - "container_ms", containerDuration.Milliseconds(), - "finalize_ms", 0, - "total_ms", time.Since(stagingStart).Milliseconds(), - ) - return fmt.Errorf("execution failed: %w", err) - } - containerDuration := time.Since(containerStart) - - finalizeStart := time.Now() - // Move job to finished directory - finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) - if _, moveErr := telemetry.ExecWithMetrics(w.logger, "finalize job", 100*time.Millisecond, func() (string, error) { - if err := os.Rename(outputDir, finishedDir); err != nil { - return "", fmt.Errorf("rename to finished failed: %w", err) - } - return "", nil - }); moveErr != nil { - w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr) - } - finalizeDuration := time.Since(finalizeStart) - totalDuration := time.Since(stagingStart) - var ioDelta telemetry.IOStats - if ioErr == nil { - if after, err := telemetry.ReadProcessIO(); err == nil { - ioDelta = telemetry.DiffIO(ioBefore, after) - } - } - - w.logger.Info("job timing", - "job", task.JobName, - "staging_ms", stagingDuration.Milliseconds(), - "container_ms", containerDuration.Milliseconds(), - "finalize_ms", finalizeDuration.Milliseconds(), - "total_ms", totalDuration.Milliseconds(), - "io_read_bytes", ioDelta.ReadBytes, - "io_write_bytes", ioDelta.WriteBytes, - ) - - return nil -} - -func parseDatasets(args string) []string { - if !strings.Contains(args, "--datasets") { - return nil - } - - parts := strings.Fields(args) - for i, part := range parts { - if part == "--datasets" && i+1 < len(parts) { - return strings.Split(parts[i+1], ",") - } - } - - return nil -} - -func (w *Worker) waitForTasks() { - timeout := time.After(5 * time.Minute) - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-timeout: - w.logger.Warn("shutdown timeout, force stopping", - "running_tasks", len(w.running)) - return - case <-ticker.C: - count := w.runningCount() - if count == 0 { - w.logger.Info("all tasks completed, shutting down") - return - } - w.logger.Debug("waiting for tasks to complete", - "remaining", count) - } - } -} - -func (w *Worker) runningCount() int { - w.runningMu.RLock() - defer w.runningMu.RUnlock() - return len(w.running) -} - -func (w *Worker) datasetIsFresh(dataset string) bool { - w.datasetCacheMu.RLock() - defer w.datasetCacheMu.RUnlock() - expires, ok := w.datasetCache[dataset] - return ok && time.Now().Before(expires) -} - -func (w *Worker) markDatasetFetched(dataset string) { - expires := time.Now().Add(w.datasetCacheTTL) - w.datasetCacheMu.Lock() - w.datasetCache[dataset] = expires - w.datasetCacheMu.Unlock() -} - -// GetMetrics returns current worker metrics. -func (w *Worker) GetMetrics() map[string]any { - stats := w.metrics.GetStats() - stats["worker_id"] = w.id - stats["max_workers"] = w.config.MaxWorkers - return stats -} - -// Stop gracefully shuts down the worker. -func (w *Worker) Stop() { - w.cancel() - w.waitForTasks() - - // FIXED: Check error return values - if err := w.server.Close(); err != nil { - w.logger.Warn("error closing server connection", "error", err) - } - if err := w.queue.Close(); err != nil { - w.logger.Warn("error closing queue connection", "error", err) - } - if w.metricsSrv != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := w.metricsSrv.Shutdown(ctx); err != nil { - w.logger.Warn("metrics exporter shutdown error", "error", err) - } - } - w.logger.Info("worker stopped", "worker_id", w.id) -} - -// Execute task with lease management and retry. -func (w *Worker) executeTaskWithLease(task *queue.Task) { - // Track task for graceful shutdown - w.gracefulWait.Add(1) - w.activeTasks.Store(task.ID, task) - defer w.gracefulWait.Done() - defer w.activeTasks.Delete(task.ID) - - // Create task-specific context with timeout - taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing - taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata - taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata - - taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) - defer taskCancel() - - logger := w.logger.Job(taskCtx, task.JobName, task.ID) - logger.Info("starting task", - "worker_id", w.id, - "datasets", task.Datasets, - "priority", task.Priority) - - // Record task start - w.metrics.RecordTaskStart() - defer w.metrics.RecordTaskCompletion() - - // Check for context cancellation - select { - case <-taskCtx.Done(): - logger.Info("task cancelled before execution") - return - default: - } - - // Parse datasets from task arguments - if task.Datasets == nil { - task.Datasets = parseDatasets(task.Args) - } - - // Start heartbeat goroutine - heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) - defer cancelHeartbeat() - - go w.heartbeatLoop(heartbeatCtx, task.ID) - - // Update task status - task.Status = "running" - now := time.Now() - task.StartedAt = &now - task.WorkerID = w.id - - if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { - logger.Error("failed to update task status", "error", err) - w.metrics.RecordTaskFailure() - return - } - - if w.config.AutoFetchData && len(task.Datasets) > 0 { - if err := w.fetchDatasets(taskCtx, task); err != nil { - logger.Error("data fetch failed", "error", err) - task.Status = "failed" - task.Error = fmt.Sprintf("Data fetch failed: %v", err) - endTime := time.Now() - task.EndedAt = &endTime - err := w.queue.UpdateTask(task) - if err != nil { - logger.Error("failed to update task status after data fetch failure", "error", err) - } - w.metrics.RecordTaskFailure() - return - } - } - - // Execute job with panic recovery - var execErr error - func() { - defer func() { - if r := recover(); r != nil { - execErr = fmt.Errorf("panic during execution: %v", r) - } - }() - execErr = w.runJob(taskCtx, task) - }() - - // Finalize task - endTime := time.Now() - task.EndedAt = &endTime - - if execErr != nil { - task.Error = execErr.Error() - - // Check if transient error (network, timeout, etc) - if isTransientError(execErr) && task.RetryCount < task.MaxRetries { - w.logger.Warn("task failed with transient error, will retry", - "task_id", task.ID, - "error", execErr, - "retry_count", task.RetryCount) - _ = w.queue.RetryTask(task) - } else { - task.Status = "failed" - _ = w.queue.UpdateTaskWithMetrics(task, "final") - } - } else { - task.Status = "completed" - - // Read output file for completed tasks - jobPaths := config.NewJobPaths(w.config.BasePath) - outputDir := filepath.Join(jobPaths.RunningPath(), task.JobName) - logFile := filepath.Join(outputDir, "output.log") - if outputBytes, err := os.ReadFile(logFile); err == nil { - task.Output = string(outputBytes) - } - - _ = w.queue.UpdateTaskWithMetrics(task, "final") - } - - // Release lease - _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) -} - -// Heartbeat loop to renew lease. -func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { - ticker := time.NewTicker(w.config.HeartbeatInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { - w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) - return - } - // Also update worker heartbeat - _ = w.queue.Heartbeat(w.config.WorkerID) - } - } -} - -// Shutdown gracefully shuts down the worker. -func (w *Worker) Shutdown() error { - w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) - - // Wait for active tasks with timeout - done := make(chan struct{}) - go func() { - w.gracefulWait.Wait() - close(done) - }() - - timeout := time.After(w.config.GracefulTimeout) - select { - case <-done: - w.logger.Info("all tasks completed, shutdown successful") - case <-timeout: - w.logger.Warn("graceful shutdown timeout, releasing active leases") - w.releaseAllLeases() - } - - return w.queue.Close() -} - -// Release all active leases. -func (w *Worker) releaseAllLeases() { - w.activeTasks.Range(func(key, _ interface{}) bool { - taskID := key.(string) - if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { - w.logger.Error("failed to release lease", "task_id", taskID, "error", err) - } - return true - }) -} - -// Helper functions. -func (w *Worker) countActiveTasks() int { - count := 0 - w.activeTasks.Range(func(_, _ interface{}) bool { - count++ - return true - }) - return count -} - -func isTransientError(err error) bool { - if err == nil { - return false - } - // Check if error is transient (network, timeout, resource unavailable, etc) - errStr := err.Error() - transientIndicators := []string{ - "connection refused", - "timeout", - "temporary failure", - "resource temporarily unavailable", - "no such host", - "network unreachable", - } - for _, indicator := range transientIndicators { - if strings.Contains(strings.ToLower(errStr), indicator) { - return true - } - } - return false + return defaultConfigPath } func main() { @@ -986,17 +43,12 @@ func main() { apiKey := auth.GetAPIKeyFromSources(authFlags) // Load configuration - configPath := "config-local.yaml" - if authFlags.ConfigFile != "" { - configPath = authFlags.ConfigFile - } - - resolvedConfig, err := config.ResolveConfigPath(configPath) + resolvedConfig, err := config.ResolveConfigPath(resolveWorkerConfigPath(authFlags)) if err != nil { log.Fatalf("%v", err) } - cfg, err := LoadConfig(resolvedConfig) + cfg, err := worker.LoadConfig(resolvedConfig) if err != nil { log.Fatalf("Failed to load config: %v", err) } @@ -1022,7 +74,7 @@ func main() { log.Fatal("Authentication required but no API key provided") } - worker, err := NewWorker(cfg, apiKey) + wrk, err := worker.NewWorker(cfg, apiKey) if err != nil { log.Fatalf("Failed to create worker: %v", err) } @@ -1030,15 +82,15 @@ func main() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - go worker.Start() + go wrk.Start() sig := <-sigChan log.Printf("Received signal: %v", sig) // Use graceful shutdown - if err := worker.Shutdown(); err != nil { + if err := wrk.Shutdown(); err != nil { log.Printf("Graceful shutdown error: %v", err) - worker.Stop() // Fallback to force stop + wrk.Stop() // Fallback to force stop } else { log.Println("Worker shut down gracefully") } diff --git a/internal/envpool/envpool.go b/internal/envpool/envpool.go new file mode 100644 index 0000000..f8ef58e --- /dev/null +++ b/internal/envpool/envpool.go @@ -0,0 +1,288 @@ +package envpool + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + "time" +) + +type CommandRunner interface { + CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) +} + +type execRunner struct{} + +func (r execRunner) CombinedOutput( + ctx context.Context, + name string, + args ...string, +) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + return cmd.CombinedOutput() +} + +type Pool struct { + runner CommandRunner + + imagePrefix string + + cacheMu sync.Mutex + cache map[string]cacheEntry + cacheTTL time.Duration +} + +type cacheEntry struct { + exists bool + expires time.Time +} + +func New(imagePrefix string) *Pool { + prefix := strings.TrimSpace(imagePrefix) + if prefix == "" { + prefix = "fetchml-prewarm" + } + return &Pool{ + runner: execRunner{}, + imagePrefix: prefix, + cache: make(map[string]cacheEntry), + cacheTTL: 30 * time.Second, + } +} + +func (p *Pool) WithRunner(r CommandRunner) *Pool { + if r != nil { + p.runner = r + } + return p +} + +func (p *Pool) WithCacheTTL(ttl time.Duration) *Pool { + if ttl > 0 { + p.cacheTTL = ttl + } + return p +} + +func (p *Pool) WarmImageTag(depsManifestSHA256 string) (string, error) { + sha := strings.TrimSpace(depsManifestSHA256) + if sha == "" { + return "", fmt.Errorf("missing deps sha256") + } + if !isLowerHexLen(sha, 64) { + return "", fmt.Errorf("invalid deps sha256") + } + return fmt.Sprintf("%s:%s", p.imagePrefix, sha[:12]), nil +} + +func (p *Pool) ImageExists(ctx context.Context, imageRef string) (bool, error) { + ref := strings.TrimSpace(imageRef) + if ref == "" { + return false, fmt.Errorf("missing image ref") + } + + p.cacheMu.Lock() + if ent, ok := p.cache[ref]; ok && time.Now().Before(ent.expires) { + exists := ent.exists + p.cacheMu.Unlock() + return exists, nil + } + p.cacheMu.Unlock() + + out, err := p.runner.CombinedOutput(ctx, "podman", "image", "inspect", ref) + if err == nil { + p.setCache(ref, true) + return true, nil + } + if looksLikeImageNotFound(out) { + p.setCache(ref, false) + return false, nil + } + var ee *exec.ExitError + if errors.As(err, &ee) { + p.setCache(ref, false) + return false, nil + } + return false, err +} + +func looksLikeImageNotFound(out []byte) bool { + s := strings.ToLower(strings.TrimSpace(string(out))) + if s == "" { + return false + } + return strings.Contains(s, "no such") || + strings.Contains(s, "not found") || + strings.Contains(s, "does not exist") +} + +func (p *Pool) setCache(imageRef string, exists bool) { + p.cacheMu.Lock() + p.cache[imageRef] = cacheEntry{exists: exists, expires: time.Now().Add(p.cacheTTL)} + p.cacheMu.Unlock() +} + +type PrepareRequest struct { + BaseImage string + TargetImage string + HostWorkspace string + ContainerWorkspace string + DepsPathInContainer string +} + +func (p *Pool) PruneImages(ctx context.Context, olderThan time.Duration) error { + if olderThan <= 0 { + return fmt.Errorf("invalid olderThan") + } + h := int(olderThan.Round(time.Hour).Hours()) + if h < 1 { + h = 1 + } + until := fmt.Sprintf("%dh", h) + _, err := p.runner.CombinedOutput( + ctx, + "podman", + "image", + "prune", + "-a", + "-f", + "--filter", + "label=fetchml.prewarm=true", + "--filter", + "until="+until, + ) + return err +} + +func (p *Pool) Prepare(ctx context.Context, req PrepareRequest) error { + baseImage := strings.TrimSpace(req.BaseImage) + targetImage := strings.TrimSpace(req.TargetImage) + hostWS := strings.TrimSpace(req.HostWorkspace) + containerWS := strings.TrimSpace(req.ContainerWorkspace) + depsInContainer := strings.TrimSpace(req.DepsPathInContainer) + + if baseImage == "" { + return fmt.Errorf("missing base image") + } + if targetImage == "" { + return fmt.Errorf("missing target image") + } + if hostWS == "" { + return fmt.Errorf("missing host workspace") + } + if containerWS == "" { + return fmt.Errorf("missing container workspace") + } + if depsInContainer == "" { + return fmt.Errorf("missing deps path") + } + if !strings.HasPrefix(depsInContainer, containerWS) { + return fmt.Errorf("deps path must be under container workspace") + } + + exists, err := p.ImageExists(ctx, targetImage) + if err != nil { + return err + } + if exists { + return nil + } + + containerName, err := randomContainerName("fetchml-prewarm") + if err != nil { + return err + } + + // Do not use --rm since we need a container to commit. + runArgs := []string{ + "run", + "--name", containerName, + "--security-opt", "no-new-privileges", + "--cap-drop", "ALL", + "--userns", "keep-id", + "-v", fmt.Sprintf("%s:%s:rw", hostWS, containerWS), + baseImage, + "--workspace", containerWS, + "--deps", depsInContainer, + "--prepare-only", + } + + if out, err := p.runner.CombinedOutput(ctx, "podman", runArgs...); err != nil { + _ = p.cleanupContainer(context.Background(), containerName) + return fmt.Errorf("podman run prewarm failed: %w", scrubOutput(out, err)) + } + + if out, err := p.runner.CombinedOutput( + ctx, + "podman", + "commit", + containerName, + targetImage, + ); err != nil { + _ = p.cleanupContainer(context.Background(), containerName) + return fmt.Errorf("podman commit prewarm failed: %w", scrubOutput(out, err)) + } + _, _ = p.runner.CombinedOutput( + ctx, + "podman", + "image", + "label", + targetImage, + "fetchml.prewarm=true", + ) + + _ = p.cleanupContainer(context.Background(), containerName) + p.setCache(targetImage, true) + return nil +} + +func (p *Pool) cleanupContainer(ctx context.Context, name string) error { + n := strings.TrimSpace(name) + if n == "" { + return nil + } + _, err := p.runner.CombinedOutput(ctx, "podman", "rm", n) + return err +} + +func randomContainerName(prefix string) (string, error) { + p := strings.TrimSpace(prefix) + if p == "" { + p = "fetchml-prewarm" + } + b := make([]byte, 6) + if _, err := rand.Read(b); err != nil { + return "", err + } + return fmt.Sprintf("%s-%s", p, hex.EncodeToString(b)), nil +} + +func isLowerHexLen(s string, want int) bool { + if len(s) != want { + return false + } + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') { + continue + } + return false + } + return true +} + +func scrubOutput(out []byte, err error) error { + if len(out) == 0 { + return err + } + s := strings.TrimSpace(string(out)) + if len(s) > 400 { + s = s[:400] + } + return fmt.Errorf("%w (output=%q)", err, s) +} diff --git a/internal/resources/manager.go b/internal/resources/manager.go new file mode 100644 index 0000000..eb899dd --- /dev/null +++ b/internal/resources/manager.go @@ -0,0 +1,323 @@ +package resources + +import ( + "context" + "errors" + "fmt" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/jfraeys/fetch_ml/internal/queue" +) + +type Manager struct { + mu sync.Mutex + cond *sync.Cond + totalCPU int + freeCPU int + slotsPerGPU int + gpuFree []int + + acquireTotal atomic.Int64 + acquireWaitTotal atomic.Int64 + acquireTimeoutTotal atomic.Int64 + acquireWaitNanos atomic.Int64 +} + +type Snapshot struct { + TotalCPU int + FreeCPU int + SlotsPerGPU int + GPUFree []int + + AcquireTotal int64 + AcquireWaitTotal int64 + AcquireTimeoutTotal int64 + AcquireWaitSeconds float64 +} + +func FormatCUDAVisibleDevices(lease *Lease) string { + if lease == nil { + return "-1" + } + if len(lease.gpus) == 0 { + return "-1" + } + idx := make([]int, 0, len(lease.gpus)) + for _, g := range lease.gpus { + idx = append(idx, g.Index) + } + sort.Ints(idx) + parts := make([]string, 0, len(idx)) + for _, i := range idx { + parts = append(parts, strconv.Itoa(i)) + } + return strings.Join(parts, ",") +} + +type GPUAllocation struct { + Index int + Slots int +} + +type Lease struct { + cpu int + gpus []GPUAllocation + m *Manager +} + +func (l *Lease) CPU() int { return l.cpu } + +func (l *Lease) GPUs() []GPUAllocation { + out := make([]GPUAllocation, len(l.gpus)) + copy(out, l.gpus) + return out +} + +func (l *Lease) Release() { + if l == nil || l.m == nil { + return + } + m := l.m + m.mu.Lock() + defer m.mu.Unlock() + + if l.cpu > 0 { + m.freeCPU += l.cpu + if m.freeCPU > m.totalCPU { + m.freeCPU = m.totalCPU + } + } + for _, g := range l.gpus { + if g.Index >= 0 && g.Index < len(m.gpuFree) { + m.gpuFree[g.Index] += g.Slots + if m.gpuFree[g.Index] > m.slotsPerGPU { + m.gpuFree[g.Index] = m.slotsPerGPU + } + } + } + m.cond.Broadcast() +} + +type Options struct { + TotalCPU int + GPUCount int + SlotsPerGPU int +} + +func NewManager(opts Options) (*Manager, error) { + if opts.TotalCPU < 0 { + return nil, fmt.Errorf("total cpu must be >= 0") + } + if opts.GPUCount < 0 { + return nil, fmt.Errorf("gpu count must be >= 0") + } + if opts.SlotsPerGPU <= 0 { + opts.SlotsPerGPU = 1 + } + + m := &Manager{ + totalCPU: opts.TotalCPU, + freeCPU: opts.TotalCPU, + slotsPerGPU: opts.SlotsPerGPU, + gpuFree: make([]int, opts.GPUCount), + } + for i := range m.gpuFree { + m.gpuFree[i] = m.slotsPerGPU + } + m.cond = sync.NewCond(&m.mu) + return m, nil +} + +func (m *Manager) Snapshot() Snapshot { + if m == nil { + return Snapshot{} + } + + m.mu.Lock() + gpuFree := make([]int, len(m.gpuFree)) + copy(gpuFree, m.gpuFree) + totalCPU := m.totalCPU + freeCPU := m.freeCPU + slotsPerGPU := m.slotsPerGPU + m.mu.Unlock() + + waitNanos := m.acquireWaitNanos.Load() + return Snapshot{ + TotalCPU: totalCPU, + FreeCPU: freeCPU, + SlotsPerGPU: slotsPerGPU, + GPUFree: gpuFree, + AcquireTotal: m.acquireTotal.Load(), + AcquireWaitTotal: m.acquireWaitTotal.Load(), + AcquireTimeoutTotal: m.acquireTimeoutTotal.Load(), + AcquireWaitSeconds: float64(waitNanos) / float64(time.Second), + } +} + +func (m *Manager) Acquire(ctx context.Context, task *queue.Task) (*Lease, error) { + if m == nil { + return nil, fmt.Errorf("resource manager is nil") + } + if task == nil { + return nil, fmt.Errorf("task is nil") + } + if ctx == nil { + return nil, fmt.Errorf("context is nil") + } + + m.acquireTotal.Add(1) + start := time.Now() + waited := false + + reqCPU := task.CPU + if reqCPU < 0 { + return nil, fmt.Errorf("cpu request must be >= 0") + } + if reqCPU > m.totalCPU { + return nil, fmt.Errorf("cpu request %d exceeds total cpu %d", reqCPU, m.totalCPU) + } + + reqGPU := task.GPU + if reqGPU < 0 { + return nil, fmt.Errorf("gpu request must be >= 0") + } + if reqGPU > len(m.gpuFree) { + return nil, fmt.Errorf("gpu request %d exceeds available gpus %d", reqGPU, len(m.gpuFree)) + } + + slotsPerTaskGPU, err := m.gpuSlotsForTask(task.GPUMemory) + if err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + + for { + if ctx.Err() != nil { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + m.acquireTimeoutTotal.Add(1) + } + return nil, ctx.Err() + } + + gpuAlloc, ok := m.tryAllocateGPUsLocked(reqGPU, slotsPerTaskGPU) + if ok && (reqCPU == 0 || m.freeCPU >= reqCPU) { + if reqCPU > 0 { + m.freeCPU -= reqCPU + } + for _, g := range gpuAlloc { + m.gpuFree[g.Index] -= g.Slots + } + if waited { + m.acquireWaitTotal.Add(1) + m.acquireWaitNanos.Add(time.Since(start).Nanoseconds()) + } + return &Lease{cpu: reqCPU, gpus: gpuAlloc, m: m}, nil + } + waited = true + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + m.mu.Lock() + m.cond.Broadcast() + m.mu.Unlock() + case <-done: + } + }() + m.cond.Wait() + close(done) + } +} + +func (m *Manager) gpuSlotsForTask(gpuMem string) (int, error) { + if m.slotsPerGPU <= 0 { + return 1, nil + } + if strings.TrimSpace(gpuMem) == "" { + return m.slotsPerGPU, nil + } + + if frac, ok := parseFraction(strings.TrimSpace(gpuMem)); ok { + if frac <= 0 { + return 1, nil + } + if frac > 1 { + frac = 1 + } + slots := int(math.Ceil(frac * float64(m.slotsPerGPU))) + if slots < 1 { + slots = 1 + } + if slots > m.slotsPerGPU { + slots = m.slotsPerGPU + } + return slots, nil + } + + return m.slotsPerGPU, nil +} + +func (m *Manager) tryAllocateGPUsLocked(reqGPU int, slotsPerTaskGPU int) ([]GPUAllocation, bool) { + if reqGPU == 0 { + return nil, true + } + if slotsPerTaskGPU <= 0 { + slotsPerTaskGPU = m.slotsPerGPU + } + + alloc := make([]GPUAllocation, 0, reqGPU) + used := make(map[int]struct{}, reqGPU) + + for len(alloc) < reqGPU { + bestIdx := -1 + bestFree := -1 + for i := 0; i < len(m.gpuFree); i++ { + if _, ok := used[i]; ok { + continue + } + free := m.gpuFree[i] + if free >= slotsPerTaskGPU && free > bestFree { + bestFree = free + bestIdx = i + } + } + if bestIdx < 0 { + return nil, false + } + used[bestIdx] = struct{}{} + alloc = append(alloc, GPUAllocation{Index: bestIdx, Slots: slotsPerTaskGPU}) + } + return alloc, true +} + +func parseFraction(s string) (float64, bool) { + if s == "" { + return 0, false + } + if strings.HasSuffix(s, "%") { + v := strings.TrimSuffix(s, "%") + f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + if err != nil { + return 0, false + } + return f / 100.0, true + } + + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, false + } + if f > 1 { + return 0, false + } + return f, true +} diff --git a/internal/worker/config.go b/internal/worker/config.go new file mode 100644 index 0000000..cb3d7ed --- /dev/null +++ b/internal/worker/config.go @@ -0,0 +1,371 @@ +package worker + +import ( + "fmt" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/tracking/factory" + "gopkg.in/yaml.v3" +) + +const ( + defaultMetricsFlushInterval = 500 * time.Millisecond + datasetCacheDefaultTTL = 30 * time.Minute +) + +type QueueConfig struct { + Backend string `yaml:"backend"` + SQLitePath string `yaml:"sqlite_path"` +} + +// Config holds worker configuration. +type Config struct { + Host string `yaml:"host"` + User string `yaml:"user"` + SSHKey string `yaml:"ssh_key"` + Port int `yaml:"port"` + BasePath string `yaml:"base_path"` + TrainScript string `yaml:"train_script"` + RedisURL string `yaml:"redis_url"` + RedisAddr string `yaml:"redis_addr"` + RedisPassword string `yaml:"redis_password"` + RedisDB int `yaml:"redis_db"` + Queue QueueConfig `yaml:"queue"` + KnownHosts string `yaml:"known_hosts"` + WorkerID string `yaml:"worker_id"` + MaxWorkers int `yaml:"max_workers"` + PollInterval int `yaml:"poll_interval_seconds"` + Resources config.ResourceConfig `yaml:"resources"` + LocalMode bool `yaml:"local_mode"` + + // Authentication + Auth auth.Config `yaml:"auth"` + + // Metrics exporter + Metrics MetricsConfig `yaml:"metrics"` + // Metrics buffering + MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` + + // Data management + DataManagerPath string `yaml:"data_manager_path"` + AutoFetchData bool `yaml:"auto_fetch_data"` + DataDir string `yaml:"data_dir"` + DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` + + SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"` + + // Provenance enforcement + // Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort. + ProvenanceBestEffort bool `yaml:"provenance_best_effort"` + + // Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env). + PrewarmEnabled bool `yaml:"prewarm_enabled"` + + // Podman execution + PodmanImage string `yaml:"podman_image"` + ContainerWorkspace string `yaml:"container_workspace"` + ContainerResults string `yaml:"container_results"` + GPUDevices []string `yaml:"gpu_devices"` + GPUVendor string `yaml:"gpu_vendor"` + GPUVisibleDevices []int `yaml:"gpu_visible_devices"` + GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"` + + // Apple M-series GPU configuration + AppleGPU AppleGPUConfig `yaml:"apple_gpu"` + + // Task lease and retry settings + TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min) + HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min) + MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) + GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min) + + // Plugins configuration + Plugins map[string]factory.PluginConfig `yaml:"plugins"` +} + +// MetricsConfig controls the Prometheus exporter. +type MetricsConfig struct { + Enabled bool `yaml:"enabled"` + ListenAddr string `yaml:"listen_addr"` +} + +type SnapshotStoreConfig struct { + Enabled bool `yaml:"enabled"` + Endpoint string `yaml:"endpoint"` + Secure bool `yaml:"secure"` + Region string `yaml:"region"` + Bucket string `yaml:"bucket"` + Prefix string `yaml:"prefix"` + AccessKey string `yaml:"access_key"` + SecretKey string `yaml:"secret_key"` + SessionToken string `yaml:"session_token"` + Timeout time.Duration `yaml:"timeout"` + MaxRetries int `yaml:"max_retries"` +} + +// AppleGPUConfig holds configuration for Apple M-series GPU support +type AppleGPUConfig struct { + Enabled bool `yaml:"enabled"` + MetalDevice string `yaml:"metal_device"` + MPSRuntime string `yaml:"mps_runtime"` +} + +// LoadConfig loads worker configuration from a YAML file. +func LoadConfig(path string) (*Config, error) { + data, err := fileutil.SecureFileRead(path) + if err != nil { + return nil, err + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + if strings.TrimSpace(cfg.RedisURL) != "" { + cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL)) + cfg.RedisAddr = cfg.RedisURL + cfg.RedisPassword = "" + cfg.RedisDB = 0 + } + + // Get smart defaults for current environment + smart := config.GetSmartDefaults() + + if cfg.Port == 0 { + cfg.Port = config.DefaultSSHPort + } + if cfg.Host == "" { + cfg.Host = smart.Host() + } + if cfg.BasePath == "" { + cfg.BasePath = smart.BasePath() + } + if cfg.RedisAddr == "" { + cfg.RedisAddr = smart.RedisAddr() + } + if cfg.KnownHosts == "" { + cfg.KnownHosts = smart.KnownHostsPath() + } + if cfg.WorkerID == "" { + cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) + } + cfg.Resources.ApplyDefaults() + if cfg.MaxWorkers > 0 { + cfg.Resources.MaxWorkers = cfg.MaxWorkers + } else { + cfg.MaxWorkers = cfg.Resources.MaxWorkers + } + if cfg.PollInterval == 0 { + cfg.PollInterval = smart.PollInterval() + } + if cfg.DataManagerPath == "" { + cfg.DataManagerPath = "./data_manager" + } + if cfg.DataDir == "" { + if cfg.Host == "" || !cfg.AutoFetchData { + cfg.DataDir = config.DefaultLocalDataDir + } else { + cfg.DataDir = smart.DataDir() + } + } + if cfg.SnapshotStore.Timeout == 0 { + cfg.SnapshotStore.Timeout = 10 * time.Minute + } + if cfg.SnapshotStore.MaxRetries == 0 { + cfg.SnapshotStore.MaxRetries = 3 + } + if cfg.Metrics.ListenAddr == "" { + cfg.Metrics.ListenAddr = ":9100" + } + if cfg.MetricsFlushInterval == 0 { + cfg.MetricsFlushInterval = defaultMetricsFlushInterval + } + if cfg.DatasetCacheTTL == 0 { + cfg.DatasetCacheTTL = datasetCacheDefaultTTL + } + + if strings.TrimSpace(cfg.Queue.Backend) == "" { + cfg.Queue.Backend = string(queue.QueueBackendRedis) + } + if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) { + if strings.TrimSpace(cfg.Queue.SQLitePath) == "" { + cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db") + } + cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath) + } + + if strings.TrimSpace(cfg.GPUVendor) == "" { + if cfg.AppleGPU.Enabled { + cfg.GPUVendor = string(GPUTypeApple) + } else if len(cfg.GPUDevices) > 0 || + len(cfg.GPUVisibleDevices) > 0 || + len(cfg.GPUVisibleDeviceIDs) > 0 { + cfg.GPUVendor = string(GPUTypeNVIDIA) + } else { + cfg.GPUVendor = string(GPUTypeNone) + } + } + + // Set lease and retry defaults + if cfg.TaskLeaseDuration == 0 { + cfg.TaskLeaseDuration = 30 * time.Minute + } + if cfg.HeartbeatInterval == 0 { + cfg.HeartbeatInterval = 1 * time.Minute + } + if cfg.MaxRetries == 0 { + cfg.MaxRetries = 3 + } + if cfg.GracefulTimeout == 0 { + cfg.GracefulTimeout = 5 * time.Minute + } + + return &cfg, nil +} + +// Validate implements config.Validator interface. +func (c *Config) Validate() error { + if c.Port != 0 { + if err := config.ValidatePort(c.Port); err != nil { + return fmt.Errorf("invalid SSH port: %w", err) + } + } + + if c.BasePath != "" { + // Convert relative paths to absolute + c.BasePath = config.ExpandPath(c.BasePath) + if !filepath.IsAbs(c.BasePath) { + c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath) + } + } + + backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend)) + if backend == "" { + backend = string(queue.QueueBackendRedis) + c.Queue.Backend = backend + } + if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) { + return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite) + } + + if backend == string(queue.QueueBackendSQLite) { + if strings.TrimSpace(c.Queue.SQLitePath) == "" { + return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite) + } + c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath) + if !filepath.IsAbs(c.Queue.SQLitePath) { + c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) + } + } + + if c.RedisAddr != "" { + addr := strings.TrimSpace(c.RedisAddr) + if strings.HasPrefix(addr, "redis://") { + u, err := url.Parse(addr) + if err != nil { + return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err) + } + if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("invalid Redis configuration: invalid redis url") + } + } else { + if err := config.ValidateRedisAddr(addr); err != nil { + return fmt.Errorf("invalid Redis configuration: %w", err) + } + } + } + + if c.MaxWorkers < 1 { + return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) + } + + switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) { + case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd": + // ok + default: + return fmt.Errorf( + "gpu_vendor must be one of %q, %q, %q, %q", + string(GPUTypeNVIDIA), + "amd", + string(GPUTypeApple), + string(GPUTypeNone), + ) + } + + // Strict GPU visibility configuration: + // - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive. + // - UUID-style gpu_visible_device_ids is NVIDIA-only. + vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor)) + if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 { + return fmt.Errorf("gpu_visible_devices and gpu_visible_device_ids are mutually exclusive") + } + if len(c.GPUVisibleDeviceIDs) > 0 { + if vendor != string(GPUTypeNVIDIA) { + return fmt.Errorf( + "gpu_visible_device_ids is only supported when gpu_vendor is %q", + string(GPUTypeNVIDIA), + ) + } + for _, id := range c.GPUVisibleDeviceIDs { + id = strings.TrimSpace(id) + if id == "" { + return fmt.Errorf("gpu_visible_device_ids contains an empty value") + } + if !strings.HasPrefix(id, "GPU-") { + return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id) + } + } + } + if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) { + if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 { + return fmt.Errorf( + "gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q", + vendor, + ) + } + } + if vendor == "amd" { + if len(c.GPUVisibleDeviceIDs) > 0 { + return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor) + } + for _, idx := range c.GPUVisibleDevices { + if idx < 0 { + return fmt.Errorf("gpu_visible_devices contains negative index %d", idx) + } + } + } + + if c.SnapshotStore.Enabled { + if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" { + return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true") + } + if strings.TrimSpace(c.SnapshotStore.Bucket) == "" { + return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true") + } + ak := strings.TrimSpace(c.SnapshotStore.AccessKey) + sk := strings.TrimSpace(c.SnapshotStore.SecretKey) + if (ak == "") != (sk == "") { + return fmt.Errorf( + "snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty", + ) + } + if c.SnapshotStore.Timeout < 0 { + return fmt.Errorf("snapshot_store.timeout must be >= 0") + } + if c.SnapshotStore.MaxRetries < 0 { + return fmt.Errorf("snapshot_store.max_retries must be >= 0") + } + } + + return nil +} diff --git a/internal/worker/core.go b/internal/worker/core.go new file mode 100644 index 0000000..1c467bb --- /dev/null +++ b/internal/worker/core.go @@ -0,0 +1,547 @@ +package worker + +import ( + "context" + "fmt" + "log" + "log/slog" + "math" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/envpool" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/network" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/resources" + "github.com/jfraeys/fetch_ml/internal/tracking" + "github.com/jfraeys/fetch_ml/internal/tracking/factory" + trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// MLServer wraps network.SSHClient for backward compatibility. +type MLServer struct { + *network.SSHClient +} + +// JupyterManager is the subset of the Jupyter service manager used by the worker. +// It exists to keep task execution testable. +type JupyterManager interface { + StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) + StopService(ctx context.Context, serviceID string) error + RemoveService(ctx context.Context, serviceID string, purge bool) error + RestoreWorkspace(ctx context.Context, name string) (string, error) + ListServices() []*jupyter.JupyterService +} + +// isValidName validates that input strings contain only safe characters. +// isValidName checks if the input string is a valid name. +func isValidName(input string) bool { + return len(input) > 0 && len(input) < 256 +} + +// NewMLServer creates a new ML server connection. +// NewMLServer returns a new MLServer instance. +func NewMLServer(cfg *Config) (*MLServer, error) { + if cfg.LocalMode { + return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil + } + + client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) + if err != nil { + return nil, err + } + + return &MLServer{SSHClient: client}, nil +} + +// Worker represents an ML task worker. +type Worker struct { + id string + config *Config + server *MLServer + queue queue.Backend + resources *resources.Manager + running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown + runningMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + logger *logging.Logger + metrics *metrics.Metrics + metricsSrv *http.Server + + datasetCache map[string]time.Time + datasetCacheMu sync.RWMutex + datasetCacheTTL time.Duration + + // Graceful shutdown fields + shutdownCh chan struct{} + activeTasks sync.Map // map[string]*queue.Task - track active tasks + gracefulWait sync.WaitGroup + + podman *container.PodmanManager + jupyter JupyterManager + trackingRegistry *tracking.Registry + envPool *envpool.Pool + + prewarmMu sync.Mutex + prewarmTargetID string + prewarmCancel context.CancelFunc + prewarmStartedAt time.Time +} + +func envInt(name string) (int, bool) { + v := strings.TrimSpace(os.Getenv(name)) + if v == "" { + return 0, false + } + n, err := strconv.Atoi(v) + if err != nil { + return 0, false + } + return n, true +} + +func parseCPUFromConfig(cfg *Config) int { + if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { + return n + } + if cfg != nil { + if cfg.Resources.PodmanCPUs != "" { + if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { + if f < 0 { + return 0 + } + return int(math.Floor(f)) + } + } + } + return runtime.NumCPU() +} + +func parseGPUCountFromConfig(cfg *Config) int { + factory := &GPUDetectorFactory{} + detector := factory.CreateDetector(cfg) + return detector.DetectGPUCount() +} + +func (w *Worker) getGPUDetector() GPUDetector { + factory := &GPUDetectorFactory{} + return factory.CreateDetector(w.config) +} + +func parseGPUSlotsPerGPUFromConfig() int { + if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { + return n + } + return 1 +} + +func (w *Worker) setupMetricsExporter() { + if !w.config.Metrics.Enabled { + return + } + + reg := prometheus.NewRegistry() + reg.MustRegister( + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), + collectors.NewGoCollector(), + ) + + labels := prometheus.Labels{"worker_id": w.id} + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_processed_total", + Help: "Total tasks processed successfully by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksProcessed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_failed_total", + Help: "Total tasks failed by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksFailed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_active", + Help: "Number of tasks currently running on this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.runningCount()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_queued", + Help: "Latest observed queue depth from Redis.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.QueuedTasks.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_transferred_bytes_total", + Help: "Total bytes transferred while fetching datasets.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataTransferred.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_fetch_time_seconds_total", + Help: "Total time spent fetching datasets (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_execution_time_seconds_total", + Help: "Total execution time for completed tasks (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_hit_total", + Help: "Total environment prewarm hits (warmed image already existed).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvHit.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_miss_total", + Help: "Total environment prewarm misses (warmed image did not exist yet).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvMiss.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_built_total", + Help: "Total environment prewarm images built.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvBuilt.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_time_seconds_total", + Help: "Total time spent building prewarm images (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_hit_total", + Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotHit.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_miss_total", + Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotMiss.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_built_total", + Help: "Total snapshots prewarmed into .prewarm/.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotBuilt.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_time_seconds_total", + Help: "Total time spent prewarming snapshots (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_worker_max_concurrency", + Help: "Configured maximum concurrent tasks for this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.config.MaxWorkers) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_cpu_total", + Help: "Total CPU tokens managed by the worker resource manager.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().TotalCPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_cpu_free", + Help: "Free CPU tokens currently available in the worker resource manager.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().FreeCPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_total", + Help: "Total resource acquisition attempts.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_wait_total", + Help: "Total resource acquisitions that had to wait for resources.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireWaitTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_timeout_total", + Help: "Total resource acquisition attempts that timed out.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireTimeoutTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_wait_seconds_total", + Help: "Total seconds spent waiting for resources across all acquisitions.", + ConstLabels: labels, + }, func() float64 { + return w.resources.Snapshot().AcquireWaitSeconds + })) + + snap := w.resources.Snapshot() + for i := range snap.GPUFree { + gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)} + idx := i + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_gpu_slots_total", + Help: "Total GPU slots per GPU index.", + ConstLabels: gpuLabels, + }, func() float64 { + return float64(w.resources.Snapshot().SlotsPerGPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_gpu_slots_free", + Help: "Free GPU slots per GPU index.", + ConstLabels: gpuLabels, + }, func() float64 { + s := w.resources.Snapshot() + if idx < 0 || idx >= len(s.GPUFree) { + return 0 + } + return float64(s.GPUFree[idx]) + })) + } + + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + + srv := &http.Server{ + Addr: w.config.Metrics.ListenAddr, + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + w.metricsSrv = srv + go func() { + w.logger.Info("metrics exporter listening", + "addr", w.config.Metrics.ListenAddr, + "enabled", w.config.Metrics.Enabled) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + w.logger.Warn("metrics exporter stopped", + "error", err) + } + }() +} + +// NewWorker creates a new worker instance. +func NewWorker(cfg *Config, _ string) (*Worker, error) { + srv, err := NewMLServer(cfg) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if closeErr := srv.Close(); closeErr != nil { + log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr) + } + } + }() + + backendCfg := queue.BackendConfig{ + Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + SQLitePath: cfg.Queue.SQLitePath, + MetricsFlushInterval: cfg.MetricsFlushInterval, + } + queueClient, err := queue.NewBackend(backendCfg) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if closeErr := queueClient.Close(); closeErr != nil { + log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr) + } + } + }() + + // Create data_dir if it doesn't exist (for production without NAS) + if cfg.DataDir != "" { + if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { + log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + cancel() + } + }() + ctx = logging.EnsureTrace(ctx) + ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) + + baseLogger := logging.NewLogger(slog.LevelInfo, false) + logger := baseLogger.Component(ctx, "worker") + metricsObj := &metrics.Metrics{} + + podmanMgr, err := container.NewPodmanManager(logger) + if err != nil { + return nil, fmt.Errorf("failed to create podman manager: %w", err) + } + + jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig()) + if err != nil { + return nil, fmt.Errorf("failed to create jupyter service manager: %w", err) + } + + trackingRegistry := tracking.NewRegistry(logger) + pluginLoader := factory.NewPluginLoader(logger, podmanMgr) + + if len(cfg.Plugins) == 0 { + logger.Warn("no plugins configured, defining defaults") + // Register defaults manually for backward compatibility/local dev + mlflowPlugin, err := trackingplugins.NewMLflowPlugin( + logger, + podmanMgr, + trackingplugins.MLflowOptions{ + ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"), + }, + ) + if err == nil { + trackingRegistry.Register(mlflowPlugin) + } + + tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin( + logger, + podmanMgr, + trackingplugins.TensorBoardOptions{ + LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"), + }, + ) + if err == nil { + trackingRegistry.Register(tensorboardPlugin) + } + + trackingRegistry.Register(trackingplugins.NewWandbPlugin()) + } else { + if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil { + return nil, fmt.Errorf("failed to load plugins: %w", err) + } + } + + worker := &Worker{ + id: cfg.WorkerID, + config: cfg, + server: srv, + queue: queueClient, + running: make(map[string]context.CancelFunc), + datasetCache: make(map[string]time.Time), + datasetCacheTTL: cfg.DatasetCacheTTL, + ctx: ctx, + cancel: cancel, + logger: logger, + metrics: metricsObj, + shutdownCh: make(chan struct{}), + podman: podmanMgr, + jupyter: jupyterMgr, + trackingRegistry: trackingRegistry, + envPool: envpool.New(""), + } + + rm, rmErr := resources.NewManager(resources.Options{ + TotalCPU: parseCPUFromConfig(cfg), + GPUCount: parseGPUCountFromConfig(cfg), + SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(), + }) + if rmErr != nil { + return nil, fmt.Errorf("failed to init resource manager: %w", rmErr) + } + worker.resources = rm + + if !cfg.LocalMode { + gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) + if cfg.AppleGPU.Enabled { + logger.Warn("apple MPS GPU mode is intended for development; do not use in production", + "gpu_type", "apple", + ) + } + if gpuType == "amd" { + logger.Warn("amd GPU mode is intended for development; do not use in production", + "gpu_type", "amd", + ) + } + } + + worker.setupMetricsExporter() + + // Pre-pull tracking images in background + go worker.prePullImages() + + return worker, nil +} + +func (w *Worker) prePullImages() { + if w.config.LocalMode { + return + } + + w.logger.Info("starting image pre-pulling") + + // Pull worker image + if w.config.PodmanImage != "" { + w.pullImage(w.config.PodmanImage) + } + + // Pull plugin images + for name, cfg := range w.config.Plugins { + if !cfg.Enabled || cfg.Image == "" { + continue + } + w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) + w.pullImage(cfg.Image) + } +} + +func (w *Worker) pullImage(image string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "podman", "pull", image) + if output, err := cmd.CombinedOutput(); err != nil { + w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) + } else { + w.logger.Info("image pulled successfully", "image", image) + } +} diff --git a/internal/worker/data_integrity.go b/internal/worker/data_integrity.go new file mode 100644 index 0000000..b223a32 --- /dev/null +++ b/internal/worker/data_integrity.go @@ -0,0 +1,824 @@ +package worker + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/errtypes" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +// NEW: Fetch datasets using data_manager. +func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { + logger := w.logger.Job(ctx, task.JobName, task.ID) + logger.Info("fetching datasets", + "worker_id", w.id, + "dataset_count", len(task.Datasets)) + + for _, dataset := range task.Datasets { + if w.datasetIsFresh(dataset) { + logger.Debug("skipping cached dataset", + "dataset", dataset) + continue + } + // Check for cancellation before each dataset fetch + select { + case <-ctx.Done(): + return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err()) + default: + } + + logger.Info("fetching dataset", + "worker_id", w.id, + "dataset", dataset) + + // Create command with context for cancellation support + cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) + // Validate inputs to prevent command injection + if !isValidName(task.JobName) || !isValidName(dataset) { + cancel() + return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters") + } + //nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated + cmd := exec.CommandContext(cmdCtx, + w.config.DataManagerPath, + "fetch", + task.JobName, + dataset, + ) + + output, err := cmd.CombinedOutput() + cancel() // Clean up context + + if err != nil { + return &errtypes.DataFetchError{ + Dataset: dataset, + JobName: task.JobName, + Err: fmt.Errorf("command failed: %w, output: %s", err, output), + } + } + + logger.Info("dataset ready", + "worker_id", w.id, + "dataset", dataset) + w.markDatasetFetched(dataset) + } + + return nil +} + +func resolveDatasets(task *queue.Task) []string { + if task == nil { + return nil + } + if len(task.DatasetSpecs) > 0 { + out := make([]string, 0, len(task.DatasetSpecs)) + for _, ds := range task.DatasetSpecs { + if ds.Name != "" { + out = append(out, ds.Name) + } + } + if len(out) > 0 { + return out + } + } + if len(task.Datasets) > 0 { + return task.Datasets + } + return parseDatasets(task.Args) +} + +func parseDatasets(args string) []string { + if !strings.Contains(args, "--datasets") { + return nil + } + + parts := strings.Fields(args) + for i, part := range parts { + if part == "--datasets" && i+1 < len(parts) { + return strings.Split(parts[i+1], ",") + } + } + + return nil +} + +func (w *Worker) datasetIsFresh(dataset string) bool { + w.datasetCacheMu.RLock() + defer w.datasetCacheMu.RUnlock() + expires, ok := w.datasetCache[dataset] + return ok && time.Now().Before(expires) +} + +func (w *Worker) markDatasetFetched(dataset string) { + expires := time.Now().Add(w.datasetCacheTTL) + w.datasetCacheMu.Lock() + w.datasetCache[dataset] = expires + w.datasetCacheMu.Unlock() +} + +func (w *Worker) cancelPrewarmLocked() { + if w.prewarmCancel != nil { + w.prewarmCancel() + w.prewarmCancel = nil + } + w.prewarmTargetID = "" +} + +func (w *Worker) prewarmNextLoop() { + if w == nil || w.config == nil || !w.config.PrewarmEnabled { + return + } + if w.ctx == nil || w.queue == nil || w.metrics == nil { + return + } + // Phase 1: Best-effort prewarm of the next queued task. + // This must never be required for correctness. + runOnce := func() { + _, err := w.PrewarmNextOnce(w.ctx) + if err != nil { + w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err) + } + } + + // Run once immediately so prewarm doesn't lag behind the worker loop. + runOnce() + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-w.ctx.Done(): + w.prewarmMu.Lock() + w.cancelPrewarmLocked() + w.prewarmMu.Unlock() + return + case <-ticker.C: + } + runOnce() + } +} + +func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { + if w == nil || w.config == nil || !w.config.PrewarmEnabled { + return false, nil + } + if ctx == nil || w.queue == nil || w.metrics == nil { + return false, nil + } + + next, err := w.queue.PeekNextTask() + if err != nil { + return false, err + } + if next == nil { + w.prewarmMu.Lock() + w.cancelPrewarmLocked() + w.prewarmMu.Unlock() + return false, nil + } + + return w.prewarmTaskOnce(ctx, next) +} + +func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) { + if w == nil || w.config == nil || !w.config.PrewarmEnabled { + return false, nil + } + if ctx == nil || w.queue == nil || w.metrics == nil { + return false, nil + } + if next == nil { + return false, nil + } + + w.prewarmMu.Lock() + if w.prewarmTargetID == next.ID { + w.prewarmMu.Unlock() + return false, nil + } + w.cancelPrewarmLocked() + prewarmCtx, cancel := context.WithCancel(ctx) + w.prewarmCancel = cancel + w.prewarmTargetID = next.ID + w.prewarmStartedAt = time.Now() + startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano) + phase := "datasets" + dsCnt := len(resolveDatasets(next)) + snapID := next.SnapshotID + if strings.TrimSpace(snapID) != "" { + phase = "snapshot" + } else if dsCnt == 0 { + phase = "env" + } + _ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{ + WorkerID: w.id, + TaskID: next.ID, + SnapshotID: snapID, + StartedAt: startedAt, + UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Phase: phase, + DatasetCnt: dsCnt, + EnvHit: w.metrics.PrewarmEnvHit.Load(), + EnvMiss: w.metrics.PrewarmEnvMiss.Load(), + EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(), + EnvTimeNs: w.metrics.PrewarmEnvTime.Load(), + }) + w.prewarmMu.Unlock() + + w.logger.Info("prewarm started", + "worker_id", w.id, + "task_id", next.ID, + "snapshot_id", snapID, + "phase", phase, + ) + + local := *next + local.Datasets = resolveDatasets(&local) + + hasSnapshot := strings.TrimSpace(local.SnapshotID) != "" + hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0 + hasEnv := false + if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" { + if local.Metadata != nil { + depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"]) + commitID := strings.TrimSpace(local.Metadata["commit_id"]) + if depsSHA != "" && commitID != "" { + expMgr := experiment.NewManager(w.config.BasePath) + hostWorkspace := expMgr.GetFilesPath(commitID) + if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" { + if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" { + hasEnv = true + } + } + } + } + } + if !hasSnapshot && !hasDatasets && !hasEnv { + _ = w.queue.ClearWorkerPrewarmState(w.id) + return false, nil + } + + if hasSnapshot { + want := "" + if local.Metadata != nil { + want = local.Metadata["snapshot_sha256"] + } + start := time.Now() + src, err := ResolveSnapshot( + prewarmCtx, + w.config.DataDir, + &w.config.SnapshotStore, + local.SnapshotID, + want, + nil, + ) + if err != nil { + return true, err + } + dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID) + _ = os.RemoveAll(dst) + if err := copyDir(src, dst); err != nil { + return true, err + } + w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start)) + } + + if hasDatasets { + if err := w.fetchDatasets(prewarmCtx, &local); err != nil { + return true, err + } + } + + _ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{ + WorkerID: w.id, + TaskID: local.ID, + SnapshotID: local.SnapshotID, + StartedAt: startedAt, + UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Phase: "ready", + DatasetCnt: len(local.Datasets), + EnvHit: w.metrics.PrewarmEnvHit.Load(), + EnvMiss: w.metrics.PrewarmEnvMiss.Load(), + EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(), + EnvTimeNs: w.metrics.PrewarmEnvTime.Load(), + }) + + w.logger.Info("prewarm ready", + "worker_id", w.id, + "task_id", local.ID, + "snapshot_id", local.SnapshotID, + ) + + return true, nil +} + +func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if task.SnapshotID == "" { + return nil + } + if err := container.ValidateJobName(task.SnapshotID); err != nil { + return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err) + } + if task.Metadata == nil { + return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID) + } + want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) + if err != nil { + return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err) + } + if want == "" { + return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID) + } + path, err := ResolveSnapshot( + ctx, + w.config.DataDir, + &w.config.SnapshotStore, + task.SnapshotID, + want, + nil, + ) + if err != nil { + return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err) + } + got, err := dirOverallSHA256Hex(path) + if err != nil { + return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err) + } + if got != want { + return fmt.Errorf( + "snapshot %q: checksum mismatch: expected %s, got %s", + task.SnapshotID, + want, + got, + ) + } + w.logger.Job( + ctx, + task.JobName, + task.ID, + ).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID) + return nil +} + +func fileSHA256Hex(path string) (string, error) { + f, err := os.Open(filepath.Clean(path)) + if err != nil { + return "", err + } + defer func() { _ = f.Close() }() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + +func normalizeSHA256ChecksumHex(checksum string) (string, error) { + checksum = strings.TrimSpace(checksum) + checksum = strings.TrimPrefix(checksum, "sha256:") + checksum = strings.TrimPrefix(checksum, "SHA256:") + checksum = strings.TrimSpace(checksum) + if checksum == "" { + return "", nil + } + if len(checksum) != 64 { + return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum)) + } + if _, err := hex.DecodeString(checksum); err != nil { + return "", fmt.Errorf("invalid sha256 hex: %w", err) + } + return strings.ToLower(checksum), nil +} + +func dirOverallSHA256Hex(root string) (string, error) { + root = filepath.Clean(root) + info, err := os.Stat(root) + if err != nil { + return "", err + } + if !info.IsDir() { + return "", fmt.Errorf("not a directory") + } + + var files []string + err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + rel, err := filepath.Rel(root, path) + if err != nil { + return err + } + files = append(files, rel) + return nil + }) + if err != nil { + return "", err + } + + // Deterministic order. + for i := 0; i < len(files); i++ { + for j := i + 1; j < len(files); j++ { + if files[i] > files[j] { + files[i], files[j] = files[j], files[i] + } + } + } + + // Hash file hashes to avoid holding all bytes. + overall := sha256.New() + for _, rel := range files { + p := filepath.Join(root, rel) + sum, err := fileSHA256Hex(p) + if err != nil { + return "", err + } + overall.Write([]byte(sum)) + } + return fmt.Sprintf("%x", overall.Sum(nil)), nil +} + +func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if len(task.DatasetSpecs) == 0 { + return nil + } + + logger := w.logger.Job(ctx, task.JobName, task.ID) + for _, ds := range task.DatasetSpecs { + want, err := normalizeSHA256ChecksumHex(ds.Checksum) + if err != nil { + return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err) + } + if want == "" { + continue + } + if err := container.ValidateJobName(ds.Name); err != nil { + return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err) + } + path := filepath.Join(w.config.DataDir, ds.Name) + got, err := dirOverallSHA256Hex(path) + if err != nil { + return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err) + } + if got != want { + return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got) + } + logger.Info("dataset checksum verified", "dataset", ds.Name) + } + return nil +} + +func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { + if task == nil { + return nil, fmt.Errorf("task is nil") + } + out := map[string]string{} + + if task.SnapshotID != "" { + out["snapshot_id"] = task.SnapshotID + } + + datasets := resolveDatasets(task) + if len(datasets) > 0 { + out["datasets"] = strings.Join(datasets, ",") + } + if len(task.DatasetSpecs) > 0 { + b, err := json.Marshal(task.DatasetSpecs) + if err != nil { + return nil, fmt.Errorf("marshal dataset_specs: %w", err) + } + out["dataset_specs"] = string(b) + } + + if task.Metadata == nil { + return out, nil + } + commitID := task.Metadata["commit_id"] + if commitID == "" { + return out, nil + } + + expMgr := experiment.NewManager(basePath) + manifest, err := expMgr.ReadManifest(commitID) + if err == nil && manifest != nil && manifest.OverallSHA != "" { + out["experiment_manifest_overall_sha"] = manifest.OverallSHA + } + + filesPath := expMgr.GetFilesPath(commitID) + depName, err := selectDependencyManifest(filesPath) + if err == nil && depName != "" { + depPath := filepath.Join(filesPath, depName) + sha, err := fileSHA256Hex(depPath) + if err == nil && sha != "" { + out["deps_manifest_name"] = depName + out["deps_manifest_sha256"] = sha + } + } + + return out, nil +} + +func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) { + if task == nil { + return + } + prov, err := computeTaskProvenance(w.config.BasePath, task) + if err != nil { + w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err) + return + } + if len(prov) == 0 { + return + } + if task.Metadata == nil { + task.Metadata = map[string]string{} + } + for k, v := range prov { + if v == "" { + continue + } + // Phase 1: best-effort only; do not error if overwriting. + task.Metadata[k] = v + } +} + +func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if task.Metadata == nil { + return fmt.Errorf("missing task metadata") + } + commitID := task.Metadata["commit_id"] + if commitID == "" { + return fmt.Errorf("missing commit_id") + } + + current, err := computeTaskProvenance(w.config.BasePath, task) + if err != nil { + return err + } + + snapshotCur := "" + if task.SnapshotID != "" { + want := "" + if task.Metadata != nil { + want = task.Metadata["snapshot_sha256"] + } + wantNorm, nerr := normalizeSHA256ChecksumHex(want) + if nerr != nil { + if w.config != nil && w.config.ProvenanceBestEffort { + w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance", + "snapshot_id", task.SnapshotID, + "error", nerr) + } else { + return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr) + } + } else if wantNorm != "" { + resolved, err := ResolveSnapshot( + ctx, w.config.DataDir, + &w.config.SnapshotStore, + task.SnapshotID, + wantNorm, + nil, + ) + if err != nil { + if w.config != nil && w.config.ProvenanceBestEffort { + w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance", + "snapshot_id", task.SnapshotID, + "error", err) + } else { + return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err) + } + } else { + sha, err := dirOverallSHA256Hex(resolved) + if err == nil { + snapshotCur = sha + } else if w.config != nil && w.config.ProvenanceBestEffort { + w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance", + "snapshot_id", task.SnapshotID, + "error", err) + } else { + return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err) + } + } + } + if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort { + // Best-effort fallback: if the caller didn't provide snapshot_sha256, + // compute from the local snapshot directory if it exists. + localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID)) + if sha, err := dirOverallSHA256Hex(localPath); err == nil { + snapshotCur = sha + } + } + } + + logger := w.logger.Job(ctx, task.JobName, task.ID) + + type requiredField struct { + Key string + Cur string + } + required := []requiredField{ + {Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]}, + {Key: "deps_manifest_name", Cur: current["deps_manifest_name"]}, + {Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]}, + } + if task.SnapshotID != "" { + required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur}) + } + + for _, f := range required { + want := strings.TrimSpace(task.Metadata[f.Key]) + if f.Key == "snapshot_sha256" { + norm, nerr := normalizeSHA256ChecksumHex(want) + if nerr != nil { + if w.config != nil && w.config.ProvenanceBestEffort { + logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode", + "snapshot_id", task.SnapshotID, + "error", nerr) + want = "" + } else { + return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr) + } + } else { + want = norm + } + } + if want == "" { + if w.config != nil && w.config.ProvenanceBestEffort { + logger.Warn("missing provenance field; continuing due to best-effort mode", + "field", f.Key) + if f.Cur != "" { + if f.Key == "snapshot_sha256" { + task.Metadata[f.Key] = "sha256:" + f.Cur + } else { + task.Metadata[f.Key] = f.Cur + } + } + continue + } + return fmt.Errorf("missing provenance field: %s", f.Key) + } + if f.Cur == "" { + if w.config != nil && w.config.ProvenanceBestEffort { + logger.Warn("unable to compute provenance field; continuing due to best-effort mode", + "field", f.Key) + continue + } + return fmt.Errorf("unable to compute provenance field: %s", f.Key) + } + if want != f.Cur { + if w.config != nil && w.config.ProvenanceBestEffort { + logger.Warn("provenance mismatch; continuing due to best-effort mode", + "field", f.Key, + "expected", want, + "current", f.Cur) + if f.Key == "snapshot_sha256" { + task.Metadata[f.Key] = "sha256:" + f.Cur + } else { + task.Metadata[f.Key] = f.Cur + } + continue + } + return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur) + } + } + + return nil +} + +func selectDependencyManifest(filesPath string) (string, error) { + if filesPath == "" { + return "", fmt.Errorf("missing files path") + } + candidates := []string{ + "environment.yml", + "environment.yaml", + "poetry.lock", + "pyproject.toml", + "requirements.txt", + } + for _, name := range candidates { + p := filepath.Join(filesPath, name) + if _, err := os.Stat(p); err == nil { + if name == "poetry.lock" { + pyprojectPath := filepath.Join(filesPath, "pyproject.toml") + if _, err := os.Stat(pyprojectPath); err != nil { + return "", fmt.Errorf( + "poetry.lock found but pyproject.toml missing (required for Poetry projects)") + } + } + return name, nil + } + } + return "", fmt.Errorf( + "missing dependency manifest (supported: environment.yml, environment.yaml, " + + "poetry.lock, pyproject.toml, requirements.txt)") +} + +// Exported wrappers for tests under tests/. + +func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) } + +func SelectDependencyManifest(filesPath string) (string, error) { + return selectDependencyManifest(filesPath) +} + +func NormalizeSHA256ChecksumHex(checksum string) (string, error) { + return normalizeSHA256ChecksumHex(checksum) +} + +func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) } + +func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { + return computeTaskProvenance(basePath, task) +} + +func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error { + return w.enforceTaskProvenance(ctx, task) +} + +func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { + return w.verifyDatasetSpecs(ctx, task) +} + +func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { + return w.verifySnapshot(ctx, task) +} + +func NewTestWorker(cfg *Config) *Worker { + baseLogger := logging.NewLogger(slog.LevelInfo, false) + ctx := logging.EnsureTrace(context.Background()) + logger := baseLogger.Component(ctx, "worker") + if cfg == nil { + cfg = &Config{} + } + if cfg.DatasetCacheTTL == 0 { + cfg.DatasetCacheTTL = datasetCacheDefaultTTL + } + return &Worker{ + id: cfg.WorkerID, + config: cfg, + logger: logger, + datasetCache: make(map[string]time.Time), + datasetCacheTTL: cfg.DatasetCacheTTL, + } +} + +func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker { + baseLogger := logging.NewLogger(slog.LevelInfo, false) + ctx := logging.EnsureTrace(context.Background()) + ctx, cancel := context.WithCancel(ctx) + logger := baseLogger.Component(ctx, "worker") + if cfg == nil { + cfg = &Config{} + } + if cfg.DatasetCacheTTL == 0 { + cfg.DatasetCacheTTL = datasetCacheDefaultTTL + } + return &Worker{ + id: cfg.WorkerID, + config: cfg, + logger: logger, + queue: tq, + metrics: &metrics.Metrics{}, + ctx: ctx, + cancel: cancel, + running: make(map[string]context.CancelFunc), + datasetCache: make(map[string]time.Time), + datasetCacheTTL: cfg.DatasetCacheTTL, + } +} + +func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker { + w := NewTestWorkerWithQueue(cfg, tq) + w.jupyter = jm + return w +} diff --git a/internal/worker/execution.go b/internal/worker/execution.go new file mode 100644 index 0000000..12ce4bd --- /dev/null +++ b/internal/worker/execution.go @@ -0,0 +1,1029 @@ +package worker + +import ( + "context" + "encoding/hex" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "runtime/debug" + "strconv" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/errtypes" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/telemetry" + "github.com/jfraeys/fetch_ml/internal/tracking" +) + +func gpuVisibleDevicesString(cfg *Config, fallback string) string { + if cfg == nil { + return strings.TrimSpace(fallback) + } + if len(cfg.GPUVisibleDeviceIDs) > 0 { + parts := make([]string, 0, len(cfg.GPUVisibleDeviceIDs)) + for _, id := range cfg.GPUVisibleDeviceIDs { + id = strings.TrimSpace(id) + if id == "" { + continue + } + parts = append(parts, id) + } + return strings.Join(parts, ",") + } + if len(cfg.GPUVisibleDevices) == 0 { + return strings.TrimSpace(fallback) + } + parts := make([]string, 0, len(cfg.GPUVisibleDevices)) + for _, v := range cfg.GPUVisibleDevices { + if v < 0 { + continue + } + parts = append(parts, strconv.Itoa(v)) + } + return strings.Join(parts, ",") +} + +func filterExistingDevicePaths(paths []string) []string { + if len(paths) == 0 { + return nil + } + seen := make(map[string]struct{}, len(paths)) + out := make([]string, 0, len(paths)) + for _, p := range paths { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := seen[p]; ok { + continue + } + if _, err := os.Stat(p); err != nil { + continue + } + seen[p] = struct{}{} + out = append(out, p) + } + return out +} + +func gpuVisibleEnvVarName(cfg *Config) string { + if cfg == nil { + return "CUDA_VISIBLE_DEVICES" + } + switch strings.ToLower(strings.TrimSpace(cfg.GPUVendor)) { + case "amd": + return "HIP_VISIBLE_DEVICES" + case string(GPUTypeApple), string(GPUTypeNone): + return "" + default: + return "CUDA_VISIBLE_DEVICES" + } +} + +func runIDForTask(task *queue.Task) string { + if task == nil { + return "" + } + created := task.CreatedAt + if created.IsZero() { + created = time.Now().UTC() + } + short := task.ID + if len(short) > 8 { + short = short[:8] + } + job := strings.TrimSpace(task.JobName) + if job == "" { + job = "job" + } + return fmt.Sprintf("run-%s-%s-%s", job, created.UTC().Format("20060102-150405"), short) +} + +func (w *Worker) buildInitialRunManifest( + task *queue.Task, + podmanImage string, +) *manifest.RunManifest { + if task == nil { + return nil + } + m := manifest.NewRunManifest(runIDForTask(task), task.ID, task.JobName, task.CreatedAt) + m.PodmanImage = strings.TrimSpace(podmanImage) + if host, err := os.Hostname(); err == nil { + m.WorkerHost = strings.TrimSpace(host) + } + if info, ok := debug.ReadBuildInfo(); ok && info != nil { + m.WorkerVersion = strings.TrimSpace(info.Main.Version) + } + if task.Metadata != nil { + m.CommitID = strings.TrimSpace(task.Metadata["commit_id"]) + m.ExperimentManifestSHA = strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"]) + m.DepsManifestName = strings.TrimSpace(task.Metadata["deps_manifest_name"]) + m.DepsManifestSHA = strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) + m.SnapshotSHA256 = strings.TrimSpace(task.Metadata["snapshot_sha256"]) + // Forward compatibility: copy selected metadata keys verbatim. + for k, v := range task.Metadata { + if strings.TrimSpace(k) == "" || strings.TrimSpace(v) == "" { + continue + } + m.Metadata[k] = v + } + } + m.SnapshotID = strings.TrimSpace(task.SnapshotID) + m.Metadata["task_args"] = strings.TrimSpace(task.Args) + return m +} + +func (w *Worker) upsertRunManifest( + dir string, + task *queue.Task, + mutate func(m *manifest.RunManifest), +) { + if strings.TrimSpace(dir) == "" { + return + } + if task == nil { + return + } + + cur, err := manifest.LoadFromDir(dir) + if err != nil { + cur = w.buildInitialRunManifest(task, w.config.PodmanImage) + } + if cur == nil { + return + } + if mutate != nil { + mutate(cur) + } + if err := cur.WriteToDir(dir); err != nil { + w.logger.Warn( + "failed to write run manifest", + "job", task.JobName, + "task_id", task.ID, + "error", err, + ) + } +} + +func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error { + sid := strings.TrimSpace(snapshotID) + if sid == "" { + return nil + } + if err := container.ValidateJobName(sid); err != nil { + return err + } + if strings.TrimSpace(taskID) == "" { + return fmt.Errorf("missing task id") + } + if strings.TrimSpace(jobDir) == "" { + return fmt.Errorf("missing job dir") + } + src := filepath.Join(dataDir, "snapshots", sid) + return StageSnapshotFromPath(basePath, taskID, src, jobDir) +} + +func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error { + if strings.TrimSpace(basePath) == "" { + return fmt.Errorf("missing base path") + } + if strings.TrimSpace(taskID) == "" { + return fmt.Errorf("missing task id") + } + if strings.TrimSpace(jobDir) == "" { + return fmt.Errorf("missing job dir") + } + + dst := filepath.Join(jobDir, "snapshot") + _ = os.RemoveAll(dst) + + prewarmSrc := filepath.Join(basePath, ".prewarm", "snapshots", taskID) + if info, err := os.Stat(prewarmSrc); err == nil && info.IsDir() { + // TODO: Emit Prometheus prewarm snapshot hit metric when available + return os.Rename(prewarmSrc, dst) + } + // TODO: Emit Prometheus prewarm snapshot miss metric when available + + return copyDir(srcPath, dst) +} + +func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevices string) error { + visibleDevices := gpuVisibleDevicesString(w.config, cudaVisibleDevices) + visibleEnvVar := gpuVisibleEnvVarName(w.config) + + // Validate job name to prevent path traversal + if err := container.ValidateJobName(task.JobName); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + + jobDir, outputDir, logFile, err := w.setupJobDirectories(task) + if err != nil { + return err + } + + // Best-effort: write initial run manifest into pending dir so it follows the job via rename. + w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) { + m.TrainScriptPath = strings.TrimSpace(w.config.TrainScript) + if strings.TrimSpace(w.config.Host) != "" { + m.Metadata["worker_config_host"] = strings.TrimSpace(w.config.Host) + } + m.Metadata["task_args"] = strings.TrimSpace(task.Args) + m.MarkStarted(time.Now().UTC()) + m.GPUDevices = w.getGPUDevicePaths() + if strings.TrimSpace(visibleEnvVar) != "" { + m.Metadata["gpu_visible_devices"] = strings.TrimSpace(visibleDevices) + m.Metadata["gpu_visible_env"] = strings.TrimSpace(visibleEnvVar) + } + }) + + if err := w.stageExperimentFiles(task, jobDir); err != nil { + w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) { + now := time.Now().UTC() + exitCode := 1 + m.MarkFinished(now, &exitCode, err) + m.Metadata["failure_phase"] = "stage_experiment_files" + }) + failedDir := filepath.Join(config.NewJobPaths(w.config.BasePath).FailedPath(), task.JobName) + _ = os.MkdirAll(filepath.Dir(failedDir), 0750) + _ = os.RemoveAll(failedDir) + _ = os.Rename(jobDir, failedDir) + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + if err := w.stageSnapshot(ctx, task, jobDir); err != nil { + w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) { + now := time.Now().UTC() + exitCode := 1 + m.MarkFinished(now, &exitCode, err) + m.Metadata["failure_phase"] = "stage_snapshot" + }) + failedDir := filepath.Join(config.NewJobPaths(w.config.BasePath).FailedPath(), task.JobName) + _ = os.MkdirAll(filepath.Dir(failedDir), 0750) + _ = os.RemoveAll(failedDir) + _ = os.Rename(jobDir, failedDir) + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + + return w.executeJob(ctx, task, jobDir, outputDir, logFile, visibleDevices, visibleEnvVar) +} + +func (w *Worker) RunJob(ctx context.Context, task *queue.Task, cudaVisibleDevices string) error { + return w.runJob(ctx, task, cudaVisibleDevices) +} + +func (w *Worker) stageSnapshot(ctx context.Context, task *queue.Task, jobDir string) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if strings.TrimSpace(task.SnapshotID) == "" { + return nil + } + if task.Metadata == nil { + return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID) + } + want := task.Metadata["snapshot_sha256"] + resolved, err := ResolveSnapshot( + ctx, + w.config.DataDir, + &w.config.SnapshotStore, + task.SnapshotID, + want, + nil, + ) + if err != nil { + return err + } + return StageSnapshotFromPath(w.config.BasePath, task.ID, resolved, jobDir) +} + +func (w *Worker) validateTaskForExecution(_ context.Context, task *queue.Task) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if err := container.ValidateJobName(task.JobName); err != nil { + return err + } + if task.Metadata == nil { + return fmt.Errorf("missing task metadata") + } + commitID, ok := task.Metadata["commit_id"] + if !ok || commitID == "" { + return fmt.Errorf("missing commit_id") + } + if len(commitID) != 40 { + return fmt.Errorf("invalid commit_id length") + } + if _, err := hex.DecodeString(commitID); err != nil { + return fmt.Errorf("invalid commit_id: %w", err) + } + + expMgr := experiment.NewManager(w.config.BasePath) + meta, err := expMgr.ReadMetadata(commitID) + if err != nil { + return fmt.Errorf("missing or unreadable experiment metadata: %w", err) + } + if meta.CommitID != commitID { + return fmt.Errorf("experiment metadata commit_id mismatch") + } + + filesPath := expMgr.GetFilesPath(commitID) + trainScriptHostPath := filepath.Join(filesPath, w.config.TrainScript) + if _, err := os.Stat(trainScriptHostPath); err != nil { + return fmt.Errorf("missing train script: %w", err) + } + if _, err := selectDependencyManifest(filesPath); err != nil { + return err + } + + // Validate content integrity manifest + if err := expMgr.ValidateManifest(commitID); err != nil { + return fmt.Errorf("content integrity validation failed: %w", err) + } + + return nil +} + +func (w *Worker) podmanImageDigest(ctx context.Context, imageRef string) string { + ref := strings.TrimSpace(imageRef) + if ref == "" { + return "" + } + inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + cmd := exec.CommandContext(inspectCtx, "podman", "image", "inspect", "--format", "{{.Id}}", ref) + out, err := cmd.CombinedOutput() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +} + +func (w *Worker) stageExperimentFiles(task *queue.Task, jobDir string) error { + if task == nil { + return fmt.Errorf("task is nil") + } + if task.Metadata == nil { + return fmt.Errorf("missing task metadata") + } + commitID, ok := task.Metadata["commit_id"] + if !ok || commitID == "" { + return fmt.Errorf("missing commit_id") + } + + expMgr := experiment.NewManager(w.config.BasePath) + src := expMgr.GetFilesPath(commitID) + dst := filepath.Join(jobDir, "code") + + if err := copyDir(src, dst); err != nil { + return err + } + + return nil +} + +func copyDir(src, dst string) error { + src = filepath.Clean(src) + dst = filepath.Clean(dst) + + srcInfo, err := os.Stat(src) + if err != nil { + return err + } + if !srcInfo.IsDir() { + return fmt.Errorf("source is not a directory") + } + + if err := os.MkdirAll(dst, 0750); err != nil { + return err + } + + return filepath.WalkDir(src, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + rel, err := filepath.Rel(src, path) + if err != nil { + return err + } + rel = filepath.Clean(rel) + if rel == "." { + return nil + } + if strings.HasPrefix(rel, "..") { + return fmt.Errorf("invalid relative path") + } + outPath := filepath.Join(dst, rel) + if d.IsDir() { + return os.MkdirAll(outPath, 0750) + } + + info, err := d.Info() + if err != nil { + return err + } + mode := info.Mode() & 0777 + in, err := os.Open(filepath.Clean(path)) + if err != nil { + return err + } + defer func() { _ = in.Close() }() + out, err := fileutil.SecureOpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return err + } + defer func() { _ = out.Close() }() + _, err = io.Copy(out, in) + return err + }) +} + +func (w *Worker) setupJobDirectories( + task *queue.Task, +) (jobDir, outputDir, logFile string, err error) { + jobPaths := config.NewJobPaths(w.config.BasePath) + pendingDir := jobPaths.PendingPath() + jobDir = filepath.Join(pendingDir, task.JobName) + outputDir = filepath.Join(jobPaths.RunningPath(), task.JobName) + logFile = filepath.Join(outputDir, "output.log") + + // Create pending directory + if err := os.MkdirAll(pendingDir, 0750); err != nil { + return "", "", "", &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to create pending dir: %w", err), + } + } + + // Create job directory in pending + if err := os.MkdirAll(jobDir, 0750); err != nil { + return "", "", "", &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to create job dir: %w", err), + } + } + + // Sanitize paths + jobDir, err = container.SanitizePath(jobDir) + if err != nil { + return "", "", "", &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + outputDir, err = container.SanitizePath(outputDir) + if err != nil { + return "", "", "", &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + + return jobDir, outputDir, logFile, nil +} + +func (w *Worker) executeJob( + ctx context.Context, + task *queue.Task, + jobDir, outputDir, logFile string, + visibleDevices string, + visibleEnvVar string, +) error { + // Create output directory + if _, err := telemetry.ExecWithMetrics( + w.logger, + "create output dir", + 100*time.Millisecond, + func() (string, error) { + if err := os.MkdirAll(outputDir, 0750); err != nil { + return "", fmt.Errorf("mkdir failed: %w", err) + } + return "", nil + }, + ); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to create output dir: %w", err), + } + } + + // Move job from pending to running + stagingStart := time.Now() + if _, err := telemetry.ExecWithMetrics( + w.logger, + "stage job", + 100*time.Millisecond, + func() (string, error) { + // Remove existing directory if it exists + if _, err := os.Stat(outputDir); err == nil { + if err := os.RemoveAll(outputDir); err != nil { + return "", fmt.Errorf("remove existing failed: %w", err) + } + } + if err := os.Rename(jobDir, outputDir); err != nil { + return "", fmt.Errorf("rename failed: %w", err) + } + return "", nil + }, + ); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "setup", + Err: fmt.Errorf("failed to move job: %w", err), + } + } + stagingDuration := time.Since(stagingStart) + + // Best-effort: record staging duration in running dir. + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + m.StagingDurationMS = stagingDuration.Milliseconds() + m.GPUDevices = w.getGPUDevicePaths() + if strings.TrimSpace(visibleEnvVar) != "" { + m.Metadata["gpu_visible_devices"] = strings.TrimSpace(visibleDevices) + m.Metadata["gpu_visible_env"] = strings.TrimSpace(visibleEnvVar) + } + }) + + // Execute job + if w.config.LocalMode { + execStart := time.Now() + err := w.executeLocalJob(ctx, task, outputDir, logFile, visibleDevices, visibleEnvVar) + execDuration := time.Since(execStart) + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + now := time.Now().UTC() + m.ExecutionDurationMS = execDuration.Milliseconds() + if err != nil { + exitCode := 1 + m.MarkFinished(now, &exitCode, err) + } else { + exitCode := 0 + m.MarkFinished(now, &exitCode, nil) + } + }) + + finalizeStart := time.Now() + jobPaths := config.NewJobPaths(w.config.BasePath) + var dest string + if err != nil { + dest = filepath.Join(jobPaths.FailedPath(), task.JobName) + } else { + dest = filepath.Join(jobPaths.FinishedPath(), task.JobName) + } + _ = os.MkdirAll(filepath.Dir(dest), 0750) + _ = os.RemoveAll(dest) + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds() + }) + if moveErr := os.Rename(outputDir, dest); moveErr != nil { + w.logger.Warn("failed to move local-mode job dir", "job", task.JobName, "error", moveErr) + } + return err + } + + return w.executeContainerJob( + ctx, + task, + outputDir, + logFile, + stagingDuration, + visibleDevices, + visibleEnvVar, + ) +} + +func (w *Worker) executeLocalJob( + ctx context.Context, + task *queue.Task, + outputDir, logFile string, + visibleDevices string, + visibleEnvVar string, +) error { + // Create experiment script + scriptContent := `#!/bin/bash +set -e + +echo "Starting experiment: ` + task.JobName + `" +echo "Task ID: ` + task.ID + `" +echo "Timestamp: $(date)" + +# Simulate ML experiment +echo "Loading data..." +sleep 1 + +echo "Training model..." +sleep 2 + +echo "Evaluating model..." +sleep 1 + +# Generate results +ACCURACY=0.95 +LOSS=0.05 +EPOCHS=10 + +echo "" +echo "=== EXPERIMENT RESULTS ===" +echo "Accuracy: $ACCURACY" +echo "Loss: $LOSS" +echo "Epochs: $EPOCHS" +echo "Status: SUCCESS" +echo "=========================" +echo "Experiment completed successfully!" +` + + scriptPath := filepath.Join(outputDir, "run.sh") + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "execution", + Err: fmt.Errorf("failed to write script: %w", err), + } + } + + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + m.Command = "bash" + m.Args = scriptPath + }) + + logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err) + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "execution", + Err: fmt.Errorf("failed to open log file: %w", err), + } + } + defer func() { + if err := logFileHandle.Close(); err != nil { + log.Printf("Warning: failed to close log file: %v", err) + } + }() + + // Execute the script directly + localCmd := exec.CommandContext(ctx, "bash", scriptPath) + env := os.Environ() + if strings.TrimSpace(visibleEnvVar) != "" { + env = append(env, fmt.Sprintf("%s=%s", visibleEnvVar, strings.TrimSpace(visibleDevices))) + } + snap := filepath.Join(outputDir, "snapshot") + if info, err := os.Stat(snap); err == nil && info.IsDir() { + env = append(env, fmt.Sprintf("FETCH_ML_SNAPSHOT_DIR=%s", snap)) + if strings.TrimSpace(task.SnapshotID) != "" { + env = append(env, fmt.Sprintf("FETCH_ML_SNAPSHOT_ID=%s", strings.TrimSpace(task.SnapshotID))) + } + } + localCmd.Env = env + localCmd.Stdout = logFileHandle + localCmd.Stderr = logFileHandle + + w.logger.Info("executing local job", + "job", task.JobName, + "task_id", task.ID, + "script", scriptPath) + + if err := localCmd.Run(); err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "execution", + Err: fmt.Errorf("execution failed: %w", err), + } + } + + return nil +} + +func (w *Worker) executeContainerJob( + ctx context.Context, + task *queue.Task, + outputDir, logFile string, + stagingDuration time.Duration, + visibleDevices string, + visibleEnvVar string, +) error { + containerResults := w.config.ContainerResults + if containerResults == "" { + containerResults = config.DefaultContainerResults + } + + containerWorkspace := w.config.ContainerWorkspace + if containerWorkspace == "" { + containerWorkspace = config.DefaultContainerWorkspace + } + + jobPaths := config.NewJobPaths(w.config.BasePath) + stagingStart := time.Now() + + // Optional: provision tracking tools for this task + var trackingEnv map[string]string + if w.trackingRegistry != nil && task.Tracking != nil { + configs := make(map[string]tracking.ToolConfig) + + if task.Tracking.MLflow != nil && task.Tracking.MLflow.Enabled { + mode := tracking.ModeSidecar + if task.Tracking.MLflow.Mode != "" { + mode = tracking.ToolMode(task.Tracking.MLflow.Mode) + } + configs["mlflow"] = tracking.ToolConfig{ + Enabled: true, + Mode: mode, + Settings: map[string]any{ + "job_name": task.JobName, + "tracking_uri": task.Tracking.MLflow.TrackingURI, + }, + } + } + + if task.Tracking.TensorBoard != nil && task.Tracking.TensorBoard.Enabled { + mode := tracking.ModeSidecar + if task.Tracking.TensorBoard.Mode != "" { + mode = tracking.ToolMode(task.Tracking.TensorBoard.Mode) + } + configs["tensorboard"] = tracking.ToolConfig{ + Enabled: true, + Mode: mode, + Settings: map[string]any{ + "job_name": task.JobName, + }, + } + } + + if task.Tracking.Wandb != nil && task.Tracking.Wandb.Enabled { + mode := tracking.ModeRemote + if task.Tracking.Wandb.Mode != "" { + mode = tracking.ToolMode(task.Tracking.Wandb.Mode) + } + configs["wandb"] = tracking.ToolConfig{ + Enabled: true, + Mode: mode, + Settings: map[string]any{ + "api_key": task.Tracking.Wandb.APIKey, + "project": task.Tracking.Wandb.Project, + "entity": task.Tracking.Wandb.Entity, + }, + } + } + + if len(configs) > 0 { + var err error + trackingEnv, err = w.trackingRegistry.ProvisionAll(ctx, task.ID, configs) + if err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "tracking_provision", + Err: err, + } + } + defer w.trackingRegistry.TeardownAll(context.Background(), task.ID) + } + } + + var volumes map[string]string + if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok { + volumes = make(map[string]string) + // Mount to /tracking/tensorboard inside container + containerPath := "/tracking/tensorboard" + volumes[val] = containerPath + ":rw" + + // Update environment variable for the script + trackingEnv["TENSORBOARD_LOG_DIR"] = containerPath + // Remove the host path from Env to avoid leaking host info + delete(trackingEnv, "TENSORBOARD_HOST_LOG_DIR") + } + + if trackingEnv == nil { + trackingEnv = make(map[string]string) + } + if strings.TrimSpace(visibleEnvVar) != "" { + trackingEnv[visibleEnvVar] = strings.TrimSpace(visibleDevices) + } + snap := filepath.Join(outputDir, "snapshot") + if info, err := os.Stat(snap); err == nil && info.IsDir() { + trackingEnv["FETCH_ML_SNAPSHOT_DIR"] = "/snapshot" + if strings.TrimSpace(task.SnapshotID) != "" { + trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID) + } + if volumes == nil { + volumes = make(map[string]string) + } + volumes[snap] = "/snapshot:ro" + } + + cpusOverride, memOverride := container.PodmanResourceOverrides(task.CPU, task.MemoryGB) + + selectedImage := w.config.PodmanImage + if w.envPool != nil && + !w.config.LocalMode && + strings.TrimSpace(w.config.PodmanImage) != "" && + task != nil && + task.Metadata != nil { + depsSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) + if depsSHA != "" { + if warmTag, err := w.envPool.WarmImageTag(depsSHA); err == nil { + inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + exists, ierr := w.envPool.ImageExists(inspectCtx, warmTag) + cancel() + if ierr == nil && exists { + selectedImage = warmTag + } + } + } + } + + podmanCfg := container.PodmanConfig{ + Image: selectedImage, + Workspace: filepath.Join(outputDir, "code"), + Results: filepath.Join(outputDir, "results"), + ContainerWorkspace: containerWorkspace, + ContainerResults: containerResults, + AppleGPU: w.config.AppleGPU.Enabled, + GPUDevices: w.getGPUDevicePaths(), + Env: trackingEnv, + Volumes: volumes, + Memory: memOverride, + CPUs: cpusOverride, + } + + scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript) + manifestName, err := selectDependencyManifest(filepath.Join(outputDir, "code")) + if err != nil { + return &errtypes.TaskExecutionError{ + TaskID: task.ID, + JobName: task.JobName, + Phase: "validation", + Err: err, + } + } + depsPath := filepath.Join(containerWorkspace, manifestName) + + var extraArgs []string + if task.Args != "" { + extraArgs = strings.Fields(task.Args) + } + + ioBefore, ioErr := telemetry.ReadProcessIO() + podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, depsPath, extraArgs) + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + m.PodmanImage = strings.TrimSpace(selectedImage) + m.ImageDigest = strings.TrimSpace(w.podmanImageDigest(ctx, selectedImage)) + m.Command = podmanCmd.Path + if len(podmanCmd.Args) > 1 { + m.Args = strings.Join(podmanCmd.Args[1:], " ") + } else { + m.Args = "" + } + }) + logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err == nil { + podmanCmd.Stdout = logFileHandle + podmanCmd.Stderr = logFileHandle + } else { + w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err) + } + + w.logger.Info("executing podman job", + "job", task.JobName, + "image", selectedImage, + "workspace", podmanCfg.Workspace, + "results", podmanCfg.Results) + + containerStart := time.Now() + if err := podmanCmd.Run(); err != nil { + containerDuration := time.Since(containerStart) + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + now := time.Now().UTC() + exitCode := 1 + m.ExecutionDurationMS = containerDuration.Milliseconds() + m.MarkFinished(now, &exitCode, err) + }) + // Move job to failed directory + failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) + if _, moveErr := telemetry.ExecWithMetrics( + w.logger, + "move failed job", + 100*time.Millisecond, + func() (string, error) { + if err := os.Rename(outputDir, failedDir); err != nil { + return "", fmt.Errorf("rename to failed failed: %w", err) + } + return "", nil + }); moveErr != nil { + w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr) + } + + if ioErr == nil { + if after, err := telemetry.ReadProcessIO(); err == nil { + delta := telemetry.DiffIO(ioBefore, after) + w.logger.Debug("worker io stats", + "job", task.JobName, + "read_bytes", delta.ReadBytes, + "write_bytes", delta.WriteBytes) + } + } + w.logger.Info("job timing (failure)", + "job", task.JobName, + "staging_ms", stagingDuration.Milliseconds(), + "container_ms", containerDuration.Milliseconds(), + "finalize_ms", 0, + "total_ms", time.Since(stagingStart).Milliseconds(), + ) + return fmt.Errorf("execution failed: %w", err) + } + containerDuration := time.Since(containerStart) + + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + m.ExecutionDurationMS = containerDuration.Milliseconds() + }) + + finalizeStart := time.Now() + // Move job to finished directory + finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) + // Best-effort: finalize manifest before moving the directory. + w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) { + now := time.Now().UTC() + exitCode := 0 + m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds() + m.MarkFinished(now, &exitCode, nil) + }) + if _, moveErr := telemetry.ExecWithMetrics( + w.logger, + "finalize job", + 100*time.Millisecond, + func() (string, error) { + if err := os.Rename(outputDir, finishedDir); err != nil { + return "", fmt.Errorf("rename to finished failed: %w", err) + } + return "", nil + }); moveErr != nil { + w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr) + } + finalizeDuration := time.Since(finalizeStart) + totalDuration := time.Since(stagingStart) + var ioDelta telemetry.IOStats + if ioErr == nil { + if after, err := telemetry.ReadProcessIO(); err == nil { + ioDelta = telemetry.DiffIO(ioBefore, after) + } + } + + w.logger.Info("job timing", + "job", task.JobName, + "staging_ms", stagingDuration.Milliseconds(), + "container_ms", containerDuration.Milliseconds(), + "finalize_ms", finalizeDuration.Milliseconds(), + "total_ms", totalDuration.Milliseconds(), + "io_read_bytes", ioDelta.ReadBytes, + "io_write_bytes", ioDelta.WriteBytes, + ) + + return nil +} + +// getGPUDevicePaths returns the appropriate GPU device paths based on configuration +func (w *Worker) getGPUDevicePaths() []string { + if w != nil && w.config != nil { + if len(w.config.GPUDevices) > 0 { + return filterExistingDevicePaths(w.config.GPUDevices) + } + } + detector := w.getGPUDetector() + return filterExistingDevicePaths(detector.GetDevicePaths()) +} diff --git a/internal/worker/gpu_detector.go b/internal/worker/gpu_detector.go new file mode 100644 index 0000000..86b1a4c --- /dev/null +++ b/internal/worker/gpu_detector.go @@ -0,0 +1,168 @@ +package worker + +import ( + "os" + "path/filepath" + "strings" +) + +// GPUType represents different GPU types +type GPUType string + +const ( + GPUTypeNVIDIA GPUType = "nvidia" + GPUTypeApple GPUType = "apple" + GPUTypeNone GPUType = "none" +) + +// GPUDetector interface for detecting GPU availability +type GPUDetector interface { + DetectGPUCount() int + GetGPUType() GPUType + GetDevicePaths() []string +} + +// NVIDIA GPUDetector implementation +type NVIDIADetector struct{} + +func (d *NVIDIADetector) DetectGPUCount() int { + if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 { + return n + } + // Could use nvidia-sml or other detection methods here + return 0 +} + +func (d *NVIDIADetector) GetGPUType() GPUType { + return GPUTypeNVIDIA +} + +func (d *NVIDIADetector) GetDevicePaths() []string { + // Prefer standard NVIDIA device nodes when present. + patterns := []string{ + "/dev/nvidiactl", + "/dev/nvidia-modeset", + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + "/dev/nvidia*", + } + seen := make(map[string]struct{}) + out := make([]string, 0, 8) + for _, pat := range patterns { + if filepath.Base(pat) == pat { + continue + } + if strings.Contains(pat, "*") { + matches, _ := filepath.Glob(pat) + for _, m := range matches { + if _, ok := seen[m]; ok { + continue + } + if _, err := os.Stat(m); err != nil { + continue + } + seen[m] = struct{}{} + out = append(out, m) + } + continue + } + if _, ok := seen[pat]; ok { + continue + } + if _, err := os.Stat(pat); err != nil { + continue + } + seen[pat] = struct{}{} + out = append(out, pat) + } + // Fallback for non-NVIDIA setups where only generic DRM device exists. + if len(out) == 0 { + if _, err := os.Stat("/dev/dri"); err == nil { + out = append(out, "/dev/dri") + } + } + return out +} + +// Apple M-series GPUDetector implementation +type AppleDetector struct { + enabled bool +} + +func (d *AppleDetector) DetectGPUCount() int { + if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 { + return n + } + if d.enabled { + return 1 + } + return 0 +} + +func (d *AppleDetector) GetGPUType() GPUType { + return GPUTypeApple +} + +func (d *AppleDetector) GetDevicePaths() []string { + return []string{"/dev/metal", "/dev/mps"} +} + +// None GPUDetector implementation +type NoneDetector struct{} + +func (d *NoneDetector) DetectGPUCount() int { + return 0 +} + +func (d *NoneDetector) GetGPUType() GPUType { + return GPUTypeNone +} + +func (d *NoneDetector) GetDevicePaths() []string { + return nil +} + +// GPUDetectorFactory creates appropriate GPU detector based on config +type GPUDetectorFactory struct{} + +func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector { + // Check for explicit environment override + if gpuType := os.Getenv("FETCH_ML_GPU_TYPE"); gpuType != "" { + switch gpuType { + case string(GPUTypeNVIDIA): + return &NVIDIADetector{} + case string(GPUTypeApple): + return &AppleDetector{enabled: true} + case string(GPUTypeNone): + return &NoneDetector{} + } + } + + // Respect configured vendor when explicitly set. + if cfg != nil { + switch GPUType(cfg.GPUVendor) { + case GPUTypeApple: + return &AppleDetector{enabled: cfg.AppleGPU.Enabled} + case GPUTypeNone: + return &NoneDetector{} + case GPUTypeNVIDIA: + return &NVIDIADetector{} + case "amd": + // AMD uses similar device exposure patterns in this codebase. + return &NVIDIADetector{} + } + } + + // Auto-detect based on config + if cfg != nil { + if cfg.AppleGPU.Enabled { + return &AppleDetector{enabled: true} + } + if len(cfg.GPUDevices) > 0 { + return &NVIDIADetector{} + } + } + + // Default to no GPU + return &NoneDetector{} +} diff --git a/internal/worker/jupyter_task.go b/internal/worker/jupyter_task.go new file mode 100644 index 0000000..dd9a0dc --- /dev/null +++ b/internal/worker/jupyter_task.go @@ -0,0 +1,130 @@ +package worker + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +const ( + jupyterTaskTypeKey = "task_type" + jupyterTaskTypeValue = "jupyter" + jupyterTaskActionKey = "jupyter_action" + jupyterActionStart = "start" + jupyterActionStop = "stop" + jupyterActionRemove = "remove" + jupyterActionRestore = "restore" + jupyterActionList = "list" + jupyterNameKey = "jupyter_name" + jupyterWorkspaceKey = "jupyter_workspace" + jupyterServiceIDKey = "jupyter_service_id" + jupyterTaskOutputType = "jupyter_output" +) + +type jupyterTaskOutput struct { + Type string `json:"type"` + Service *jupyter.JupyterService `json:"service,omitempty"` + Services []*jupyter.JupyterService `json:"services"` + RestorePath string `json:"restore_path,omitempty"` +} + +func isJupyterTask(task *queue.Task) bool { + if task == nil || task.Metadata == nil { + return false + } + return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue +} + +func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { + if w == nil { + return nil, fmt.Errorf("worker is nil") + } + if task == nil { + return nil, fmt.Errorf("task is nil") + } + if w.jupyter == nil { + return nil, fmt.Errorf("jupyter manager not configured") + } + if task.Metadata == nil { + return nil, fmt.Errorf("missing task metadata") + } + + action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey])) + if action == "" { + return nil, fmt.Errorf("missing jupyter action") + } + + // Validate job name since it is used as the task status key and shows up in logs. + if err := container.ValidateJobName(task.JobName); err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + switch action { + case jupyterActionStart: + name := strings.TrimSpace(task.Metadata[jupyterNameKey]) + ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey]) + if name == "" { + return nil, fmt.Errorf("missing jupyter name") + } + if ws == "" { + return nil, fmt.Errorf("missing jupyter workspace") + } + service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws}) + if err != nil { + return nil, err + } + out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service} + return json.Marshal(out) + case jupyterActionStop: + serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey]) + if serviceID == "" { + return nil, fmt.Errorf("missing jupyter service id") + } + if err := w.jupyter.StopService(ctx, serviceID); err != nil { + return nil, err + } + out := jupyterTaskOutput{Type: jupyterTaskOutputType} + return json.Marshal(out) + case jupyterActionRemove: + serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey]) + if serviceID == "" { + return nil, fmt.Errorf("missing jupyter service id") + } + purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true") + if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil { + return nil, err + } + out := jupyterTaskOutput{Type: jupyterTaskOutputType} + return json.Marshal(out) + case jupyterActionList: + services := w.jupyter.ListServices() + out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services} + return json.Marshal(out) + case jupyterActionRestore: + name := strings.TrimSpace(task.Metadata[jupyterNameKey]) + if name == "" { + return nil, fmt.Errorf("missing jupyter name") + } + restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name) + if err != nil { + return nil, err + } + out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath} + return json.Marshal(out) + default: + return nil, fmt.Errorf("invalid jupyter action: %s", action) + } +} + +func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { + return w.runJupyterTask(ctx, task) +} diff --git a/internal/worker/runloop.go b/internal/worker/runloop.go new file mode 100644 index 0000000..3ced976 --- /dev/null +++ b/internal/worker/runloop.go @@ -0,0 +1,525 @@ +package worker + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/resources" +) + +// Start starts the worker's main processing loop. +func (w *Worker) Start() { + w.logger.Info("worker started", + "worker_id", w.id, + "max_concurrent", w.config.MaxWorkers, + "poll_interval", w.config.PollInterval) + + go w.heartbeat() + if w.config != nil && w.config.PrewarmEnabled { + go w.prewarmNextLoop() + go w.prewarmImageGCLoop() + } + + for { + switch { + case w.ctx.Err() != nil: + w.logger.Info("shutdown signal received, waiting for tasks") + w.waitForTasks() + return + default: + } + if w.runningCount() >= w.config.MaxWorkers { + time.Sleep(50 * time.Millisecond) + continue + } + + queueStart := time.Now() + blockTimeout := time.Duration(w.config.PollInterval) * time.Second + task, err := w.queue.GetNextTaskWithLeaseBlocking( + w.config.WorkerID, + w.config.TaskLeaseDuration, + blockTimeout, + ) + queueLatency := time.Since(queueStart) + if err != nil { + if err == context.DeadlineExceeded { + continue + } + w.logger.Error("error fetching task", + "worker_id", w.id, + "error", err) + continue + } + + if task == nil { + if queueLatency > 200*time.Millisecond { + w.logger.Debug("queue poll latency", + "latency_ms", queueLatency.Milliseconds()) + } + continue + } + + if depth, derr := w.queue.QueueDepth(); derr == nil { + if queueLatency > 100*time.Millisecond || depth > 0 { + w.logger.Debug("queue fetch metrics", + "latency_ms", queueLatency.Milliseconds(), + "remaining_depth", depth) + } + } else if queueLatency > 100*time.Millisecond { + w.logger.Debug("queue fetch metrics", + "latency_ms", queueLatency.Milliseconds(), + "depth_error", derr) + } + + // Reserve a running slot *before* starting the goroutine so we don't drain + // the entire queue while max_workers is 1. + w.reserveRunningSlot(task.ID) + go w.executeTaskWithLease(task) + } +} + +func (w *Worker) reserveRunningSlot(taskID string) { + w.runningMu.Lock() + defer w.runningMu.Unlock() + if w.running == nil { + w.running = make(map[string]context.CancelFunc) + } + // Track a cancel func for future shutdown handling; currently best-effort. + _, cancel := context.WithCancel(w.ctx) + w.running[taskID] = cancel +} + +func (w *Worker) releaseRunningSlot(taskID string) { + w.runningMu.Lock() + defer w.runningMu.Unlock() + if w.running == nil { + return + } + if cancel, ok := w.running[taskID]; ok { + cancel() + delete(w.running, taskID) + } +} + +func (w *Worker) prewarmImageGCLoop() { + if w.config == nil || !w.config.PrewarmEnabled { + return + } + if w.envPool == nil { + return + } + if w.config.LocalMode { + return + } + if strings.TrimSpace(w.config.PodmanImage) == "" { + return + } + + lastSeen := "" + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + runGC := func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + _ = w.envPool.PruneImages(ctx, 24*time.Hour) + } + + for { + select { + case <-w.ctx.Done(): + return + case <-ticker.C: + if w.queue != nil { + v, err := w.queue.PrewarmGCRequestValue() + if err == nil && v != "" && v != lastSeen { + lastSeen = v + runGC() + continue + } + } + runGC() + } + } +} + +func (w *Worker) heartbeat() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-w.ctx.Done(): + return + case <-ticker.C: + if err := w.queue.Heartbeat(w.id); err != nil { + w.logger.Warn("heartbeat failed", + "worker_id", w.id, + "error", err) + } + } + } +} + +func (w *Worker) waitForTasks() { + timeout := time.After(5 * time.Minute) + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-timeout: + w.logger.Warn("shutdown timeout, force stopping", + "running_tasks", len(w.running)) + return + case <-ticker.C: + count := w.runningCount() + if count == 0 { + w.logger.Info("all tasks completed, shutting down") + return + } + w.logger.Debug("waiting for tasks to complete", + "remaining", count) + } + } +} + +func (w *Worker) runningCount() int { + w.runningMu.RLock() + defer w.runningMu.RUnlock() + return len(w.running) +} + +// GetMetrics returns current worker metrics. +func (w *Worker) GetMetrics() map[string]any { + stats := w.metrics.GetStats() + stats["worker_id"] = w.id + stats["max_workers"] = w.config.MaxWorkers + return stats +} + +// Stop gracefully shuts down the worker. +func (w *Worker) Stop() { + w.cancel() + w.waitForTasks() + + // FIXED: Check error return values + if err := w.server.Close(); err != nil { + w.logger.Warn("error closing server connection", "error", err) + } + if err := w.queue.Close(); err != nil { + w.logger.Warn("error closing queue connection", "error", err) + } + if w.metricsSrv != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := w.metricsSrv.Shutdown(ctx); err != nil { + w.logger.Warn("metrics exporter shutdown error", "error", err) + } + } + w.logger.Info("worker stopped", "worker_id", w.id) +} + +// Execute task with lease management and retry. +func (w *Worker) executeTaskWithLease(task *queue.Task) { + defer w.releaseRunningSlot(task.ID) + + // Track task for graceful shutdown + w.gracefulWait.Add(1) + w.activeTasks.Store(task.ID, task) + defer w.gracefulWait.Done() + defer w.activeTasks.Delete(task.ID) + + // Create task-specific context with timeout + taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing + taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata + taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata + + taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour) + defer taskCancel() + + logger := w.logger.Job(taskCtx, task.JobName, task.ID) + logger.Info("starting task", + "worker_id", w.id, + "datasets", task.Datasets, + "priority", task.Priority) + + // Record task start + w.metrics.RecordTaskStart() + defer w.metrics.RecordTaskCompletion() + + // Check for context cancellation + select { + case <-taskCtx.Done(): + logger.Info("task cancelled before execution") + return + default: + } + + // Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline). + if isJupyterTask(task) { + out, err := w.runJupyterTask(taskCtx, task) + endTime := time.Now() + task.EndedAt = &endTime + if err != nil { + logger.Error("jupyter task failed", "error", err) + task.Status = "failed" + task.Error = err.Error() + _ = w.queue.UpdateTaskWithMetrics(task, "final") + w.metrics.RecordTaskFailure() + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) + return + } + if len(out) > 0 { + task.Output = string(out) + } + task.Status = "completed" + _ = w.queue.UpdateTaskWithMetrics(task, "final") + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) + return + } + + // Parse datasets from task arguments + task.Datasets = resolveDatasets(task) + + if err := w.validateTaskForExecution(taskCtx, task); err != nil { + logger.Error("task validation failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Validation failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { + logger.Error("failed to update task status after validation failure", "error", updateErr) + } + w.metrics.RecordTaskFailure() + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) + return + } + + if err := w.enforceTaskProvenance(taskCtx, task); err != nil { + logger.Error("provenance validation failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Provenance validation failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { + logger.Error( + "failed to update task status after provenance validation failure", + "error", updateErr) + } + w.metrics.RecordTaskFailure() + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) + return + } + + lease, err := w.resources.Acquire(taskCtx, task) + if err != nil { + logger.Error("resource acquisition failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Resource acquisition failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { + logger.Error( + "failed to update task status after resource acquisition failure", + "error", updateErr) + } + w.metrics.RecordTaskFailure() + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) + return + } + defer lease.Release() + + // Start heartbeat goroutine + heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) + defer cancelHeartbeat() + + go w.heartbeatLoop(heartbeatCtx, task.ID) + + // Update task status + task.Status = "running" + now := time.Now() + task.StartedAt = &now + task.WorkerID = w.id + + // Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state. + w.recordTaskProvenance(taskCtx, task) + + if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil { + logger.Error("failed to update task status", "error", err) + w.metrics.RecordTaskFailure() + return + } + + if w.config.AutoFetchData && len(task.Datasets) > 0 { + if err := w.fetchDatasets(taskCtx, task); err != nil { + logger.Error("data fetch failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Data fetch failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + err := w.queue.UpdateTask(task) + if err != nil { + logger.Error("failed to update task status after data fetch failure", "error", err) + } + w.metrics.RecordTaskFailure() + return + } + } + if err := w.verifyDatasetSpecs(taskCtx, task); err != nil { + logger.Error("dataset checksum verification failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { + logger.Error( + "failed to update task after dataset checksum verification failure", + "error", updateErr) + } + w.metrics.RecordTaskFailure() + return + } + if err := w.verifySnapshot(taskCtx, task); err != nil { + logger.Error("snapshot checksum verification failed", "error", err) + task.Status = "failed" + task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err) + endTime := time.Now() + task.EndedAt = &endTime + if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil { + logger.Error( + "failed to update task after snapshot checksum verification failure", + "error", updateErr) + } + w.metrics.RecordTaskFailure() + return + } + + // Execute job with panic recovery + var execErr error + func() { + defer func() { + if r := recover(); r != nil { + execErr = fmt.Errorf("panic during execution: %v", r) + } + }() + cudaVisible := resources.FormatCUDAVisibleDevices(lease) + execErr = w.runJob(taskCtx, task, cudaVisible) + }() + + // Finalize task + endTime := time.Now() + task.EndedAt = &endTime + + if execErr != nil { + task.Error = execErr.Error() + + // Check if transient error (network, timeout, etc) + if isTransientError(execErr) && task.RetryCount < task.MaxRetries { + w.logger.Warn("task failed with transient error, will retry", + "task_id", task.ID, + "error", execErr, + "retry_count", task.RetryCount) + _ = w.queue.RetryTask(task) + } else { + task.Status = "failed" + _ = w.queue.UpdateTaskWithMetrics(task, "final") + } + } else { + task.Status = "completed" + _ = w.queue.UpdateTaskWithMetrics(task, "final") + } + + // Release lease + _ = w.queue.ReleaseLease(task.ID, w.config.WorkerID) +} + +// Heartbeat loop to renew lease. +func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) { + ticker := time.NewTicker(w.config.HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil { + w.logger.Error("failed to renew lease", "task_id", taskID, "error", err) + return + } + // Also update worker heartbeat + _ = w.queue.Heartbeat(w.config.WorkerID) + } + } +} + +// Shutdown gracefully shuts down the worker. +func (w *Worker) Shutdown() error { + w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks()) + + // Wait for active tasks with timeout + done := make(chan struct{}) + go func() { + w.gracefulWait.Wait() + close(done) + }() + + timeout := time.After(w.config.GracefulTimeout) + select { + case <-done: + w.logger.Info("all tasks completed, shutdown successful") + case <-timeout: + w.logger.Warn("graceful shutdown timeout, releasing active leases") + w.releaseAllLeases() + } + + return w.queue.Close() +} + +// Release all active leases. +func (w *Worker) releaseAllLeases() { + w.activeTasks.Range(func(key, _ interface{}) bool { + taskID := key.(string) + if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil { + w.logger.Error("failed to release lease", "task_id", taskID, "error", err) + } + return true + }) +} + +// Helper functions. +func (w *Worker) countActiveTasks() int { + count := 0 + w.activeTasks.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + +func isTransientError(err error) bool { + if err == nil { + return false + } + // Check if error is transient (network, timeout, resource unavailable, etc) + errStr := err.Error() + transientIndicators := []string{ + "connection refused", + "timeout", + "temporary failure", + "resource temporarily unavailable", + "no such host", + "network unreachable", + } + for _, indicator := range transientIndicators { + if strings.Contains(strings.ToLower(errStr), indicator) { + return true + } + } + return false +} diff --git a/internal/worker/snapshot_store.go b/internal/worker/snapshot_store.go new file mode 100644 index 0000000..72a4ffd --- /dev/null +++ b/internal/worker/snapshot_store.go @@ -0,0 +1,270 @@ +package worker + +import ( + "archive/tar" + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +type SnapshotFetcher interface { + Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) +} + +type minioSnapshotFetcher struct { + client *minio.Client +} + +func (f *minioSnapshotFetcher) Get(ctx context.Context, bucket, key string) (io.ReadCloser, error) { + obj, err := f.client.GetObject(ctx, bucket, key, minio.GetObjectOptions{}) + if err != nil { + return nil, err + } + return obj, nil +} + +func newMinioSnapshotFetcher(cfg *SnapshotStoreConfig) (*minioSnapshotFetcher, error) { + if cfg == nil { + return nil, fmt.Errorf("missing snapshot store config") + } + endpoint := strings.TrimSpace(cfg.Endpoint) + if endpoint == "" { + return nil, fmt.Errorf("missing snapshot store endpoint") + } + bucket := strings.TrimSpace(cfg.Bucket) + if bucket == "" { + return nil, fmt.Errorf("missing snapshot store bucket") + } + creds := cfg.credentials() + client, err := minio.New(endpoint, &minio.Options{ + Creds: creds, + Secure: cfg.Secure, + Region: strings.TrimSpace(cfg.Region), + MaxRetries: cfg.MaxRetries, + }) + if err != nil { + return nil, err + } + return &minioSnapshotFetcher{client: client}, nil +} + +func (c *SnapshotStoreConfig) credentials() *credentials.Credentials { + if c != nil { + ak := strings.TrimSpace(c.AccessKey) + sk := strings.TrimSpace(c.SecretKey) + st := strings.TrimSpace(c.SessionToken) + if ak != "" && sk != "" { + return credentials.NewStaticV4(ak, sk, st) + } + } + return credentials.NewChainCredentials([]credentials.Provider{ + &credentials.EnvMinio{}, + &credentials.EnvAWS{}, + }) +} + +func ResolveSnapshot( + ctx context.Context, + dataDir string, + cfg *SnapshotStoreConfig, + snapshotID string, + wantSHA256 string, + fetcher SnapshotFetcher, +) (string, error) { + dataDir = strings.TrimSpace(dataDir) + if dataDir == "" { + return "", fmt.Errorf("missing data_dir") + } + snapshotID = strings.TrimSpace(snapshotID) + if snapshotID == "" { + return "", fmt.Errorf("missing snapshot_id") + } + if err := container.ValidateJobName(snapshotID); err != nil { + return "", fmt.Errorf("invalid snapshot_id: %w", err) + } + want, err := normalizeSHA256ChecksumHex(wantSHA256) + if err != nil || want == "" { + return "", fmt.Errorf("invalid snapshot_sha256") + } + + cacheDir := filepath.Join(dataDir, "snapshots", "sha256", want) + if info, err := os.Stat(cacheDir); err == nil && info.IsDir() { + return cacheDir, nil + } + + if cfg == nil || !cfg.Enabled { + return filepath.Join(dataDir, "snapshots", snapshotID), nil + } + + bucket := strings.TrimSpace(cfg.Bucket) + if bucket == "" { + return "", fmt.Errorf("missing snapshot store bucket") + } + prefix := strings.Trim(strings.TrimSpace(cfg.Prefix), "/") + key := snapshotID + ".tar.gz" + if prefix != "" { + key = path.Join(prefix, key) + } + + if fetcher == nil { + mf, err := newMinioSnapshotFetcher(cfg) + if err != nil { + return "", err + } + fetcher = mf + } + + fetchCtx := ctx + if cfg.Timeout > 0 { + var cancel context.CancelFunc + fetchCtx, cancel = context.WithTimeout(ctx, cfg.Timeout) + defer cancel() + } + + rc, err := fetcher.Get(fetchCtx, bucket, key) + if err != nil { + return "", err + } + defer func() { _ = rc.Close() }() + + tmpRoot := filepath.Join(dataDir, "snapshots", ".tmp") + if err := os.MkdirAll(tmpRoot, 0750); err != nil { + return "", err + } + workDir, err := os.MkdirTemp(tmpRoot, "fetchml-snapshot-") + if err != nil { + return "", err + } + defer func() { _ = os.RemoveAll(workDir) }() + + archivePath := filepath.Join(workDir, "snapshot.tar.gz") + f, err := fileutil.SecureOpenFile(archivePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return "", err + } + _, copyErr := io.Copy(f, rc) + closeErr := f.Close() + if copyErr != nil { + return "", copyErr + } + if closeErr != nil { + return "", closeErr + } + + extractDir := filepath.Join(workDir, "extracted") + if err := os.MkdirAll(extractDir, 0750); err != nil { + return "", err + } + if err := extractTarGz(archivePath, extractDir); err != nil { + return "", err + } + + got, err := dirOverallSHA256Hex(extractDir) + if err != nil { + return "", err + } + if got != want { + return "", fmt.Errorf("snapshot checksum mismatch: expected %s, got %s", want, got) + } + + if err := os.MkdirAll(filepath.Dir(cacheDir), 0750); err != nil { + return "", err + } + if err := os.Rename(extractDir, cacheDir); err != nil { + if info, statErr := os.Stat(cacheDir); statErr == nil && info.IsDir() { + return cacheDir, nil + } + return "", err + } + + return cacheDir, nil +} + +func extractTarGz(archivePath, dstDir string) error { + archivePath = filepath.Clean(archivePath) + dstDir = filepath.Clean(dstDir) + + f, err := os.Open(archivePath) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + gz, err := gzip.NewReader(f) + if err != nil { + return err + } + defer func() { _ = gz.Close() }() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + name := strings.TrimSpace(hdr.Name) + name = strings.TrimPrefix(name, "./") + clean := path.Clean(name) + if clean == "." { + continue + } + if strings.HasPrefix(clean, "../") || clean == ".." || strings.HasPrefix(clean, "/") { + return fmt.Errorf("invalid tar entry") + } + + target, err := safeJoin(dstDir, filepath.FromSlash(clean)) + if err != nil { + return err + } + + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, 0750); err != nil { + return err + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0750); err != nil { + return err + } + mode := hdr.FileInfo().Mode() & 0777 + out, err := fileutil.SecureOpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return err + } + if _, err := io.CopyN(out, tr, hdr.Size); err != nil { + _ = out.Close() + return err + } + if err := out.Close(); err != nil { + return err + } + default: + return fmt.Errorf("unsupported tar entry type") + } + } + return nil +} + +func safeJoin(baseDir, rel string) (string, error) { + baseDir = filepath.Clean(baseDir) + joined := filepath.Join(baseDir, rel) + joined = filepath.Clean(joined) + basePrefix := baseDir + string(os.PathSeparator) + if joined != baseDir && !strings.HasPrefix(joined, basePrefix) { + return "", fmt.Errorf("invalid relative path") + } + return joined, nil +}