package worker import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "log/slog" "os" "os/exec" "path/filepath" "runtime" "sort" "strings" "sync" "time" "github.com/jfraeys/fetch_ml/internal/container" "github.com/jfraeys/fetch_ml/internal/errtypes" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" "github.com/jfraeys/fetch_ml/internal/queue" ) // NEW: Fetch datasets using data_manager. func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error { logger := w.logger.Job(ctx, task.JobName, task.ID) logger.Info("fetching datasets", "worker_id", w.id, "dataset_count", len(task.Datasets)) for _, dataset := range task.Datasets { if w.datasetIsFresh(dataset) { logger.Debug("skipping cached dataset", "dataset", dataset) continue } // Check for cancellation before each dataset fetch select { case <-ctx.Done(): return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err()) default: } logger.Info("fetching dataset", "worker_id", w.id, "dataset", dataset) // Create command with context for cancellation support cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) // Validate inputs to prevent command injection if !isValidName(task.JobName) || !isValidName(dataset) { cancel() return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters") } //nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated cmd := exec.CommandContext(cmdCtx, w.config.DataManagerPath, "fetch", task.JobName, dataset, ) output, err := cmd.CombinedOutput() cancel() // Clean up context if err != nil { return &errtypes.DataFetchError{ Dataset: dataset, JobName: task.JobName, Err: fmt.Errorf("command failed: %w, output: %s", err, output), } } logger.Info("dataset ready", "worker_id", w.id, "dataset", dataset) w.markDatasetFetched(dataset) } return nil } func resolveDatasets(task *queue.Task) []string { if task == nil { return nil } if len(task.DatasetSpecs) > 0 { out := make([]string, 0, len(task.DatasetSpecs)) for _, ds := range task.DatasetSpecs { if ds.Name != "" { out = append(out, ds.Name) } } if len(out) > 0 { return out } } if len(task.Datasets) > 0 { return task.Datasets } return parseDatasets(task.Args) } func parseDatasets(args string) []string { if !strings.Contains(args, "--datasets") { return nil } parts := strings.Fields(args) for i, part := range parts { if part == "--datasets" && i+1 < len(parts) { return strings.Split(parts[i+1], ",") } } return nil } func (w *Worker) datasetIsFresh(dataset string) bool { w.datasetCacheMu.RLock() defer w.datasetCacheMu.RUnlock() expires, ok := w.datasetCache[dataset] return ok && time.Now().Before(expires) } func (w *Worker) markDatasetFetched(dataset string) { expires := time.Now().Add(w.datasetCacheTTL) w.datasetCacheMu.Lock() w.datasetCache[dataset] = expires w.datasetCacheMu.Unlock() } func (w *Worker) cancelPrewarmLocked() { if w.prewarmCancel != nil { w.prewarmCancel() w.prewarmCancel = nil } w.prewarmTargetID = "" } func (w *Worker) prewarmNextLoop() { if w == nil || w.config == nil || !w.config.PrewarmEnabled { return } if w.ctx == nil || w.queue == nil || w.metrics == nil { return } // Phase 1: Best-effort prewarm of the next queued task. // This must never be required for correctness. runOnce := func() { _, err := w.PrewarmNextOnce(w.ctx) if err != nil { w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err) } } // Run once immediately so prewarm doesn't lag behind the worker loop. runOnce() ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-w.ctx.Done(): w.prewarmMu.Lock() w.cancelPrewarmLocked() w.prewarmMu.Unlock() return case <-ticker.C: } runOnce() } } func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { if w == nil || w.config == nil || !w.config.PrewarmEnabled { return false, nil } if ctx == nil || w.queue == nil || w.metrics == nil { return false, nil } next, err := w.queue.PeekNextTask() if err != nil { return false, err } if next == nil { w.prewarmMu.Lock() w.cancelPrewarmLocked() w.prewarmMu.Unlock() return false, nil } return w.prewarmTaskOnce(ctx, next) } func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) { if w == nil || w.config == nil || !w.config.PrewarmEnabled { return false, nil } if ctx == nil || w.queue == nil || w.metrics == nil { return false, nil } if next == nil { return false, nil } w.prewarmMu.Lock() if w.prewarmTargetID == next.ID { w.prewarmMu.Unlock() return false, nil } w.cancelPrewarmLocked() prewarmCtx, cancel := context.WithCancel(ctx) w.prewarmCancel = cancel w.prewarmTargetID = next.ID w.prewarmStartedAt = time.Now() startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano) phase := "datasets" dsCnt := len(resolveDatasets(next)) snapID := next.SnapshotID if strings.TrimSpace(snapID) != "" { phase = "snapshot" } else if dsCnt == 0 { phase = "env" } _ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{ WorkerID: w.id, TaskID: next.ID, SnapshotID: snapID, StartedAt: startedAt, UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano), Phase: phase, DatasetCnt: dsCnt, EnvHit: w.metrics.PrewarmEnvHit.Load(), EnvMiss: w.metrics.PrewarmEnvMiss.Load(), EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(), EnvTimeNs: w.metrics.PrewarmEnvTime.Load(), }) w.prewarmMu.Unlock() w.logger.Info("prewarm started", "worker_id", w.id, "task_id", next.ID, "snapshot_id", snapID, "phase", phase, ) local := *next local.Datasets = resolveDatasets(&local) hasSnapshot := strings.TrimSpace(local.SnapshotID) != "" hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0 hasEnv := false if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" { if local.Metadata != nil { depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"]) commitID := strings.TrimSpace(local.Metadata["commit_id"]) if depsSHA != "" && commitID != "" { expMgr := experiment.NewManager(w.config.BasePath) hostWorkspace := expMgr.GetFilesPath(commitID) if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" { if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" { hasEnv = true } } } } } if !hasSnapshot && !hasDatasets && !hasEnv { _ = w.queue.ClearWorkerPrewarmState(w.id) return false, nil } if hasSnapshot { want := "" if local.Metadata != nil { want = local.Metadata["snapshot_sha256"] } start := time.Now() src, err := ResolveSnapshot( prewarmCtx, w.config.DataDir, &w.config.SnapshotStore, local.SnapshotID, want, nil, ) if err != nil { return true, err } dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID) _ = os.RemoveAll(dst) if err := copyDir(src, dst); err != nil { return true, err } w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start)) } if hasDatasets { if err := w.fetchDatasets(prewarmCtx, &local); err != nil { return true, err } } _ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{ WorkerID: w.id, TaskID: local.ID, SnapshotID: local.SnapshotID, StartedAt: startedAt, UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano), Phase: "ready", DatasetCnt: len(local.Datasets), EnvHit: w.metrics.PrewarmEnvHit.Load(), EnvMiss: w.metrics.PrewarmEnvMiss.Load(), EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(), EnvTimeNs: w.metrics.PrewarmEnvTime.Load(), }) w.logger.Info("prewarm ready", "worker_id", w.id, "task_id", local.ID, "snapshot_id", local.SnapshotID, ) return true, nil } func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } if task.SnapshotID == "" { return nil } if err := container.ValidateJobName(task.SnapshotID); err != nil { return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err) } if task.Metadata == nil { return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID) } want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) if err != nil { return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err) } if want == "" { return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID) } path, err := ResolveSnapshot( ctx, w.config.DataDir, &w.config.SnapshotStore, task.SnapshotID, want, nil, ) if err != nil { return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err) } got, err := dirOverallSHA256Hex(path) if err != nil { return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err) } if got != want { return fmt.Errorf( "snapshot %q: checksum mismatch: expected %s, got %s", task.SnapshotID, want, got, ) } w.logger.Job( ctx, task.JobName, task.ID, ).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID) return nil } func fileSHA256Hex(path string) (string, error) { f, err := os.Open(filepath.Clean(path)) if err != nil { return "", err } defer func() { _ = f.Close() }() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return "", err } return fmt.Sprintf("%x", h.Sum(nil)), nil } func normalizeSHA256ChecksumHex(checksum string) (string, error) { checksum = strings.TrimSpace(checksum) checksum = strings.TrimPrefix(checksum, "sha256:") checksum = strings.TrimPrefix(checksum, "SHA256:") checksum = strings.TrimSpace(checksum) if checksum == "" { return "", nil } if len(checksum) != 64 { return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum)) } if _, err := hex.DecodeString(checksum); err != nil { return "", fmt.Errorf("invalid sha256 hex: %w", err) } return strings.ToLower(checksum), nil } func dirOverallSHA256HexGo(root string) (string, error) { root = filepath.Clean(root) info, err := os.Stat(root) if err != nil { return "", err } if !info.IsDir() { return "", fmt.Errorf("not a directory") } var files []string err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } if d.IsDir() { return nil } rel, err := filepath.Rel(root, path) if err != nil { return err } files = append(files, rel) return nil }) if err != nil { return "", err } // Deterministic order. for i := 0; i < len(files); i++ { for j := i + 1; j < len(files); j++ { if files[i] > files[j] { files[i], files[j] = files[j], files[i] } } } // Hash file hashes to avoid holding all bytes. overall := sha256.New() for _, rel := range files { p := filepath.Join(root, rel) sum, err := fileSHA256Hex(p) if err != nil { return "", err } overall.Write([]byte(sum)) } return fmt.Sprintf("%x", overall.Sum(nil)), nil } // dirOverallSHA256HexParallel is a parallel Go implementation for baseline comparison. // This demonstrates best-effort Go performance before C++ optimization. // Uses worker pool to hash files in parallel, then combines deterministically. func dirOverallSHA256HexParallel(root string) (string, error) { root = filepath.Clean(root) info, err := os.Stat(root) if err != nil { return "", err } if !info.IsDir() { return "", fmt.Errorf("not a directory") } // Collect all files first var files []string err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error { if walkErr != nil { return walkErr } if d.IsDir() { return nil } rel, err := filepath.Rel(root, path) if err != nil { return err } files = append(files, rel) return nil }) if err != nil { return "", err } // Sort for deterministic order sort.Strings(files) // Parallel hashing with worker pool numWorkers := runtime.NumCPU() if numWorkers > 8 { numWorkers = 8 // Cap at 8 workers } type result struct { index int hash string err error } workCh := make(chan int, len(files)) resultCh := make(chan result, len(files)) // Start workers var wg sync.WaitGroup for i := 0; i < numWorkers; i++ { wg.Add(1) go func() { defer wg.Done() for idx := range workCh { rel := files[idx] p := filepath.Join(root, rel) hash, err := fileSHA256Hex(p) resultCh <- result{index: idx, hash: hash, err: err} } }() } // Send work go func() { for i := range files { workCh <- i } close(workCh) }() // Collect results go func() { wg.Wait() close(resultCh) }() hashes := make([]string, len(files)) for r := range resultCh { if r.err != nil { return "", r.err } hashes[r.index] = r.hash } // Combine hashes deterministically overall := sha256.New() for _, h := range hashes { overall.Write([]byte(h)) } return fmt.Sprintf("%x", overall.Sum(nil)), nil } func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } if len(task.DatasetSpecs) == 0 { return nil } logger := w.logger.Job(ctx, task.JobName, task.ID) for _, ds := range task.DatasetSpecs { want, err := normalizeSHA256ChecksumHex(ds.Checksum) if err != nil { return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err) } if want == "" { continue } if err := container.ValidateJobName(ds.Name); err != nil { return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err) } path := filepath.Join(w.config.DataDir, ds.Name) got, err := dirOverallSHA256HexGo(path) if err != nil { return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err) } if got != want { return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got) } logger.Info("dataset checksum verified", "dataset", ds.Name) } return nil } func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { if task == nil { return nil, fmt.Errorf("task is nil") } out := map[string]string{} if task.SnapshotID != "" { out["snapshot_id"] = task.SnapshotID } datasets := resolveDatasets(task) if len(datasets) > 0 { out["datasets"] = strings.Join(datasets, ",") } if len(task.DatasetSpecs) > 0 { b, err := json.Marshal(task.DatasetSpecs) if err != nil { return nil, fmt.Errorf("marshal dataset_specs: %w", err) } out["dataset_specs"] = string(b) } if task.Metadata == nil { return out, nil } commitID := task.Metadata["commit_id"] if commitID == "" { return out, nil } expMgr := experiment.NewManager(basePath) manifest, err := expMgr.ReadManifest(commitID) if err == nil && manifest != nil && manifest.OverallSHA != "" { out["experiment_manifest_overall_sha"] = manifest.OverallSHA } filesPath := expMgr.GetFilesPath(commitID) depName, err := selectDependencyManifest(filesPath) if err == nil && depName != "" { depPath := filepath.Join(filesPath, depName) sha, err := fileSHA256Hex(depPath) if err == nil && sha != "" { out["deps_manifest_name"] = depName out["deps_manifest_sha256"] = sha } } return out, nil } func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) { if task == nil { return } prov, err := computeTaskProvenance(w.config.BasePath, task) if err != nil { w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err) return } if len(prov) == 0 { return } if task.Metadata == nil { task.Metadata = map[string]string{} } for k, v := range prov { if v == "" { continue } // Phase 1: best-effort only; do not error if overwriting. task.Metadata[k] = v } } func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error { if task == nil { return fmt.Errorf("task is nil") } if task.Metadata == nil { return fmt.Errorf("missing task metadata") } commitID := task.Metadata["commit_id"] if commitID == "" { return fmt.Errorf("missing commit_id") } current, err := computeTaskProvenance(w.config.BasePath, task) if err != nil { return err } snapshotCur := "" if task.SnapshotID != "" { want := "" if task.Metadata != nil { want = task.Metadata["snapshot_sha256"] } wantNorm, nerr := normalizeSHA256ChecksumHex(want) if nerr != nil { if w.config != nil && w.config.ProvenanceBestEffort { w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance", "snapshot_id", task.SnapshotID, "error", nerr) } else { return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr) } } else if wantNorm != "" { resolved, err := ResolveSnapshot( ctx, w.config.DataDir, &w.config.SnapshotStore, task.SnapshotID, wantNorm, nil, ) if err != nil { if w.config != nil && w.config.ProvenanceBestEffort { w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance", "snapshot_id", task.SnapshotID, "error", err) } else { return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err) } } else { sha, err := dirOverallSHA256HexGo(resolved) if err == nil { snapshotCur = sha } else if w.config != nil && w.config.ProvenanceBestEffort { w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance", "snapshot_id", task.SnapshotID, "error", err) } else { return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err) } } } if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort { // Best-effort fallback: if the caller didn't provide snapshot_sha256, // compute from the local snapshot directory if it exists. localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID)) if sha, err := dirOverallSHA256HexGo(localPath); err == nil { snapshotCur = sha } } } logger := w.logger.Job(ctx, task.JobName, task.ID) type requiredField struct { Key string Cur string } required := []requiredField{ {Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]}, {Key: "deps_manifest_name", Cur: current["deps_manifest_name"]}, {Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]}, } if task.SnapshotID != "" { required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur}) } for _, f := range required { want := strings.TrimSpace(task.Metadata[f.Key]) if f.Key == "snapshot_sha256" { norm, nerr := normalizeSHA256ChecksumHex(want) if nerr != nil { if w.config != nil && w.config.ProvenanceBestEffort { logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode", "snapshot_id", task.SnapshotID, "error", nerr) want = "" } else { return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr) } } else { want = norm } } if want == "" { if w.config != nil && w.config.ProvenanceBestEffort { logger.Warn("missing provenance field; continuing due to best-effort mode", "field", f.Key) if f.Cur != "" { if f.Key == "snapshot_sha256" { task.Metadata[f.Key] = "sha256:" + f.Cur } else { task.Metadata[f.Key] = f.Cur } } continue } return fmt.Errorf("missing provenance field: %s", f.Key) } if f.Cur == "" { if w.config != nil && w.config.ProvenanceBestEffort { logger.Warn("unable to compute provenance field; continuing due to best-effort mode", "field", f.Key) continue } return fmt.Errorf("unable to compute provenance field: %s", f.Key) } if want != f.Cur { if w.config != nil && w.config.ProvenanceBestEffort { logger.Warn("provenance mismatch; continuing due to best-effort mode", "field", f.Key, "expected", want, "current", f.Cur) if f.Key == "snapshot_sha256" { task.Metadata[f.Key] = "sha256:" + f.Cur } else { task.Metadata[f.Key] = f.Cur } continue } return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur) } } return nil } func selectDependencyManifest(filesPath string) (string, error) { if filesPath == "" { return "", fmt.Errorf("missing files path") } candidates := []string{ "environment.yml", "environment.yaml", "poetry.lock", "pyproject.toml", "requirements.txt", } for _, name := range candidates { p := filepath.Join(filesPath, name) if _, err := os.Stat(p); err == nil { if name == "poetry.lock" { pyprojectPath := filepath.Join(filesPath, "pyproject.toml") if _, err := os.Stat(pyprojectPath); err != nil { return "", fmt.Errorf( "poetry.lock found but pyproject.toml missing (required for Poetry projects)") } } return name, nil } } return "", fmt.Errorf( "missing dependency manifest (supported: environment.yml, environment.yaml, " + "poetry.lock, pyproject.toml, requirements.txt)") } // Exported wrappers for tests under tests/. func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) } func SelectDependencyManifest(filesPath string) (string, error) { return selectDependencyManifest(filesPath) } func NormalizeSHA256ChecksumHex(checksum string) (string, error) { return normalizeSHA256ChecksumHex(checksum) } func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) } // DirOverallSHA256HexParallel is an exported wrapper for testing/benchmarking. func DirOverallSHA256HexParallel(root string) (string, error) { return dirOverallSHA256HexParallel(root) } func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) { return computeTaskProvenance(basePath, task) } func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error { return w.enforceTaskProvenance(ctx, task) } func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error { return w.verifyDatasetSpecs(ctx, task) } func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { return w.verifySnapshot(ctx, task) } func NewTestWorker(cfg *Config) *Worker { baseLogger := logging.NewLogger(slog.LevelInfo, false) ctx := logging.EnsureTrace(context.Background()) logger := baseLogger.Component(ctx, "worker") if cfg == nil { cfg = &Config{} } if cfg.DatasetCacheTTL == 0 { cfg.DatasetCacheTTL = datasetCacheDefaultTTL } return &Worker{ id: cfg.WorkerID, config: cfg, logger: logger, datasetCache: make(map[string]time.Time), datasetCacheTTL: cfg.DatasetCacheTTL, } } func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker { baseLogger := logging.NewLogger(slog.LevelInfo, false) ctx := logging.EnsureTrace(context.Background()) ctx, cancel := context.WithCancel(ctx) logger := baseLogger.Component(ctx, "worker") if cfg == nil { cfg = &Config{} } if cfg.DatasetCacheTTL == 0 { cfg.DatasetCacheTTL = datasetCacheDefaultTTL } return &Worker{ id: cfg.WorkerID, config: cfg, logger: logger, queue: tq, metrics: &metrics.Metrics{}, ctx: ctx, cancel: cancel, running: make(map[string]context.CancelFunc), datasetCache: make(map[string]time.Time), datasetCacheTTL: cfg.DatasetCacheTTL, } } func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker { w := NewTestWorkerWithQueue(cfg, tq) w.jupyter = jm return w }