407 lines
14 KiB
Go
407 lines
14 KiB
Go
package worker_test
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
func TestSelectDependencyManifestPriority(t *testing.T) {
|
|
base := t.TempDir()
|
|
|
|
// Create all candidates.
|
|
candidates := []string{
|
|
"requirements.txt",
|
|
"pyproject.toml",
|
|
"poetry.lock",
|
|
"environment.yaml",
|
|
"environment.yml",
|
|
}
|
|
for _, name := range candidates {
|
|
p := filepath.Join(base, name)
|
|
if err := os.WriteFile(p, []byte("# test\n"), 0600); err != nil {
|
|
t.Fatalf("write %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
// With all present, environment.yml should win.
|
|
if got, err := worker.SelectDependencyManifest(base); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
} else if got != "environment.yml" {
|
|
t.Fatalf("expected environment.yml, got %q", got)
|
|
}
|
|
|
|
// Remove environment.yml; environment.yaml should win.
|
|
if err := os.Remove(filepath.Join(base, "environment.yml")); err != nil {
|
|
t.Fatalf("remove environment.yml: %v", err)
|
|
}
|
|
if got, err := worker.SelectDependencyManifest(base); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
} else if got != "environment.yaml" {
|
|
t.Fatalf("expected environment.yaml, got %q", got)
|
|
}
|
|
|
|
// Remove environment.yaml; poetry.lock should win.
|
|
if err := os.Remove(filepath.Join(base, "environment.yaml")); err != nil {
|
|
t.Fatalf("remove environment.yaml: %v", err)
|
|
}
|
|
if got, err := worker.SelectDependencyManifest(base); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
} else if got != "poetry.lock" {
|
|
t.Fatalf("expected poetry.lock, got %q", got)
|
|
}
|
|
|
|
// Remove poetry.lock; pyproject.toml should win.
|
|
if err := os.Remove(filepath.Join(base, "poetry.lock")); err != nil {
|
|
t.Fatalf("remove poetry.lock: %v", err)
|
|
}
|
|
if got, err := worker.SelectDependencyManifest(base); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
} else if got != "pyproject.toml" {
|
|
t.Fatalf("expected pyproject.toml, got %q", got)
|
|
}
|
|
|
|
// Remove pyproject.toml; requirements.txt should win.
|
|
if err := os.Remove(filepath.Join(base, "pyproject.toml")); err != nil {
|
|
t.Fatalf("remove pyproject.toml: %v", err)
|
|
}
|
|
if got, err := worker.SelectDependencyManifest(base); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
} else if got != "requirements.txt" {
|
|
t.Fatalf("expected requirements.txt, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSelectDependencyManifestPoetryRequiresPyproject(t *testing.T) {
|
|
base := t.TempDir()
|
|
|
|
// poetry.lock exists but pyproject.toml is missing.
|
|
if err := os.WriteFile(filepath.Join(base, "poetry.lock"), []byte("# test\n"), 0600); err != nil {
|
|
t.Fatalf("write poetry.lock: %v", err)
|
|
}
|
|
|
|
if _, err := worker.SelectDependencyManifest(base); err == nil {
|
|
t.Fatalf("expected error when poetry.lock exists without pyproject.toml")
|
|
}
|
|
}
|
|
|
|
func TestSelectDependencyManifestMissing(t *testing.T) {
|
|
base := t.TempDir()
|
|
if _, err := worker.SelectDependencyManifest(base); err == nil {
|
|
t.Fatalf("expected error when no manifest exists")
|
|
}
|
|
}
|
|
|
|
func TestResolveDatasetsPrecedence(t *testing.T) {
|
|
if got := worker.ResolveDatasets(nil); got != nil {
|
|
t.Fatalf("expected nil for nil task")
|
|
}
|
|
|
|
t.Run("DatasetSpecsWins", func(t *testing.T) {
|
|
task := &queue.Task{
|
|
DatasetSpecs: []queue.DatasetSpec{{Name: "ds-spec"}},
|
|
Datasets: []string{"ds-legacy"},
|
|
Args: "--datasets ds-args",
|
|
}
|
|
got := worker.ResolveDatasets(task)
|
|
if len(got) != 1 || got[0] != "ds-spec" {
|
|
t.Fatalf("expected dataset_specs to win, got %v", got)
|
|
}
|
|
})
|
|
|
|
t.Run("DatasetsWinsOverArgs", func(t *testing.T) {
|
|
task := &queue.Task{
|
|
Datasets: []string{"ds-legacy"},
|
|
Args: "--datasets ds-args",
|
|
}
|
|
got := worker.ResolveDatasets(task)
|
|
if len(got) != 1 || got[0] != "ds-legacy" {
|
|
t.Fatalf("expected datasets to win over args, got %v", got)
|
|
}
|
|
})
|
|
|
|
t.Run("ArgsFallback", func(t *testing.T) {
|
|
task := &queue.Task{Args: "--datasets a,b,c"}
|
|
got := worker.ResolveDatasets(task)
|
|
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
|
|
t.Fatalf("expected args datasets, got %v", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestComputeTaskProvenance(t *testing.T) {
|
|
base := t.TempDir()
|
|
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
|
|
|
// Create experiment files structure.
|
|
expMgr := experiment.NewManager(base)
|
|
requireNoErr(t, expMgr.CreateExperiment(commitID))
|
|
filesPath := expMgr.GetFilesPath(commitID)
|
|
|
|
// Create a deps manifest in the files path.
|
|
depsPath := filepath.Join(filesPath, "requirements.txt")
|
|
requireNoErr(t, os.WriteFile(depsPath, []byte("numpy==1.0.0\n"), 0600))
|
|
|
|
// Write an experiment manifest.json with deterministic overall sha.
|
|
manifest := &experiment.Manifest{
|
|
CommitID: commitID,
|
|
Files: map[string]string{"train.py": "deadbeef"},
|
|
OverallSHA: "0123456789abcdef",
|
|
Timestamp: 1,
|
|
}
|
|
requireNoErr(t, expMgr.WriteManifest(manifest))
|
|
|
|
// Task references commit_id in metadata.
|
|
task := &queue.Task{
|
|
JobName: "job",
|
|
SnapshotID: "snap-1",
|
|
DatasetSpecs: []queue.DatasetSpec{{Name: "ds1", Version: "v1"}},
|
|
Metadata: map[string]string{"commit_id": commitID},
|
|
}
|
|
|
|
prov, err := worker.ComputeTaskProvenance(base, task)
|
|
if err != nil {
|
|
t.Fatalf("ComputeTaskProvenance error: %v", err)
|
|
}
|
|
if prov["snapshot_id"] != "snap-1" {
|
|
t.Fatalf("expected snapshot_id, got %q", prov["snapshot_id"])
|
|
}
|
|
if prov["datasets"] != "ds1" {
|
|
t.Fatalf("expected datasets=ds1, got %q", prov["datasets"])
|
|
}
|
|
if prov["dataset_specs"] == "" {
|
|
t.Fatalf("expected dataset_specs json")
|
|
}
|
|
if prov["experiment_manifest_overall_sha"] != "0123456789abcdef" {
|
|
t.Fatalf("expected manifest sha, got %q", prov["experiment_manifest_overall_sha"])
|
|
}
|
|
if prov["deps_manifest_name"] != "requirements.txt" {
|
|
t.Fatalf("expected deps_manifest_name requirements.txt, got %q", prov["deps_manifest_name"])
|
|
}
|
|
if prov["deps_manifest_sha256"] == "" {
|
|
t.Fatalf("expected deps_manifest_sha256")
|
|
}
|
|
|
|
// Graceful behavior with missing metadata.
|
|
task2 := &queue.Task{SnapshotID: "snap-2"}
|
|
prov2, err := worker.ComputeTaskProvenance(base, task2)
|
|
if err != nil {
|
|
t.Fatalf("ComputeTaskProvenance (missing metadata) error: %v", err)
|
|
}
|
|
if prov2["snapshot_id"] != "snap-2" {
|
|
t.Fatalf("expected snapshot_id snap-2, got %q", prov2["snapshot_id"])
|
|
}
|
|
}
|
|
|
|
func requireNoErr(t *testing.T, err error) {
|
|
t.Helper()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeSHA256ChecksumHex(t *testing.T) {
|
|
got, err := worker.NormalizeSHA256ChecksumHex("sha256:" + strings.Repeat("a", 64))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if got != strings.Repeat("a", 64) {
|
|
t.Fatalf("unexpected normalized checksum: %q", got)
|
|
}
|
|
|
|
if _, err := worker.NormalizeSHA256ChecksumHex("sha256:deadbeef"); err == nil {
|
|
t.Fatalf("expected error for short checksum")
|
|
}
|
|
}
|
|
|
|
func TestVerifyDatasetSpecs(t *testing.T) {
|
|
base := t.TempDir()
|
|
dataDir := filepath.Join(base, "data")
|
|
requireNoErr(t, os.MkdirAll(dataDir, 0750))
|
|
|
|
// Create dataset directory with one file.
|
|
dsName := "dataset1"
|
|
dsPath := filepath.Join(dataDir, dsName)
|
|
requireNoErr(t, os.MkdirAll(dsPath, 0750))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(dsPath, "file.txt"), []byte("hello"), 0600))
|
|
|
|
sha, err := worker.DirOverallSHA256Hex(dsPath)
|
|
requireNoErr(t, err)
|
|
|
|
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
|
task := &queue.Task{
|
|
JobName: "job",
|
|
ID: "t1",
|
|
DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + sha}},
|
|
}
|
|
if err := w.VerifyDatasetSpecs(context.Background(), task); err != nil {
|
|
t.Fatalf("expected checksum verification to pass, got %v", err)
|
|
}
|
|
|
|
taskBad := &queue.Task{
|
|
JobName: "job",
|
|
ID: "t2",
|
|
DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + strings.Repeat("b", 64)}},
|
|
}
|
|
if err := w.VerifyDatasetSpecs(context.Background(), taskBad); err == nil {
|
|
t.Fatalf("expected checksum mismatch error")
|
|
}
|
|
}
|
|
|
|
func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
|
|
base := t.TempDir()
|
|
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
|
|
|
expMgr := experiment.NewManager(base)
|
|
requireNoErr(t, expMgr.CreateExperiment(commitID))
|
|
filesPath := expMgr.GetFilesPath(commitID)
|
|
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
|
|
manifest := &experiment.Manifest{
|
|
CommitID: commitID,
|
|
Files: map[string]string{"train.py": "deadbeef"},
|
|
OverallSHA: "0123456789abcdef",
|
|
Timestamp: 1,
|
|
}
|
|
requireNoErr(t, expMgr.WriteManifest(manifest))
|
|
|
|
w := worker.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
|
|
|
|
// Missing expected fields should fail.
|
|
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
|
|
if err := w.EnforceTaskProvenance(context.Background(), taskMissing); err == nil {
|
|
t.Fatalf("expected missing provenance fields error")
|
|
}
|
|
|
|
// Mismatch should fail.
|
|
taskMismatch := &queue.Task{JobName: "job", ID: "t2", Metadata: map[string]string{
|
|
"commit_id": commitID,
|
|
"experiment_manifest_overall_sha": "bad",
|
|
"deps_manifest_name": "requirements.txt",
|
|
"deps_manifest_sha256": "bad",
|
|
}}
|
|
if err := w.EnforceTaskProvenance(context.Background(), taskMismatch); err == nil {
|
|
t.Fatalf("expected mismatch provenance error")
|
|
}
|
|
|
|
// SnapshotID set but missing snapshot_sha256 should fail in strict mode.
|
|
snapDir := filepath.Join(base, "data", "snapshots", "snap1")
|
|
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
|
|
|
wSnap := worker.NewTestWorker(&worker.Config{
|
|
BasePath: base,
|
|
DataDir: filepath.Join(base, "data"),
|
|
ProvenanceBestEffort: false,
|
|
})
|
|
taskSnapMissing := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{
|
|
"commit_id": commitID,
|
|
"experiment_manifest_overall_sha": "0123456789abcdef",
|
|
"deps_manifest_name": "requirements.txt",
|
|
"deps_manifest_sha256": "bad", // still mismatch but we're focusing snapshot field presence
|
|
}}
|
|
if err := wSnap.EnforceTaskProvenance(context.Background(), taskSnapMissing); err == nil {
|
|
t.Fatalf("expected strict provenance to fail when snapshot_sha256 missing")
|
|
}
|
|
}
|
|
|
|
func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
|
|
base := t.TempDir()
|
|
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
|
|
|
|
expMgr := experiment.NewManager(base)
|
|
requireNoErr(t, expMgr.CreateExperiment(commitID))
|
|
filesPath := expMgr.GetFilesPath(commitID)
|
|
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
|
|
|
|
manifest := &experiment.Manifest{
|
|
CommitID: commitID,
|
|
Files: map[string]string{"train.py": "deadbeef"},
|
|
OverallSHA: "0123456789abcdef",
|
|
Timestamp: 1,
|
|
}
|
|
requireNoErr(t, expMgr.WriteManifest(manifest))
|
|
|
|
dataDir := filepath.Join(base, "data")
|
|
snapDir := filepath.Join(dataDir, "snapshots", "snap1")
|
|
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
|
|
|
w := worker.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
|
|
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
|
|
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
|
|
t.Fatalf("expected best-effort to pass, got %v", err)
|
|
}
|
|
if task.Metadata["experiment_manifest_overall_sha"] == "" ||
|
|
task.Metadata["deps_manifest_sha256"] == "" ||
|
|
task.Metadata["snapshot_sha256"] == "" {
|
|
t.Fatalf("expected best-effort to populate provenance metadata")
|
|
}
|
|
}
|
|
|
|
func TestVerifySnapshot(t *testing.T) {
|
|
base := t.TempDir()
|
|
dataDir := filepath.Join(base, "data")
|
|
requireNoErr(t, os.MkdirAll(dataDir, 0750))
|
|
|
|
snapID := "snap1"
|
|
snapDir := filepath.Join(dataDir, "snapshots", snapID)
|
|
requireNoErr(t, os.MkdirAll(snapDir, 0750))
|
|
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
|
|
|
|
sha, err := worker.DirOverallSHA256Hex(snapDir)
|
|
requireNoErr(t, err)
|
|
|
|
w := worker.NewTestWorker(&worker.Config{DataDir: dataDir})
|
|
|
|
t.Run("Ok", func(t *testing.T) {
|
|
task := &queue.Task{
|
|
JobName: "job",
|
|
ID: "t1",
|
|
SnapshotID: snapID,
|
|
Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha},
|
|
}
|
|
if err := w.VerifySnapshot(context.Background(), task); err != nil {
|
|
t.Fatalf("expected snapshot verification to pass, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("MissingChecksum", func(t *testing.T) {
|
|
task := &queue.Task{JobName: "job", ID: "t2", SnapshotID: snapID, Metadata: map[string]string{}}
|
|
if err := w.VerifySnapshot(context.Background(), task); err == nil {
|
|
t.Fatalf("expected error for missing snapshot_sha256")
|
|
}
|
|
})
|
|
|
|
t.Run("Mismatch", func(t *testing.T) {
|
|
task := &queue.Task{
|
|
JobName: "job",
|
|
ID: "t3",
|
|
SnapshotID: snapID,
|
|
Metadata: map[string]string{"snapshot_sha256": "sha256:" + strings.Repeat("b", 64)},
|
|
}
|
|
if err := w.VerifySnapshot(context.Background(), task); err == nil {
|
|
t.Fatalf("expected checksum mismatch")
|
|
}
|
|
})
|
|
|
|
t.Run("MissingDir", func(t *testing.T) {
|
|
task := &queue.Task{
|
|
JobName: "job",
|
|
ID: "t4",
|
|
SnapshotID: "missing",
|
|
Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha},
|
|
}
|
|
if err := w.VerifySnapshot(context.Background(), task); err == nil {
|
|
t.Fatalf("expected missing snapshot directory error")
|
|
}
|
|
})
|
|
}
|