fetch_ml/tests/unit/container/podman_test.go

185 lines
4.7 KiB
Go

package tests
import (
"context"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
)
func TestBuildRunArgs_CgroupsDisabled(t *testing.T) {
old := os.Getenv("FETCHML_PODMAN_CGROUPS")
_ = os.Setenv("FETCHML_PODMAN_CGROUPS", "disabled")
t.Cleanup(func() {
_ = os.Setenv("FETCHML_PODMAN_CGROUPS", old)
})
cfg := &container.ContainerConfig{Image: "img", Command: []string{"echo", "hi"}}
args := container.BuildRunArgs(cfg)
found := false
for _, a := range args {
if a == "--cgroups=disabled" {
found = true
break
}
}
if !found {
t.Fatalf("expected --cgroups=disabled in args: %#v", args)
}
}
func TestParseContainerID_LastLine(t *testing.T) {
out := "Trying to pull quay.io/jupyter/base-notebook:latest...\nWriting manifest to image destination\nabc123\n"
id, err := container.ParseContainerID(out)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if id != "abc123" {
t.Fatalf("expected abc123, got %q", id)
}
}
func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
cfg := container.PodmanConfig{
Image: "registry.example/fetch:latest",
Workspace: "/host/workspace",
Results: "/host/results",
ContainerWorkspace: "/workspace",
ContainerResults: "/results",
GPUDevices: []string{"/dev/dri"},
Env: map[string]string{
"CUDA_VISIBLE_DEVICES": "0,1",
},
}
cmd := container.BuildPodmanCommand(
context.Background(),
cfg,
"/workspace/train.py",
"/workspace/requirements.txt",
[]string{"--foo=bar", "baz"},
)
expected := []string{
"podman",
"run", "--rm",
"--security-opt", "no-new-privileges",
"--cap-drop", "ALL",
"--memory", config.DefaultPodmanMemory,
"--cpus", config.DefaultPodmanCPUs,
"--userns", "keep-id",
"-v", "/host/workspace:/workspace:rw",
"-v", "/host/results:/results:rw",
"--device", "/dev/dri",
"-e", "CUDA_VISIBLE_DEVICES=0,1",
"registry.example/fetch:latest",
"--workspace", "/workspace",
"--deps", "/workspace/requirements.txt",
"--script", "/workspace/train.py",
"--args",
"--foo=bar", "baz",
}
if !reflect.DeepEqual(cmd.Args, expected) {
t.Fatalf("unexpected podman args\nwant: %v\ngot: %v", expected, cmd.Args)
}
}
func TestBuildPodmanCommand_Overrides(t *testing.T) {
cfg := container.PodmanConfig{
Image: "fetch:test",
Workspace: "/w",
Results: "/r",
ContainerWorkspace: "/cw",
ContainerResults: "/cr",
Memory: "16g",
CPUs: "8",
}
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil)
if contains(cmd.Args, "--device") {
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)
}
if !containsSequence(cmd.Args, []string{"--memory", "16g"}) {
t.Fatalf("expected custom memory flag, got %v", cmd.Args)
}
if !containsSequence(cmd.Args, []string{"--cpus", "8"}) {
t.Fatalf("expected custom cpu flag, got %v", cmd.Args)
}
}
func TestPodmanResourceOverrides_FromTaskValues(t *testing.T) {
cpus, mem := container.PodmanResourceOverrides(2, 8)
if cpus != "2" {
t.Fatalf("expected cpus override '2', got %q", cpus)
}
if mem != "8g" {
t.Fatalf("expected memory override '8g', got %q", mem)
}
cpus, mem = container.PodmanResourceOverrides(0, 0)
if cpus != "" || mem != "" {
t.Fatalf("expected empty overrides for zero values, got cpus=%q mem=%q", cpus, mem)
}
}
func TestSanitizePath(t *testing.T) {
input := filepath.Join("/tmp", "..", "tmp", "jobs")
cleaned, err := container.SanitizePath(input)
if err != nil {
t.Fatalf("expected path to sanitize, got error: %v", err)
}
expected := filepath.Clean(input)
if cleaned != expected {
t.Fatalf("sanitize mismatch: want %s got %s", expected, cleaned)
}
}
func TestSanitizePathRejectsTraversal(t *testing.T) {
if _, err := container.SanitizePath("../../etc/passwd"); err == nil {
t.Fatal("expected traversal path to be rejected")
}
}
func TestValidateJobName(t *testing.T) {
if err := container.ValidateJobName("job-123"); err != nil {
t.Fatalf("validate job unexpectedly failed: %v", err)
}
}
func TestValidateJobNameRejectsBadInput(t *testing.T) {
cases := []string{"", "bad/name", "job..1"}
for _, tc := range cases {
if err := container.ValidateJobName(tc); err == nil {
t.Fatalf("expected job name %q to be rejected", tc)
}
}
}
func contains(values []string, target string) bool {
for _, v := range values {
if v == target {
return true
}
}
return false
}
func containsSequence(values []string, seq []string) bool {
outerLen := len(values)
innerLen := len(seq)
for i := 0; i <= outerLen-innerLen; i++ {
if reflect.DeepEqual(values[i:i+innerLen], seq) {
return true
}
}
return false
}