fetch_ml/internal/worker/execution.go
Jeremie Fraeys d1bef0a450
refactor: Phase 3 - fix config/storage boundaries
Move schema ownership to infrastructure layer:

- Redis keys: config/constants.go -> queue/keys.go (TaskQueueKey, TaskPrefix, etc.)

- Filesystem paths: config/paths.go -> storage/paths.go (JobPaths)

- Create config/shared.go with RedisConfig, SSHConfig

- Update all imports: worker/, api/helpers, api/ws_jobs, api/ws_validate

- Clean up: remove duplicates from queue/task.go, queue/queue.go, config/paths.go

Build status: Compiles successfully
2026-02-17 12:49:53 -05:00

1081 lines
30 KiB
Go

package worker
import (
"context"
"encoding/hex"
"fmt"
"io"
"log"
"os"
"os/exec"
"path/filepath"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"github.com/jfraeys/fetch_ml/internal/tracking"
)
func gpuVisibleDevicesString(cfg *Config, fallback string) string {
if cfg == nil {
return strings.TrimSpace(fallback)
}
if len(cfg.GPUVisibleDeviceIDs) > 0 {
parts := make([]string, 0, len(cfg.GPUVisibleDeviceIDs))
for _, id := range cfg.GPUVisibleDeviceIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
parts = append(parts, id)
}
return strings.Join(parts, ",")
}
if len(cfg.GPUVisibleDevices) == 0 {
return strings.TrimSpace(fallback)
}
parts := make([]string, 0, len(cfg.GPUVisibleDevices))
for _, v := range cfg.GPUVisibleDevices {
if v < 0 {
continue
}
parts = append(parts, strconv.Itoa(v))
}
return strings.Join(parts, ",")
}
func filterExistingDevicePaths(paths []string) []string {
if len(paths) == 0 {
return nil
}
seen := make(map[string]struct{}, len(paths))
out := make([]string, 0, len(paths))
for _, p := range paths {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
if _, err := os.Stat(p); err != nil {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
return out
}
func gpuVisibleEnvVarName(cfg *Config) string {
if cfg == nil {
return "CUDA_VISIBLE_DEVICES"
}
switch strings.ToLower(strings.TrimSpace(cfg.GPUVendor)) {
case "amd":
return "HIP_VISIBLE_DEVICES"
case string(GPUTypeApple), string(GPUTypeNone):
return ""
default:
return "CUDA_VISIBLE_DEVICES"
}
}
func runIDForTask(task *queue.Task) string {
if task == nil {
return ""
}
created := task.CreatedAt
if created.IsZero() {
created = time.Now().UTC()
}
short := task.ID
if len(short) > 8 {
short = short[:8]
}
job := strings.TrimSpace(task.JobName)
if job == "" {
job = "job"
}
return fmt.Sprintf("run-%s-%s-%s", job, created.UTC().Format("20060102-150405"), short)
}
func (w *Worker) buildInitialRunManifest(
task *queue.Task,
podmanImage string,
) *manifest.RunManifest {
if task == nil {
return nil
}
m := manifest.NewRunManifest(runIDForTask(task), task.ID, task.JobName, task.CreatedAt)
m.PodmanImage = strings.TrimSpace(podmanImage)
if host, err := os.Hostname(); err == nil {
m.WorkerHost = strings.TrimSpace(host)
}
if info, ok := debug.ReadBuildInfo(); ok && info != nil {
m.WorkerVersion = strings.TrimSpace(info.Main.Version)
}
if task.Metadata != nil {
m.CommitID = strings.TrimSpace(task.Metadata["commit_id"])
m.ExperimentManifestSHA = strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
m.DepsManifestName = strings.TrimSpace(task.Metadata["deps_manifest_name"])
m.DepsManifestSHA = strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
m.SnapshotSHA256 = strings.TrimSpace(task.Metadata["snapshot_sha256"])
// Forward compatibility: copy selected metadata keys verbatim.
for k, v := range task.Metadata {
if strings.TrimSpace(k) == "" || strings.TrimSpace(v) == "" {
continue
}
m.Metadata[k] = v
}
}
m.SnapshotID = strings.TrimSpace(task.SnapshotID)
m.Metadata["task_args"] = strings.TrimSpace(task.Args)
return m
}
func (w *Worker) upsertRunManifest(
dir string,
task *queue.Task,
mutate func(m *manifest.RunManifest),
) {
if strings.TrimSpace(dir) == "" {
return
}
if task == nil {
return
}
cur, err := manifest.LoadFromDir(dir)
if err != nil {
cur = w.buildInitialRunManifest(task, w.config.PodmanImage)
}
if cur == nil {
return
}
if mutate != nil {
mutate(cur)
}
if err := cur.WriteToDir(dir); err != nil {
w.logger.Warn(
"failed to write run manifest",
"job", task.JobName,
"task_id", task.ID,
"error", err,
)
}
}
func StageSnapshot(basePath, dataDir, taskID, snapshotID, jobDir string) error {
sid := strings.TrimSpace(snapshotID)
if sid == "" {
return nil
}
if err := container.ValidateJobName(sid); err != nil {
return err
}
if strings.TrimSpace(taskID) == "" {
return fmt.Errorf("missing task id")
}
if strings.TrimSpace(jobDir) == "" {
return fmt.Errorf("missing job dir")
}
src := filepath.Join(dataDir, "snapshots", sid)
return StageSnapshotFromPath(basePath, taskID, src, jobDir)
}
func StageSnapshotFromPath(basePath, taskID, srcPath, jobDir string) error {
if strings.TrimSpace(basePath) == "" {
return fmt.Errorf("missing base path")
}
if strings.TrimSpace(taskID) == "" {
return fmt.Errorf("missing task id")
}
if strings.TrimSpace(jobDir) == "" {
return fmt.Errorf("missing job dir")
}
dst := filepath.Join(jobDir, "snapshot")
_ = os.RemoveAll(dst)
prewarmSrc := filepath.Join(basePath, ".prewarm", "snapshots", taskID)
if info, err := os.Stat(prewarmSrc); err == nil && info.IsDir() {
// TODO: Emit Prometheus prewarm snapshot hit metric when available
return os.Rename(prewarmSrc, dst)
}
// TODO: Emit Prometheus prewarm snapshot miss metric when available
return copyDir(srcPath, dst)
}
func (w *Worker) runJob(ctx context.Context, task *queue.Task, cudaVisibleDevices string) error {
visibleDevices := gpuVisibleDevicesString(w.config, cudaVisibleDevices)
visibleEnvVar := gpuVisibleEnvVarName(w.config)
// Validate job name to prevent path traversal
if err := container.ValidateJobName(task.JobName); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
jobDir, outputDir, logFile, err := w.setupJobDirectories(task)
if err != nil {
return err
}
// Best-effort: write initial run manifest into pending dir so it follows the job via rename.
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
m.TrainScriptPath = strings.TrimSpace(w.config.TrainScript)
if strings.TrimSpace(w.config.Host) != "" {
m.Metadata["worker_config_host"] = strings.TrimSpace(w.config.Host)
}
m.Metadata["task_args"] = strings.TrimSpace(task.Args)
m.MarkStarted(time.Now().UTC())
m.GPUDevices = w.getGPUDevicePaths()
if strings.TrimSpace(visibleEnvVar) != "" {
m.Metadata["gpu_visible_devices"] = strings.TrimSpace(visibleDevices)
m.Metadata["gpu_visible_env"] = strings.TrimSpace(visibleEnvVar)
}
})
if err := w.stageExperimentFiles(task, jobDir); err != nil {
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
if a, aerr := scanArtifacts(jobDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
now := time.Now().UTC()
exitCode := 1
m.MarkFinished(now, &exitCode, err)
m.Metadata["failure_phase"] = "stage_experiment_files"
})
failedDir := filepath.Join(storage.NewJobPaths(w.config.BasePath).FailedPath(), task.JobName)
_ = os.MkdirAll(filepath.Dir(failedDir), 0750)
_ = os.RemoveAll(failedDir)
_ = os.Rename(jobDir, failedDir)
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
if err := w.stageSnapshot(ctx, task, jobDir); err != nil {
w.upsertRunManifest(jobDir, task, func(m *manifest.RunManifest) {
if a, aerr := scanArtifacts(jobDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
now := time.Now().UTC()
exitCode := 1
m.MarkFinished(now, &exitCode, err)
m.Metadata["failure_phase"] = "stage_snapshot"
})
failedDir := filepath.Join(storage.NewJobPaths(w.config.BasePath).FailedPath(), task.JobName)
_ = os.MkdirAll(filepath.Dir(failedDir), 0750)
_ = os.RemoveAll(failedDir)
_ = os.Rename(jobDir, failedDir)
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
return w.executeJob(ctx, task, jobDir, outputDir, logFile, visibleDevices, visibleEnvVar)
}
func (w *Worker) RunJob(ctx context.Context, task *queue.Task, cudaVisibleDevices string) error {
return w.runJob(ctx, task, cudaVisibleDevices)
}
func (w *Worker) stageSnapshot(ctx context.Context, task *queue.Task, jobDir string) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if strings.TrimSpace(task.SnapshotID) == "" {
return nil
}
if task.Metadata == nil {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
want := task.Metadata["snapshot_sha256"]
resolved, err := ResolveSnapshot(
ctx,
w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
want,
nil,
)
if err != nil {
return err
}
return StageSnapshotFromPath(w.config.BasePath, task.ID, resolved, jobDir)
}
func (w *Worker) validateTaskForExecution(_ context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if err := container.ValidateJobName(task.JobName); err != nil {
return err
}
if task.Metadata == nil {
return fmt.Errorf("missing task metadata")
}
commitID, ok := task.Metadata["commit_id"]
if !ok || commitID == "" {
return fmt.Errorf("missing commit_id")
}
if len(commitID) != 40 {
return fmt.Errorf("invalid commit_id length")
}
if _, err := hex.DecodeString(commitID); err != nil {
return fmt.Errorf("invalid commit_id: %w", err)
}
expMgr := experiment.NewManager(w.config.BasePath)
meta, err := expMgr.ReadMetadata(commitID)
if err != nil {
return fmt.Errorf("missing or unreadable experiment metadata: %w", err)
}
if meta.CommitID != commitID {
return fmt.Errorf("experiment metadata commit_id mismatch")
}
filesPath := expMgr.GetFilesPath(commitID)
trainScriptHostPath := filepath.Join(filesPath, w.config.TrainScript)
if _, err := os.Stat(trainScriptHostPath); err != nil {
return fmt.Errorf("missing train script: %w", err)
}
if _, err := selectDependencyManifest(filesPath); err != nil {
return err
}
// Validate content integrity manifest
if err := expMgr.ValidateManifest(commitID); err != nil {
return fmt.Errorf("content integrity validation failed: %w", err)
}
return nil
}
func (w *Worker) podmanImageDigest(ctx context.Context, imageRef string) string {
ref := strings.TrimSpace(imageRef)
if ref == "" {
return ""
}
inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := exec.CommandContext(inspectCtx, "podman", "image", "inspect", "--format", "{{.Id}}", ref)
out, err := cmd.CombinedOutput()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
func (w *Worker) stageExperimentFiles(task *queue.Task, jobDir string) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.Metadata == nil {
return fmt.Errorf("missing task metadata")
}
commitID, ok := task.Metadata["commit_id"]
if !ok || commitID == "" {
return fmt.Errorf("missing commit_id")
}
expMgr := experiment.NewManager(w.config.BasePath)
src := expMgr.GetFilesPath(commitID)
dst := filepath.Join(jobDir, "code")
if err := copyDir(src, dst); err != nil {
return err
}
return nil
}
func copyDir(src, dst string) error {
src = filepath.Clean(src)
dst = filepath.Clean(dst)
srcInfo, err := os.Stat(src)
if err != nil {
return err
}
if !srcInfo.IsDir() {
return fmt.Errorf("source is not a directory")
}
if err := os.MkdirAll(dst, 0750); err != nil {
return err
}
return filepath.WalkDir(src, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(src, path)
if err != nil {
return err
}
rel = filepath.Clean(rel)
if rel == "." {
return nil
}
if strings.HasPrefix(rel, "..") {
return fmt.Errorf("invalid relative path")
}
outPath := filepath.Join(dst, rel)
if d.IsDir() {
return os.MkdirAll(outPath, 0750)
}
info, err := d.Info()
if err != nil {
return err
}
mode := info.Mode() & 0777
in, err := os.Open(filepath.Clean(path))
if err != nil {
return err
}
defer func() { _ = in.Close() }()
out, err := fileutil.SecureOpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return err
}
defer func() { _ = out.Close() }()
_, err = io.Copy(out, in)
return err
})
}
func (w *Worker) setupJobDirectories(
task *queue.Task,
) (jobDir, outputDir, logFile string, err error) {
jobPaths := storage.NewJobPaths(w.config.BasePath)
pendingDir := jobPaths.PendingPath()
jobDir = filepath.Join(pendingDir, task.JobName)
outputDir = filepath.Join(jobPaths.RunningPath(), task.JobName)
logFile = filepath.Join(outputDir, "output.log")
// Create pending directory
if err := os.MkdirAll(pendingDir, 0750); err != nil {
return "", "", "", &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create pending dir: %w", err),
}
}
// Create job directory in pending
if err := os.MkdirAll(jobDir, 0750); err != nil {
return "", "", "", &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create job dir: %w", err),
}
}
// Sanitize paths
jobDir, err = container.SanitizePath(jobDir)
if err != nil {
return "", "", "", &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
outputDir, err = container.SanitizePath(outputDir)
if err != nil {
return "", "", "", &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
return jobDir, outputDir, logFile, nil
}
func (w *Worker) executeJob(
ctx context.Context,
task *queue.Task,
jobDir, outputDir, logFile string,
visibleDevices string,
visibleEnvVar string,
) error {
// Create output directory
if _, err := telemetry.ExecWithMetrics(
w.logger,
"create output dir",
100*time.Millisecond,
func() (string, error) {
if err := os.MkdirAll(outputDir, 0750); err != nil {
return "", fmt.Errorf("mkdir failed: %w", err)
}
return "", nil
},
); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to create output dir: %w", err),
}
}
// Move job from pending to running
stagingStart := time.Now()
if _, err := telemetry.ExecWithMetrics(
w.logger,
"stage job",
100*time.Millisecond,
func() (string, error) {
// Remove existing directory if it exists
if _, err := os.Stat(outputDir); err == nil {
if err := os.RemoveAll(outputDir); err != nil {
return "", fmt.Errorf("remove existing failed: %w", err)
}
}
if err := os.Rename(jobDir, outputDir); err != nil {
return "", fmt.Errorf("rename failed: %w", err)
}
return "", nil
},
); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "setup",
Err: fmt.Errorf("failed to move job: %w", err),
}
}
stagingDuration := time.Since(stagingStart)
// Best-effort: record staging duration in running dir.
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
m.StagingDurationMS = stagingDuration.Milliseconds()
m.GPUDevices = w.getGPUDevicePaths()
if strings.TrimSpace(visibleEnvVar) != "" {
m.Metadata["gpu_visible_devices"] = strings.TrimSpace(visibleDevices)
m.Metadata["gpu_visible_env"] = strings.TrimSpace(visibleEnvVar)
}
})
// Execute job
if w.config.LocalMode {
execStart := time.Now()
err := w.executeLocalJob(ctx, task, outputDir, logFile, visibleDevices, visibleEnvVar)
execDuration := time.Since(execStart)
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
m.ExecutionDurationMS = execDuration.Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
if err != nil {
exitCode := 1
m.MarkFinished(now, &exitCode, err)
} else {
exitCode := 0
m.MarkFinished(now, &exitCode, nil)
}
})
finalizeStart := time.Now()
jobPaths := storage.NewJobPaths(w.config.BasePath)
var dest string
if err != nil {
dest = filepath.Join(jobPaths.FailedPath(), task.JobName)
} else {
dest = filepath.Join(jobPaths.FinishedPath(), task.JobName)
}
_ = os.MkdirAll(filepath.Dir(dest), 0750)
_ = os.RemoveAll(dest)
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds()
})
if moveErr := os.Rename(outputDir, dest); moveErr != nil {
w.logger.Warn("failed to move local-mode job dir", "job", task.JobName, "error", moveErr)
}
return err
}
return w.executeContainerJob(
ctx,
task,
outputDir,
logFile,
stagingDuration,
visibleDevices,
visibleEnvVar,
)
}
func (w *Worker) executeLocalJob(
ctx context.Context,
task *queue.Task,
outputDir, logFile string,
visibleDevices string,
visibleEnvVar string,
) error {
// Create experiment script
scriptContent := `#!/bin/bash
set -e
echo "Starting experiment: ` + task.JobName + `"
echo "Task ID: ` + task.ID + `"
echo "Timestamp: $(date)"
# Simulate ML experiment
echo "Loading data..."
sleep 1
echo "Training model..."
sleep 2
echo "Evaluating model..."
sleep 1
# Generate results
ACCURACY=0.95
LOSS=0.05
EPOCHS=10
echo ""
echo "=== EXPERIMENT RESULTS ==="
echo "Accuracy: $ACCURACY"
echo "Loss: $LOSS"
echo "Epochs: $EPOCHS"
echo "Status: SUCCESS"
echo "========================="
echo "Experiment completed successfully!"
`
scriptPath := filepath.Join(outputDir, "run.sh")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("failed to write script: %w", err),
}
}
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
m.Command = "bash"
m.Args = scriptPath
})
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
w.logger.Warn("failed to open log file for local output", "path", logFile, "error", err)
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("failed to open log file: %w", err),
}
}
defer func() {
if err := logFileHandle.Close(); err != nil {
log.Printf("Warning: failed to close log file: %v", err)
}
}()
// Execute the script directly
localCmd := exec.CommandContext(ctx, "bash", scriptPath)
env := os.Environ()
if strings.TrimSpace(visibleEnvVar) != "" {
env = append(env, fmt.Sprintf("%s=%s", visibleEnvVar, strings.TrimSpace(visibleDevices)))
}
snap := filepath.Join(outputDir, "snapshot")
if info, err := os.Stat(snap); err == nil && info.IsDir() {
env = append(env, fmt.Sprintf("FETCH_ML_SNAPSHOT_DIR=%s", snap))
if strings.TrimSpace(task.SnapshotID) != "" {
env = append(env, fmt.Sprintf("FETCH_ML_SNAPSHOT_ID=%s", strings.TrimSpace(task.SnapshotID)))
}
}
localCmd.Env = env
localCmd.Stdout = logFileHandle
localCmd.Stderr = logFileHandle
w.logger.Info("executing local job",
"job", task.JobName,
"task_id", task.ID,
"script", scriptPath)
if err := localCmd.Run(); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "execution",
Err: fmt.Errorf("execution failed: %w", err),
}
}
return nil
}
func (w *Worker) executeContainerJob(
ctx context.Context,
task *queue.Task,
outputDir, logFile string,
stagingDuration time.Duration,
visibleDevices string,
visibleEnvVar string,
) error {
containerResults := w.config.ContainerResults
if containerResults == "" {
containerResults = config.DefaultContainerResults
}
containerWorkspace := w.config.ContainerWorkspace
if containerWorkspace == "" {
containerWorkspace = config.DefaultContainerWorkspace
}
jobPaths := storage.NewJobPaths(w.config.BasePath)
stagingStart := time.Now()
// Optional: provision tracking tools for this task
var trackingEnv map[string]string
if w.trackingRegistry != nil && task.Tracking != nil {
configs := make(map[string]tracking.ToolConfig)
if task.Tracking.MLflow != nil && task.Tracking.MLflow.Enabled {
mode := tracking.ModeSidecar
if task.Tracking.MLflow.Mode != "" {
mode = tracking.ToolMode(task.Tracking.MLflow.Mode)
}
configs["mlflow"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"job_name": task.JobName,
"tracking_uri": task.Tracking.MLflow.TrackingURI,
},
}
}
if task.Tracking.TensorBoard != nil && task.Tracking.TensorBoard.Enabled {
mode := tracking.ModeSidecar
if task.Tracking.TensorBoard.Mode != "" {
mode = tracking.ToolMode(task.Tracking.TensorBoard.Mode)
}
configs["tensorboard"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"job_name": task.JobName,
},
}
}
if task.Tracking.Wandb != nil && task.Tracking.Wandb.Enabled {
mode := tracking.ModeRemote
if task.Tracking.Wandb.Mode != "" {
mode = tracking.ToolMode(task.Tracking.Wandb.Mode)
}
configs["wandb"] = tracking.ToolConfig{
Enabled: true,
Mode: mode,
Settings: map[string]any{
"api_key": task.Tracking.Wandb.APIKey,
"project": task.Tracking.Wandb.Project,
"entity": task.Tracking.Wandb.Entity,
},
}
}
if len(configs) > 0 {
var err error
trackingEnv, err = w.trackingRegistry.ProvisionAll(ctx, task.ID, configs)
if err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "tracking_provision",
Err: err,
}
}
defer w.trackingRegistry.TeardownAll(context.Background(), task.ID)
}
}
var volumes map[string]string
if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok {
volumes = make(map[string]string)
// Mount to /tracking/tensorboard inside container
containerPath := "/tracking/tensorboard"
volumes[val] = containerPath + ":rw"
// Update environment variable for the script
trackingEnv["TENSORBOARD_LOG_DIR"] = containerPath
// Remove the host path from Env to avoid leaking host info
delete(trackingEnv, "TENSORBOARD_HOST_LOG_DIR")
}
if trackingEnv == nil {
trackingEnv = make(map[string]string)
}
cacheRoot := filepath.Join(w.config.BasePath, ".cache")
if err := os.MkdirAll(cacheRoot, 0755); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "cache_setup",
Err: err,
}
}
if volumes == nil {
volumes = make(map[string]string)
}
volumes[cacheRoot] = "/workspace/.cache:rw"
defaultEnv := map[string]string{
"HF_HOME": "/workspace/.cache/huggingface",
"TRANSFORMERS_CACHE": "/workspace/.cache/huggingface/hub",
"HF_DATASETS_CACHE": "/workspace/.cache/huggingface/datasets",
"TORCH_HOME": "/workspace/.cache/torch",
"TORCH_HUB_DIR": "/workspace/.cache/torch/hub",
"KERAS_HOME": "/workspace/.cache/keras",
"CUDA_CACHE_PATH": "/workspace/.cache/cuda",
"PIP_CACHE_DIR": "/workspace/.cache/pip",
}
for k, v := range defaultEnv {
if _, ok := trackingEnv[k]; ok {
continue
}
trackingEnv[k] = v
}
if strings.TrimSpace(visibleEnvVar) != "" {
trackingEnv[visibleEnvVar] = strings.TrimSpace(visibleDevices)
}
snap := filepath.Join(outputDir, "snapshot")
if info, err := os.Stat(snap); err == nil && info.IsDir() {
trackingEnv["FETCH_ML_SNAPSHOT_DIR"] = "/snapshot"
if strings.TrimSpace(task.SnapshotID) != "" {
trackingEnv["FETCH_ML_SNAPSHOT_ID"] = strings.TrimSpace(task.SnapshotID)
}
volumes[snap] = "/snapshot:ro"
}
cpusOverride, memOverride := container.PodmanResourceOverrides(task.CPU, task.MemoryGB)
selectedImage := w.config.PodmanImage
if w.envPool != nil &&
!w.config.LocalMode &&
strings.TrimSpace(w.config.PodmanImage) != "" &&
task != nil &&
task.Metadata != nil {
depsSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if depsSHA != "" {
if warmTag, err := w.envPool.WarmImageTag(depsSHA); err == nil {
inspectCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
exists, ierr := w.envPool.ImageExists(inspectCtx, warmTag)
cancel()
if ierr == nil && exists {
selectedImage = warmTag
}
}
}
}
podmanCfg := container.PodmanConfig{
Image: selectedImage,
Workspace: filepath.Join(outputDir, "code"),
Results: filepath.Join(outputDir, "results"),
ContainerWorkspace: containerWorkspace,
ContainerResults: containerResults,
AppleGPU: w.config.AppleGPU.Enabled,
GPUDevices: w.getGPUDevicePaths(),
Env: trackingEnv,
Volumes: volumes,
Memory: memOverride,
CPUs: cpusOverride,
}
scriptPath := filepath.Join(containerWorkspace, w.config.TrainScript)
manifestName, err := selectDependencyManifest(filepath.Join(outputDir, "code"))
if err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,
Phase: "validation",
Err: err,
}
}
depsPath := filepath.Join(containerWorkspace, manifestName)
var extraArgs []string
if task.Args != "" {
extraArgs = strings.Fields(task.Args)
}
ioBefore, ioErr := telemetry.ReadProcessIO()
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, depsPath, extraArgs)
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
m.PodmanImage = strings.TrimSpace(selectedImage)
m.ImageDigest = strings.TrimSpace(w.podmanImageDigest(ctx, selectedImage))
m.Command = podmanCmd.Path
if len(podmanCmd.Args) > 1 {
m.Args = strings.Join(podmanCmd.Args[1:], " ")
} else {
m.Args = ""
}
})
logFileHandle, err := fileutil.SecureOpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err == nil {
podmanCmd.Stdout = logFileHandle
podmanCmd.Stderr = logFileHandle
} else {
w.logger.Warn("failed to open log file for podman output", "path", logFile, "error", err)
}
w.logger.Info("executing podman job",
"job", task.JobName,
"image", selectedImage,
"workspace", podmanCfg.Workspace,
"results", podmanCfg.Results)
containerStart := time.Now()
if err := podmanCmd.Run(); err != nil {
containerDuration := time.Since(containerStart)
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
exitCode := 1
m.ExecutionDurationMS = containerDuration.Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
m.MarkFinished(now, &exitCode, err)
})
// Move job to failed directory
failedDir := filepath.Join(jobPaths.FailedPath(), task.JobName)
if _, moveErr := telemetry.ExecWithMetrics(
w.logger,
"move failed job",
100*time.Millisecond,
func() (string, error) {
if err := os.Rename(outputDir, failedDir); err != nil {
return "", fmt.Errorf("rename to failed failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to failed dir", "job", task.JobName, "error", moveErr)
}
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
delta := telemetry.DiffIO(ioBefore, after)
w.logger.Debug("worker io stats",
"job", task.JobName,
"read_bytes", delta.ReadBytes,
"write_bytes", delta.WriteBytes)
}
}
w.logger.Info("job timing (failure)",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", 0,
"total_ms", time.Since(stagingStart).Milliseconds(),
)
return fmt.Errorf("execution failed: %w", err)
}
containerDuration := time.Since(containerStart)
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
m.ExecutionDurationMS = containerDuration.Milliseconds()
})
finalizeStart := time.Now()
// Move job to finished directory
finishedDir := filepath.Join(jobPaths.FinishedPath(), task.JobName)
// Best-effort: finalize manifest before moving the directory.
w.upsertRunManifest(outputDir, task, func(m *manifest.RunManifest) {
now := time.Now().UTC()
exitCode := 0
m.FinalizeDurationMS = time.Since(finalizeStart).Milliseconds()
if a, aerr := scanArtifacts(outputDir); aerr == nil {
m.Artifacts = a
} else {
w.logger.Warn("failed to scan artifacts", "job", task.JobName, "task_id", task.ID, "error", aerr)
}
m.MarkFinished(now, &exitCode, nil)
})
if _, moveErr := telemetry.ExecWithMetrics(
w.logger,
"finalize job",
100*time.Millisecond,
func() (string, error) {
if err := os.Rename(outputDir, finishedDir); err != nil {
return "", fmt.Errorf("rename to finished failed: %w", err)
}
return "", nil
}); moveErr != nil {
w.logger.Warn("failed to move job to finished dir", "job", task.JobName, "error", moveErr)
}
finalizeDuration := time.Since(finalizeStart)
totalDuration := time.Since(stagingStart)
var ioDelta telemetry.IOStats
if ioErr == nil {
if after, err := telemetry.ReadProcessIO(); err == nil {
ioDelta = telemetry.DiffIO(ioBefore, after)
}
}
w.logger.Info("job timing",
"job", task.JobName,
"staging_ms", stagingDuration.Milliseconds(),
"container_ms", containerDuration.Milliseconds(),
"finalize_ms", finalizeDuration.Milliseconds(),
"total_ms", totalDuration.Milliseconds(),
"io_read_bytes", ioDelta.ReadBytes,
"io_write_bytes", ioDelta.WriteBytes,
)
return nil
}
// getGPUDevicePaths returns the appropriate GPU device paths based on configuration
func (w *Worker) getGPUDevicePaths() []string {
if w != nil && w.config != nil {
if len(w.config.GPUDevices) > 0 {
return filterExistingDevicePaths(w.config.GPUDevices)
}
}
detector := w.getGPUDetector()
return filterExistingDevicePaths(detector.GetDevicePaths())
}