// Package executor provides job execution implementations package executor import ( "context" "fmt" "os" "path/filepath" "strings" "time" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/errtypes" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/telemetry" "github.com/jfraeys/fetch_ml/internal/tracking" "github.com/jfraeys/fetch_ml/internal/worker/interfaces" ) // ContainerConfig holds configuration for container execution type ContainerConfig struct { Sandbox SandboxConfig PodmanImage string ContainerResults string ContainerWorkspace string TrainScript string BasePath string AppleGPUEnabled bool } // SandboxConfig interface to avoid import cycle type SandboxConfig interface { GetNoNewPrivileges() bool GetDropAllCaps() bool GetAllowedCaps() []string GetUserNS() bool GetRunAsUID() int GetRunAsGID() int GetSeccompProfile() string GetReadOnlyRoot() bool GetNetworkMode() string GetMaxProcesses() int GetMaxOpenFiles() int GetDisableSwap() bool GetOOMScoreAdj() int GetTaskUID() int GetTaskGID() int } // ContainerExecutor executes jobs in containers using podman type ContainerExecutor struct { logger *logging.Logger writer interfaces.ManifestWriter registry *tracking.Registry envPool EnvironmentPool config ContainerConfig } // EnvironmentPool interface for environment image pooling type EnvironmentPool interface { WarmImageTag(depsSHA string) (string, error) ImageExists(ctx context.Context, tag string) (bool, error) } // NewContainerExecutor creates a new container job executor func NewContainerExecutor( logger *logging.Logger, writer interfaces.ManifestWriter, cfg ContainerConfig, ) *ContainerExecutor { return &ContainerExecutor{ logger: logger, writer: writer, config: cfg, } } // SetRegistry sets the tracking registry (optional) func (e *ContainerExecutor) SetRegistry(registry *tracking.Registry) { e.registry = registry } // SetEnvPool sets the environment pool (optional) func (e *ContainerExecutor) SetEnvPool(pool EnvironmentPool) { e.envPool = pool } // Execute runs a job in a container func (e *ContainerExecutor) Execute(ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv) error { containerResults := e.config.ContainerResults if containerResults == "" { containerResults = config.DefaultContainerResults } containerWorkspace := e.config.ContainerWorkspace if containerWorkspace == "" { containerWorkspace = config.DefaultContainerWorkspace } jobPaths := storage.NewJobPaths(e.config.BasePath) // Setup tracking environment trackingEnv, err := e.setupTrackingEnv(ctx, task) if err != nil { return err } defer e.teardownTracking(ctx, task) // Setup volumes volumes := e.setupVolumes(trackingEnv, env.OutputDir) // Setup container environment if strings.TrimSpace(env.GPUEnvVar) != "" { trackingEnv[env.GPUEnvVar] = strings.TrimSpace(env.GPUDevicesStr) } snap := filepath.Join(env.OutputDir, "snapshot") if info, err := os.Stat(snap); err == nil && info.IsDir() { trackingEnv["FETCH_ML_SNAPSHOT_DIR"] = "/snapshot" if strings.TrimSpace(task.SnapshotID) != "" { trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID) } volumes[snap] = "/snapshot:ro" } cpusOverride, memOverride := container.PodmanResourceOverrides(task.CPU, task.MemoryGB) // Select image (with warm cache check) selectedImage := e.selectImage(ctx, task) // Build podman config podmanCfg := container.PodmanConfig{ Image: selectedImage, Workspace: filepath.Join(env.OutputDir, "code"), Results: filepath.Join(env.OutputDir, "results"), ContainerWorkspace: containerWorkspace, ContainerResults: containerResults, AppleGPU: e.config.AppleGPUEnabled, GPUDevices: env.GPUDevices, Env: trackingEnv, Volumes: volumes, Memory: memOverride, CPUs: cpusOverride, } // Build and execute command return e.runPodman(ctx, task, env, jobPaths, podmanCfg, selectedImage) } func (e *ContainerExecutor) setupTrackingEnv(ctx context.Context, task *queue.Task) (map[string]string, error) { if e.registry == nil || task.Tracking == nil { return make(map[string]string), nil } configs := make(map[string]tracking.ToolConfig) if task.Tracking.MLflow != nil && task.Tracking.MLflow.Enabled { mode := tracking.ModeSidecar if task.Tracking.MLflow.Mode != "" { mode = tracking.ToolMode(task.Tracking.MLflow.Mode) } configs["mlflow"] = tracking.ToolConfig{ Enabled: true, Mode: mode, Settings: map[string]any{ "job_name": task.JobName, "tracking_uri": task.Tracking.MLflow.TrackingURI, }, } } if task.Tracking.TensorBoard != nil && task.Tracking.TensorBoard.Enabled { mode := tracking.ModeSidecar if task.Tracking.TensorBoard.Mode != "" { mode = tracking.ToolMode(task.Tracking.TensorBoard.Mode) } configs["tensorboard"] = tracking.ToolConfig{ Enabled: true, Mode: mode, Settings: map[string]any{ "job_name": task.JobName, }, } } if task.Tracking.Wandb != nil && task.Tracking.Wandb.Enabled { mode := tracking.ModeRemote if task.Tracking.Wandb.Mode != "" { mode = tracking.ToolMode(task.Tracking.Wandb.Mode) } configs["wandb"] = tracking.ToolConfig{ Enabled: true, Mode: mode, Settings: map[string]any{ "api_key": task.Tracking.Wandb.APIKey, "project": task.Tracking.Wandb.Project, "entity": task.Tracking.Wandb.Entity, }, } } if len(configs) == 0 { return make(map[string]string), nil } env, err := e.registry.ProvisionAll(ctx, task.ID, configs) if err != nil { return nil, &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "tracking_provision", Err: err, } } return env, nil } func (e *ContainerExecutor) teardownTracking(ctx context.Context, task *queue.Task) { if e.registry != nil && task.Tracking != nil { e.registry.TeardownAll(ctx, task.ID) } } func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputDir string) map[string]string { _ = _outputDir volumes := make(map[string]string) if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok { containerPath := "/tracking/tensorboard" volumes[val] = containerPath + ":rw" trackingEnv["TENSORBOARD_LOG_DIR"] = containerPath delete(trackingEnv, "TENSORBOARD_HOST_LOG_DIR") } cacheRoot := filepath.Join(e.config.BasePath, ".cache") os.MkdirAll(cacheRoot, 0750) volumes[cacheRoot] = "/workspace/.cache:rw" defaultEnv := map[string]string{ "HF_HOME": "/workspace/.cache/huggingface", "TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub", "HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets", "TORCH_HOME": "/workspace/.cache/torch", "TORCH_HUB_DIR": "/workspace/.cache/torch/hub", "KERAS_HOME": "/workspace/.cache/keras", "CUDA_CACHE_PATH": "/workspace/.cache/cuda", "PIP_CACHE_DIR": "/workspace/.cache/pip", } for k, v := range defaultEnv { if _, ok := trackingEnv[k]; !ok { trackingEnv[k] = v } } return volumes } func (e *ContainerExecutor) selectImage(ctx context.Context, task *queue.Task) string { selectedImage := e.config.PodmanImage if e.envPool == nil || task.Metadata == nil { return selectedImage } depsSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) if depsSHA == "" { return selectedImage } warmTag, err := e.envPool.WarmImageTag(depsSHA) if err != nil { return selectedImage } inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() exists, err := e.envPool.ImageExists(inspectCtx, warmTag) if err == nil && exists { return warmTag } return selectedImage } func (e *ContainerExecutor) runPodman( ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv, jobPaths *storage.JobPaths, podmanCfg container.PodmanConfig, selectedImage string, ) error { scriptPath := filepath.Join(podmanCfg.ContainerWorkspace, e.config.TrainScript) manifestName, err := SelectDependencyManifest(filepath.Join(env.OutputDir, "code")) if err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "validation", Message: "dependency manifest selection failed", Err: err, Context: map[string]string{"image": selectedImage, "output_dir": env.OutputDir}, Timestamp: time.Now().UTC(), Recoverable: false, } } depsPath := filepath.Join(podmanCfg.ContainerWorkspace, manifestName) var extraArgs []string if task.Args != "" { extraArgs = strings.Fields(task.Args) } // Open log file logFileHandle, err := fileutil.SecureOpenFile(env.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { e.logger.Warn("failed to open log file for podman output", "path", env.LogFile, "error", err) } // Convert SandboxConfig to PodmanSecurityConfig securityConfig := container.PodmanSecurityConfig{ NoNewPrivileges: e.config.Sandbox.GetNoNewPrivileges(), DropAllCaps: e.config.Sandbox.GetDropAllCaps(), AllowedCaps: e.config.Sandbox.GetAllowedCaps(), UserNS: e.config.Sandbox.GetUserNS(), RunAsUID: e.config.Sandbox.GetRunAsUID(), RunAsGID: e.config.Sandbox.GetRunAsGID(), SeccompProfile: e.config.Sandbox.GetSeccompProfile(), ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(), NetworkMode: e.config.Sandbox.GetNetworkMode(), MaxProcesses: e.config.Sandbox.GetMaxProcesses(), MaxOpenFiles: e.config.Sandbox.GetMaxOpenFiles(), DisableSwap: e.config.Sandbox.GetDisableSwap(), OOMScoreAdj: e.config.Sandbox.GetOOMScoreAdj(), TaskUID: e.config.Sandbox.GetTaskUID(), TaskGID: e.config.Sandbox.GetTaskGID(), } podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs) // Update manifest if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { m.PodmanImage = strings.TrimSpace(selectedImage) m.Command = podmanCmd.Path if len(podmanCmd.Args) > 1 { m.Args = strings.Join(podmanCmd.Args[1:], " ") } else { m.Args = "" } }) } if logFileHandle != nil { podmanCmd.Stdout = logFileHandle podmanCmd.Stderr = logFileHandle defer logFileHandle.Close() } e.logger.Info("executing podman job", "job", task.JobName, "image", selectedImage, "workspace", podmanCfg.Workspace, "results", podmanCfg.Results) // Execute containerStart := time.Now() if err := podmanCmd.Run(); err != nil { containerDuration := time.Since(containerStart) return e.handleFailure(task, env, jobPaths, err, containerDuration) } containerDuration := time.Since(containerStart) return e.handleSuccess(task, env, jobPaths, containerDuration) } func (e *ContainerExecutor) handleFailure( task *queue.Task, env interfaces.ExecutionEnv, jobPaths *storage.JobPaths, runErr error, duration time.Duration, ) error { if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { now := time.Now().UTC() exitCode := 1 m.ExecutionDurationMS = duration.Milliseconds() m.MarkFinished(now, &exitCode, runErr) }) } failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName) os.MkdirAll(filepath.Dir(failedDir), 0750) os.RemoveAll(failedDir) telemetry.ExecWithMetrics( e.logger, "move failed job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(env.OutputDir, failedDir); err != nil { return "", fmt.Errorf("rename to failed failed: %w", err) } return "", nil }) // Return enriched error with context return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, Phase: "execution", Message: "container execution failed", Err: runErr, Context: map[string]string{"duration_ms": fmt.Sprintf("%d", duration.Milliseconds())}, Timestamp: time.Now().UTC(), Recoverable: true, // Container failures may be retryable } } func (e *ContainerExecutor) handleSuccess( task *queue.Task, env interfaces.ExecutionEnv, jobPaths *storage.JobPaths, duration time.Duration, ) error { if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { m.ExecutionDurationMS = duration.Milliseconds() }) } finalizeStart := time.Now() finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName) if e.writer != nil { e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) { now := time.Now().UTC() exitCode := 0 m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds() m.MarkFinished(now, &exitCode, nil) }) } os.MkdirAll(filepath.Dir(finishedDir), 0750) os.RemoveAll(finishedDir) telemetry.ExecWithMetrics( e.logger, "finalize job", 100*time.Millisecond, func() (string, error) { if err := os.Rename(env.OutputDir, finishedDir); err != nil { return "", fmt.Errorf("rename to finished failed: %w", err) } return "", nil }) return nil } func SelectDependencyManifest(filesPath string) (string, error) { if filesPath == "" { return "", fmt.Errorf("missing files path") } candidates := []string{ "environment.yml", "environment.yaml", "poetry.lock", "pyproject.toml", "requirements.txt", } for _, name := range candidates { p := filepath.Join(filesPath, name) if _, err := os.Stat(p); err == nil { if name == "poetry.lock" { pyprojectPath := filepath.Join(filesPath, "pyproject.toml") if _, err := os.Stat(pyprojectPath); err != nil { return "", fmt.Errorf( "poetry.lock found but pyproject.toml missing (required for Poetry projects)") } } return name, nil } } return "", fmt.Errorf( "missing dependency manifest (supported: environment.yml, environment.yaml, " + "poetry.lock, pyproject.toml, requirements.txt)") }