185 lines
4.7 KiB
Go
185 lines
4.7 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/container"
|
|
)
|
|
|
|
func TestBuildRunArgs_CgroupsDisabled(t *testing.T) {
|
|
old := os.Getenv("FETCHML_PODMAN_CGROUPS")
|
|
_ = os.Setenv("FETCHML_PODMAN_CGROUPS", "disabled")
|
|
t.Cleanup(func() {
|
|
_ = os.Setenv("FETCHML_PODMAN_CGROUPS", old)
|
|
})
|
|
|
|
cfg := &container.ContainerConfig{Image: "img", Command: []string{"echo", "hi"}}
|
|
args := container.BuildRunArgs(cfg)
|
|
found := false
|
|
for _, a := range args {
|
|
if a == "--cgroups=disabled" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("expected --cgroups=disabled in args: %#v", args)
|
|
}
|
|
}
|
|
|
|
func TestParseContainerID_LastLine(t *testing.T) {
|
|
out := "Trying to pull quay.io/jupyter/base-notebook:latest...\nWriting manifest to image destination\nabc123\n"
|
|
id, err := container.ParseContainerID(out)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if id != "abc123" {
|
|
t.Fatalf("expected abc123, got %q", id)
|
|
}
|
|
}
|
|
|
|
func TestBuildPodmanCommand_DefaultsAndArgs(t *testing.T) {
|
|
cfg := container.PodmanConfig{
|
|
Image: "registry.example/fetch:latest",
|
|
Workspace: "/host/workspace",
|
|
Results: "/host/results",
|
|
ContainerWorkspace: "/workspace",
|
|
ContainerResults: "/results",
|
|
GPUDevices: []string{"/dev/dri"},
|
|
Env: map[string]string{
|
|
"CUDA_VISIBLE_DEVICES": "0,1",
|
|
},
|
|
}
|
|
|
|
cmd := container.BuildPodmanCommand(
|
|
context.Background(),
|
|
cfg,
|
|
"/workspace/train.py",
|
|
"/workspace/requirements.txt",
|
|
[]string{"--foo=bar", "baz"},
|
|
)
|
|
|
|
expected := []string{
|
|
"podman",
|
|
"run", "--rm",
|
|
"--security-opt", "no-new-privileges",
|
|
"--cap-drop", "ALL",
|
|
"--memory", config.DefaultPodmanMemory,
|
|
"--cpus", config.DefaultPodmanCPUs,
|
|
"--userns", "keep-id",
|
|
"-v", "/host/workspace:/workspace:rw",
|
|
"-v", "/host/results:/results:rw",
|
|
"--device", "/dev/dri",
|
|
"-e", "CUDA_VISIBLE_DEVICES=0,1",
|
|
"registry.example/fetch:latest",
|
|
"--workspace", "/workspace",
|
|
"--deps", "/workspace/requirements.txt",
|
|
"--script", "/workspace/train.py",
|
|
"--args",
|
|
"--foo=bar", "baz",
|
|
}
|
|
|
|
if !reflect.DeepEqual(cmd.Args, expected) {
|
|
t.Fatalf("unexpected podman args\nwant: %v\ngot: %v", expected, cmd.Args)
|
|
}
|
|
}
|
|
|
|
func TestBuildPodmanCommand_Overrides(t *testing.T) {
|
|
cfg := container.PodmanConfig{
|
|
Image: "fetch:test",
|
|
Workspace: "/w",
|
|
Results: "/r",
|
|
ContainerWorkspace: "/cw",
|
|
ContainerResults: "/cr",
|
|
Memory: "16g",
|
|
CPUs: "8",
|
|
}
|
|
|
|
cmd := container.BuildPodmanCommand(context.Background(), cfg, "script.py", "reqs.txt", nil)
|
|
|
|
if contains(cmd.Args, "--device") {
|
|
t.Fatalf("expected GPU device flag to be omitted when GPUDevices is empty: %v", cmd.Args)
|
|
}
|
|
|
|
if !containsSequence(cmd.Args, []string{"--memory", "16g"}) {
|
|
t.Fatalf("expected custom memory flag, got %v", cmd.Args)
|
|
}
|
|
|
|
if !containsSequence(cmd.Args, []string{"--cpus", "8"}) {
|
|
t.Fatalf("expected custom cpu flag, got %v", cmd.Args)
|
|
}
|
|
}
|
|
|
|
func TestPodmanResourceOverrides_FromTaskValues(t *testing.T) {
|
|
cpus, mem := container.PodmanResourceOverrides(2, 8)
|
|
if cpus != "2" {
|
|
t.Fatalf("expected cpus override '2', got %q", cpus)
|
|
}
|
|
if mem != "8g" {
|
|
t.Fatalf("expected memory override '8g', got %q", mem)
|
|
}
|
|
|
|
cpus, mem = container.PodmanResourceOverrides(0, 0)
|
|
if cpus != "" || mem != "" {
|
|
t.Fatalf("expected empty overrides for zero values, got cpus=%q mem=%q", cpus, mem)
|
|
}
|
|
}
|
|
|
|
func TestSanitizePath(t *testing.T) {
|
|
input := filepath.Join("/tmp", "..", "tmp", "jobs")
|
|
cleaned, err := container.SanitizePath(input)
|
|
if err != nil {
|
|
t.Fatalf("expected path to sanitize, got error: %v", err)
|
|
}
|
|
|
|
expected := filepath.Clean(input)
|
|
if cleaned != expected {
|
|
t.Fatalf("sanitize mismatch: want %s got %s", expected, cleaned)
|
|
}
|
|
}
|
|
|
|
func TestSanitizePathRejectsTraversal(t *testing.T) {
|
|
if _, err := container.SanitizePath("../../etc/passwd"); err == nil {
|
|
t.Fatal("expected traversal path to be rejected")
|
|
}
|
|
}
|
|
|
|
func TestValidateJobName(t *testing.T) {
|
|
if err := container.ValidateJobName("job-123"); err != nil {
|
|
t.Fatalf("validate job unexpectedly failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateJobNameRejectsBadInput(t *testing.T) {
|
|
cases := []string{"", "bad/name", "job..1"}
|
|
for _, tc := range cases {
|
|
if err := container.ValidateJobName(tc); err == nil {
|
|
t.Fatalf("expected job name %q to be rejected", tc)
|
|
}
|
|
}
|
|
}
|
|
|
|
func contains(values []string, target string) bool {
|
|
for _, v := range values {
|
|
if v == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func containsSequence(values []string, seq []string) bool {
|
|
outerLen := len(values)
|
|
innerLen := len(seq)
|
|
for i := 0; i <= outerLen-innerLen; i++ {
|
|
if reflect.DeepEqual(values[i:i+innerLen], seq) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|