package tests import ( "context" "os" "path/filepath" "reflect" "testing" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/container" ) func TestBuildRunArgs_CgroupsDisabled(t *testing.T) { old := os.Getenv("FETCHML_PODMAN_CGROUPS") _ = os.Setenv("FETCHML_PODMAN_CGROUPS", "disabled") t.Cleanup(func() { _ = os.Setenv("FETCHML_PODMAN_CGROUPS", old) }) cfg := &container.ContainerConfig{Image: "img", Command: []string{"echo", "hi"}} args := container.BuildRunArgs(cfg) found := false for _, a := range args { if a == "--cgroups=disabled" { found = true break } } if !found { t.Fatalf("expected --cgroups=disabled in args: %#v", args) } } func TestParseContainerID_LastLine(t *testing.T) { out := "Trying to pull quay.io/jupyter/base-notebook:latest...\nWriting manifest to image destination\nabc123\n" id, err := container.ParseContainerID(out) if err != nil { t.Fatalf("expected nil error, got %v", err) } if id != "abc123" { t.Fatalf("expected abc123, got %q", id) } } func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) { cfg := container.PodmanConfig{ Image: "registry.example/fetch:latest", Workspace: "/host/workspace", Results: "/host/results", ContainerWorkspace: "/workspace", ContainerResults: "/results", GPUDevices: []string{"/dev/dri"}, Env: map[string]string{ "CUDA_VISIBLE_DEVICES": "0,1", }, } cmd := container.BuildPodmanCommandLegacy( context.Background(), cfg, "/workspace/train.py", "/workspace/requirements.txt", []string{"--foo=bar", "baz"}, ) expected := []string{ "podman", "run", "--rm", "--security-opt", "no-new-privileges", "--cap-drop", "ALL", "--memory", config.DefaultPodmanMemory, "--cpus", config.DefaultPodmanCPUs, "--userns", "keep-id", "-v", "/host/workspace:/workspace:rw", "-v", "/host/results:/results:rw", "--device", "/dev/dri", "-e", "CUDA_VISIBLE_DEVICES=0,1", "registry.example/fetch:latest", "--workspace", "/workspace", "--deps", "/workspace/requirements.txt", "--script", "/workspace/train.py", "--args", "--foo=bar", "baz", } if !reflect.DeepEqual(cmd.Args, expected) { t.Fatalf("unexpected podman args\nwant: %v\ngot: %v", expected, cmd.Args) } } func TestBuildPodmanCommand_Overrides(t *testing.T) { cfg := container.PodmanConfig{ Image: "fetch:test", Workspace: "/w", Results: "/r", ContainerWorkspace: "/cw", ContainerResults: "/cr", Memory: "16g", CPUs: "8", } cmd := container.BuildPodmanCommandLegacy(context.Background(), cfg, "script.py", "reqs.txt", nil) if contains(cmd.Args, "--device") { t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args) } if !containsSequence(cmd.Args, []string{"--memory", "16g"}) { t.Fatalf("expected custom memory flag, got %v", cmd.Args) } if !containsSequence(cmd.Args, []string{"--cpus", "8"}) { t.Fatalf("expected custom cpu flag, got %v", cmd.Args) } } func TestPodmanResourceOverrides_FromTaskValues(t *testing.T) { cpus, mem := container.PodmanResourceOverrides(2, 8) if cpus != "2" { t.Fatalf("expected cpus override '2', got %q", cpus) } if mem != "8g" { t.Fatalf("expected memory override '8g', got %q", mem) } cpus, mem = container.PodmanResourceOverrides(0, 0) if cpus != "" || mem != "" { t.Fatalf("expected empty overrides for zero values, got cpus=%q mem=%q", cpus, mem) } } func TestSanitizePath(t *testing.T) { input := filepath.Join("/tmp", "..", "tmp", "jobs") cleaned, err := container.SanitizePath(input) if err != nil { t.Fatalf("expected path to sanitize, got error: %v", err) } expected := filepath.Clean(input) if cleaned != expected { t.Fatalf("sanitize mismatch: want %s got %s", expected, cleaned) } } func TestSanitizePathRejectsTraversal(t *testing.T) { if _, err := container.SanitizePath("../../etc/passwd"); err == nil { t.Fatal("expected traversal path to be rejected") } } func TestValidateJobName(t *testing.T) { if err := container.ValidateJobName("job-123"); err != nil { t.Fatalf("validate job unexpectedly failed: %v", err) } } func TestValidateJobNameRejectsBadInput(t *testing.T) { cases := []string{"", "bad/name", "job..1"} for _, tc := range cases { if err := container.ValidateJobName(tc); err == nil { t.Fatalf("expected job name %q to be rejected", tc) } } } func contains(values []string, target string) bool { for _, v := range values { if v == target { return true } } return false } func containsSequence(values []string, seq []string) bool { outerLen := len(values) innerLen := len(seq) for i := 0; i <= outerLen-innerLen; i++ { if reflect.DeepEqual(values[i:i+innerLen], seq) { return true } } return false }