fetch_ml/internal/container/podman.go

396 lines
11 KiB
Go

// Package container provides Podman container management utilities.
package container
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// PodmanManager manages Podman containers
type PodmanManager struct {
logger *logging.Logger
}
// NewPodmanManager creates a new Podman manager
func NewPodmanManager(logger *logging.Logger) (*PodmanManager, error) {
return &PodmanManager{
logger: logger,
}, nil
}
// ContainerConfig holds configuration for starting a container
type ContainerConfig struct {
Name string `json:"name"`
Image string `json:"image"`
Command []string `json:"command"`
Env map[string]string `json:"env"`
Volumes map[string]string `json:"volumes"`
Ports map[int]int `json:"ports"`
SecurityOpts []string `json:"security_opts"`
Resources ResourceConfig `json:"resources"`
Network NetworkConfig `json:"network"`
}
// ResourceConfig defines resource limits for containers
type ResourceConfig struct {
MemoryLimit string `json:"memory_limit"`
CPULimit string `json:"cpu_limit"`
GPUDevices []string `json:"gpu_devices"`
AppleGPU bool `json:"apple_gpu"`
}
// NetworkConfig defines network settings for containers
type NetworkConfig struct {
AllowNetwork bool `json:"allow_network"`
}
func podmanCgroupsMode() string {
return strings.TrimSpace(os.Getenv("FETCHML_PODMAN_CGROUPS"))
}
func BuildRunArgs(config *ContainerConfig) []string {
args := []string{"run", "-d"}
if podmanCgroupsMode() == "disabled" {
args = append(args, "--cgroups=disabled")
}
// Add name
if config.Name != "" {
args = append(args, "--name", config.Name)
}
// Add security options
for _, opt := range config.SecurityOpts {
args = append(args, "--security-opt", opt)
}
// Add resource limits
if config.Resources.MemoryLimit != "" {
args = append(args, "--memory", config.Resources.MemoryLimit)
}
if config.Resources.CPULimit != "" {
args = append(args, "--cpus", config.Resources.CPULimit)
}
if config.Resources.AppleGPU {
args = append(args, "--device", "/dev/metal")
args = append(args, "--device", "/dev/mps")
}
for _, device := range config.Resources.GPUDevices {
args = append(args, "--device", device)
}
// Add volumes
for hostPath, containerPath := range config.Volumes {
mount := fmt.Sprintf("%s:%s", hostPath, containerPath)
args = append(args, "-v", mount)
}
// Add ports
for hostPort, containerPort := range config.Ports {
portMapping := fmt.Sprintf("%d:%d", hostPort, containerPort)
args = append(args, "-p", portMapping)
}
// Add environment variables
for key, value := range config.Env {
args = append(args, "-e", fmt.Sprintf("%s=%s", key, value))
}
// Add image and command
args = append(args, config.Image)
args = append(args, config.Command...)
return args
}
func ParseContainerID(output string) (string, error) {
out := strings.TrimSpace(output)
if out == "" {
return "", fmt.Errorf("no container ID returned")
}
lines := strings.Split(out, "\n")
for i := len(lines) - 1; i >= 0; i-- {
line := strings.TrimSpace(lines[i])
if line == "" {
continue
}
return line, nil
}
return "", fmt.Errorf("no container ID returned")
}
// StartContainer starts a new container
func (pm *PodmanManager) StartContainer(
ctx context.Context,
config *ContainerConfig,
) (string, error) {
args := BuildRunArgs(config)
// Execute command
cmd := exec.CommandContext(ctx, "podman", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to start container: %w, output: %s", err, string(output))
}
containerID, err := ParseContainerID(string(output))
if err != nil {
return "", err
}
pm.logger.Info("container started", "container_id", containerID, "name", config.Name)
return containerID, nil
}
// StopContainer stops a container
func (pm *PodmanManager) StopContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "podman", "stop", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to stop container: %w, output: %s", err, string(output))
}
pm.logger.Info("container stopped", "container_id", containerID)
return nil
}
// GetContainerStateStatus returns the container's lifecycle state from `podman inspect`.
// Typical values: running, exited, created, paused.
func (pm *PodmanManager) GetContainerStateStatus(
ctx context.Context,
containerID string,
) (string, error) {
// Validate containerID to prevent injection
if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") {
return "", fmt.Errorf("invalid container ID: %s", containerID)
}
cmd := exec.CommandContext(ctx, "podman", "inspect", "--format", "{{.State.Status}}", containerID) //nolint:gosec
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to inspect container: %w, output: %s", err, string(output))
}
status := strings.TrimSpace(string(output))
if status == "" {
return "unknown", nil
}
return status, nil
}
// RemoveContainer removes a container
func (pm *PodmanManager) RemoveContainer(ctx context.Context, containerID string) error {
cmd := exec.CommandContext(ctx, "podman", "rm", containerID)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove container: %w, output: %s", err, string(output))
}
pm.logger.Info("container removed", "container_id", containerID)
return nil
}
// GetContainerStatus gets the status of a container
func (pm *PodmanManager) GetContainerStatus(
ctx context.Context,
containerID string,
) (string, error) {
// Validate containerID to prevent injection
if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") {
return "", fmt.Errorf("invalid container ID: %s", containerID)
}
cmd := exec.CommandContext(ctx, "podman", "ps", "--filter", "id="+containerID,
"--format", "{{.Status}}") //nolint:gosec
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get container status: %w, output: %s", err, string(output))
}
status := strings.TrimSpace(string(output))
if status == "" {
// Container might be stopped, check all containers
cmd = exec.CommandContext(
ctx,
"podman",
"ps",
"-a",
"--filter",
"id="+containerID,
"--format",
"{{.Status}}",
) //nolint:gosec
output, err = cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get container status: %w, output: %s", err, string(output))
}
status = strings.TrimSpace(string(output))
if status == "" {
return "unknown", nil
}
}
return status, nil
}
// ExecContainer executes a command inside a running container and returns the output
func (pm *PodmanManager) ExecContainer(ctx context.Context, containerID string, command []string) (string, error) {
// Validate containerID to prevent injection
if containerID == "" || strings.ContainsAny(containerID, "&;|<>$`\"'") {
return "", fmt.Errorf("invalid container ID: %s", containerID)
}
// Validate command to prevent injection
for _, arg := range command {
if strings.ContainsAny(arg, "&;|<>$`\"'") {
return "", fmt.Errorf("invalid command argument: %s", arg)
}
}
// Build podman exec command
args := []string{"exec", containerID}
args = append(args, command...)
cmd := exec.CommandContext(ctx, "podman", args...) //nolint:gosec
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to execute command in container: %w, output: %s", err, string(output))
}
return string(output), nil
}
// PodmanConfig holds configuration for Podman container execution
type PodmanConfig struct {
Image string
Workspace string
Results string
ContainerWorkspace string
ContainerResults string
AppleGPU bool
GPUDevices []string
Env map[string]string
Volumes map[string]string
Memory string
CPUs string
}
// PodmanResourceOverrides converts per-task resource requests into Podman-compatible
// `--cpus` and `--memory` flag values.
//
// cpu and memoryGB are treated as optional; values <= 0 return empty overrides.
func PodmanResourceOverrides(cpu int, memoryGB int) (cpus string, memory string) {
if cpu > 0 {
cpus = strconv.Itoa(cpu)
}
if memoryGB > 0 {
memory = fmt.Sprintf("%dg", memoryGB)
}
return cpus, memory
}
// BuildPodmanCommand builds a Podman command for executing ML experiments
func BuildPodmanCommand(
ctx context.Context,
cfg PodmanConfig,
scriptPath, depsPath string,
extraArgs []string,
) *exec.Cmd {
args := []string{
"run", "--rm",
"--security-opt", "no-new-privileges",
"--cap-drop", "ALL",
}
if cfg.Memory != "" {
args = append(args, "--memory", cfg.Memory)
} else {
args = append(args, "--memory", config.DefaultPodmanMemory)
}
if cfg.CPUs != "" {
args = append(args, "--cpus", cfg.CPUs)
} else {
args = append(args, "--cpus", config.DefaultPodmanCPUs)
}
args = append(args, "--userns", "keep-id")
// Mount workspace
workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace)
args = append(args, "-v", workspaceMount)
// Mount results
resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults)
args = append(args, "-v", resultsMount)
// Mount additional volumes
for hostPath, containerPath := range cfg.Volumes {
mount := fmt.Sprintf("%s:%s", hostPath, containerPath)
args = append(args, "-v", mount)
}
// Use injected GPU device paths for Apple GPU or custom configurations
for _, device := range cfg.GPUDevices {
args = append(args, "--device", device)
}
// Add environment variables
for key, value := range cfg.Env {
args = append(args, "-e", fmt.Sprintf("%s=%s", key, value))
}
// Image and command
args = append(args, cfg.Image,
"--workspace", cfg.ContainerWorkspace,
"--deps", depsPath,
"--script", scriptPath,
)
// Add extra arguments via --args flag
if len(extraArgs) > 0 {
args = append(args, "--args")
args = append(args, extraArgs...)
}
return exec.CommandContext(ctx, "podman", args...)
}
// SanitizePath ensures a path is safe to use (prevents path traversal)
func SanitizePath(path string) (string, error) {
// Clean the path to remove any .. or . components
cleaned := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleaned, "..") {
return "", fmt.Errorf("path traversal detected: %s", path)
}
return cleaned, nil
}
// ValidateJobName validates a job name is safe
func ValidateJobName(jobName string) error {
if jobName == "" {
return fmt.Errorf("job name cannot be empty")
}
// Check for dangerous characters
if strings.ContainsAny(jobName, "/\\<>:\"|?*") {
return fmt.Errorf("job name contains invalid characters: %s", jobName)
}
// Check for path traversal
if strings.Contains(jobName, "..") {
return fmt.Errorf("job name contains path traversal: %s", jobName)
}
return nil
}