package worker import ( "context" "fmt" "log" "log/slog" "math" "net/http" "os" "os/exec" "path/filepath" "runtime" "strconv" "strings" "sync" "time" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/envpool" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/tracking" "github.com/jfraeys/fetch_ml/internal/tracking/factory" trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" ) // MLServer wraps network.SSHClient for backward compatibility. type MLServer struct { *network.SSHClient } // JupyterManager is the subset of the Jupyter service manager used by the worker. // It exists to keep task execution testable. type JupyterManager interface { StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) StopService(ctx context.Context, serviceID string) error RemoveService(ctx context.Context, serviceID string, purge bool) error RestoreWorkspace(ctx context.Context, name string) (string, error) ListServices() []*jupyter.JupyterService } // isValidName validates that input strings contain only safe characters. // isValidName checks if the input string is a valid name. func isValidName(input string) bool { return len(input) > 0 && len(input) < 256 } // NewMLServer creates a new ML server connection. // NewMLServer returns a new MLServer instance. func NewMLServer(cfg *Config) (*MLServer, error) { if cfg.LocalMode { return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil } client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) if err != nil { return nil, err } return &MLServer{SSHClient: client}, nil } // Worker represents an ML task worker. type Worker struct { id string config *Config server *MLServer queue queue.Backend resources *resources.Manager running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown runningMu sync.RWMutex ctx context.Context cancel context.CancelFunc logger *logging.Logger metrics *metrics.Metrics metricsSrv *http.Server datasetCache map[string]time.Time datasetCacheMu sync.RWMutex datasetCacheTTL time.Duration // Graceful shutdown fields shutdownCh chan struct{} activeTasks sync.Map // map[string]*queue.Task - track active tasks gracefulWait sync.WaitGroup podman *container.PodmanManager jupyter JupyterManager trackingRegistry *tracking.Registry envPool *envpool.Pool prewarmMu sync.Mutex prewarmTargetID string prewarmCancel context.CancelFunc prewarmStartedAt time.Time } func envInt(name string) (int, bool) { v := strings.TrimSpace(os.Getenv(name)) if v == "" { return 0, false } n, err := strconv.Atoi(v) if err != nil { return 0, false } return n, true } func parseCPUFromConfig(cfg *Config) int { if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { return n } if cfg != nil { if cfg.Resources.PodmanCPUs != "" { if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { if f < 0 { return 0 } return int(math.Floor(f)) } } } return runtime.NumCPU() } func parseGPUCountFromConfig(cfg *Config) int { factory := &GPUDetectorFactory{} detector := factory.CreateDetector(cfg) return detector.DetectGPUCount() } func (w *Worker) getGPUDetector() GPUDetector { factory := &GPUDetectorFactory{} return factory.CreateDetector(w.config) } func parseGPUSlotsPerGPUFromConfig() int { if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { return n } return 1 } func (w *Worker) setupMetricsExporter() { if !w.config.Metrics.Enabled { return } reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), collectors.NewGoCollector(), ) labels := prometheus.Labels{"worker_id": w.id} reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_processed_total", Help: "Total tasks processed successfully by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksProcessed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_failed_total", Help: "Total tasks failed by this worker.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.TasksFailed.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_active", Help: "Number of tasks currently running on this worker.", ConstLabels: labels, }, func() float64 { return float64(w.runningCount()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_tasks_queued", Help: "Latest observed queue depth from Redis.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.QueuedTasks.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_transferred_bytes_total", Help: "Total bytes transferred while fetching datasets.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataTransferred.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_data_fetch_time_seconds_total", Help: "Total time spent fetching datasets (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_execution_time_seconds_total", Help: "Total execution time for completed tasks (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_env_hit_total", Help: "Total environment prewarm hits (warmed image already existed).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmEnvHit.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_env_miss_total", Help: "Total environment prewarm misses (warmed image did not exist yet).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmEnvMiss.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_env_built_total", Help: "Total environment prewarm images built.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmEnvBuilt.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_env_time_seconds_total", Help: "Total time spent building prewarm images (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_snapshot_hit_total", Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmSnapshotHit.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_snapshot_miss_total", Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmSnapshotMiss.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_snapshot_built_total", Help: "Total snapshots prewarmed into .prewarm/.", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmSnapshotBuilt.Load()) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_prewarm_snapshot_time_seconds_total", Help: "Total time spent prewarming snapshots (seconds).", ConstLabels: labels, }, func() float64 { return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_worker_max_concurrency", Help: "Configured maximum concurrent tasks for this worker.", ConstLabels: labels, }, func() float64 { return float64(w.config.MaxWorkers) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_cpu_total", Help: "Total CPU tokens managed by the worker resource manager.", ConstLabels: labels, }, func() float64 { return float64(w.resources.Snapshot().TotalCPU) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_cpu_free", Help: "Free CPU tokens currently available in the worker resource manager.", ConstLabels: labels, }, func() float64 { return float64(w.resources.Snapshot().FreeCPU) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_acquire_total", Help: "Total resource acquisition attempts.", ConstLabels: labels, }, func() float64 { return float64(w.resources.Snapshot().AcquireTotal) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_acquire_wait_total", Help: "Total resource acquisitions that had to wait for resources.", ConstLabels: labels, }, func() float64 { return float64(w.resources.Snapshot().AcquireWaitTotal) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_acquire_timeout_total", Help: "Total resource acquisition attempts that timed out.", ConstLabels: labels, }, func() float64 { return float64(w.resources.Snapshot().AcquireTimeoutTotal) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_acquire_wait_seconds_total", Help: "Total seconds spent waiting for resources across all acquisitions.", ConstLabels: labels, }, func() float64 { return w.resources.Snapshot().AcquireWaitSeconds })) snap := w.resources.Snapshot() for i := range snap.GPUFree { gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)} idx := i reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_gpu_slots_total", Help: "Total GPU slots per GPU index.", ConstLabels: gpuLabels, }, func() float64 { return float64(w.resources.Snapshot().SlotsPerGPU) })) reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "fetchml_resources_gpu_slots_free", Help: "Free GPU slots per GPU index.", ConstLabels: gpuLabels, }, func() float64 { s := w.resources.Snapshot() if idx < 0 || idx >= len(s.GPUFree) { return 0 } return float64(s.GPUFree[idx]) })) } mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) srv := &http.Server{ Addr: w.config.Metrics.ListenAddr, Handler: mux, ReadHeaderTimeout: 5 * time.Second, } w.metricsSrv = srv go func() { w.logger.Info("metrics exporter listening", "addr", w.config.Metrics.ListenAddr, "enabled", w.config.Metrics.Enabled) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { w.logger.Warn("metrics exporter stopped", "error", err) } }() } // NewWorker creates a new worker instance. func NewWorker(cfg *Config, _ string) (*Worker, error) { srv, err := NewMLServer(cfg) if err != nil { return nil, err } defer func() { if err != nil { if closeErr := srv.Close(); closeErr != nil { log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr) } } }() backendCfg := queue.BackendConfig{ Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, SQLitePath: cfg.Queue.SQLitePath, MetricsFlushInterval: cfg.MetricsFlushInterval, } queueClient, err := queue.NewBackend(backendCfg) if err != nil { return nil, err } defer func() { if err != nil { if closeErr := queueClient.Close(); closeErr != nil { log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr) } } }() // Create data_dir if it doesn't exist (for production without NAS) if cfg.DataDir != "" { if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) } } ctx, cancel := context.WithCancel(context.Background()) defer func() { if err != nil { cancel() } }() ctx = logging.EnsureTrace(ctx) ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) baseLogger := logging.NewLogger(slog.LevelInfo, false) logger := baseLogger.Component(ctx, "worker") metricsObj := &metrics.Metrics{} podmanMgr, err := container.NewPodmanManager(logger) if err != nil { return nil, fmt.Errorf("failed to create podman manager: %w", err) } jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig()) if err != nil { return nil, fmt.Errorf("failed to create jupyter service manager: %w", err) } trackingRegistry := tracking.NewRegistry(logger) pluginLoader := factory.NewPluginLoader(logger, podmanMgr) if len(cfg.Plugins) == 0 { logger.Warn("no plugins configured, defining defaults") // Register defaults manually for backward compatibility/local dev mlflowPlugin, err := trackingplugins.NewMLflowPlugin( logger, podmanMgr, trackingplugins.MLflowOptions{ ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"), }, ) if err == nil { trackingRegistry.Register(mlflowPlugin) } tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin( logger, podmanMgr, trackingplugins.TensorBoardOptions{ LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"), }, ) if err == nil { trackingRegistry.Register(tensorboardPlugin) } trackingRegistry.Register(trackingplugins.NewWandbPlugin()) } else { if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil { return nil, fmt.Errorf("failed to load plugins: %w", err) } } worker := &Worker{ id: cfg.WorkerID, config: cfg, server: srv, queue: queueClient, running: make(map[string]context.CancelFunc), datasetCache: make(map[string]time.Time), datasetCacheTTL: cfg.DatasetCacheTTL, ctx: ctx, cancel: cancel, logger: logger, metrics: metricsObj, shutdownCh: make(chan struct{}), podman: podmanMgr, jupyter: jupyterMgr, trackingRegistry: trackingRegistry, envPool: envpool.New(""), } rm, rmErr := resources.NewManager(resources.Options{ TotalCPU: parseCPUFromConfig(cfg), GPUCount: parseGPUCountFromConfig(cfg), SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(), }) if rmErr != nil { return nil, fmt.Errorf("failed to init resource manager: %w", rmErr) } worker.resources = rm if !cfg.LocalMode { gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) if cfg.AppleGPU.Enabled { logger.Warn("apple MPS GPU mode is intended for development; do not use in production", "gpu_type", "apple", ) } if gpuType == "amd" { logger.Warn("amd GPU mode is intended for development; do not use in production", "gpu_type", "amd", ) } } worker.setupMetricsExporter() // Pre-pull tracking images in background go worker.prePullImages() return worker, nil } func (w *Worker) prePullImages() { if w.config.LocalMode { return } w.logger.Info("starting image pre-pulling") // Pull worker image if w.config.PodmanImage != "" { w.pullImage(w.config.PodmanImage) } // Pull plugin images for name, cfg := range w.config.Plugins { if !cfg.Enabled || cfg.Image == "" { continue } w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) w.pullImage(cfg.Image) } } func (w *Worker) pullImage(image string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "podman", "pull", image) if output, err := cmd.CombinedOutput(); err != nil { w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) } else { w.logger.Info("image pulled successfully", "image", image) } }