package worker_test import ( "context" "os" "path/filepath" "strings" "testing" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/worker" tests "github.com/jfraeys/fetch_ml/tests/fixtures" ) func TestSelectDependencyManifestPriority(t *testing.T) { base := t.TempDir() // Create all candidates. candidates := []string{ "requirements.txt", "pyproject.toml", "poetry.lock", "environment.yaml", "environment.yml", } for _, name := range candidates { p := filepath.Join(base, name) if err := os.WriteFile(p, []byte("# test\n"), 0600); err != nil { t.Fatalf("write %s: %v", name, err) } } // With all present, environment.yml should win. if got, err := worker.SelectDependencyManifest(base); err != nil { t.Fatalf("unexpected error: %v", err) } else if got != "environment.yml" { t.Fatalf("expected environment.yml, got %q", got) } // Remove environment.yml; environment.yaml should win. if err := os.Remove(filepath.Join(base, "environment.yml")); err != nil { t.Fatalf("remove environment.yml: %v", err) } if got, err := worker.SelectDependencyManifest(base); err != nil { t.Fatalf("unexpected error: %v", err) } else if got != "environment.yaml" { t.Fatalf("expected environment.yaml, got %q", got) } // Remove environment.yaml; poetry.lock should win. if err := os.Remove(filepath.Join(base, "environment.yaml")); err != nil { t.Fatalf("remove environment.yaml: %v", err) } if got, err := worker.SelectDependencyManifest(base); err != nil { t.Fatalf("unexpected error: %v", err) } else if got != "poetry.lock" { t.Fatalf("expected poetry.lock, got %q", got) } // Remove poetry.lock; pyproject.toml should win. if err := os.Remove(filepath.Join(base, "poetry.lock")); err != nil { t.Fatalf("remove poetry.lock: %v", err) } if got, err := worker.SelectDependencyManifest(base); err != nil { t.Fatalf("unexpected error: %v", err) } else if got != "pyproject.toml" { t.Fatalf("expected pyproject.toml, got %q", got) } // Remove pyproject.toml; requirements.txt should win. if err := os.Remove(filepath.Join(base, "pyproject.toml")); err != nil { t.Fatalf("remove pyproject.toml: %v", err) } if got, err := worker.SelectDependencyManifest(base); err != nil { t.Fatalf("unexpected error: %v", err) } else if got != "requirements.txt" { t.Fatalf("expected requirements.txt, got %q", got) } } func TestSelectDependencyManifestPoetryRequiresPyproject(t *testing.T) { base := t.TempDir() // poetry.lock exists but pyproject.toml is missing. if err := os.WriteFile(filepath.Join(base, "poetry.lock"), []byte("# test\n"), 0600); err != nil { t.Fatalf("write poetry.lock: %v", err) } if _, err := worker.SelectDependencyManifest(base); err == nil { t.Fatalf("expected error when poetry.lock exists without pyproject.toml") } } func TestSelectDependencyManifestMissing(t *testing.T) { base := t.TempDir() if _, err := worker.SelectDependencyManifest(base); err == nil { t.Fatalf("expected error when no manifest exists") } } func TestResolveDatasetsPrecedence(t *testing.T) { if got := tests.ResolveDatasets(nil); got != nil { t.Fatalf("expected nil for nil task") } t.Run("DatasetSpecsWins", func(t *testing.T) { task := &queue.Task{ DatasetSpecs: []queue.DatasetSpec{{Name: "ds-spec"}}, Datasets: []string{"ds-legacy"}, Args: "--datasets ds-args", } got := tests.ResolveDatasets(task) if len(got) != 1 || got[0] != "ds-spec" { t.Fatalf("expected dataset_specs to win, got %v", got) } }) t.Run("DatasetsWinsOverArgs", func(t *testing.T) { task := &queue.Task{ Datasets: []string{"ds-legacy"}, Args: "--datasets ds-args", } got := tests.ResolveDatasets(task) if len(got) != 1 || got[0] != "ds-legacy" { t.Fatalf("expected datasets to win over args, got %v", got) } }) t.Run("ArgsFallback", func(t *testing.T) { task := &queue.Task{Args: "--datasets a,b,c"} got := tests.ResolveDatasets(task) if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" { t.Fatalf("expected args datasets, got %v", got) } }) } func TestComputeTaskProvenance(t *testing.T) { base := t.TempDir() commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex // Create experiment files structure. expMgr := experiment.NewManager(base) requireNoErr(t, expMgr.CreateExperiment(commitID)) filesPath := expMgr.GetFilesPath(commitID) // Create a deps manifest in the files path. depsPath := filepath.Join(filesPath, "requirements.txt") requireNoErr(t, os.WriteFile(depsPath, []byte("numpy==1.0.0\n"), 0600)) // Write an experiment manifest.json with deterministic overall sha. manifest := &experiment.Manifest{ CommitID: commitID, Files: map[string]string{"train.py": "deadbeef"}, OverallSHA: "0123456789abcdef", Timestamp: 1, } requireNoErr(t, expMgr.WriteManifest(manifest)) // Task references commit_id in metadata. task := &queue.Task{ JobName: "job", SnapshotID: "snap-1", DatasetSpecs: []queue.DatasetSpec{{Name: "ds1", Version: "v1"}}, Metadata: map[string]string{"commit_id": commitID}, } prov, err := worker.ComputeTaskProvenance(base, task) if err != nil { t.Fatalf("ComputeTaskProvenance error: %v", err) } if prov["snapshot_id"] != "snap-1" { t.Fatalf("expected snapshot_id, got %q", prov["snapshot_id"]) } if prov["datasets"] != "ds1" { t.Fatalf("expected datasets=ds1, got %q", prov["datasets"]) } if prov["dataset_specs"] == "" { t.Fatalf("expected dataset_specs json") } if prov["experiment_manifest_overall_sha"] != "0123456789abcdef" { t.Fatalf("expected manifest sha, got %q", prov["experiment_manifest_overall_sha"]) } if prov["deps_manifest_name"] != "requirements.txt" { t.Fatalf("expected deps_manifest_name requirements.txt, got %q", prov["deps_manifest_name"]) } if prov["deps_manifest_sha256"] == "" { t.Fatalf("expected deps_manifest_sha256") } // Graceful behavior with missing metadata. task2 := &queue.Task{SnapshotID: "snap-2"} prov2, err := worker.ComputeTaskProvenance(base, task2) if err != nil { t.Fatalf("ComputeTaskProvenance (missing metadata) error: %v", err) } if prov2["snapshot_id"] != "snap-2" { t.Fatalf("expected snapshot_id snap-2, got %q", prov2["snapshot_id"]) } } func requireNoErr(t *testing.T, err error) { t.Helper() if err != nil { t.Fatalf("unexpected error: %v", err) } } func TestNormalizeSHA256ChecksumHex(t *testing.T) { got, err := worker.NormalizeSHA256ChecksumHex("sha256:" + strings.Repeat("a", 64)) if err != nil { t.Fatalf("unexpected error: %v", err) } if got != strings.Repeat("a", 64) { t.Fatalf("unexpected normalized checksum: %q", got) } if _, err := worker.NormalizeSHA256ChecksumHex("sha256:deadbeef"); err == nil { t.Fatalf("expected error for short checksum") } } func TestVerifyDatasetSpecs(t *testing.T) { base := t.TempDir() dataDir := filepath.Join(base, "data") requireNoErr(t, os.MkdirAll(dataDir, 0750)) // Create dataset directory with one file. dsName := "dataset1" dsPath := filepath.Join(dataDir, dsName) requireNoErr(t, os.MkdirAll(dsPath, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(dsPath, "file.txt"), []byte("hello"), 0600)) sha, err := worker.DirOverallSHA256Hex(dsPath) requireNoErr(t, err) w := tests.NewTestWorker(&worker.Config{DataDir: dataDir}) task := &queue.Task{ JobName: "job", ID: "t1", DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + sha}}, } if err := w.VerifyDatasetSpecs(context.Background(), task); err != nil { t.Fatalf("expected checksum verification to pass, got %v", err) } taskBad := &queue.Task{ JobName: "job", ID: "t2", DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + strings.Repeat("b", 64)}}, } if err := w.VerifyDatasetSpecs(context.Background(), taskBad); err == nil { t.Fatalf("expected checksum mismatch error") } } func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) { base := t.TempDir() commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex expMgr := experiment.NewManager(base) requireNoErr(t, expMgr.CreateExperiment(commitID)) filesPath := expMgr.GetFilesPath(commitID) requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) manifest := &experiment.Manifest{ CommitID: commitID, Files: map[string]string{"train.py": "deadbeef"}, OverallSHA: "0123456789abcdef", Timestamp: 1, } requireNoErr(t, expMgr.WriteManifest(manifest)) w := tests.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false}) // Missing expected fields should fail. taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}} if err := w.EnforceTaskProvenance(context.Background(), taskMissing); err == nil { t.Fatalf("expected missing provenance fields error") } // Mismatch should fail. taskMismatch := &queue.Task{JobName: "job", ID: "t2", Metadata: map[string]string{ "commit_id": commitID, "experiment_manifest_overall_sha": "bad", "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": "bad", }} if err := w.EnforceTaskProvenance(context.Background(), taskMismatch); err == nil { t.Fatalf("expected mismatch provenance error") } // SnapshotID set but missing snapshot_sha256 should fail in strict mode. snapDir := filepath.Join(base, "data", "snapshots", "snap1") requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) wSnap := tests.NewTestWorker(&worker.Config{ BasePath: base, DataDir: filepath.Join(base, "data"), ProvenanceBestEffort: false, }) taskSnapMissing := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{ "commit_id": commitID, "experiment_manifest_overall_sha": "0123456789abcdef", "deps_manifest_name": "requirements.txt", "deps_manifest_sha256": "bad", // still mismatch but we're focusing snapshot field presence }} if err := wSnap.EnforceTaskProvenance(context.Background(), taskSnapMissing); err == nil { t.Fatalf("expected strict provenance to fail when snapshot_sha256 missing") } } func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) { base := t.TempDir() commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex expMgr := experiment.NewManager(base) requireNoErr(t, expMgr.CreateExperiment(commitID)) filesPath := expMgr.GetFilesPath(commitID) requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600)) requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600)) manifest := &experiment.Manifest{ CommitID: commitID, Files: map[string]string{"train.py": "deadbeef"}, OverallSHA: "0123456789abcdef", Timestamp: 1, } requireNoErr(t, expMgr.WriteManifest(manifest)) dataDir := filepath.Join(base, "data") snapDir := filepath.Join(dataDir, "snapshots", "snap1") requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) w := tests.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true}) task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}} if err := w.EnforceTaskProvenance(context.Background(), task); err != nil { t.Fatalf("expected best-effort to pass, got %v", err) } if task.Metadata["experiment_manifest_overall_sha"] == "" || task.Metadata["deps_manifest_sha256"] == "" || task.Metadata["snapshot_sha256"] == "" { t.Fatalf("expected best-effort to populate provenance metadata") } } func TestVerifySnapshot(t *testing.T) { base := t.TempDir() dataDir := filepath.Join(base, "data") requireNoErr(t, os.MkdirAll(dataDir, 0750)) snapID := "snap1" snapDir := filepath.Join(dataDir, "snapshots", snapID) requireNoErr(t, os.MkdirAll(snapDir, 0750)) requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600)) sha, err := worker.DirOverallSHA256Hex(snapDir) requireNoErr(t, err) w := tests.NewTestWorker(&worker.Config{DataDir: dataDir}) t.Run("Ok", func(t *testing.T) { task := &queue.Task{ JobName: "job", ID: "t1", SnapshotID: snapID, Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha}, } if err := w.VerifySnapshot(context.Background(), task); err != nil { t.Fatalf("expected snapshot verification to pass, got %v", err) } }) t.Run("MissingChecksum", func(t *testing.T) { task := &queue.Task{JobName: "job", ID: "t2", SnapshotID: snapID, Metadata: map[string]string{}} if err := w.VerifySnapshot(context.Background(), task); err == nil { t.Fatalf("expected error for missing snapshot_sha256") } }) t.Run("Mismatch", func(t *testing.T) { task := &queue.Task{ JobName: "job", ID: "t3", SnapshotID: snapID, Metadata: map[string]string{"snapshot_sha256": "sha256:" + strings.Repeat("b", 64)}, } if err := w.VerifySnapshot(context.Background(), task); err == nil { t.Fatalf("expected checksum mismatch") } }) t.Run("MissingDir", func(t *testing.T) { task := &queue.Task{ JobName: "job", ID: "t4", SnapshotID: "missing", Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha}, } if err := w.VerifySnapshot(context.Background(), task); err == nil { t.Fatalf("expected missing snapshot directory error") } }) }