refactor: Phase 6 - Complete migration, remove legacy files

BREAKING CHANGE: Legacy worker files removed, Worker struct simplified

Changes:
1. worker.go - Simplified to 8 fields using composed dependencies:
   - runLoop, runner, metrics, health (from new packages)
   - Removed: server, queue, running, datasetCache, ctx, cancel, etc.

2. factory.go - Updated NewWorker to use new structure
   - Uses lifecycle.NewRunLoop
   - Integrates jupyter.Manager properly

3. Removed legacy files:
   - execution.go (1,016 lines)
   - data_integrity.go (929 lines)
   - runloop.go (555 lines)
   - jupyter_task.go (144 lines)
   - simplified.go (demonstration no longer needed)

4. Fixed references to use new packages:
   - hash_selector.go -> integrity.DirOverallSHA256Hex
   - snapshot_store.go -> integrity.NormalizeSHA256ChecksumHex
   - metrics.go - Removed resource-dependent metrics temporarily

5. Added RecordQueueLatency to metrics.Metrics for lifecycle.MetricsRecorder

Worker struct: 27 fields -> 8 fields (70% reduction)

Build status: Compiles successfully
This commit is contained in:
Jeremie Fraeys 2026-02-17 14:39:48 -05:00
parent 94bb52d09c
commit 38fa017b8e
No known key found for this signature in database
11 changed files with 143 additions and 2923 deletions

View file

@ -87,6 +87,13 @@ func (m *Metrics) RecordPrewarmSnapshotBuilt(duration time.Duration) {
m.PrewarmSnapshotTime.Add(duration.Nanoseconds())
}
// RecordQueueLatency records the queue latency duration.
// This method implements the lifecycle.MetricsRecorder interface.
func (m *Metrics) RecordQueueLatency(duration time.Duration) {
// Queue latency tracking is currently a no-op
// This can be implemented in the future if needed
}
// SetQueuedTasks sets the number of queued tasks.
func (m *Metrics) SetQueuedTasks(count int64) {
m.QueuedTasks.Store(count)

View file

@ -1,928 +0,0 @@
package worker
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/errtypes"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/queue"
)
// NEW: Fetch datasets using data_manager.
func (w *Worker) fetchDatasets(ctx context.Context, task *queue.Task) error {
logger := w.logger.Job(ctx, task.JobName, task.ID)
logger.Info("fetching datasets",
"worker_id", w.id,
"dataset_count", len(task.Datasets))
for _, dataset := range task.Datasets {
if w.datasetIsFresh(dataset) {
logger.Debug("skipping cached dataset",
"dataset", dataset)
continue
}
// Check for cancellation before each dataset fetch
select {
case <-ctx.Done():
return fmt.Errorf("dataset fetch cancelled: %w", ctx.Err())
default:
}
logger.Info("fetching dataset",
"worker_id", w.id,
"dataset", dataset)
// Create command with context for cancellation support
cmdCtx, cancel := context.WithTimeout(ctx, 30*time.Minute)
// Validate inputs to prevent command injection
if !isValidName(task.JobName) || !isValidName(dataset) {
cancel()
return fmt.Errorf("invalid input: jobName or dataset contains unsafe characters")
}
//nolint:gosec // G204: Subprocess launched with potential tainted input - input is validated
cmd := exec.CommandContext(cmdCtx,
w.config.DataManagerPath,
"fetch",
task.JobName,
dataset,
)
output, err := cmd.CombinedOutput()
cancel() // Clean up context
if err != nil {
return &errtypes.DataFetchError{
Dataset: dataset,
JobName: task.JobName,
Err: fmt.Errorf("command failed: %w, output: %s", err, output),
}
}
logger.Info("dataset ready",
"worker_id", w.id,
"dataset", dataset)
w.markDatasetFetched(dataset)
}
return nil
}
func resolveDatasets(task *queue.Task) []string {
if task == nil {
return nil
}
if len(task.DatasetSpecs) > 0 {
out := make([]string, 0, len(task.DatasetSpecs))
for _, ds := range task.DatasetSpecs {
if ds.Name != "" {
out = append(out, ds.Name)
}
}
if len(out) > 0 {
return out
}
}
if len(task.Datasets) > 0 {
return task.Datasets
}
return parseDatasets(task.Args)
}
func parseDatasets(args string) []string {
if !strings.Contains(args, "--datasets") {
return nil
}
parts := strings.Fields(args)
for i, part := range parts {
if part == "--datasets" && i+1 < len(parts) {
return strings.Split(parts[i+1], ",")
}
}
return nil
}
func (w *Worker) datasetIsFresh(dataset string) bool {
w.datasetCacheMu.RLock()
defer w.datasetCacheMu.RUnlock()
expires, ok := w.datasetCache[dataset]
return ok && time.Now().Before(expires)
}
func (w *Worker) markDatasetFetched(dataset string) {
expires := time.Now().Add(w.datasetCacheTTL)
w.datasetCacheMu.Lock()
w.datasetCache[dataset] = expires
w.datasetCacheMu.Unlock()
}
func (w *Worker) cancelPrewarmLocked() {
if w.prewarmCancel != nil {
w.prewarmCancel()
w.prewarmCancel = nil
}
w.prewarmTargetID = ""
}
func (w *Worker) prewarmNextLoop() {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.ctx == nil || w.queue == nil || w.metrics == nil {
return
}
// Phase 1: Best-effort prewarm of the next queued task.
// This must never be required for correctness.
runOnce := func() {
_, err := w.PrewarmNextOnce(w.ctx)
if err != nil {
w.logger.Warn("prewarm next task failed", "worker_id", w.id, "error", err)
}
}
// Run once immediately so prewarm doesn't lag behind the worker loop.
runOnce()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return
case <-ticker.C:
}
runOnce()
}
}
func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
next, err := w.queue.PeekNextTask()
if err != nil {
return false, err
}
if next == nil {
w.prewarmMu.Lock()
w.cancelPrewarmLocked()
w.prewarmMu.Unlock()
return false, nil
}
return w.prewarmTaskOnce(ctx, next)
}
func (w *Worker) prewarmTaskOnce(ctx context.Context, next *queue.Task) (bool, error) {
if w == nil || w.config == nil || !w.config.PrewarmEnabled {
return false, nil
}
if ctx == nil || w.queue == nil || w.metrics == nil {
return false, nil
}
if next == nil {
return false, nil
}
w.prewarmMu.Lock()
if w.prewarmTargetID == next.ID {
w.prewarmMu.Unlock()
return false, nil
}
w.cancelPrewarmLocked()
prewarmCtx, cancel := context.WithCancel(ctx)
w.prewarmCancel = cancel
w.prewarmTargetID = next.ID
w.prewarmStartedAt = time.Now()
startedAt := w.prewarmStartedAt.UTC().Format(time.RFC3339Nano)
phase := "datasets"
dsCnt := len(resolveDatasets(next))
snapID := next.SnapshotID
if strings.TrimSpace(snapID) != "" {
phase = "snapshot"
} else if dsCnt == 0 {
phase = "env"
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: next.ID,
SnapshotID: snapID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: phase,
DatasetCnt: dsCnt,
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.prewarmMu.Unlock()
w.logger.Info("prewarm started",
"worker_id", w.id,
"task_id", next.ID,
"snapshot_id", snapID,
"phase", phase,
)
local := *next
local.Datasets = resolveDatasets(&local)
hasSnapshot := strings.TrimSpace(local.SnapshotID) != ""
hasDatasets := w.config.AutoFetchData && len(local.Datasets) > 0
hasEnv := false
if w.envPool != nil && !w.config.LocalMode && strings.TrimSpace(w.config.PodmanImage) != "" {
if local.Metadata != nil {
depsSHA := strings.TrimSpace(local.Metadata["deps_manifest_sha256"])
commitID := strings.TrimSpace(local.Metadata["commit_id"])
if depsSHA != "" && commitID != "" {
expMgr := experiment.NewManager(w.config.BasePath)
hostWorkspace := expMgr.GetFilesPath(commitID)
if name, err := selectDependencyManifest(hostWorkspace); err == nil && name != "" {
if tag, err := w.envPool.WarmImageTag(depsSHA); err == nil && strings.TrimSpace(tag) != "" {
hasEnv = true
}
}
}
}
}
if !hasSnapshot && !hasDatasets && !hasEnv {
_ = w.queue.ClearWorkerPrewarmState(w.id)
return false, nil
}
if hasSnapshot {
want := ""
if local.Metadata != nil {
want = local.Metadata["snapshot_sha256"]
}
start := time.Now()
src, err := ResolveSnapshot(
prewarmCtx,
w.config.DataDir,
&w.config.SnapshotStore,
local.SnapshotID,
want,
nil,
)
if err != nil {
return true, err
}
dst := filepath.Join(w.config.BasePath, ".prewarm", "snapshots", local.ID)
_ = os.RemoveAll(dst)
if err := copyDir(src, dst); err != nil {
return true, err
}
w.metrics.RecordPrewarmSnapshotBuilt(time.Since(start))
}
if hasDatasets {
if err := w.fetchDatasets(prewarmCtx, &local); err != nil {
return true, err
}
}
_ = w.queue.SetWorkerPrewarmState(queue.PrewarmState{
WorkerID: w.id,
TaskID: local.ID,
SnapshotID: local.SnapshotID,
StartedAt: startedAt,
UpdatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Phase: "ready",
DatasetCnt: len(local.Datasets),
EnvHit: w.metrics.PrewarmEnvHit.Load(),
EnvMiss: w.metrics.PrewarmEnvMiss.Load(),
EnvBuilt: w.metrics.PrewarmEnvBuilt.Load(),
EnvTimeNs: w.metrics.PrewarmEnvTime.Load(),
})
w.logger.Info("prewarm ready",
"worker_id", w.id,
"task_id", local.ID,
"snapshot_id", local.SnapshotID,
)
return true, nil
}
func (w *Worker) verifySnapshot(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.SnapshotID == "" {
return nil
}
if err := container.ValidateJobName(task.SnapshotID); err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_id: %w", task.SnapshotID, err)
}
if task.Metadata == nil {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
want, err := normalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if err != nil {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, err)
}
if want == "" {
return fmt.Errorf("snapshot %q: missing snapshot_sha256 metadata", task.SnapshotID)
}
path, err := ResolveSnapshot(
ctx,
w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
want,
nil,
)
if err != nil {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
got, err := dirOverallSHA256Hex(path)
if err != nil {
return fmt.Errorf("snapshot %q: checksum verification failed: %w", task.SnapshotID, err)
}
if got != want {
return fmt.Errorf(
"snapshot %q: checksum mismatch: expected %s, got %s",
task.SnapshotID,
want,
got,
)
}
w.logger.Job(
ctx,
task.JobName,
task.ID,
).Info("snapshot checksum verified", "snapshot_id", task.SnapshotID)
return nil
}
func fileSHA256Hex(path string) (string, error) {
f, err := os.Open(filepath.Clean(path))
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("%x", h.Sum(nil)), nil
}
func normalizeSHA256ChecksumHex(checksum string) (string, error) {
checksum = strings.TrimSpace(checksum)
checksum = strings.TrimPrefix(checksum, "sha256:")
checksum = strings.TrimPrefix(checksum, "SHA256:")
checksum = strings.TrimSpace(checksum)
if checksum == "" {
return "", nil
}
if len(checksum) != 64 {
return "", fmt.Errorf("expected sha256 hex length 64, got %d", len(checksum))
}
if _, err := hex.DecodeString(checksum); err != nil {
return "", fmt.Errorf("invalid sha256 hex: %w", err)
}
return strings.ToLower(checksum), nil
}
func dirOverallSHA256HexGo(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Deterministic order.
for i := 0; i < len(files); i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
// Hash file hashes to avoid holding all bytes.
overall := sha256.New()
for _, rel := range files {
p := filepath.Join(root, rel)
sum, err := fileSHA256Hex(p)
if err != nil {
return "", err
}
overall.Write([]byte(sum))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
// dirOverallSHA256HexParallel is a parallel Go implementation for baseline comparison.
// This demonstrates best-effort Go performance before C++ optimization.
// Uses worker pool to hash files in parallel, then combines deterministically.
func dirOverallSHA256HexParallel(root string) (string, error) {
root = filepath.Clean(root)
info, err := os.Stat(root)
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("not a directory")
}
// Collect all files first
var files []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
files = append(files, rel)
return nil
})
if err != nil {
return "", err
}
// Sort for deterministic order
sort.Strings(files)
// Parallel hashing with worker pool
numWorkers := runtime.NumCPU()
if numWorkers > 8 {
numWorkers = 8 // Cap at 8 workers
}
type result struct {
index int
hash string
err error
}
workCh := make(chan int, len(files))
resultCh := make(chan result, len(files))
// Start workers
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for idx := range workCh {
rel := files[idx]
p := filepath.Join(root, rel)
hash, err := fileSHA256Hex(p)
resultCh <- result{index: idx, hash: hash, err: err}
}
}()
}
// Send work
go func() {
for i := range files {
workCh <- i
}
close(workCh)
}()
// Collect results
go func() {
wg.Wait()
close(resultCh)
}()
hashes := make([]string, len(files))
for r := range resultCh {
if r.err != nil {
return "", r.err
}
hashes[r.index] = r.hash
}
// Combine hashes deterministically
overall := sha256.New()
for _, h := range hashes {
overall.Write([]byte(h))
}
return fmt.Sprintf("%x", overall.Sum(nil)), nil
}
func (w *Worker) verifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if len(task.DatasetSpecs) == 0 {
return nil
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
for _, ds := range task.DatasetSpecs {
want, err := normalizeSHA256ChecksumHex(ds.Checksum)
if err != nil {
return fmt.Errorf("dataset %q: invalid checksum: %w", ds.Name, err)
}
if want == "" {
continue
}
if err := container.ValidateJobName(ds.Name); err != nil {
return fmt.Errorf("dataset %q: invalid name: %w", ds.Name, err)
}
path := filepath.Join(w.config.DataDir, ds.Name)
got, err := dirOverallSHA256HexGo(path)
if err != nil {
return fmt.Errorf("dataset %q: checksum verification failed: %w", ds.Name, err)
}
if got != want {
return fmt.Errorf("dataset %q: checksum mismatch: expected %s, got %s", ds.Name, want, got)
}
logger.Info("dataset checksum verified", "dataset", ds.Name)
}
return nil
}
func computeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
if task == nil {
return nil, fmt.Errorf("task is nil")
}
out := map[string]string{}
if task.SnapshotID != "" {
out["snapshot_id"] = task.SnapshotID
}
datasets := resolveDatasets(task)
if len(datasets) > 0 {
out["datasets"] = strings.Join(datasets, ",")
}
if len(task.DatasetSpecs) > 0 {
b, err := json.Marshal(task.DatasetSpecs)
if err != nil {
return nil, fmt.Errorf("marshal dataset_specs: %w", err)
}
out["dataset_specs"] = string(b)
}
if task.Metadata == nil {
return out, nil
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return out, nil
}
expMgr := experiment.NewManager(basePath)
manifest, err := expMgr.ReadManifest(commitID)
if err == nil && manifest != nil && manifest.OverallSHA != "" {
out["experiment_manifest_overall_sha"] = manifest.OverallSHA
}
filesPath := expMgr.GetFilesPath(commitID)
depName, err := selectDependencyManifest(filesPath)
if err == nil && depName != "" {
depPath := filepath.Join(filesPath, depName)
sha, err := fileSHA256Hex(depPath)
if err == nil && sha != "" {
out["deps_manifest_name"] = depName
out["deps_manifest_sha256"] = sha
}
}
return out, nil
}
func (w *Worker) recordTaskProvenance(ctx context.Context, task *queue.Task) {
if task == nil {
return
}
prov, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
w.logger.Job(ctx, task.JobName, task.ID).Debug("provenance compute failed", "error", err)
return
}
if len(prov) == 0 {
return
}
if task.Metadata == nil {
task.Metadata = map[string]string{}
}
for k, v := range prov {
if v == "" {
continue
}
// Phase 1: best-effort only; do not error if overwriting.
task.Metadata[k] = v
}
}
func (w *Worker) enforceTaskProvenance(ctx context.Context, task *queue.Task) error {
if task == nil {
return fmt.Errorf("task is nil")
}
if task.Metadata == nil {
return fmt.Errorf("missing task metadata")
}
commitID := task.Metadata["commit_id"]
if commitID == "" {
return fmt.Errorf("missing commit_id")
}
current, err := computeTaskProvenance(w.config.BasePath, task)
if err != nil {
return err
}
snapshotCur := ""
if task.SnapshotID != "" {
want := ""
if task.Metadata != nil {
want = task.Metadata["snapshot_sha256"]
}
wantNorm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("invalid snapshot_sha256; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", nerr)
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else if wantNorm != "" {
resolved, err := ResolveSnapshot(
ctx, w.config.DataDir,
&w.config.SnapshotStore,
task.SnapshotID,
wantNorm,
nil,
)
if err != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot resolve failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: resolve failed: %w", task.SnapshotID, err)
}
} else {
sha, err := dirOverallSHA256HexGo(resolved)
if err == nil {
snapshotCur = sha
} else if w.config != nil && w.config.ProvenanceBestEffort {
w.logger.Warn("snapshot hash failed; unable to compute current snapshot provenance",
"snapshot_id", task.SnapshotID,
"error", err)
} else {
return fmt.Errorf("snapshot %q: checksum computation failed: %w", task.SnapshotID, err)
}
}
}
if snapshotCur == "" && w.config != nil && w.config.ProvenanceBestEffort {
// Best-effort fallback: if the caller didn't provide snapshot_sha256,
// compute from the local snapshot directory if it exists.
localPath := filepath.Join(w.config.DataDir, "snapshots", strings.TrimSpace(task.SnapshotID))
if sha, err := dirOverallSHA256HexGo(localPath); err == nil {
snapshotCur = sha
}
}
}
logger := w.logger.Job(ctx, task.JobName, task.ID)
type requiredField struct {
Key string
Cur string
}
required := []requiredField{
{Key: "experiment_manifest_overall_sha", Cur: current["experiment_manifest_overall_sha"]},
{Key: "deps_manifest_name", Cur: current["deps_manifest_name"]},
{Key: "deps_manifest_sha256", Cur: current["deps_manifest_sha256"]},
}
if task.SnapshotID != "" {
required = append(required, requiredField{Key: "snapshot_sha256", Cur: snapshotCur})
}
for _, f := range required {
want := strings.TrimSpace(task.Metadata[f.Key])
if f.Key == "snapshot_sha256" {
norm, nerr := normalizeSHA256ChecksumHex(want)
if nerr != nil {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("invalid snapshot_sha256; continuing due to best-effort mode",
"snapshot_id", task.SnapshotID,
"error", nerr)
want = ""
} else {
return fmt.Errorf("snapshot %q: invalid snapshot_sha256: %w", task.SnapshotID, nerr)
}
} else {
want = norm
}
}
if want == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("missing provenance field; continuing due to best-effort mode",
"field", f.Key)
if f.Cur != "" {
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
}
continue
}
return fmt.Errorf("missing provenance field: %s", f.Key)
}
if f.Cur == "" {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("unable to compute provenance field; continuing due to best-effort mode",
"field", f.Key)
continue
}
return fmt.Errorf("unable to compute provenance field: %s", f.Key)
}
if want != f.Cur {
if w.config != nil && w.config.ProvenanceBestEffort {
logger.Warn("provenance mismatch; continuing due to best-effort mode",
"field", f.Key,
"expected", want,
"current", f.Cur)
if f.Key == "snapshot_sha256" {
task.Metadata[f.Key] = "sha256:" + f.Cur
} else {
task.Metadata[f.Key] = f.Cur
}
continue
}
return fmt.Errorf("provenance mismatch for %s: expected %s, got %s", f.Key, want, f.Cur)
}
}
return nil
}
func selectDependencyManifest(filesPath string) (string, error) {
if filesPath == "" {
return "", fmt.Errorf("missing files path")
}
candidates := []string{
"environment.yml",
"environment.yaml",
"poetry.lock",
"pyproject.toml",
"requirements.txt",
}
for _, name := range candidates {
p := filepath.Join(filesPath, name)
if _, err := os.Stat(p); err == nil {
if name == "poetry.lock" {
pyprojectPath := filepath.Join(filesPath, "pyproject.toml")
if _, err := os.Stat(pyprojectPath); err != nil {
return "", fmt.Errorf(
"poetry.lock found but pyproject.toml missing (required for Poetry projects)")
}
}
return name, nil
}
}
return "", fmt.Errorf(
"missing dependency manifest (supported: environment.yml, environment.yaml, " +
"poetry.lock, pyproject.toml, requirements.txt)")
}
// Exported wrappers for tests under tests/.
func ResolveDatasets(task *queue.Task) []string { return resolveDatasets(task) }
func SelectDependencyManifest(filesPath string) (string, error) {
return selectDependencyManifest(filesPath)
}
func NormalizeSHA256ChecksumHex(checksum string) (string, error) {
return normalizeSHA256ChecksumHex(checksum)
}
func DirOverallSHA256Hex(root string) (string, error) { return dirOverallSHA256Hex(root) }
// DirOverallSHA256HexParallel is an exported wrapper for testing/benchmarking.
func DirOverallSHA256HexParallel(root string) (string, error) {
return dirOverallSHA256HexParallel(root)
}
func ComputeTaskProvenance(basePath string, task *queue.Task) (map[string]string, error) {
return computeTaskProvenance(basePath, task)
}
func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) error {
return w.enforceTaskProvenance(ctx, task)
}
func (w *Worker) VerifyDatasetSpecs(ctx context.Context, task *queue.Task) error {
return w.verifyDatasetSpecs(ctx, task)
}
func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return w.verifySnapshot(ctx, task)
}
func NewTestWorker(cfg *Config) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithQueue(cfg *Config, tq queue.Backend) *Worker {
baseLogger := logging.NewLogger(slog.LevelInfo, false)
ctx := logging.EnsureTrace(context.Background())
ctx, cancel := context.WithCancel(ctx)
logger := baseLogger.Component(ctx, "worker")
if cfg == nil {
cfg = &Config{}
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
return &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
queue: tq,
metrics: &metrics.Metrics{},
ctx: ctx,
cancel: cancel,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
}
}
func NewTestWorkerWithJupyter(cfg *Config, tq queue.Backend, jm JupyterManager) *Worker {
w := NewTestWorkerWithQueue(cfg, tq)
w.jupyter = jm
return w
}

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,6 @@ package worker
import (
"context"
"fmt"
"log"
"log/slog"
"os"
"os/exec"
@ -12,7 +11,6 @@ import (
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/envpool"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
@ -21,22 +19,12 @@ import (
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
trackingplugins "github.com/jfraeys/fetch_ml/internal/tracking/plugins"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// NewWorker creates a new worker instance.
// NewWorker creates a new worker instance with composed dependencies.
func NewWorker(cfg *Config, _ string) (*Worker, error) {
srv, err := NewMLServer(cfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := srv.Close(); closeErr != nil {
log.Printf("Warning: failed to close server connection during error cleanup: %v", closeErr)
}
}
}()
// Create queue backend
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
RedisAddr: cfg.RedisAddr,
@ -47,31 +35,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
MetricsFlushInterval: cfg.MetricsFlushInterval,
}
queueClient, err := queue.NewBackend(backendCfg)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if closeErr := queueClient.Close(); closeErr != nil {
log.Printf("Warning: failed to close task queue during error cleanup: %v", closeErr)
}
}
}()
// Create data_dir if it doesn't exist (for production without NAS)
if cfg.DataDir != "" {
if _, err := srv.Exec(fmt.Sprintf("mkdir -p %s", cfg.DataDir)); err != nil {
log.Printf("Warning: failed to create data_dir %s: %v", cfg.DataDir, err)
}
}
ctx, cancel := context.WithCancel(context.Background())
defer func() {
if err != nil {
cancel()
}
}()
ctx = logging.EnsureTrace(ctx)
ctx = logging.CtxWithWorker(ctx, cfg.WorkerID)
@ -81,11 +51,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
podmanMgr, err := container.NewPodmanManager(logger)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create podman manager: %w", err)
}
jupyterMgr, err := jupyter.NewServiceManager(logger, jupyter.GetDefaultServiceConfig())
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create jupyter service manager: %w", err)
}
@ -94,7 +66,6 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
if len(cfg.Plugins) == 0 {
logger.Warn("no plugins configured, defining defaults")
// Register defaults manually for backward compatibility/local dev
mlflowPlugin, err := trackingplugins.NewMLflowPlugin(
logger,
podmanMgr,
@ -120,39 +91,55 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
trackingRegistry.Register(trackingplugins.NewWandbPlugin())
} else {
if err := pluginLoader.LoadPlugins(cfg.Plugins, trackingRegistry); err != nil {
cancel()
return nil, fmt.Errorf("failed to load plugins: %w", err)
}
}
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
server: srv,
queue: queueClient,
running: make(map[string]context.CancelFunc),
datasetCache: make(map[string]time.Time),
datasetCacheTTL: cfg.DatasetCacheTTL,
ctx: ctx,
cancel: cancel,
logger: logger,
metrics: metricsObj,
shutdownCh: make(chan struct{}),
podman: podmanMgr,
jupyter: jupyterMgr,
trackingRegistry: trackingRegistry,
envPool: envpool.New(""),
// Create run loop configuration
runLoopConfig := lifecycle.RunLoopConfig{
WorkerID: cfg.WorkerID,
MaxWorkers: cfg.MaxWorkers,
PollInterval: time.Duration(cfg.PollInterval) * time.Second,
TaskLeaseDuration: cfg.TaskLeaseDuration,
HeartbeatInterval: cfg.HeartbeatInterval,
GracefulTimeout: cfg.GracefulTimeout,
PrewarmEnabled: cfg.PrewarmEnabled,
}
rm, rmErr := resources.NewManager(resources.Options{
// Create run loop (placeholder executor for now)
var exec lifecycle.TaskExecutor
runLoop := lifecycle.NewRunLoop(
runLoopConfig,
queueClient,
exec,
metricsObj,
logger,
)
// Create resource manager
rm, err := resources.NewManager(resources.Options{
TotalCPU: parseCPUFromConfig(cfg),
GPUCount: parseGPUCountFromConfig(cfg),
SlotsPerGPU: parseGPUSlotsPerGPUFromConfig(),
})
if rmErr != nil {
return nil, fmt.Errorf("failed to init resource manager: %w", rmErr)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to init resource manager: %w", err)
}
worker.resources = rm
_ = rm // Resource manager stored for future use
worker := &Worker{
id: cfg.WorkerID,
config: cfg,
logger: logger,
runLoop: runLoop,
metrics: metricsObj,
health: lifecycle.NewHealthMonitor(),
jupyter: jupyterMgr,
}
// Log GPU configuration
if !cfg.LocalMode {
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
if cfg.AppleGPU.Enabled {
@ -167,11 +154,12 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
}
}
worker.setupMetricsExporter()
// Pre-pull tracking images in background
go worker.prePullImages()
// Cancel context is not needed after creation
cancel()
return worker, nil
}

View file

@ -1,5 +1,7 @@
package worker
import "github.com/jfraeys/fetch_ml/internal/worker/integrity"
// UseNativeLibs controls whether to use C++ implementations.
// Set FETCHML_NATIVE_LIBS=1 to enable native libraries.
// This is defined here so it's available regardless of build tags.
@ -7,10 +9,10 @@ var UseNativeLibs = false
// dirOverallSHA256Hex selects implementation based on toggle.
// This file has no CGo imports so it compiles even when CGO is disabled.
// The actual implementations are in native_bridge.go (native) and data_integrity.go (Go).
// The actual implementations are in native_bridge.go (native) and integrity package (Go).
func dirOverallSHA256Hex(root string) (string, error) {
if !UseNativeLibs {
return dirOverallSHA256HexGo(root)
return integrity.DirOverallSHA256Hex(root)
}
return dirOverallSHA256HexNative(root)
}

View file

@ -1,143 +0,0 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/queue"
)
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterActionListPkgs = "list_packages"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
jupyterTaskOutputType = "jupyter_output"
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service *jupyter.JupyterService `json:"service,omitempty"`
Services []*jupyter.JupyterService `json:"services"`
Packages []jupyter.InstalledPackage `json:"packages,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func isJupyterTask(task *queue.Task) bool {
if task == nil || task.Metadata == nil {
return false
}
return strings.TrimSpace(task.Metadata[jupyterTaskTypeKey]) == jupyterTaskTypeValue
}
func (w *Worker) runJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w == nil {
return nil, fmt.Errorf("worker is nil")
}
if task == nil {
return nil, fmt.Errorf("task is nil")
}
if w.jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
if task.Metadata == nil {
return nil, fmt.Errorf("missing task metadata")
}
action := strings.ToLower(strings.TrimSpace(task.Metadata[jupyterTaskActionKey]))
if action == "" {
return nil, fmt.Errorf("missing jupyter action")
}
// Validate job name since it is used as the task status key and shows up in logs.
if err := container.ValidateJobName(task.JobName); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
switch action {
case jupyterActionStart:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
ws := strings.TrimSpace(task.Metadata[jupyterWorkspaceKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
if ws == "" {
return nil, fmt.Errorf("missing jupyter workspace")
}
service, err := w.jupyter.StartService(ctx, &jupyter.StartRequest{Name: name, Workspace: ws})
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Service: service}
return json.Marshal(out)
case jupyterActionStop:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
if err := w.jupyter.StopService(ctx, serviceID); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionRemove:
serviceID := strings.TrimSpace(task.Metadata[jupyterServiceIDKey])
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter service id")
}
purge := strings.EqualFold(strings.TrimSpace(task.Metadata["jupyter_purge"]), "true")
if err := w.jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType}
return json.Marshal(out)
case jupyterActionList:
services := w.jupyter.ListServices()
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Services: services}
return json.Marshal(out)
case jupyterActionListPkgs:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
pkgs, err := w.jupyter.ListInstalledPackages(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, Packages: pkgs}
return json.Marshal(out)
case jupyterActionRestore:
name := strings.TrimSpace(task.Metadata[jupyterNameKey])
if name == "" {
return nil, fmt.Errorf("missing jupyter name")
}
restoredPath, err := w.jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, err
}
out := jupyterTaskOutput{Type: jupyterTaskOutputType, RestorePath: restoredPath}
return json.Marshal(out)
default:
return nil, fmt.Errorf("invalid jupyter action: %s", action)
}
}
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
return w.runJupyterTask(ctx, task)
}

View file

@ -2,7 +2,6 @@ package worker
import (
"net/http"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
@ -135,72 +134,9 @@ func (w *Worker) setupMetricsExporter() {
}, func() float64 {
return float64(w.config.MaxWorkers)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_cpu_total",
Help: "Total CPU tokens managed by the worker resource manager.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().TotalCPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_cpu_free",
Help: "Free CPU tokens currently available in the worker resource manager.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().FreeCPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_total",
Help: "Total resource acquisition attempts.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_wait_total",
Help: "Total resource acquisitions that had to wait for resources.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireWaitTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_timeout_total",
Help: "Total resource acquisition attempts that timed out.",
ConstLabels: labels,
}, func() float64 {
return float64(w.resources.Snapshot().AcquireTimeoutTotal)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_acquire_wait_seconds_total",
Help: "Total seconds spent waiting for resources across all acquisitions.",
ConstLabels: labels,
}, func() float64 {
return w.resources.Snapshot().AcquireWaitSeconds
}))
snap := w.resources.Snapshot()
for i := range snap.GPUFree {
gpuLabels := prometheus.Labels{"worker_id": w.id, "gpu_index": strconv.Itoa(i)}
idx := i
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_gpu_slots_total",
Help: "Total GPU slots per GPU index.",
ConstLabels: gpuLabels,
}, func() float64 {
return float64(w.resources.Snapshot().SlotsPerGPU)
}))
reg.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "fetchml_resources_gpu_slots_free",
Help: "Free GPU slots per GPU index.",
ConstLabels: gpuLabels,
}, func() float64 {
s := w.resources.Snapshot()
if idx < 0 || idx >= len(s.GPUFree) {
return 0
}
return float64(s.GPUFree[idx])
}))
}
// Note: Resource metrics temporarily disabled during migration
// These will be re-enabled once resource manager is integrated
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))

View file

@ -1,554 +0,0 @@
package worker
import (
"context"
"fmt"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
)
// Start starts the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker started",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers,
"poll_interval", w.config.PollInterval)
go w.heartbeat()
if w.config != nil && w.config.PrewarmEnabled {
go w.prewarmNextLoop()
go w.prewarmImageGCLoop()
}
for {
switch {
case w.ctx.Err() != nil:
w.logger.Info("shutdown signal received, waiting for tasks")
w.waitForTasks()
return
default:
}
if w.runningCount() >= w.config.MaxWorkers {
time.Sleep(50 * time.Millisecond)
continue
}
queueStart := time.Now()
blockTimeout := time.Duration(w.config.PollInterval) * time.Second
task, err := w.queue.GetNextTaskWithLeaseBlocking(
w.config.WorkerID,
w.config.TaskLeaseDuration,
blockTimeout,
)
queueLatency := time.Since(queueStart)
if err != nil {
if err == context.DeadlineExceeded {
continue
}
w.logger.Error("error fetching task",
"worker_id", w.id,
"error", err)
continue
}
if task == nil {
if queueLatency > 200*time.Millisecond {
w.logger.Debug("queue poll latency",
"latency_ms", queueLatency.Milliseconds())
}
continue
}
if depth, derr := w.queue.QueueDepth(); derr == nil {
if queueLatency > 100*time.Millisecond || depth > 0 {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"remaining_depth", depth)
}
} else if queueLatency > 100*time.Millisecond {
w.logger.Debug("queue fetch metrics",
"latency_ms", queueLatency.Milliseconds(),
"depth_error", derr)
}
// Reserve a running slot *before* starting the goroutine so we don't drain
// the entire queue while max_workers is 1.
w.reserveRunningSlot(task.ID)
go w.executeTaskWithLease(task)
}
}
func (w *Worker) reserveRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
w.running = make(map[string]context.CancelFunc)
}
// Track a cancel func for future shutdown handling; currently best-effort.
_, cancel := context.WithCancel(w.ctx)
w.running[taskID] = cancel
}
func (w *Worker) releaseRunningSlot(taskID string) {
w.runningMu.Lock()
defer w.runningMu.Unlock()
if w.running == nil {
return
}
if cancel, ok := w.running[taskID]; ok {
cancel()
delete(w.running, taskID)
}
}
func (w *Worker) prewarmImageGCLoop() {
if w.config == nil || !w.config.PrewarmEnabled {
return
}
if w.envPool == nil {
return
}
if w.config.LocalMode {
return
}
if strings.TrimSpace(w.config.PodmanImage) == "" {
return
}
lastSeen := ""
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
runGC := func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
_ = w.envPool.PruneImages(ctx, 24*time.Hour)
}
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if w.queue != nil {
v, err := w.queue.PrewarmGCRequestValue()
if err == nil && v != "" && v != lastSeen {
lastSeen = v
runGC()
continue
}
}
runGC()
}
}
}
func (w *Worker) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
if err := w.queue.Heartbeat(w.id); err != nil {
w.logger.Warn("heartbeat failed",
"worker_id", w.id,
"error", err)
}
}
}
}
func (w *Worker) waitForTasks() {
timeout := time.After(5 * time.Minute)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-timeout:
w.logger.Warn("shutdown timeout, force stopping",
"running_tasks", len(w.running))
return
case <-ticker.C:
count := w.runningCount()
if count == 0 {
w.logger.Info("all tasks completed, shutting down")
return
}
w.logger.Debug("waiting for tasks to complete",
"remaining", count)
}
}
}
func (w *Worker) runningCount() int {
w.runningMu.RLock()
defer w.runningMu.RUnlock()
return len(w.running)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
return stats
}
// Stop gracefully shuts down the worker.
func (w *Worker) Stop() {
w.cancel()
w.waitForTasks()
// FIXED: Check error return values
if err := w.server.Close(); err != nil {
w.logger.Warn("error closing server connection", "error", err)
}
if err := w.queue.Close(); err != nil {
w.logger.Warn("error closing queue connection", "error", err)
}
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics exporter shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Execute task with lease management and retry.
func (w *Worker) executeTaskWithLease(task *queue.Task) {
defer w.releaseRunningSlot(task.ID)
// Track task for graceful shutdown
w.gracefulWait.Add(1)
w.activeTasks.Store(task.ID, task)
defer w.gracefulWait.Done()
defer w.activeTasks.Delete(task.ID)
// Create task-specific context with timeout
taskCtx := logging.EnsureTrace(w.ctx) // add trace + span if missing
taskCtx = logging.CtxWithJob(taskCtx, task.JobName) // add job metadata
taskCtx = logging.CtxWithTask(taskCtx, task.ID) // add task metadata
taskCtx, taskCancel := context.WithTimeout(taskCtx, 24*time.Hour)
defer taskCancel()
logger := w.logger.Job(taskCtx, task.JobName, task.ID)
logger.Info("starting task",
"worker_id", w.id,
"datasets", task.Datasets,
"priority", task.Priority)
// Record task start
w.metrics.RecordTaskStart()
defer w.metrics.RecordTaskCompletion()
// Check for context cancellation
select {
case <-taskCtx.Done():
logger.Info("task cancelled before execution")
return
default:
}
// Jupyter tasks are executed directly by the worker (no experiment/provenance pipeline).
if isJupyterTask(task) {
out, err := w.runJupyterTask(taskCtx, task)
endTime := time.Now()
task.EndedAt = &endTime
if err != nil {
logger.Error("jupyter task failed", "error", err)
task.Status = "failed"
task.Error = err.Error()
_ = w.queue.UpdateTaskWithMetrics(task, "final")
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if len(out) > 0 {
task.Output = string(out)
}
task.Status = "completed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
// Parse datasets from task arguments
task.Datasets = resolveDatasets(task)
if err := w.validateTaskForExecution(taskCtx, task); err != nil {
logger.Error("task validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error("failed to update task status after validation failure", "error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
if err := w.enforceTaskProvenance(taskCtx, task); err != nil {
logger.Error("provenance validation failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Provenance validation failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after provenance validation failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
lease, err := w.resources.Acquire(taskCtx, task)
if err != nil {
logger.Error("resource acquisition failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Resource acquisition failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task status after resource acquisition failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
return
}
defer lease.Release()
// Start heartbeat goroutine
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer cancelHeartbeat()
go w.heartbeatLoop(heartbeatCtx, task.ID)
// Update task status and record attempt start
task.Status = "running"
now := time.Now()
task.StartedAt = &now
task.WorkerID = w.id
// Record this attempt in the task history
attempt := queue.Attempt{
Attempt: task.RetryCount + 1, // 1-indexed attempt number
StartedAt: now,
WorkerID: w.id,
Status: "running",
}
task.Attempts = append(task.Attempts, attempt)
// Phase 1 provenance capture: best-effort metadata enrichment before persisting the running state.
w.recordTaskProvenance(taskCtx, task)
if err := w.queue.UpdateTaskWithMetrics(task, "start"); err != nil {
logger.Error("failed to update task status", "error", err)
w.metrics.RecordTaskFailure()
return
}
if w.config.AutoFetchData && len(task.Datasets) > 0 {
if err := w.fetchDatasets(taskCtx, task); err != nil {
logger.Error("data fetch failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Data fetch failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
err := w.queue.UpdateTask(task)
if err != nil {
logger.Error("failed to update task status after data fetch failure", "error", err)
}
w.metrics.RecordTaskFailure()
return
}
}
if err := w.verifyDatasetSpecs(taskCtx, task); err != nil {
logger.Error("dataset checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Dataset checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after dataset checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
if err := w.verifySnapshot(taskCtx, task); err != nil {
logger.Error("snapshot checksum verification failed", "error", err)
task.Status = "failed"
task.Error = fmt.Sprintf("Snapshot checksum verification failed: %v", err)
endTime := time.Now()
task.EndedAt = &endTime
if updateErr := w.queue.UpdateTaskWithMetrics(task, "final"); updateErr != nil {
logger.Error(
"failed to update task after snapshot checksum verification failure",
"error", updateErr)
}
w.metrics.RecordTaskFailure()
return
}
// Execute job with panic recovery
var execErr error
func() {
defer func() {
if r := recover(); r != nil {
execErr = fmt.Errorf("panic during execution: %v", r)
}
}()
cudaVisible := resources.FormatCUDAVisibleDevices(lease)
execErr = w.runJob(taskCtx, task, cudaVisible)
}()
// Finalize task and record attempt completion
endTime := time.Now()
task.EndedAt = &endTime
// Update the last attempt with completion info
if len(task.Attempts) > 0 {
lastIdx := len(task.Attempts) - 1
task.Attempts[lastIdx].EndedAt = &endTime
}
if execErr != nil {
task.Error = execErr.Error()
// Update last attempt with failure info
if len(task.Attempts) > 0 {
lastIdx := len(task.Attempts) - 1
task.Attempts[lastIdx].Status = "failed"
task.Attempts[lastIdx].Error = execErr.Error()
task.Attempts[lastIdx].FailureClass = queue.ClassifyFailure(0, nil, execErr.Error())
// TODO: Capture exit code and signal from actual execution
}
// Check if transient error (network, timeout, etc)
if isTransientError(execErr) && task.RetryCount < task.MaxRetries {
w.logger.Warn("task failed with transient error, will retry",
"task_id", task.ID,
"error", execErr,
"retry_count", task.RetryCount)
_ = w.queue.RetryTask(task)
} else {
task.Status = "failed"
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
} else {
task.Status = "completed"
// Update last attempt with success
if len(task.Attempts) > 0 {
lastIdx := len(task.Attempts) - 1
task.Attempts[lastIdx].Status = "completed"
}
_ = w.queue.UpdateTaskWithMetrics(task, "final")
}
// Release lease
_ = w.queue.ReleaseLease(task.ID, w.config.WorkerID)
}
// Heartbeat loop to renew lease.
func (w *Worker) heartbeatLoop(ctx context.Context, taskID string) {
ticker := time.NewTicker(w.config.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := w.queue.RenewLease(taskID, w.config.WorkerID, w.config.TaskLeaseDuration); err != nil {
w.logger.Error("failed to renew lease", "task_id", taskID, "error", err)
return
}
// Also update worker heartbeat
_ = w.queue.Heartbeat(w.config.WorkerID)
}
}
}
// Shutdown gracefully shuts down the worker.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "active_tasks", w.countActiveTasks())
// Wait for active tasks with timeout
done := make(chan struct{})
go func() {
w.gracefulWait.Wait()
close(done)
}()
timeout := time.After(w.config.GracefulTimeout)
select {
case <-done:
w.logger.Info("all tasks completed, shutdown successful")
case <-timeout:
w.logger.Warn("graceful shutdown timeout, releasing active leases")
w.releaseAllLeases()
}
return w.queue.Close()
}
// Release all active leases.
func (w *Worker) releaseAllLeases() {
w.activeTasks.Range(func(key, _ interface{}) bool {
taskID := key.(string)
if err := w.queue.ReleaseLease(taskID, w.config.WorkerID); err != nil {
w.logger.Error("failed to release lease", "task_id", taskID, "error", err)
}
return true
})
}
// Helper functions.
func (w *Worker) countActiveTasks() int {
count := 0
w.activeTasks.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
func isTransientError(err error) bool {
if err == nil {
return false
}
// Check if error is transient (network, timeout, resource unavailable, etc)
errStr := err.Error()
transientIndicators := []string{
"connection refused",
"timeout",
"temporary failure",
"resource temporarily unavailable",
"no such host",
"network unreachable",
}
for _, indicator := range transientIndicators {
if strings.Contains(strings.ToLower(errStr), indicator) {
return true
}
}
return false
}

View file

@ -1,108 +0,0 @@
// Package worker provides the ML task worker implementation
package worker
import (
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// SimplifiedWorker demonstrates the target architecture for the worker refactor.
// This is Phase 5 of the architectural refactoring plan.
//
// Instead of 27 fields in the monolithic Worker struct, this uses composed
// dependencies that implement clear interfaces.
//
// Key improvements:
// - Dependencies are injected via constructor
// - Clear separation of concerns (execution, lifecycle, metrics)
// - Each component can be mocked for testing
// - No direct access to low-level resources
type SimplifiedWorker struct {
id string
config *Config
logger *logging.Logger
// Composed dependencies from previous phases
runLoop *lifecycle.RunLoop
runner *executor.JobRunner
metrics lifecycle.MetricsRecorder
health *lifecycle.HealthMonitor
}
// SimplifiedWorkerConfig holds configuration for the simplified worker
type SimplifiedWorkerConfig struct {
ID string
Config *Config
Logger *logging.Logger
Queue queue.Backend
JobRunner *executor.JobRunner
Metrics lifecycle.MetricsRecorder
Executor lifecycle.TaskExecutor
}
// NewSimplifiedWorker creates a new simplified worker
func NewSimplifiedWorker(cfg SimplifiedWorkerConfig) *SimplifiedWorker {
// Build run loop configuration from worker config
runLoopConfig := lifecycle.RunLoopConfig{
WorkerID: cfg.ID,
MaxWorkers: cfg.Config.MaxWorkers,
PollInterval: time.Duration(cfg.Config.PollInterval) * time.Second,
TaskLeaseDuration: cfg.Config.TaskLeaseDuration,
HeartbeatInterval: cfg.Config.HeartbeatInterval,
GracefulTimeout: cfg.Config.GracefulTimeout,
PrewarmEnabled: cfg.Config.PrewarmEnabled,
}
// Create run loop
runLoop := lifecycle.NewRunLoop(
runLoopConfig,
cfg.Queue,
cfg.Executor,
cfg.Metrics,
cfg.Logger,
)
return &SimplifiedWorker{
id: cfg.ID,
config: cfg.Config,
logger: cfg.Logger,
runLoop: runLoop,
runner: cfg.JobRunner,
metrics: cfg.Metrics,
health: lifecycle.NewHealthMonitor(),
}
}
// Start begins the worker's main processing loop
func (w *SimplifiedWorker) Start() {
w.logger.Info("simplified worker starting",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers)
w.runLoop.Start()
}
// Stop gracefully shuts down the worker
func (w *SimplifiedWorker) Stop() {
w.logger.Info("simplified worker stopping", "worker_id", w.id)
w.runLoop.Stop()
}
// IsHealthy returns true if the worker is healthy
func (w *SimplifiedWorker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute)
}
// GetMetrics returns current worker metrics
func (w *SimplifiedWorker) GetMetrics() map[string]any {
// Simplified metrics - real implementation would aggregate from components
return map[string]any{
"worker_id": w.id,
"max_workers": w.config.MaxWorkers,
"healthy": w.IsHealthy(),
}
}

View file

@ -13,6 +13,7 @@ import (
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
@ -92,7 +93,7 @@ func ResolveSnapshot(
if err := container.ValidateJobName(snapshotID); err != nil {
return "", fmt.Errorf("invalid snapshot_id: %w", err)
}
want, err := normalizeSHA256ChecksumHex(wantSHA256)
want, err := integrity.NormalizeSHA256ChecksumHex(wantSHA256)
if err != nil || want == "" {
return "", fmt.Errorf("invalid snapshot_sha256")
}

View file

@ -1,29 +1,24 @@
// Package worker provides the ML task worker implementation
package worker
import (
"context"
"net/http"
"sync"
"time"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/envpool"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/tracking"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
)
// MLServer wraps network.SSHClient for backward compatibility.
type MLServer struct {
*network.SSHClient
SSHClient interface{}
}
// JupyterManager is the subset of the Jupyter service manager used by the worker.
// It exists to keep task execution testable.
// JupyterManager interface for jupyter service management
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
@ -34,59 +29,98 @@ type JupyterManager interface {
}
// isValidName validates that input strings contain only safe characters.
// isValidName checks if the input string is a valid name.
func isValidName(input string) bool {
return len(input) > 0 && len(input) < 256
}
// NewMLServer creates a new ML server connection.
// NewMLServer returns a new MLServer instance.
func NewMLServer(cfg *Config) (*MLServer, error) {
if cfg.LocalMode {
return &MLServer{SSHClient: network.NewLocalClient(cfg.BasePath)}, nil
}
client, err := network.NewSSHClient(cfg.Host, cfg.User, cfg.SSHKey, cfg.Port, cfg.KnownHosts)
if err != nil {
return nil, err
}
return &MLServer{SSHClient: client}, nil
return &MLServer{}, nil
}
// Worker represents an ML task worker.
// Worker represents an ML task worker with composed dependencies.
type Worker struct {
id string
config *Config
server *MLServer
queue queue.Backend
resources *resources.Manager
running map[string]context.CancelFunc // Store cancellation functions for graceful shutdown
runningMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
logger *logging.Logger
id string
config *Config
logger *logging.Logger
// Composed dependencies from previous phases
runLoop *lifecycle.RunLoop
runner *executor.JobRunner
metrics *metrics.Metrics
metricsSrv *http.Server
health *lifecycle.HealthMonitor
datasetCache map[string]time.Time
datasetCacheMu sync.RWMutex
datasetCacheTTL time.Duration
// Legacy fields for backward compatibility during migration
jupyter JupyterManager
}
// Graceful shutdown fields
shutdownCh chan struct{}
activeTasks sync.Map // map[string]*queue.Task - track active tasks
gracefulWait sync.WaitGroup
// Start begins the worker's main processing loop.
func (w *Worker) Start() {
w.logger.Info("worker starting",
"worker_id", w.id,
"max_concurrent", w.config.MaxWorkers)
podman *container.PodmanManager
jupyter JupyterManager
trackingRegistry *tracking.Registry
envPool *envpool.Pool
w.health.RecordHeartbeat()
w.runLoop.Start()
}
prewarmMu sync.Mutex
prewarmTargetID string
prewarmCancel context.CancelFunc
prewarmStartedAt time.Time
// Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() {
w.logger.Info("worker stopping", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker stopped", "worker_id", w.id)
}
// Shutdown performs a graceful shutdown with timeout.
func (w *Worker) Shutdown() error {
w.logger.Info("starting graceful shutdown", "worker_id", w.id)
w.runLoop.Stop()
if w.metricsSrv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := w.metricsSrv.Shutdown(ctx); err != nil {
w.logger.Warn("metrics server shutdown error", "error", err)
}
}
w.logger.Info("worker shut down gracefully", "worker_id", w.id)
return nil
}
// IsHealthy returns true if the worker is healthy.
func (w *Worker) IsHealthy() bool {
return w.health.IsHealthy(5 * time.Minute)
}
// GetMetrics returns current worker metrics.
func (w *Worker) GetMetrics() map[string]any {
stats := w.metrics.GetStats()
stats["worker_id"] = w.id
stats["max_workers"] = w.config.MaxWorkers
stats["healthy"] = w.IsHealthy()
return stats
}
// GetID returns the worker ID.
func (w *Worker) GetID() string {
return w.id
}
// runningCount returns the number of currently running tasks
func (w *Worker) runningCount() int {
return 0 // Placeholder - will be implemented with runLoop integration
}
func (w *Worker) getGPUDetector() GPUDetector {