package tests import ( "path/filepath" "reflect" "testing" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" ) func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) { cfg := container.PodmanConfig{ Image: "registry.example/fetch:latest", Workspace: "/host/workspace", Results: "/host/results", ContainerWorkspace: "/workspace", ContainerResults: "/results", GPUAccess: true, } cmd := container.BuildPodmanCommand(cfg, "/workspace/train.py", "/workspace/requirements.txt", []string{"--foo=bar", "baz"}) expected := []string{ "podman", "run", "--rm", "--security-opt", "no-new-privileges", "--cap-drop", "ALL", "--memory", config.DefaultPodmanMemory, "--cpus", config.DefaultPodmanCPUs, "--userns", "keep-id", "-v", "/host/workspace:/workspace:rw", "-v", "/host/results:/results:rw", "--device", "/dev/dri", "registry.example/fetch:latest", "--workspace", "/workspace", "--requirements", "/workspace/requirements.txt", "--script", "/workspace/train.py", "--args", "--foo=bar", "baz", } if !reflect.DeepEqual(cmd.Args, expected) { t.Fatalf("unexpected podman args\nwant: %v\ngot: %v", expected, cmd.Args) } } func TestBuildPodmanCommand_Overrides(t *testing.T) { cfg := container.PodmanConfig{ Image: "fetch:test", Workspace: "/w", Results: "/r", ContainerWorkspace: "/cw", ContainerResults: "/cr", GPUAccess: false, Memory: "16g", CPUs: "8", } cmd := container.BuildPodmanCommand(cfg, "script.py", "reqs.txt", nil) if contains(cmd.Args, "--device") { t.Fatalf("expected GPU device flag to be omitted when GPUAccess is false: %v", cmd.Args) } if !containsSequence(cmd.Args, []string{"--memory", "16g"}) { t.Fatalf("expected custom memory flag, got %v", cmd.Args) } if !containsSequence(cmd.Args, []string{"--cpus", "8"}) { t.Fatalf("expected custom cpu flag, got %v", cmd.Args) } } func TestSanitizePath(t *testing.T) { input := filepath.Join("/tmp", "..", "tmp", "jobs") cleaned, err := container.SanitizePath(input) if err != nil { t.Fatalf("expected path to sanitize, got error: %v", err) } expected := filepath.Clean(input) if cleaned != expected { t.Fatalf("sanitize mismatch: want %s got %s", expected, cleaned) } } func TestSanitizePathRejectsTraversal(t *testing.T) { if _, err := container.SanitizePath("../../etc/passwd"); err == nil { t.Fatal("expected traversal path to be rejected") } } func TestValidateJobName(t *testing.T) { if err := container.ValidateJobName("job-123"); err != nil { t.Fatalf("validate job unexpectedly failed: %v", err) } } func TestValidateJobNameRejectsBadInput(t *testing.T) { cases := []string{"", "bad/name", "job..1"} for _, tc := range cases { if err := container.ValidateJobName(tc); err == nil { t.Fatalf("expected job name %q to be rejected", tc) } } } func contains(values []string, target string) bool { for _, v := range values { if v == target { return true } } return false } func containsSequence(values []string, seq []string) bool { outerLen := len(values) innerLen := len(seq) for i := 0; i <= outerLen-innerLen; i++ { if reflect.DeepEqual(values[i:i+innerLen], seq) { return true } } return false }