fetch_ml/tests/unit/worker/worker_test.go
Jeremie Fraeys d87c556afa
test(all): update test suite for scheduler and security features
Update comprehensive test coverage:
- E2E tests with scheduler integration
- Integration tests with tenant isolation
- Unit tests with security assertions
- Security tests with audit validation
- Audit verification tests
- Auth tests with tenant scoping
- Config validation tests
- Container security tests
- Worker tests with scheduler mock
- Environment pool tests
- Load tests with distributed patterns
- Test fixtures with scheduler support
- Update go.mod/go.sum with new dependencies
2026-02-26 12:08:46 -05:00

408 lines
14 KiB
Go

package worker_test
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
tests "github.com/jfraeys/fetch_ml/tests/fixtures"
)
func TestSelectDependencyManifestPriority(t *testing.T) {
base := t.TempDir()
// Create all candidates.
candidates := []string{
"requirements.txt",
"pyproject.toml",
"poetry.lock",
"environment.yaml",
"environment.yml",
}
for _, name := range candidates {
p := filepath.Join(base, name)
if err := os.WriteFile(p, []byte("# test\n"), 0600); err != nil {
t.Fatalf("write %s: %v", name, err)
}
}
// With all present, environment.yml should win.
if got, err := worker.SelectDependencyManifest(base); err != nil {
t.Fatalf("unexpected error: %v", err)
} else if got != "environment.yml" {
t.Fatalf("expected environment.yml, got %q", got)
}
// Remove environment.yml; environment.yaml should win.
if err := os.Remove(filepath.Join(base, "environment.yml")); err != nil {
t.Fatalf("remove environment.yml: %v", err)
}
if got, err := worker.SelectDependencyManifest(base); err != nil {
t.Fatalf("unexpected error: %v", err)
} else if got != "environment.yaml" {
t.Fatalf("expected environment.yaml, got %q", got)
}
// Remove environment.yaml; poetry.lock should win.
if err := os.Remove(filepath.Join(base, "environment.yaml")); err != nil {
t.Fatalf("remove environment.yaml: %v", err)
}
if got, err := worker.SelectDependencyManifest(base); err != nil {
t.Fatalf("unexpected error: %v", err)
} else if got != "poetry.lock" {
t.Fatalf("expected poetry.lock, got %q", got)
}
// Remove poetry.lock; pyproject.toml should win.
if err := os.Remove(filepath.Join(base, "poetry.lock")); err != nil {
t.Fatalf("remove poetry.lock: %v", err)
}
if got, err := worker.SelectDependencyManifest(base); err != nil {
t.Fatalf("unexpected error: %v", err)
} else if got != "pyproject.toml" {
t.Fatalf("expected pyproject.toml, got %q", got)
}
// Remove pyproject.toml; requirements.txt should win.
if err := os.Remove(filepath.Join(base, "pyproject.toml")); err != nil {
t.Fatalf("remove pyproject.toml: %v", err)
}
if got, err := worker.SelectDependencyManifest(base); err != nil {
t.Fatalf("unexpected error: %v", err)
} else if got != "requirements.txt" {
t.Fatalf("expected requirements.txt, got %q", got)
}
}
func TestSelectDependencyManifestPoetryRequiresPyproject(t *testing.T) {
base := t.TempDir()
// poetry.lock exists but pyproject.toml is missing.
if err := os.WriteFile(filepath.Join(base, "poetry.lock"), []byte("# test\n"), 0600); err != nil {
t.Fatalf("write poetry.lock: %v", err)
}
if _, err := worker.SelectDependencyManifest(base); err == nil {
t.Fatalf("expected error when poetry.lock exists without pyproject.toml")
}
}
func TestSelectDependencyManifestMissing(t *testing.T) {
base := t.TempDir()
if _, err := worker.SelectDependencyManifest(base); err == nil {
t.Fatalf("expected error when no manifest exists")
}
}
func TestResolveDatasetsPrecedence(t *testing.T) {
if got := tests.ResolveDatasets(nil); got != nil {
t.Fatalf("expected nil for nil task")
}
t.Run("DatasetSpecsWins", func(t *testing.T) {
task := &queue.Task{
DatasetSpecs: []queue.DatasetSpec{{Name: "ds-spec"}},
Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args",
}
got := tests.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-spec" {
t.Fatalf("expected dataset_specs to win, got %v", got)
}
})
t.Run("DatasetsWinsOverArgs", func(t *testing.T) {
task := &queue.Task{
Datasets: []string{"ds-legacy"},
Args: "--datasets ds-args",
}
got := tests.ResolveDatasets(task)
if len(got) != 1 || got[0] != "ds-legacy" {
t.Fatalf("expected datasets to win over args, got %v", got)
}
})
t.Run("ArgsFallback", func(t *testing.T) {
task := &queue.Task{Args: "--datasets a,b,c"}
got := tests.ResolveDatasets(task)
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "c" {
t.Fatalf("expected args datasets, got %v", got)
}
})
}
func TestComputeTaskProvenance(t *testing.T) {
base := t.TempDir()
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
// Create experiment files structure.
expMgr := experiment.NewManager(base)
requireNoErr(t, expMgr.CreateExperiment(commitID))
filesPath := expMgr.GetFilesPath(commitID)
// Create a deps manifest in the files path.
depsPath := filepath.Join(filesPath, "requirements.txt")
requireNoErr(t, os.WriteFile(depsPath, []byte("numpy==1.0.0\n"), 0600))
// Write an experiment manifest.json with deterministic overall sha.
manifest := &experiment.Manifest{
CommitID: commitID,
Files: map[string]string{"train.py": "deadbeef"},
OverallSHA: "0123456789abcdef",
Timestamp: 1,
}
requireNoErr(t, expMgr.WriteManifest(manifest))
// Task references commit_id in metadata.
task := &queue.Task{
JobName: "job",
SnapshotID: "snap-1",
DatasetSpecs: []queue.DatasetSpec{{Name: "ds1", Version: "v1"}},
Metadata: map[string]string{"commit_id": commitID},
}
prov, err := worker.ComputeTaskProvenance(base, task)
if err != nil {
t.Fatalf("ComputeTaskProvenance error: %v", err)
}
if prov["snapshot_id"] != "snap-1" {
t.Fatalf("expected snapshot_id, got %q", prov["snapshot_id"])
}
if prov["datasets"] != "ds1" {
t.Fatalf("expected datasets=ds1, got %q", prov["datasets"])
}
if prov["dataset_specs"] == "" {
t.Fatalf("expected dataset_specs json")
}
if prov["experiment_manifest_overall_sha"] != "0123456789abcdef" {
t.Fatalf("expected manifest sha, got %q", prov["experiment_manifest_overall_sha"])
}
if prov["deps_manifest_name"] != "requirements.txt" {
t.Fatalf("expected deps_manifest_name requirements.txt, got %q", prov["deps_manifest_name"])
}
if prov["deps_manifest_sha256"] == "" {
t.Fatalf("expected deps_manifest_sha256")
}
// Graceful behavior with missing metadata.
task2 := &queue.Task{SnapshotID: "snap-2"}
prov2, err := worker.ComputeTaskProvenance(base, task2)
if err != nil {
t.Fatalf("ComputeTaskProvenance (missing metadata) error: %v", err)
}
if prov2["snapshot_id"] != "snap-2" {
t.Fatalf("expected snapshot_id snap-2, got %q", prov2["snapshot_id"])
}
}
func requireNoErr(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNormalizeSHA256ChecksumHex(t *testing.T) {
got, err := worker.NormalizeSHA256ChecksumHex("sha256:" + strings.Repeat("a", 64))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != strings.Repeat("a", 64) {
t.Fatalf("unexpected normalized checksum: %q", got)
}
if _, err := worker.NormalizeSHA256ChecksumHex("sha256:deadbeef"); err == nil {
t.Fatalf("expected error for short checksum")
}
}
func TestVerifyDatasetSpecs(t *testing.T) {
base := t.TempDir()
dataDir := filepath.Join(base, "data")
requireNoErr(t, os.MkdirAll(dataDir, 0750))
// Create dataset directory with one file.
dsName := "dataset1"
dsPath := filepath.Join(dataDir, dsName)
requireNoErr(t, os.MkdirAll(dsPath, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(dsPath, "file.txt"), []byte("hello"), 0600))
sha, err := worker.DirOverallSHA256Hex(dsPath)
requireNoErr(t, err)
w := tests.NewTestWorker(&worker.Config{DataDir: dataDir})
task := &queue.Task{
JobName: "job",
ID: "t1",
DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + sha}},
}
if err := w.VerifyDatasetSpecs(context.Background(), task); err != nil {
t.Fatalf("expected checksum verification to pass, got %v", err)
}
taskBad := &queue.Task{
JobName: "job",
ID: "t2",
DatasetSpecs: []queue.DatasetSpec{{Name: dsName, Checksum: "sha256:" + strings.Repeat("b", 64)}},
}
if err := w.VerifyDatasetSpecs(context.Background(), taskBad); err == nil {
t.Fatalf("expected checksum mismatch error")
}
}
func TestEnforceTaskProvenance_StrictMissingOrMismatchFails(t *testing.T) {
base := t.TempDir()
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
expMgr := experiment.NewManager(base)
requireNoErr(t, expMgr.CreateExperiment(commitID))
filesPath := expMgr.GetFilesPath(commitID)
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
manifest := &experiment.Manifest{
CommitID: commitID,
Files: map[string]string{"train.py": "deadbeef"},
OverallSHA: "0123456789abcdef",
Timestamp: 1,
}
requireNoErr(t, expMgr.WriteManifest(manifest))
w := tests.NewTestWorker(&worker.Config{BasePath: base, ProvenanceBestEffort: false})
// Missing expected fields should fail.
taskMissing := &queue.Task{JobName: "job", ID: "t1", Metadata: map[string]string{"commit_id": commitID}}
if err := w.EnforceTaskProvenance(context.Background(), taskMissing); err == nil {
t.Fatalf("expected missing provenance fields error")
}
// Mismatch should fail.
taskMismatch := &queue.Task{JobName: "job", ID: "t2", Metadata: map[string]string{
"commit_id": commitID,
"experiment_manifest_overall_sha": "bad",
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": "bad",
}}
if err := w.EnforceTaskProvenance(context.Background(), taskMismatch); err == nil {
t.Fatalf("expected mismatch provenance error")
}
// SnapshotID set but missing snapshot_sha256 should fail in strict mode.
snapDir := filepath.Join(base, "data", "snapshots", "snap1")
requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
wSnap := tests.NewTestWorker(&worker.Config{
BasePath: base,
DataDir: filepath.Join(base, "data"),
ProvenanceBestEffort: false,
})
taskSnapMissing := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{
"commit_id": commitID,
"experiment_manifest_overall_sha": "0123456789abcdef",
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": "bad", // still mismatch but we're focusing snapshot field presence
}}
if err := wSnap.EnforceTaskProvenance(context.Background(), taskSnapMissing); err == nil {
t.Fatalf("expected strict provenance to fail when snapshot_sha256 missing")
}
}
func TestEnforceTaskProvenance_BestEffortOverwrites(t *testing.T) {
base := t.TempDir()
commitID := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" // 40 hex
expMgr := experiment.NewManager(base)
requireNoErr(t, expMgr.CreateExperiment(commitID))
filesPath := expMgr.GetFilesPath(commitID)
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
requireNoErr(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
manifest := &experiment.Manifest{
CommitID: commitID,
Files: map[string]string{"train.py": "deadbeef"},
OverallSHA: "0123456789abcdef",
Timestamp: 1,
}
requireNoErr(t, expMgr.WriteManifest(manifest))
dataDir := filepath.Join(base, "data")
snapDir := filepath.Join(dataDir, "snapshots", "snap1")
requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
w := tests.NewTestWorker(&worker.Config{BasePath: base, DataDir: dataDir, ProvenanceBestEffort: true})
task := &queue.Task{JobName: "job", ID: "t3", SnapshotID: "snap1", Metadata: map[string]string{"commit_id": commitID}}
if err := w.EnforceTaskProvenance(context.Background(), task); err != nil {
t.Fatalf("expected best-effort to pass, got %v", err)
}
if task.Metadata["experiment_manifest_overall_sha"] == "" ||
task.Metadata["deps_manifest_sha256"] == "" ||
task.Metadata["snapshot_sha256"] == "" {
t.Fatalf("expected best-effort to populate provenance metadata")
}
}
func TestVerifySnapshot(t *testing.T) {
base := t.TempDir()
dataDir := filepath.Join(base, "data")
requireNoErr(t, os.MkdirAll(dataDir, 0750))
snapID := "snap1"
snapDir := filepath.Join(dataDir, "snapshots", snapID)
requireNoErr(t, os.MkdirAll(snapDir, 0750))
requireNoErr(t, os.WriteFile(filepath.Join(snapDir, "file.txt"), []byte("hello"), 0600))
sha, err := worker.DirOverallSHA256Hex(snapDir)
requireNoErr(t, err)
w := tests.NewTestWorker(&worker.Config{DataDir: dataDir})
t.Run("Ok", func(t *testing.T) {
task := &queue.Task{
JobName: "job",
ID: "t1",
SnapshotID: snapID,
Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha},
}
if err := w.VerifySnapshot(context.Background(), task); err != nil {
t.Fatalf("expected snapshot verification to pass, got %v", err)
}
})
t.Run("MissingChecksum", func(t *testing.T) {
task := &queue.Task{JobName: "job", ID: "t2", SnapshotID: snapID, Metadata: map[string]string{}}
if err := w.VerifySnapshot(context.Background(), task); err == nil {
t.Fatalf("expected error for missing snapshot_sha256")
}
})
t.Run("Mismatch", func(t *testing.T) {
task := &queue.Task{
JobName: "job",
ID: "t3",
SnapshotID: snapID,
Metadata: map[string]string{"snapshot_sha256": "sha256:" + strings.Repeat("b", 64)},
}
if err := w.VerifySnapshot(context.Background(), task); err == nil {
t.Fatalf("expected checksum mismatch")
}
})
t.Run("MissingDir", func(t *testing.T) {
task := &queue.Task{
JobName: "job",
ID: "t4",
SnapshotID: "missing",
Metadata: map[string]string{"snapshot_sha256": "sha256:" + sha},
}
if err := w.VerifySnapshot(context.Background(), task); err == nil {
t.Fatalf("expected missing snapshot directory error")
}
})
}