fetch_ml/internal/worker/executor/container.go
Jeremie Fraeys 3fb6902fa1
feat(worker): integrate scheduler endpoints and security hardening
Update worker system for scheduler integration:
- Worker server with scheduler registration
- Configuration with scheduler endpoint support
- Artifact handling with integrity verification
- Container executor with supply chain validation
- Local executor enhancements
- GPU detection improvements (cross-platform)
- Error handling with execution context
- Factory pattern for executor instantiation
- Hash integrity with native library support
2026-02-26 12:06:16 -05:00

498 lines
14 KiB
Go

// Package executor provides job execution implementations
package executor
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
)
// ContainerConfig holds configuration for container execution
type ContainerConfig struct {
Sandbox SandboxConfig
PodmanImage string
ContainerResults string
ContainerWorkspace string
TrainScript string
BasePath string
AppleGPUEnabled bool
}
// SandboxConfig interface to avoid import cycle
type SandboxConfig interface {
GetNoNewPrivileges() bool
GetDropAllCaps() bool
GetAllowedCaps() []string
GetUserNS() bool
GetRunAsUID() int
GetRunAsGID() int
GetSeccompProfile() string
GetReadOnlyRoot() bool
GetNetworkMode() string
GetMaxProcesses() int
GetMaxOpenFiles() int
GetDisableSwap() bool
GetOOMScoreAdj() int
GetTaskUID() int
GetTaskGID() int
}
// ContainerExecutor executes jobs in containers using podman
type ContainerExecutor struct {
logger *logging.Logger
writer interfaces.ManifestWriter
registry *tracking.Registry
envPool EnvironmentPool
config ContainerConfig
}
// EnvironmentPool interface for environment image pooling
type EnvironmentPool interface {
WarmImageTag(depsSHA string) (string, error)
ImageExists(ctx context.Context, tag string) (bool, error)
}
// NewContainerExecutor creates a new container job executor
func NewContainerExecutor(
logger *logging.Logger,
writer interfaces.ManifestWriter,
cfg ContainerConfig,
) *ContainerExecutor {
return &ContainerExecutor{
logger: logger,
writer: writer,
config: cfg,
}
}
// SetRegistry sets the tracking registry (optional)
func (e *ContainerExecutor) SetRegistry(registry *tracking.Registry) {
e.registry = registry
}
// SetEnvPool sets the environment pool (optional)
func (e *ContainerExecutor) SetEnvPool(pool EnvironmentPool) {
e.envPool = pool
}
// Execute runs a job in a container
func (e *ContainerExecutor) Execute(ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv) error {
containerResults := e.config.ContainerResults
if containerResults == "" {
containerResults = config.DefaultContainerResults
}
containerWorkspace := e.config.ContainerWorkspace
if containerWorkspace == "" {
containerWorkspace = config.DefaultContainerWorkspace
}
jobPaths := storage.NewJobPaths(e.config.BasePath)
// Setup tracking environment
trackingEnv, err := e.setupTrackingEnv(ctx, task)
if err != nil {
return err
}
defer e.teardownTracking(ctx, task)
// Setup volumes
volumes := e.setupVolumes(trackingEnv, env.OutputDir)
// Setup container environment
if strings.TrimSpace(env.GPUEnvVar) != "" {
trackingEnv[env.GPUEnvVar] = strings.TrimSpace(env.GPUDevicesStr)
}
snap := filepath.Join(env.OutputDir, "snapshot")
if info, err := os.Stat(snap); err == nil && info.IsDir() {
trackingEnv["FETCH_ML_SNAPSHOT_DIR"] = "/snapshot"
if strings.TrimSpace(task.SnapshotID) != "" {
trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID)
}
volumes[snap] = "/snapshot:ro"
}
cpusOverride, memOverride := container.PodmanResourceOverrides(task.CPU, task.MemoryGB)
// Select image (with warm cache check)
selectedImage := e.selectImage(ctx, task)
// Build podman config
podmanCfg := container.PodmanConfig{
Image: selectedImage,
Workspace: filepath.Join(env.OutputDir, "code"),
Results: filepath.Join(env.OutputDir, "results"),
ContainerWorkspace: containerWorkspace,
ContainerResults: containerResults,
AppleGPU: e.config.AppleGPUEnabled,
GPUDevices: env.GPUDevices,
Env: trackingEnv,
Volumes: volumes,
Memory: memOverride,
CPUs: cpusOverride,
}
// Build and execute command
return e.runPodman(ctx, task, env, jobPaths, podmanCfg, selectedImage)
}
func (e *ContainerExecutor) setupTrackingEnv(ctx context.Context, task *queue.Task) (map[string]string, error) {
if e.registry == nil || task.Tracking == nil {
return make(map[string]string), nil
}
configs := make(map[string]tracking.ToolConfig)
if task.Tracking.MLflow != nil && task.Tracking.MLflow.Enabled {
mode := tracking.ModeSidecar
if task.Tracking.MLflow.Mode != "" {
mode = tracking.ToolMode(task.Tracking.MLflow.Mode)
}
configs["mlflow"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"job_name": task.JobName,
"tracking_uri": task.Tracking.MLflow.TrackingURI,
},
}
}
if task.Tracking.TensorBoard != nil && task.Tracking.TensorBoard.Enabled {
mode := tracking.ModeSidecar
if task.Tracking.TensorBoard.Mode != "" {
mode = tracking.ToolMode(task.Tracking.TensorBoard.Mode)
}
configs["tensorboard"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"job_name": task.JobName,
},
}
}
if task.Tracking.Wandb != nil && task.Tracking.Wandb.Enabled {
mode := tracking.ModeRemote
if task.Tracking.Wandb.Mode != "" {
mode = tracking.ToolMode(task.Tracking.Wandb.Mode)
}
configs["wandb"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"api_key": task.Tracking.Wandb.APIKey,
"project": task.Tracking.Wandb.Project,
"entity": task.Tracking.Wandb.Entity,
},
}
}
if len(configs) == 0 {
return make(map[string]string), nil
}
env, err := e.registry.ProvisionAll(ctx, task.ID, configs)
if err != nil {
return nil, &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "tracking_provision",
Err: err,
}
}
return env, nil
}
func (e *ContainerExecutor) teardownTracking(ctx context.Context, task *queue.Task) {
if e.registry != nil && task.Tracking != nil {
e.registry.TeardownAll(ctx, task.ID)
}
}
func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputDir string) map[string]string {
_ = _outputDir
volumes := make(map[string]string)
if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok {
containerPath := "/tracking/tensorboard"
volumes[val] = containerPath + ":rw"
trackingEnv["TENSORBOARD_LOG_DIR"] = containerPath
delete(trackingEnv, "TENSORBOARD_HOST_LOG_DIR")
}
cacheRoot := filepath.Join(e.config.BasePath, ".cache")
os.MkdirAll(cacheRoot, 0750)
volumes[cacheRoot] = "/workspace/.cache:rw"
defaultEnv := map[string]string{
"HF_HOME": "/workspace/.cache/huggingface",
"TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub",
"HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets",
"TORCH_HOME": "/workspace/.cache/torch",
"TORCH_HUB_DIR": "/workspace/.cache/torch/hub",
"KERAS_HOME": "/workspace/.cache/keras",
"CUDA_CACHE_PATH": "/workspace/.cache/cuda",
"PIP_CACHE_DIR": "/workspace/.cache/pip",
}
for k, v := range defaultEnv {
if _, ok := trackingEnv[k]; !ok {
trackingEnv[k] = v
}
}
return volumes
}
func (e *ContainerExecutor) selectImage(ctx context.Context, task *queue.Task) string {
selectedImage := e.config.PodmanImage
if e.envPool == nil || task.Metadata == nil {
return selectedImage
}
depsSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if depsSHA == "" {
return selectedImage
}
warmTag, err := e.envPool.WarmImageTag(depsSHA)
if err != nil {
return selectedImage
}
inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
exists, err := e.envPool.ImageExists(inspectCtx, warmTag)
if err == nil && exists {
return warmTag
}
return selectedImage
}
func (e *ContainerExecutor) runPodman(
ctx context.Context,
task *queue.Task,
env interfaces.ExecutionEnv,
jobPaths *storage.JobPaths,
podmanCfg container.PodmanConfig,
selectedImage string,
) error {
scriptPath := filepath.Join(podmanCfg.ContainerWorkspace, e.config.TrainScript)
manifestName, err := SelectDependencyManifest(filepath.Join(env.OutputDir, "code"))
if err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Message: "dependency manifest selection failed",
Err: err,
Context: map[string]string{"image": selectedImage, "output_dir": env.OutputDir},
Timestamp: time.Now().UTC(),
Recoverable: false,
}
}
depsPath := filepath.Join(podmanCfg.ContainerWorkspace, manifestName)
var extraArgs []string
if task.Args != "" {
extraArgs = strings.Fields(task.Args)
}
// Open log file
logFileHandle, err := fileutil.SecureOpenFile(env.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
e.logger.Warn("failed to open log file for podman output", "path", env.LogFile, "error", err)
}
// Convert SandboxConfig to PodmanSecurityConfig
securityConfig := container.PodmanSecurityConfig{
NoNewPrivileges: e.config.Sandbox.GetNoNewPrivileges(),
DropAllCaps: e.config.Sandbox.GetDropAllCaps(),
AllowedCaps: e.config.Sandbox.GetAllowedCaps(),
UserNS: e.config.Sandbox.GetUserNS(),
RunAsUID: e.config.Sandbox.GetRunAsUID(),
RunAsGID: e.config.Sandbox.GetRunAsGID(),
SeccompProfile: e.config.Sandbox.GetSeccompProfile(),
ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(),
NetworkMode: e.config.Sandbox.GetNetworkMode(),
MaxProcesses: e.config.Sandbox.GetMaxProcesses(),
MaxOpenFiles: e.config.Sandbox.GetMaxOpenFiles(),
DisableSwap: e.config.Sandbox.GetDisableSwap(),
OOMScoreAdj: e.config.Sandbox.GetOOMScoreAdj(),
TaskUID: e.config.Sandbox.GetTaskUID(),
TaskGID: e.config.Sandbox.GetTaskGID(),
}
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs)
// Update manifest
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
m.PodmanImage = strings.TrimSpace(selectedImage)
m.Command = podmanCmd.Path
if len(podmanCmd.Args) > 1 {
m.Args = strings.Join(podmanCmd.Args[1:], " ")
} else {
m.Args = ""
}
})
}
if logFileHandle != nil {
podmanCmd.Stdout = logFileHandle
podmanCmd.Stderr = logFileHandle
defer logFileHandle.Close()
}
e.logger.Info("executing podman job",
"job", task.JobName,
"image", selectedImage,
"workspace", podmanCfg.Workspace,
"results", podmanCfg.Results)
// Execute
containerStart := time.Now()
if err := podmanCmd.Run(); err != nil {
containerDuration := time.Since(containerStart)
return e.handleFailure(task, env, jobPaths, err, containerDuration)
}
containerDuration := time.Since(containerStart)
return e.handleSuccess(task, env, jobPaths, containerDuration)
}
func (e *ContainerExecutor) handleFailure(
task *queue.Task,
env interfaces.ExecutionEnv,
jobPaths *storage.JobPaths,
runErr error,
duration time.Duration,
) error {
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
exitCode := 1
m.ExecutionDurationMS = duration.Milliseconds()
m.MarkFinished(now, &exitCode, runErr)
})
}
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
os.MkdirAll(filepath.Dir(failedDir), 0750)
os.RemoveAll(failedDir)
telemetry.ExecWithMetrics(
e.logger,
"move failed job",
100*time.Millisecond,
func() (string, error) {
if err := os.Rename(env.OutputDir, failedDir); err != nil {
return "", fmt.Errorf("rename to failed failed: %w", err)
}
return "", nil
})
// Return enriched error with context
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Message: "container execution failed",
Err: runErr,
Context: map[string]string{"duration_ms": fmt.Sprintf("%d", duration.Milliseconds())},
Timestamp: time.Now().UTC(),
Recoverable: true, // Container failures may be retryable
}
}
func (e *ContainerExecutor) handleSuccess(
task *queue.Task,
env interfaces.ExecutionEnv,
jobPaths *storage.JobPaths,
duration time.Duration,
) error {
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
m.ExecutionDurationMS = duration.Milliseconds()
})
}
finalizeStart := time.Now()
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
if e.writer != nil {
e.writer.Upsert(env.OutputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
exitCode := 0
m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds()
m.MarkFinished(now, &exitCode, nil)
})
}
os.MkdirAll(filepath.Dir(finishedDir), 0750)
os.RemoveAll(finishedDir)
telemetry.ExecWithMetrics(
e.logger,
"finalize job",
100*time.Millisecond,
func() (string, error) {
if err := os.Rename(env.OutputDir, finishedDir); err != nil {
return "", fmt.Errorf("rename to finished failed: %w", err)
}
return "", nil
})
return nil
}
func SelectDependencyManifest(filesPath string) (string, error) {
if filesPath == "" {
return "", fmt.Errorf("missing files path")
}
candidates := []string{
"environment.yml",
"environment.yaml",
"poetry.lock",
"pyproject.toml",
"requirements.txt",
}
for _, name := range candidates {
p := filepath.Join(filesPath, name)
if _, err := os.Stat(p); err == nil {
if name == "poetry.lock" {
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
if _, err := os.Stat(pyprojectPath); err != nil {
return "", fmt.Errorf(
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
}
}
return name, nil
}
}
return "", fmt.Errorf(
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
"poetry.lock, pyproject.toml, requirements.txt)")
}