package worker import ( "context" "fmt" "log/slog" "os" "os/exec" "path/filepath" "strings" "time" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" "github.com/jfraeys/fetch_ml/internal/tracking" "github.com/jfraeys/fetch_ml/internal/tracking/factory" trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins" "github.com/jfraeys/fetch_ml/internal/worker/executor" "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" ) // NewWorker creates a new worker instance with composed dependencies. func NewWorker(cfg *Config, _ string) (*Worker, error) { // Create queue backend backendCfg := queue.BackendConfig{ Mode: cfg.Mode, Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, RedisDB: cfg.RedisDB, SQLitePath: cfg.Queue.SQLitePath, FilesystemPath: cfg.Queue.FilesystemPath, FallbackToFilesystem: cfg.Queue.FallbackToFilesystem, MetricsFlushInterval: cfg.MetricsFlushInterval, Scheduler: queue.SchedulerConfig{ Address: cfg.Scheduler.Address, Cert: cfg.Scheduler.Cert, Token: cfg.Scheduler.Token, }, } queueClient, err := queue.NewBackend(backendCfg) if err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) ctx = logging.EnsureTrace(ctx) ctx = logging.CtxWithWorker(ctx, cfg.WorkerID) baseLogger := logging.NewLogger(slog.LevelInfo, false) logger := baseLogger.Component(ctx, "worker") metricsObj := &metrics.Metrics{} podmanMgr, err := container.NewPodmanManager(logger) if err != nil { cancel() return nil, fmt.Errorf("failed to create podman manager: %w", err) } jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig()) if err != nil { cancel() return nil, fmt.Errorf("failed to create jupyter service manager: %w", err) } trackingRegistry := tracking.NewRegistry(logger) pluginLoader := factory.NewPluginLoader(logger, podmanMgr) if len(cfg.Plugins) == 0 { logger.Warn("no plugins configured, defining defaults") mlflowPlugin, err := trackingplugins.NewMLflowPlugin( logger, podmanMgr, trackingplugins.MLflowOptions{ ArtifactBasePath: filepath.Join(cfg.BasePath, "tracking", "mlflow"), }, ) if err == nil { trackingRegistry.Register(mlflowPlugin) } tensorboardPlugin, err := trackingplugins.NewTensorBoardPlugin( logger, podmanMgr, trackingplugins.TensorBoardOptions{ LogBasePath: filepath.Join(cfg.BasePath, "tracking", "tensorboard"), }, ) if err == nil { trackingRegistry.Register(tensorboardPlugin) } trackingRegistry.Register(trackingplugins.NewWandbPlugin()) } else { if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil { cancel() return nil, fmt.Errorf("failed to load plugins: %w", err) } } // Create run loop configuration runLoopConfig := lifecycle.RunLoopConfig{ WorkerID: cfg.WorkerID, MaxWorkers: cfg.MaxWorkers, PollInterval: time.Duration(cfg.PollInterval) * time.Second, TaskLeaseDuration: cfg.TaskLeaseDuration, HeartbeatInterval: cfg.HeartbeatInterval, GracefulTimeout: cfg.GracefulTimeout, PrewarmEnabled: cfg.PrewarmEnabled, } // Create executors localExecutor := executor.NewLocalExecutor(logger, nil) containerExecutor := executor.NewContainerExecutor( logger, nil, executor.ContainerConfig{ PodmanImage: cfg.PodmanImage, BasePath: cfg.BasePath, }, ) // Create task executor adapter exec := executor.NewTaskExecutorAdapter( localExecutor, containerExecutor, cfg.LocalMode, ) // Create state manager for task lifecycle management stateMgr := lifecycle.NewStateManager(nil) // Can pass audit logger if available // Create job runner jobRunner := executor.NewJobRunner( localExecutor, containerExecutor, nil, // ManifestWriter - can be added later if needed logger, ) runLoop := lifecycle.NewRunLoop( runLoopConfig, queueClient, exec, metricsObj, logger, stateMgr, ) // Create resource manager gpuCount, gpuDetectionInfo := parseGPUCountFromConfig(cfg) rm, err := resources.NewManager(resources.Options{ TotalCPU: parseCPUFromConfig(cfg), GPUCount: gpuCount, SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(), }) if err != nil { cancel() return nil, fmt.Errorf("failed to init resource manager: %w", err) } worker := &Worker{ ID: cfg.WorkerID, Config: cfg, Logger: logger, RunLoop: runLoop, Runner: jobRunner, Metrics: metricsObj, Health: lifecycle.NewHealthMonitor(), Resources: rm, Jupyter: jupyterMgr, gpuDetectionInfo: gpuDetectionInfo, } // In distributed mode, store the scheduler connection for heartbeats if cfg.Mode == "distributed" { if schedBackend, ok := queueClient.(*queue.SchedulerBackend); ok { worker.schedulerConn = schedBackend.Conn() } } // Log GPU configuration if !cfg.LocalMode { gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) if gpuType == "amd" { cancel() return nil, fmt.Errorf( "AMD GPU mode is not supported in production (FETCH_ML_GPU_TYPE=amd). " + "Use 'nvidia', 'apple', 'none', or GPUDevices config. " + "AMD support is available in local mode for experimental development", ) } else if cfg.AppleGPU.Enabled { logger.Warn( "apple MPS GPU mode is intended for development; do not use in production", "gpu_type", "apple", ) } } // Pre-pull tracking images in background go worker.prePullImages() // Cancel context is not needed after creation cancel() return worker, nil } // prePullImages pulls required container images in the background func (w *Worker) prePullImages() { if w.Config.LocalMode { return } w.Logger.Info("starting image pre-pulling") // Pull worker image if w.Config.PodmanImage != "" { w.pullImage(w.Config.PodmanImage) } // Pull plugin images for name, cfg := range w.Config.Plugins { if !cfg.Enabled || cfg.Image == "" { continue } w.Logger.Info("pre-pulling plugin image", "plugin", name, "image", cfg.Image) w.pullImage(cfg.Image) } } // pullImage pulls a single container image func (w *Worker) pullImage(image string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "podman", "pull", image) if output, err := cmd.CombinedOutput(); err != nil { w.Logger.Warn("failed to pull image", "image", image, "error", err, "output", string(output)) } else { w.Logger.Info("image pulled successfully", "image", image) } }