diff --git a/internal/worker/config.go b/internal/worker/config.go index 041e488..9f01a40 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -2,9 +2,12 @@ package worker import ( "fmt" + "math" "net/url" "os" "path/filepath" + "runtime" + "strconv" "strings" "time" @@ -391,3 +394,49 @@ func (c *Config) Validate() error { return nil } + +// envInt reads an integer from environment variable +func envInt(name string) (int, bool) { + v := strings.TrimSpace(os.Getenv(name)) + if v == "" { + return 0, false + } + n, err := strconv.Atoi(v) + if err != nil { + return 0, false + } + return n, true +} + +// parseCPUFromConfig determines total CPU from environment or config +func parseCPUFromConfig(cfg *Config) int { + if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { + return n + } + if cfg != nil { + if cfg.Resources.PodmanCPUs != "" { + if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { + if f < 0 { + return 0 + } + return int(math.Floor(f)) + } + } + } + return runtime.NumCPU() +} + +// parseGPUCountFromConfig detects GPU count from config +func parseGPUCountFromConfig(cfg *Config) int { + factory := &GPUDetectorFactory{} + detector := factory.CreateDetector(cfg) + return detector.DetectGPUCount() +} + +// parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment +func parseGPUSlotsPerGPUFromConfig() int { + if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { + return n + } + return 1 +} diff --git a/internal/worker/core.go b/internal/worker/core.go deleted file mode 100644 index 03c804c..0000000 --- a/internal/worker/core.go +++ /dev/null @@ -1,550 +0,0 @@ -package worker - -import ( - "context" - "fmt" - "log" - "log/slog" - "math" - "net/http" - "os" - "os/exec" - "path/filepath" - "runtime" - "strconv" - "strings" - "sync" - "time" - - "github.com/jfraeys/fetch_ml/internal/container" - "github.com/jfraeys/fetch_ml/internal/envpool" - "github.com/jfraeys/fetch_ml/internal/jupyter" - "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/metrics" - "github.com/jfraeys/fetch_ml/internal/network" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/resources" - "github.com/jfraeys/fetch_ml/internal/tracking" - "github.com/jfraeys/fetch_ml/internal/tracking/factory" - trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -// MLServer wraps network.SSHClient for backward compatibility. -type MLServer struct { - *network.SSHClient -} - -// JupyterManager is the subset of the Jupyter service manager used by the worker. -// It exists to keep task execution testable. -type JupyterManager interface { - StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) - StopService(ctx context.Context, serviceID string) error - RemoveService(ctx context.Context, serviceID string, purge bool) error - RestoreWorkspace(ctx context.Context, name string) (string, error) - ListServices() []*jupyter.JupyterService - ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) -} - -// isValidName validates that input strings contain only safe characters. -// isValidName checks if the input string is a valid name. -func isValidName(input string) bool { - return len(input) > 0 && len(input) < 256 -} - -// NewMLServer creates a new ML server connection. -// NewMLServer returns a new MLServer instance. -func NewMLServer(cfg *Config) (*MLServer, error) { - if cfg.LocalMode { - return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil - } - - client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) - if err != nil { - return nil, err - } - - return &MLServer{SSHClient: client}, nil -} - -// Worker represents an ML task worker. -type Worker struct { - id string - config *Config - server *MLServer - queue queue.Backend - resources *resources.Manager - running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown - runningMu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - logger *logging.Logger - metrics *metrics.Metrics - metricsSrv *http.Server - - datasetCache map[string]time.Time - datasetCacheMu sync.RWMutex - datasetCacheTTL time.Duration - - // Graceful shutdown fields - shutdownCh chan struct{} - activeTasks sync.Map // map[string]*queue.Task - track active tasks - gracefulWait sync.WaitGroup - - podman *container.PodmanManager - jupyter JupyterManager - trackingRegistry *tracking.Registry - envPool *envpool.Pool - - prewarmMu sync.Mutex - prewarmTargetID string - prewarmCancel context.CancelFunc - prewarmStartedAt time.Time -} - -func envInt(name string) (int, bool) { - v := strings.TrimSpace(os.Getenv(name)) - if v == "" { - return 0, false - } - n, err := strconv.Atoi(v) - if err != nil { - return 0, false - } - return n, true -} - -func parseCPUFromConfig(cfg *Config) int { - if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { - return n - } - if cfg != nil { - if cfg.Resources.PodmanCPUs != "" { - if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { - if f < 0 { - return 0 - } - return int(math.Floor(f)) - } - } - } - return runtime.NumCPU() -} - -func parseGPUCountFromConfig(cfg *Config) int { - factory := &GPUDetectorFactory{} - detector := factory.CreateDetector(cfg) - return detector.DetectGPUCount() -} - -func (w *Worker) getGPUDetector() GPUDetector { - factory := &GPUDetectorFactory{} - return factory.CreateDetector(w.config) -} - -func parseGPUSlotsPerGPUFromConfig() int { - if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { - return n - } - return 1 -} - -func (w *Worker) setupMetricsExporter() { - if !w.config.Metrics.Enabled { - return - } - - reg := prometheus.NewRegistry() - reg.MustRegister( - collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), - collectors.NewGoCollector(), - ) - - labels := prometheus.Labels{"worker_id": w.id} - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_processed_total", - Help: "Total tasks processed successfully by this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.TasksProcessed.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_failed_total", - Help: "Total tasks failed by this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.TasksFailed.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_active", - Help: "Number of tasks currently running on this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.runningCount()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_tasks_queued", - Help: "Latest observed queue depth from Redis.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.QueuedTasks.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_data_transferred_bytes_total", - Help: "Total bytes transferred while fetching datasets.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.DataTransferred.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_data_fetch_time_seconds_total", - Help: "Total time spent fetching datasets (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_execution_time_seconds_total", - Help: "Total execution time for completed tasks (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_env_hit_total", - Help: "Total environment prewarm hits (warmed image already existed).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmEnvHit.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_env_miss_total", - Help: "Total environment prewarm misses (warmed image did not exist yet).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmEnvMiss.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_env_built_total", - Help: "Total environment prewarm images built.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmEnvBuilt.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_env_time_seconds_total", - Help: "Total time spent building prewarm images (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_snapshot_hit_total", - Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmSnapshotHit.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_snapshot_miss_total", - Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmSnapshotMiss.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_snapshot_built_total", - Help: "Total snapshots prewarmed into .prewarm/.", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmSnapshotBuilt.Load()) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_prewarm_snapshot_time_seconds_total", - Help: "Total time spent prewarming snapshots (seconds).", - ConstLabels: labels, - }, func() float64 { - return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_worker_max_concurrency", - Help: "Configured maximum concurrent tasks for this worker.", - ConstLabels: labels, - }, func() float64 { - return float64(w.config.MaxWorkers) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_cpu_total", - Help: "Total CPU tokens managed by the worker resource manager.", - ConstLabels: labels, - }, func() float64 { - return float64(w.resources.Snapshot().TotalCPU) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_cpu_free", - Help: "Free CPU tokens currently available in the worker resource manager.", - ConstLabels: labels, - }, func() float64 { - return float64(w.resources.Snapshot().FreeCPU) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_acquire_total", - Help: "Total resource acquisition attempts.", - ConstLabels: labels, - }, func() float64 { - return float64(w.resources.Snapshot().AcquireTotal) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_acquire_wait_total", - Help: "Total resource acquisitions that had to wait for resources.", - ConstLabels: labels, - }, func() float64 { - return float64(w.resources.Snapshot().AcquireWaitTotal) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_acquire_timeout_total", - Help: "Total resource acquisition attempts that timed out.", - ConstLabels: labels, - }, func() float64 { - return float64(w.resources.Snapshot().AcquireTimeoutTotal) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_acquire_wait_seconds_total", - Help: "Total seconds spent waiting for resources across all acquisitions.", - ConstLabels: labels, - }, func() float64 { - return w.resources.Snapshot().AcquireWaitSeconds - })) - - snap := w.resources.Snapshot() - for i := range snap.GPUFree { - gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)} - idx := i - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_gpu_slots_total", - Help: "Total GPU slots per GPU index.", - ConstLabels: gpuLabels, - }, func() float64 { - return float64(w.resources.Snapshot().SlotsPerGPU) - })) - reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ - Name: "fetchml_resources_gpu_slots_free", - Help: "Free GPU slots per GPU index.", - ConstLabels: gpuLabels, - }, func() float64 { - s := w.resources.Snapshot() - if idx < 0 || idx >= len(s.GPUFree) { - return 0 - } - return float64(s.GPUFree[idx]) - })) - } - - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) - - srv := &http.Server{ - Addr: w.config.Metrics.ListenAddr, - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, - } - - w.metricsSrv = srv - go func() { - w.logger.Info("metrics exporter listening", - "addr", w.config.Metrics.ListenAddr, - "enabled", w.config.Metrics.Enabled) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - w.logger.Warn("metrics exporter stopped", - "error", err) - } - }() -} - -// NewWorker creates a new worker instance. -func NewWorker(cfg *Config, _ string) (*Worker, error) { - srv, err := NewMLServer(cfg) - if err != nil { - return nil, err - } - defer func() { - if err != nil { - if closeErr := srv.Close(); closeErr != nil { - log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr) - } - } - }() - - backendCfg := queue.BackendConfig{ - Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), - RedisAddr: cfg.RedisAddr, - RedisPassword: cfg.RedisPassword, - RedisDB: cfg.RedisDB, - SQLitePath: cfg.Queue.SQLitePath, - FilesystemPath: cfg.Queue.FilesystemPath, - FallbackToFilesystem: cfg.Queue.FallbackToFilesystem, - MetricsFlushInterval: cfg.MetricsFlushInterval, - } - queueClient, err := queue.NewBackend(backendCfg) - if err != nil { - return nil, err - } - defer func() { - if err != nil { - if closeErr := queueClient.Close(); closeErr != nil { - log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr) - } - } - }() - - // Create data_dir if it doesn't exist (for production without NAS) - if cfg.DataDir != "" { - if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { - log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) - } - } - - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - if err != nil { - cancel() - } - }() - ctx = logging.EnsureTrace(ctx) - ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) - - baseLogger := logging.NewLogger(slog.LevelInfo, false) - logger := baseLogger.Component(ctx, "worker") - metricsObj := &metrics.Metrics{} - - podmanMgr, err := container.NewPodmanManager(logger) - if err != nil { - return nil, fmt.Errorf("failed to create podman manager: %w", err) - } - - jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig()) - if err != nil { - return nil, fmt.Errorf("failed to create jupyter service manager: %w", err) - } - - trackingRegistry := tracking.NewRegistry(logger) - pluginLoader := factory.NewPluginLoader(logger, podmanMgr) - - if len(cfg.Plugins) == 0 { - logger.Warn("no plugins configured, defining defaults") - // Register defaults manually for backward compatibility/local dev - mlflowPlugin, err := trackingplugins.NewMLflowPlugin( - logger, - podmanMgr, - trackingplugins.MLflowOptions{ - ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"), - }, - ) - if err == nil { - trackingRegistry.Register(mlflowPlugin) - } - - tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin( - logger, - podmanMgr, - trackingplugins.TensorBoardOptions{ - LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"), - }, - ) - if err == nil { - trackingRegistry.Register(tensorboardPlugin) - } - - trackingRegistry.Register(trackingplugins.NewWandbPlugin()) - } else { - if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil { - return nil, fmt.Errorf("failed to load plugins: %w", err) - } - } - - worker := &Worker{ - id: cfg.WorkerID, - config: cfg, - server: srv, - queue: queueClient, - running: make(map[string]context.CancelFunc), - datasetCache: make(map[string]time.Time), - datasetCacheTTL: cfg.DatasetCacheTTL, - ctx: ctx, - cancel: cancel, - logger: logger, - metrics: metricsObj, - shutdownCh: make(chan struct{}), - podman: podmanMgr, - jupyter: jupyterMgr, - trackingRegistry: trackingRegistry, - envPool: envpool.New(""), - } - - rm, rmErr := resources.NewManager(resources.Options{ - TotalCPU: parseCPUFromConfig(cfg), - GPUCount: parseGPUCountFromConfig(cfg), - SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(), - }) - if rmErr != nil { - return nil, fmt.Errorf("failed to init resource manager: %w", rmErr) - } - worker.resources = rm - - if !cfg.LocalMode { - gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) - if cfg.AppleGPU.Enabled { - logger.Warn("apple MPS GPU mode is intended for development; do not use in production", - "gpu_type", "apple", - ) - } - if gpuType == "amd" { - logger.Warn("amd GPU mode is intended for development; do not use in production", - "gpu_type", "amd", - ) - } - } - - worker.setupMetricsExporter() - - // Pre-pull tracking images in background - go worker.prePullImages() - - return worker, nil -} - -func (w *Worker) prePullImages() { - if w.config.LocalMode { - return - } - - w.logger.Info("starting image pre-pulling") - - // Pull worker image - if w.config.PodmanImage != "" { - w.pullImage(w.config.PodmanImage) - } - - // Pull plugin images - for name, cfg := range w.config.Plugins { - if !cfg.Enabled || cfg.Image == "" { - continue - } - w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) - w.pullImage(cfg.Image) - } -} - -func (w *Worker) pullImage(image string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - cmd := exec.CommandContext(ctx, "podman", "pull", image) - if output, err := cmd.CombinedOutput(); err != nil { - w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) - } else { - w.logger.Info("image pulled successfully", "image", image) - } -} diff --git a/internal/worker/factory.go b/internal/worker/factory.go new file mode 100644 index 0000000..8362a6f --- /dev/null +++ b/internal/worker/factory.go @@ -0,0 +1,212 @@ +package worker + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/envpool" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/resources" + "github.com/jfraeys/fetch_ml/internal/tracking" + "github.com/jfraeys/fetch_ml/internal/tracking/factory" + trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins" +) + +// NewWorker creates a new worker instance. +func NewWorker(cfg *Config, _ string) (*Worker, error) { + srv, err := NewMLServer(cfg) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if closeErr := srv.Close(); closeErr != nil { + log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr) + } + } + }() + + backendCfg := queue.BackendConfig{ + Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), + RedisAddr: cfg.RedisAddr, + RedisPassword: cfg.RedisPassword, + RedisDB: cfg.RedisDB, + SQLitePath: cfg.Queue.SQLitePath, + FilesystemPath: cfg.Queue.FilesystemPath, + FallbackToFilesystem: cfg.Queue.FallbackToFilesystem, + MetricsFlushInterval: cfg.MetricsFlushInterval, + } + queueClient, err := queue.NewBackend(backendCfg) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if closeErr := queueClient.Close(); closeErr != nil { + log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr) + } + } + }() + + // Create data_dir if it doesn't exist (for production without NAS) + if cfg.DataDir != "" { + if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil { + log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + cancel() + } + }() + ctx = logging.EnsureTrace(ctx) + ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) + + baseLogger := logging.NewLogger(slog.LevelInfo, false) + logger := baseLogger.Component(ctx, "worker") + metricsObj := &metrics.Metrics{} + + podmanMgr, err := container.NewPodmanManager(logger) + if err != nil { + return nil, fmt.Errorf("failed to create podman manager: %w", err) + } + + jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig()) + if err != nil { + return nil, fmt.Errorf("failed to create jupyter service manager: %w", err) + } + + trackingRegistry := tracking.NewRegistry(logger) + pluginLoader := factory.NewPluginLoader(logger, podmanMgr) + + if len(cfg.Plugins) == 0 { + logger.Warn("no plugins configured, defining defaults") + // Register defaults manually for backward compatibility/local dev + mlflowPlugin, err := trackingplugins.NewMLflowPlugin( + logger, + podmanMgr, + trackingplugins.MLflowOptions{ + ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"), + }, + ) + if err == nil { + trackingRegistry.Register(mlflowPlugin) + } + + tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin( + logger, + podmanMgr, + trackingplugins.TensorBoardOptions{ + LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"), + }, + ) + if err == nil { + trackingRegistry.Register(tensorboardPlugin) + } + + trackingRegistry.Register(trackingplugins.NewWandbPlugin()) + } else { + if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil { + return nil, fmt.Errorf("failed to load plugins: %w", err) + } + } + + worker := &Worker{ + id: cfg.WorkerID, + config: cfg, + server: srv, + queue: queueClient, + running: make(map[string]context.CancelFunc), + datasetCache: make(map[string]time.Time), + datasetCacheTTL: cfg.DatasetCacheTTL, + ctx: ctx, + cancel: cancel, + logger: logger, + metrics: metricsObj, + shutdownCh: make(chan struct{}), + podman: podmanMgr, + jupyter: jupyterMgr, + trackingRegistry: trackingRegistry, + envPool: envpool.New(""), + } + + rm, rmErr := resources.NewManager(resources.Options{ + TotalCPU: parseCPUFromConfig(cfg), + GPUCount: parseGPUCountFromConfig(cfg), + SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(), + }) + if rmErr != nil { + return nil, fmt.Errorf("failed to init resource manager: %w", rmErr) + } + worker.resources = rm + + if !cfg.LocalMode { + gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) + if cfg.AppleGPU.Enabled { + logger.Warn("apple MPS GPU mode is intended for development; do not use in production", + "gpu_type", "apple", + ) + } + if gpuType == "amd" { + logger.Warn("amd GPU mode is intended for development; do not use in production", + "gpu_type", "amd", + ) + } + } + + worker.setupMetricsExporter() + + // Pre-pull tracking images in background + go worker.prePullImages() + + return worker, nil +} + +// prePullImages pulls required container images in the background +func (w *Worker) prePullImages() { + if w.config.LocalMode { + return + } + + w.logger.Info("starting image pre-pulling") + + // Pull worker image + if w.config.PodmanImage != "" { + w.pullImage(w.config.PodmanImage) + } + + // Pull plugin images + for name, cfg := range w.config.Plugins { + if !cfg.Enabled || cfg.Image == "" { + continue + } + w.logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) + w.pullImage(cfg.Image) + } +} + +// pullImage pulls a single container image +func (w *Worker) pullImage(image string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "podman", "pull", image) + if output, err := cmd.CombinedOutput(); err != nil { + w.logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) + } else { + w.logger.Info("image pulled successfully", "image", image) + } +} diff --git a/internal/worker/metrics.go b/internal/worker/metrics.go new file mode 100644 index 0000000..12a595a --- /dev/null +++ b/internal/worker/metrics.go @@ -0,0 +1,224 @@ +package worker + +import ( + "net/http" + "strconv" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// setupMetricsExporter initializes the Prometheus metrics exporter +func (w *Worker) setupMetricsExporter() { + if !w.config.Metrics.Enabled { + return + } + + reg := prometheus.NewRegistry() + reg.MustRegister( + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), + collectors.NewGoCollector(), + ) + + labels := prometheus.Labels{"worker_id": w.id} + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_processed_total", + Help: "Total tasks processed successfully by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksProcessed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_failed_total", + Help: "Total tasks failed by this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.TasksFailed.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_active", + Help: "Number of tasks currently running on this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.runningCount()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_tasks_queued", + Help: "Latest observed queue depth from Redis.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.QueuedTasks.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_transferred_bytes_total", + Help: "Total bytes transferred while fetching datasets.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataTransferred.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_data_fetch_time_seconds_total", + Help: "Total time spent fetching datasets (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.DataFetchTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_execution_time_seconds_total", + Help: "Total execution time for completed tasks (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.ExecutionTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_hit_total", + Help: "Total environment prewarm hits (warmed image already existed).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvHit.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_miss_total", + Help: "Total environment prewarm misses (warmed image did not exist yet).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvMiss.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_built_total", + Help: "Total environment prewarm images built.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvBuilt.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_env_time_seconds_total", + Help: "Total time spent building prewarm images (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmEnvTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_hit_total", + Help: "Total prewarmed snapshot hits (snapshots found in .prewarm/).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotHit.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_miss_total", + Help: "Total prewarmed snapshot misses (snapshots not found in .prewarm/).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotMiss.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_built_total", + Help: "Total snapshots prewarmed into .prewarm/.", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotBuilt.Load()) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_prewarm_snapshot_time_seconds_total", + Help: "Total time spent prewarming snapshots (seconds).", + ConstLabels: labels, + }, func() float64 { + return float64(w.metrics.PrewarmSnapshotTime.Load()) / float64(time.Second) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_worker_max_concurrency", + Help: "Configured maximum concurrent tasks for this worker.", + ConstLabels: labels, + }, func() float64 { + return float64(w.config.MaxWorkers) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_cpu_total", + Help: "Total CPU tokens managed by the worker resource manager.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().TotalCPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_cpu_free", + Help: "Free CPU tokens currently available in the worker resource manager.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().FreeCPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_total", + Help: "Total resource acquisition attempts.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_wait_total", + Help: "Total resource acquisitions that had to wait for resources.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireWaitTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_timeout_total", + Help: "Total resource acquisition attempts that timed out.", + ConstLabels: labels, + }, func() float64 { + return float64(w.resources.Snapshot().AcquireTimeoutTotal) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_acquire_wait_seconds_total", + Help: "Total seconds spent waiting for resources across all acquisitions.", + ConstLabels: labels, + }, func() float64 { + return w.resources.Snapshot().AcquireWaitSeconds + })) + + snap := w.resources.Snapshot() + for i := range snap.GPUFree { + gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)} + idx := i + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_gpu_slots_total", + Help: "Total GPU slots per GPU index.", + ConstLabels: gpuLabels, + }, func() float64 { + return float64(w.resources.Snapshot().SlotsPerGPU) + })) + reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + Name: "fetchml_resources_gpu_slots_free", + Help: "Free GPU slots per GPU index.", + ConstLabels: gpuLabels, + }, func() float64 { + s := w.resources.Snapshot() + if idx < 0 || idx >= len(s.GPUFree) { + return 0 + } + return float64(s.GPUFree[idx]) + })) + } + + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + + srv := &http.Server{ + Addr: w.config.Metrics.ListenAddr, + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + w.metricsSrv = srv + go func() { + w.logger.Info("metrics exporter listening", + "addr", w.config.Metrics.ListenAddr, + "enabled", w.config.Metrics.Enabled) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + w.logger.Warn("metrics exporter stopped", + "error", err) + } + }() +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go new file mode 100644 index 0000000..e30a04d --- /dev/null +++ b/internal/worker/worker.go @@ -0,0 +1,95 @@ +package worker + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/envpool" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/metrics" + "github.com/jfraeys/fetch_ml/internal/network" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/resources" + "github.com/jfraeys/fetch_ml/internal/tracking" +) + +// MLServer wraps network.SSHClient for backward compatibility. +type MLServer struct { + *network.SSHClient +} + +// JupyterManager is the subset of the Jupyter service manager used by the worker. +// It exists to keep task execution testable. +type JupyterManager interface { + StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) + StopService(ctx context.Context, serviceID string) error + RemoveService(ctx context.Context, serviceID string, purge bool) error + RestoreWorkspace(ctx context.Context, name string) (string, error) + ListServices() []*jupyter.JupyterService + ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) +} + +// isValidName validates that input strings contain only safe characters. +// isValidName checks if the input string is a valid name. +func isValidName(input string) bool { + return len(input) > 0 && len(input) < 256 +} + +// NewMLServer creates a new ML server connection. +// NewMLServer returns a new MLServer instance. +func NewMLServer(cfg *Config) (*MLServer, error) { + if cfg.LocalMode { + return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil + } + + client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts) + if err != nil { + return nil, err + } + + return &MLServer{SSHClient: client}, nil +} + +// Worker represents an ML task worker. +type Worker struct { + id string + config *Config + server *MLServer + queue queue.Backend + resources *resources.Manager + running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown + runningMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + logger *logging.Logger + metrics *metrics.Metrics + metricsSrv *http.Server + + datasetCache map[string]time.Time + datasetCacheMu sync.RWMutex + datasetCacheTTL time.Duration + + // Graceful shutdown fields + shutdownCh chan struct{} + activeTasks sync.Map // map[string]*queue.Task - track active tasks + gracefulWait sync.WaitGroup + + podman *container.PodmanManager + jupyter JupyterManager + trackingRegistry *tracking.Registry + envPool *envpool.Pool + + prewarmMu sync.Mutex + prewarmTargetID string + prewarmCancel context.CancelFunc + prewarmStartedAt time.Time +} + +func (w *Worker) getGPUDetector() GPUDetector { + factory := &GPUDetectorFactory{} + return factory.CreateDetector(w.config) +}