// Package container provides Podman container management utilities. package container import ( "context" "fmt" "os" "os/exec" "path/filepath" "strconv" "strings" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/logging" ) // PodmanManager manages Podman containers type PodmanManager struct { logger *logging.Logger } // NewPodmanManager creates a new Podman manager func NewPodmanManager(logger *logging.Logger) (*PodmanManager, error) { return &PodmanManager{ logger: logger, }, nil } // ContainerConfig holds configuration for starting a container type ContainerConfig struct { Name string `json:"name"` Image string `json:"image"` Command []string `json:"command"` Env map[string]string `json:"env"` Volumes map[string]string `json:"volumes"` Ports map[int]int `json:"ports"` SecurityOpts []string `json:"security_opts"` Resources ResourceConfig `json:"resources"` Network NetworkConfig `json:"network"` } // ResourceConfig defines resource limits for containers type ResourceConfig struct { MemoryLimit string `json:"memory_limit"` CPULimit string `json:"cpu_limit"` GPUDevices []string `json:"gpu_devices"` AppleGPU bool `json:"apple_gpu"` } // NetworkConfig defines network settings for containers type NetworkConfig struct { AllowNetwork bool `json:"allow_network"` } func podmanCgroupsMode() string { return strings.TrimSpace(os.Getenv("FETCHML_PODMAN_CGROUPS")) } func BuildRunArgs(config *ContainerConfig) []string { args := []string{"run", "-d"} if podmanCgroupsMode() == "disabled" { args = append(args, "--cgroups=disabled") } // Add name if config.Name != "" { args = append(args, "--name", config.Name) } // Add security options for _, opt := range config.SecurityOpts { args = append(args, "--security-opt", opt) } // Add resource limits if config.Resources.MemoryLimit != "" { args = append(args, "--memory", config.Resources.MemoryLimit) } if config.Resources.CPULimit != "" { args = append(args, "--cpus", config.Resources.CPULimit) } if config.Resources.AppleGPU { args = append(args, "--device", "/dev/metal") args = append(args, "--device", "/dev/mps") } for _, device := range config.Resources.GPUDevices { args = append(args, "--device", device) } // Add volumes for hostPath, containerPath := range config.Volumes { mount := fmt.Sprintf("%s:%s", hostPath, containerPath) args = append(args, "-v", mount) } // Add ports for hostPort, containerPort := range config.Ports { portMapping := fmt.Sprintf("%d:%d", hostPort, containerPort) args = append(args, "-p", portMapping) } // Add environment variables for key, value := range config.Env { args = append(args, "-e", fmt.Sprintf("%s=%s", key, value)) } // Add image and command args = append(args, config.Image) args = append(args, config.Command...) return args } func ParseContainerID(output string) (string, error) { out := strings.TrimSpace(output) if out == "" { return "", fmt.Errorf("no container ID returned") } lines := strings.Split(out, "\n") for i := len(lines) - 1; i >= 0; i-- { line := strings.TrimSpace(lines[i]) if line == "" { continue } return line, nil } return "", fmt.Errorf("no container ID returned") } // StartContainer starts a new container func (pm *PodmanManager) StartContainer( ctx context.Context, config *ContainerConfig, ) (string, error) { args := BuildRunArgs(config) // Execute command cmd := exec.CommandContext(ctx, "podman", args...) output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("failed to start container: %w, output: %s", err, string(output)) } containerID, err := ParseContainerID(string(output)) if err != nil { return "", err } pm.logger.Info("container started", "container_id", containerID, "name", config.Name) return containerID, nil } // StopContainer stops a container func (pm *PodmanManager) StopContainer(ctx context.Context, containerID string) error { cmd := exec.CommandContext(ctx, "podman", "stop", containerID) output, err := cmd.CombinedOutput() if err != nil { return fmt.Errorf("failed to stop container: %w, output: %s", err, string(output)) } pm.logger.Info("container stopped", "container_id", containerID) return nil } // GetContainerStateStatus returns the container's lifecycle state from `podman inspect`. // Typical values: running, exited, created, paused. func (pm *PodmanManager) GetContainerStateStatus( ctx context.Context, containerID string, ) (string, error) { // Validate containerID to prevent injection if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") { return "", fmt.Errorf("invalid container ID: %s", containerID) } cmd := exec.CommandContext(ctx, "podman", "inspect", "--format", "{{.State.Status}}", containerID) //nolint:gosec output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("failed to inspect container: %w, output: %s", err, string(output)) } status := strings.TrimSpace(string(output)) if status == "" { return "unknown", nil } return status, nil } // RemoveContainer removes a container func (pm *PodmanManager) RemoveContainer(ctx context.Context, containerID string) error { cmd := exec.CommandContext(ctx, "podman", "rm", containerID) output, err := cmd.CombinedOutput() if err != nil { return fmt.Errorf("failed to remove container: %w, output: %s", err, string(output)) } pm.logger.Info("container removed", "container_id", containerID) return nil } // GetContainerStatus gets the status of a container func (pm *PodmanManager) GetContainerStatus( ctx context.Context, containerID string, ) (string, error) { // Validate containerID to prevent injection if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") { return "", fmt.Errorf("invalid container ID: %s", containerID) } cmd := exec.CommandContext(ctx, "podman", "ps", "--filter", "id="+containerID, "--format", "{{.Status}}") //nolint:gosec output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("failed to get container status: %w, output: %s", err, string(output)) } status := strings.TrimSpace(string(output)) if status == "" { // Container might be stopped, check all containers cmd = exec.CommandContext( ctx, "podman", "ps", "-a", "--filter", "id="+containerID, "--format", "{{.Status}}", ) //nolint:gosec output, err = cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("failed to get container status: %w, output: %s", err, string(output)) } status = strings.TrimSpace(string(output)) if status == "" { return "unknown", nil } } return status, nil } // ExecContainer executes a command inside a running container and returns the output func (pm *PodmanManager) ExecContainer(ctx context.Context, containerID string, command []string) (string, error) { // Validate containerID to prevent injection if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") { return "", fmt.Errorf("invalid container ID: %s", containerID) } // Validate command to prevent injection for _, arg := range command { if strings.ContainsAny(arg, "&;|<>$`\"'") { return "", fmt.Errorf("invalid command argument: %s", arg) } } // Build podman exec command args := []string{"exec", containerID} args = append(args, command...) cmd := exec.CommandContext(ctx, "podman", args...) //nolint:gosec output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("failed to execute command in container: %w, output: %s", err, string(output)) } return string(output), nil } // PodmanConfig holds configuration for Podman container execution type PodmanConfig struct { Image string Workspace string Results string ContainerWorkspace string ContainerResults string AppleGPU bool GPUDevices []string Env map[string]string Volumes map[string]string Memory string CPUs string } // PodmanResourceOverrides converts per-task resource requests into Podman-compatible // `--cpus` and `--memory` flag values. // // cpu and memoryGB are treated as optional; values <= 0 return empty overrides. func PodmanResourceOverrides(cpu int, memoryGB int) (cpus string, memory string) { if cpu > 0 { cpus = strconv.Itoa(cpu) } if memoryGB > 0 { memory = fmt.Sprintf("%dg", memoryGB) } return cpus, memory } // BuildPodmanCommand builds a Podman command for executing ML experiments func BuildPodmanCommand( ctx context.Context, cfg PodmanConfig, scriptPath, depsPath string, extraArgs []string, ) *exec.Cmd { args := []string{ "run", "--rm", "--security-opt", "no-new-privileges", "--cap-drop", "ALL", } if cfg.Memory != "" { args = append(args, "--memory", cfg.Memory) } else { args = append(args, "--memory", config.DefaultPodmanMemory) } if cfg.CPUs != "" { args = append(args, "--cpus", cfg.CPUs) } else { args = append(args, "--cpus", config.DefaultPodmanCPUs) } args = append(args, "--userns", "keep-id") // Mount workspace workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace) args = append(args, "-v", workspaceMount) // Mount results resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults) args = append(args, "-v", resultsMount) // Mount additional volumes for hostPath, containerPath := range cfg.Volumes { mount := fmt.Sprintf("%s:%s", hostPath, containerPath) args = append(args, "-v", mount) } // Use injected GPU device paths for Apple GPU or custom configurations for _, device := range cfg.GPUDevices { args = append(args, "--device", device) } // Add environment variables for key, value := range cfg.Env { args = append(args, "-e", fmt.Sprintf("%s=%s", key, value)) } // Image and command args = append(args, cfg.Image, "--workspace", cfg.ContainerWorkspace, "--deps", depsPath, "--script", scriptPath, ) // Add extra arguments via --args flag if len(extraArgs) > 0 { args = append(args, "--args") args = append(args, extraArgs...) } return exec.CommandContext(ctx, "podman", args...) } // SanitizePath ensures a path is safe to use (prevents path traversal) func SanitizePath(path string) (string, error) { // Clean the path to remove any .. or . components cleaned := filepath.Clean(path) // Check for path traversal attempts if strings.Contains(cleaned, "..") { return "", fmt.Errorf("path traversal detected: %s", path) } return cleaned, nil } // ValidateJobName validates a job name is safe func ValidateJobName(jobName string) error { if jobName == "" { return fmt.Errorf("job name cannot be empty") } // Check for dangerous characters if strings.ContainsAny(jobName, "/\\<>:\"|?*") { return fmt.Errorf("job name contains invalid characters: %s", jobName) } // Check for path traversal if strings.Contains(jobName, "..") { return fmt.Errorf("job name contains path traversal: %s", jobName) } return nil }