feat(security): implement comprehensive security hardening phases 1-5,7

Implements defense-in-depth security for HIPAA and multi-tenant requirements:

**Phase 1 - File Ingestion Security:**
- SecurePathValidator with symlink resolution and path boundary enforcement
  in internal/fileutil/secure.go
- Magic bytes validation for ML artifacts (safetensors, GGUF, HDF5, numpy)
  in internal/fileutil/filetype.go
- Dangerous extension blocking (.pt, .pkl, .exe, .sh, .zip)
- Upload limits (10GB size, 100MB/s rate, 10 uploads/min)

**Phase 2 - Sandbox Hardening:**
- ApplySecurityDefaults() with secure-by-default principle
  - network_mode: none, read_only_root: true, no_new_privileges: true
  - drop_all_caps: true, user_ns: true, run_as_uid/gid: 1000
- PodmanSecurityConfig and BuildSecurityArgs() in internal/container/podman.go
- BuildPodmanCommand now accepts full security configuration
- Container executor passes SandboxConfig to Podman command builder
- configs/seccomp/default-hardened.json blocks dangerous syscalls
  (ptrace, mount, reboot, kexec_load, open_by_handle_at)

**Phase 3 - Secrets Management:**
- expandSecrets() for environment variable expansion using ${VAR} syntax
- validateNoPlaintextSecrets() with entropy-based detection
- Pattern matching for AWS, GitHub, GitLab, OpenAI, Stripe tokens
- Shannon entropy calculation (>4 bits/char triggers detection)
- Secrets expanded during LoadConfig() before validation

**Phase 5 - HIPAA Audit Logging:**
- Tamper-evident chain hashing with SHA-256 in internal/audit/audit.go
- Event struct extended with PrevHash, EventHash, SequenceNum
- File access event types: EventFileRead, EventFileWrite, EventFileDelete
- LogFileAccess() helper for HIPAA compliance
- VerifyChain() function for tamper detection

**Supporting Changes:**
- Add DeleteJob() and DeleteJobsByPrefix() to storage package
- Integrate SecurePathValidator in artifact scanning
This commit is contained in:
Jeremie Fraeys 2026-02-23 18:00:33 -05:00
parent aed59967b7
commit 92aab06d76
No known key found for this signature in database
9 changed files with 1363 additions and 24 deletions

View file

@ -0,0 +1,420 @@
{
"defaultAction": "SCMP_ACT_ERRNO",
"architectures": [
"SCMP_ARCH_X86_64",
"SCMP_ARCH_X86",
"SCMP_ARCH_AARCH64"
],
"syscalls": [
{
"names": [
"accept",
"accept4",
"access",
"adjtimex",
"alarm",
"bind",
"brk",
"capget",
"capset",
"chdir",
"chmod",
"chown",
"chown32",
"clock_adjtime",
"clock_adjtime64",
"clock_getres",
"clock_getres_time64",
"clock_gettime",
"clock_gettime64",
"clock_nanosleep",
"clock_nanosleep_time64",
"clone",
"clone3",
"close",
"close_range",
"connect",
"copy_file_range",
"creat",
"dup",
"dup2",
"dup3",
"epoll_create",
"epoll_create1",
"epoll_ctl",
"epoll_ctl_old",
"epoll_pwait",
"epoll_pwait2",
"epoll_wait",
"epoll_wait_old",
"eventfd",
"eventfd2",
"execve",
"execveat",
"exit",
"exit_group",
"faccessat",
"faccessat2",
"fadvise64",
"fadvise64_64",
"fallocate",
"fanotify_mark",
"fchdir",
"fchmod",
"fchmodat",
"fchown",
"fchown32",
"fchownat",
"fcntl",
"fcntl64",
"fdatasync",
"fgetxattr",
"flistxattr",
"flock",
"fork",
"fremovexattr",
"fsetxattr",
"fstat",
"fstat64",
"fstatat64",
"fstatfs",
"fstatfs64",
"fsync",
"ftruncate",
"ftruncate64",
"futex",
"futex_time64",
"getcpu",
"getcwd",
"getdents",
"getdents64",
"getegid",
"getegid32",
"geteuid",
"geteuid32",
"getgid",
"getgid32",
"getgroups",
"getgroups32",
"getitimer",
"getpeername",
"getpgid",
"getpgrp",
"getpid",
"getppid",
"getpriority",
"getrandom",
"getresgid",
"getresgid32",
"getresuid",
"getresuid32",
"getrlimit",
"get_robust_list",
"getrusage",
"getsid",
"getsockname",
"getsockopt",
"get_thread_area",
"gettid",
"gettimeofday",
"getuid",
"getuid32",
"getxattr",
"inotify_add_watch",
"inotify_init",
"inotify_init1",
"inotify_rm_watch",
"io_cancel",
"ioctl",
"io_destroy",
"io_getevents",
"io_pgetevents",
"io_pgetevents_time64",
"ioprio_get",
"ioprio_set",
"io_setup",
"io_submit",
"io_uring_enter",
"io_uring_register",
"io_uring_setup",
"kill",
"lchown",
"lchown32",
"lgetxattr",
"link",
"linkat",
"listen",
"listxattr",
"llistxattr",
"lremovexattr",
"lseek",
"lsetxattr",
"lstat",
"lstat64",
"madvise",
"membarrier",
"memfd_create",
"mincore",
"mkdir",
"mkdirat",
"mknod",
"mknodat",
"mlock",
"mlock2",
"mlockall",
"mmap",
"mmap2",
"mprotect",
"mq_getsetattr",
"mq_notify",
"mq_open",
"mq_timedreceive",
"mq_timedreceive_time64",
"mq_timedsend",
"mq_timedsend_time64",
"mq_unlink",
"mremap",
"msgctl",
"msgget",
"msgrcv",
"msgsnd",
"msync",
"munlock",
"munlockall",
"munmap",
"nanosleep",
"newfstatat",
"open",
"openat",
"openat2",
"pause",
"pidfd_open",
"pidfd_send_signal",
"pipe",
"pipe2",
"pivot_root",
"poll",
"ppoll",
"ppoll_time64",
"prctl",
"pread64",
"preadv",
"preadv2",
"prlimit64",
"pselect6",
"pselect6_time64",
"pwrite64",
"pwritev",
"pwritev2",
"read",
"readahead",
"readdir",
"readlink",
"readlinkat",
"readv",
"recv",
"recvfrom",
"recvmmsg",
"recvmmsg_time64",
"recvmsg",
"remap_file_pages",
"removexattr",
"rename",
"renameat",
"renameat2",
"restart_syscall",
"rmdir",
"rseq",
"rt_sigaction",
"rt_sigpending",
"rt_sigprocmask",
"rt_sigqueueinfo",
"rt_sigreturn",
"rt_sigsuspend",
"rt_sigtimedwait",
"rt_sigtimedwait_time64",
"rt_tgsigqueueinfo",
"sched_getaffinity",
"sched_getattr",
"sched_getparam",
"sched_get_priority_max",
"sched_get_priority_min",
"sched_getscheduler",
"sched_rr_get_interval",
"sched_rr_get_interval_time64",
"sched_setaffinity",
"sched_setattr",
"sched_setparam",
"sched_setscheduler",
"sched_yield",
"seccomp",
"select",
"semctl",
"semget",
"semop",
"semtimedop",
"semtimedop_time64",
"send",
"sendfile",
"sendfile64",
"sendmmsg",
"sendmsg",
"sendto",
"setfsgid",
"setfsgid32",
"setfsuid",
"setfsuid32",
"setgid",
"setgid32",
"setgroups",
"setgroups32",
"setitimer",
"setpgid",
"setpriority",
"setregid",
"setregid32",
"setresgid",
"setresgid32",
"setresuid",
"setresuid32",
"setreuid",
"setreuid32",
"setrlimit",
"set_robust_list",
"setsid",
"setsockopt",
"set_thread_area",
"set_tid_address",
"setuid",
"setuid32",
"setxattr",
"shmat",
"shmctl",
"shmdt",
"shmget",
"shutdown",
"sigaltstack",
"signalfd",
"signalfd4",
"sigpending",
"sigprocmask",
"sigreturn",
"socket",
"socketcall",
"socketpair",
"splice",
"stat",
"stat64",
"statfs",
"statfs64",
"statx",
"symlink",
"symlinkat",
"sync",
"sync_file_range",
"syncfs",
"sysinfo",
"tee",
"tgkill",
"time",
"timer_create",
"timer_delete",
"timer_getoverrun",
"timer_gettime",
"timer_gettime64",
"timer_settime",
"timer_settime64",
"timerfd_create",
"timerfd_gettime",
"timerfd_gettime64",
"timerfd_settime",
"timerfd_settime64",
"times",
"tkill",
"truncate",
"truncate64",
"ugetrlimit",
"umask",
"uname",
"unlink",
"unlinkat",
"utime",
"utimensat",
"utimensat_time64",
"utimes",
"vfork",
"wait4",
"waitid",
"waitpid",
"write",
"writev"
],
"action": "SCMP_ACT_ALLOW"
},
{
"names": [
"personality"
],
"action": "SCMP_ACT_ALLOW",
"args": [
{
"index": 0,
"value": 0,
"op": "SCMP_CMP_EQ"
}
]
},
{
"names": [
"personality"
],
"action": "SCMP_ACT_ALLOW",
"args": [
{
"index": 0,
"value": 8,
"op": "SCMP_CMP_EQ"
}
]
},
{
"names": [
"personality"
],
"action": "SCMP_ACT_ALLOW",
"args": [
{
"index": 0,
"value": 131072,
"op": "SCMP_CMP_EQ"
}
]
},
{
"names": [
"personality"
],
"action": "SCMP_ACT_ALLOW",
"args": [
{
"index": 0,
"value": 131073,
"op": "SCMP_CMP_EQ"
}
]
},
{
"names": [
"personality"
],
"action": "SCMP_ACT_ALLOW",
"args": [
{
"index": 0,
"value": 4294967295,
"op": "SCMP_CMP_EQ"
}
]
}
]
}

View file

@ -1,6 +1,8 @@
package audit
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
@ -25,28 +27,41 @@ const (
EventJupyterStop EventType = "jupyter_stop"
EventExperimentCreated EventType = "experiment_created"
EventExperimentDeleted EventType = "experiment_deleted"
// HIPAA-specific file access events
EventFileRead EventType = "file_read"
EventFileWrite EventType = "file_write"
EventFileDelete EventType = "file_delete"
EventDatasetAccess EventType = "dataset_access"
)
// Event represents an audit log event
// Event represents an audit log event with integrity chain
type Event struct {
Timestamp time.Time `json:"timestamp"`
EventType EventType `json:"event_type"`
UserID string `json:"user_id,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Resource string `json:"resource,omitempty"`
Action string `json:"action,omitempty"`
Resource string `json:"resource,omitempty"` // File path, dataset ID, etc.
Action string `json:"action,omitempty"` // read, write, delete
Success bool `json:"success"`
ErrorMsg string `json:"error,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
// Integrity chain fields for tamper-evident logging (HIPAA requirement)
PrevHash string `json:"prev_hash,omitempty"` // SHA-256 of previous event
EventHash string `json:"event_hash,omitempty"` // SHA-256 of this event
SequenceNum int64 `json:"sequence_num,omitempty"`
}
// Logger handles audit logging
// Logger handles audit logging with integrity chain
type Logger struct {
enabled bool
filePath string
file *os.File
mu sync.Mutex
logger *logging.Logger
enabled bool
filePath string
file *os.File
mu sync.Mutex
logger *logging.Logger
lastHash string
sequenceNum int64
}
// NewLogger creates a new audit logger
@ -68,7 +83,7 @@ func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger,
return al, nil
}
// Log logs an audit event
// Log logs an audit event with integrity chain
func (al *Logger) Log(event Event) {
if !al.enabled {
return
@ -79,6 +94,15 @@ func (al *Logger) Log(event Event) {
al.mu.Lock()
defer al.mu.Unlock()
// Set sequence number and previous hash for integrity chain
al.sequenceNum++
event.SequenceNum = al.sequenceNum
event.PrevHash = al.lastHash
// Calculate hash of this event for tamper evidence
event.EventHash = al.calculateEventHash(event)
al.lastHash = event.EventHash
// Marshal to JSON
data, err := json.Marshal(event)
if err != nil {
@ -103,10 +127,88 @@ func (al *Logger) Log(event Event) {
"user_id", event.UserID,
"resource", event.Resource,
"success", event.Success,
"seq", event.SequenceNum,
"hash", event.EventHash[:16], // Log first 16 chars of hash
)
}
}
// calculateEventHash computes SHA-256 hash of event data for integrity chain
func (al *Logger) calculateEventHash(event Event) string {
// Create a copy without the hash field for hashing
eventCopy := event
eventCopy.EventHash = ""
eventCopy.PrevHash = ""
data, err := json.Marshal(eventCopy)
if err != nil {
// Fallback: hash the timestamp and type
data = []byte(fmt.Sprintf("%s:%s:%d", event.Timestamp, event.EventType, event.SequenceNum))
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:])
}
// LogFileAccess logs a file access operation (HIPAA requirement)
func (al *Logger) LogFileAccess(
eventType EventType,
userID, filePath, ipAddr string,
success bool,
errMsg string,
) {
action := "read"
switch eventType {
case EventFileWrite:
action = "write"
case EventFileDelete:
action = "delete"
}
al.Log(Event{
EventType: eventType,
UserID: userID,
IPAddress: ipAddr,
Resource: filePath,
Action: action,
Success: success,
ErrorMsg: errMsg,
})
}
// VerifyChain checks the integrity of the audit log chain
// Returns the first sequence number where tampering is detected, or -1 if valid
func (al *Logger) VerifyChain(events []Event) (tamperedSeq int, err error) {
if len(events) == 0 {
return -1, nil
}
var expectedPrevHash string
for _, event := range events {
// Verify previous hash chain
if event.SequenceNum > 1 && event.PrevHash != expectedPrevHash {
return int(event.SequenceNum), fmt.Errorf(
"chain break at sequence %d: expected prev_hash=%s, got %s",
event.SequenceNum, expectedPrevHash, event.PrevHash,
)
}
// Verify event hash
expectedHash := al.calculateEventHash(event)
if event.EventHash != expectedHash {
return int(event.SequenceNum), fmt.Errorf(
"hash mismatch at sequence %d: expected %s, got %s",
event.SequenceNum, expectedHash, event.EventHash,
)
}
expectedPrevHash = event.EventHash
}
return -1, nil
}
// LogAuthAttempt logs an authentication attempt
func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) {
eventType := EventAuthSuccess

View file

@ -336,8 +336,161 @@ func PodmanResourceOverrides(cpu int, memoryGB int) (cpus string, memory string)
return cpus, memory
}
// BuildPodmanCommand builds a Podman command for executing ML experiments
// PodmanSecurityConfig holds security configuration for Podman containers
type PodmanSecurityConfig struct {
NoNewPrivileges bool
DropAllCaps bool
AllowedCaps []string
UserNS bool
RunAsUID int
RunAsGID int
SeccompProfile string
ReadOnlyRoot bool
NetworkMode string
}
// BuildSecurityArgs builds security-related podman arguments from PodmanSecurityConfig
func BuildSecurityArgs(sandbox PodmanSecurityConfig) []string {
args := []string{}
// No new privileges
if sandbox.NoNewPrivileges {
args = append(args, "--security-opt", "no-new-privileges:true")
}
// Capability dropping
if sandbox.DropAllCaps {
args = append(args, "--cap-drop=all")
for _, cap := range sandbox.AllowedCaps {
if cap != "" {
args = append(args, "--cap-add="+cap)
}
}
}
// User namespace mapping
if sandbox.UserNS && sandbox.RunAsUID > 0 && sandbox.RunAsGID > 0 {
// Map container root to specified UID/GID on host
args = append(args, "--userns", "keep-id")
args = append(args, "--user", fmt.Sprintf("%d:%d", sandbox.RunAsUID, sandbox.RunAsGID))
}
// Seccomp profile
if sandbox.SeccompProfile != "" && sandbox.SeccompProfile != "unconfined" {
profilePath := GetSeccompProfilePath(sandbox.SeccompProfile)
if profilePath != "" {
args = append(args, "--security-opt", fmt.Sprintf("seccomp=%s", profilePath))
}
}
// Read-only root filesystem
if sandbox.ReadOnlyRoot {
args = append(args, "--read-only")
}
// Network mode (default: none)
networkMode := sandbox.NetworkMode
if networkMode == "" {
networkMode = "none"
}
args = append(args, "--network", networkMode)
return args
}
// GetSeccompProfilePath returns the filesystem path for a named seccomp profile
func GetSeccompProfilePath(profileName string) string {
// Check standard locations
searchPaths := []string{
filepath.Join("configs", "seccomp", profileName+".json"),
filepath.Join("/etc", "fetchml", "seccomp", profileName+".json"),
filepath.Join("/usr", "share", "fetchml", "seccomp", profileName+".json"),
}
for _, path := range searchPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
// If profileName is already a path, return it
if filepath.IsAbs(profileName) {
return profileName
}
return ""
}
// BuildPodmanCommand builds a Podman command for executing ML experiments with security options
func BuildPodmanCommand(
ctx context.Context,
cfg PodmanConfig,
sandbox PodmanSecurityConfig,
scriptPath, depsPath string,
extraArgs []string,
) *exec.Cmd {
args := []string{"run", "--rm"}
// Add security options from sandbox config
securityArgs := BuildSecurityArgs(sandbox)
args = append(args, securityArgs...)
// Resource limits
if cfg.Memory != "" {
args = append(args, "--memory", cfg.Memory)
} else {
args = append(args, "--memory", config.DefaultPodmanMemory)
}
if cfg.CPUs != "" {
args = append(args, "--cpus", cfg.CPUs)
} else {
args = append(args, "--cpus", config.DefaultPodmanCPUs)
}
// Mount workspace
workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace)
args = append(args, "-v", workspaceMount)
// Mount results
resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults)
args = append(args, "-v", resultsMount)
// Mount additional volumes
for hostPath, containerPath := range cfg.Volumes {
mount := fmt.Sprintf("%s:%s", hostPath, containerPath)
args = append(args, "-v", mount)
}
// Use injected GPU device paths for Apple GPU or custom configurations
for _, device := range cfg.GPUDevices {
args = append(args, "--device", device)
}
// Add environment variables
for key, value := range cfg.Env {
args = append(args, "-e", fmt.Sprintf("%s=%s", key, value))
}
// Image and command
args = append(args, cfg.Image,
"--workspace", cfg.ContainerWorkspace,
"--deps", depsPath,
"--script", scriptPath,
)
// Add extra arguments via --args flag
if len(extraArgs) > 0 {
args = append(args, "--args")
args = append(args, extraArgs...)
}
return exec.CommandContext(ctx, "podman", args...)
}
// BuildPodmanCommandLegacy builds a Podman command using legacy security settings
// Deprecated: Use BuildPodmanCommand with SandboxConfig instead
func BuildPodmanCommandLegacy(
ctx context.Context,
cfg PodmanConfig,
scriptPath, depsPath string,

View file

@ -0,0 +1,229 @@
// Package fileutil provides secure file operation utilities to prevent path traversal attacks.
package fileutil
import (
"bytes"
"fmt"
"os"
"path/filepath"
"strings"
)
// FileType represents a known file type with its magic bytes
type FileType struct {
Name string
MagicBytes []byte
Extensions []string
Description string
}
// Known file types for ML artifacts
var (
// SafeTensor uses ZIP format
SafeTensors = FileType{
Name: "safetensors",
MagicBytes: []byte{0x50, 0x4B, 0x03, 0x04}, // ZIP header
Extensions: []string{".safetensors"},
Description: "SafeTensors model format",
}
// GGUF format
GGUF = FileType{
Name: "gguf",
MagicBytes: []byte{0x47, 0x47, 0x55, 0x46}, // "GGUF"
Extensions: []string{".gguf"},
Description: "GGML/GGUF model format",
}
// HDF5 format
HDF5 = FileType{
Name: "hdf5",
MagicBytes: []byte{0x89, 0x48, 0x44, 0x46}, // HDF5 signature
Extensions: []string{".h5", ".hdf5", ".hdf"},
Description: "HDF5 data format",
}
// NumPy format
NumPy = FileType{
Name: "numpy",
MagicBytes: []byte{0x93, 0x4E, 0x55, 0x4D}, // NUMPY magic
Extensions: []string{".npy"},
Description: "NumPy array format",
}
// JSON format
JSON = FileType{
Name: "json",
MagicBytes: []byte{0x7B}, // "{"
Extensions: []string{".json"},
Description: "JSON data format",
}
// CSV format (text-based, no reliable magic bytes)
CSV = FileType{
Name: "csv",
MagicBytes: nil, // Text-based, validated by content inspection
Extensions: []string{".csv"},
Description: "CSV data format",
}
// YAML format (text-based)
YAML = FileType{
Name: "yaml",
MagicBytes: nil, // Text-based
Extensions: []string{".yaml", ".yml"},
Description: "YAML configuration format",
}
// Text format
Text = FileType{
Name: "text",
MagicBytes: nil, // Text-based
Extensions: []string{".txt", ".md", ".rst"},
Description: "Plain text format",
}
// AllAllowedTypes contains all types that are permitted for upload
AllAllowedTypes = []FileType{SafeTensors, GGUF, HDF5, NumPy, JSON, CSV, YAML, Text}
// BinaryModelTypes contains binary model formats only
BinaryModelTypes = []FileType{SafeTensors, GGUF, HDF5, NumPy}
)
// DangerousExtensions are file extensions that should be rejected immediately
var DangerousExtensions = []string{
".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution
".pth", // PyTorch state dict (often pickle-based)
".joblib", // scikit-learn pickle format
".exe", ".dll", ".so", ".dylib", // Executables
".sh", ".bat", ".cmd", ".ps1", // Scripts
".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files)
".rar", ".7z",
}
// ValidateFileType checks if a file matches an allowed type using magic bytes validation.
// Returns the detected type or an error if the file is not allowed.
func ValidateFileType(filePath string, allowedTypes []FileType) (*FileType, error) {
// First check extension-based rejection
ext := strings.ToLower(filepath.Ext(filePath))
for _, dangerous := range DangerousExtensions {
if ext == dangerous {
return nil, fmt.Errorf("file type not allowed (dangerous extension): %s", ext)
}
}
// Open and read the file
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open file for type validation: %w", err)
}
defer file.Close()
// Read first 8 bytes for magic byte detection
header := make([]byte, 8)
n, err := file.Read(header)
if err != nil {
return nil, fmt.Errorf("failed to read file header: %w", err)
}
header = header[:n]
// Try to match by magic bytes first
for _, ft := range allowedTypes {
if len(ft.MagicBytes) > 0 && len(header) >= len(ft.MagicBytes) {
if bytes.Equal(header[:len(ft.MagicBytes)], ft.MagicBytes) {
return &ft, nil
}
}
}
// For text-based formats, validate by extension and content
for _, ft := range allowedTypes {
if ft.MagicBytes == nil {
// Check if extension matches
for _, allowedExt := range ft.Extensions {
if ext == allowedExt {
// Additional content validation for text files
if err := validateTextContent(filePath, ft); err != nil {
return nil, err
}
return &ft, nil
}
}
}
}
return nil, fmt.Errorf("file type not recognized or not in allowed list")
}
// validateTextContent performs basic validation on text files
func validateTextContent(filePath string, ft FileType) error {
// Read a sample of the file
data, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read text file: %w", err)
}
// Check for null bytes (indicates binary content)
if bytes.Contains(data, []byte{0x00}) {
return fmt.Errorf("file contains null bytes, not valid %s", ft.Name)
}
// For JSON, validate it can be parsed
if ft.Name == "json" {
// Basic JSON validation - check for valid JSON structure
trimmed := bytes.TrimSpace(data)
if len(trimmed) == 0 {
return fmt.Errorf("empty JSON file")
}
if (trimmed[0] != '{' && trimmed[0] != '[') ||
(trimmed[len(trimmed)-1] != '}' && trimmed[len(trimmed)-1] != ']') {
return fmt.Errorf("invalid JSON structure")
}
}
return nil
}
// IsAllowedExtension checks if a file extension is in the allowed list
func IsAllowedExtension(filePath string, allowedTypes []FileType) bool {
ext := strings.ToLower(filepath.Ext(filePath))
// Check against dangerous extensions first
for _, dangerous := range DangerousExtensions {
if ext == dangerous {
return false
}
}
// Check against allowed types
for _, ft := range allowedTypes {
for _, allowedExt := range ft.Extensions {
if ext == allowedExt {
return true
}
}
}
return false
}
// ValidateDatasetFile validates a dataset file for safe formats
func ValidateDatasetFile(filePath string) error {
_, err := ValidateFileType(filePath, AllAllowedTypes)
return err
}
// ValidateModelFile validates a model file for safe binary formats only
func ValidateModelFile(filePath string) error {
ft, err := ValidateFileType(filePath, BinaryModelTypes)
if err != nil {
return err
}
// Additional check: ensure it's actually a model format, not just a matching extension
if ft == nil {
return fmt.Errorf("file type validation returned nil type")
}
return nil
}

View file

@ -2,21 +2,134 @@
package fileutil
import (
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"strings"
)
// SecureFileRead securely reads a file after cleaning the path to prevent path traversal
// SecurePathValidator provides path traversal protection with symlink resolution.
type SecurePathValidator struct {
BasePath string
}
// NewSecurePathValidator creates a new path validator for a base directory.
func NewSecurePathValidator(basePath string) *SecurePathValidator {
return &SecurePathValidator{BasePath: basePath}
}
// ValidatePath ensures resolved path is within base directory.
// It resolves symlinks and returns the canonical absolute path.
func (v *SecurePathValidator) ValidatePath(inputPath string) (string, error) {
if v.BasePath == "" {
return "", fmt.Errorf("base path not set")
}
// Clean the path to remove . and ..
cleaned := filepath.Clean(inputPath)
// Get absolute base path and resolve any symlinks (critical for macOS /tmp -> /private/tmp)
baseAbs, err := filepath.Abs(v.BasePath)
if err != nil {
return "", fmt.Errorf("failed to get absolute base path: %w", err)
}
// Resolve symlinks in base path for accurate comparison
baseResolved, err := filepath.EvalSymlinks(baseAbs)
if err != nil {
// Base path may not exist yet, use as-is
baseResolved = baseAbs
}
// If cleaned is already absolute, check if it's within base
var absPath string
if filepath.IsAbs(cleaned) {
absPath = cleaned
} else {
// Join with base path if relative
absPath = filepath.Join(baseAbs, cleaned)
}
// Resolve symlinks - critical for security
resolved, err := filepath.EvalSymlinks(absPath)
if err != nil {
// If the file doesn't exist, we still need to check the directory path
// Try to resolve the parent directory
dir := filepath.Dir(absPath)
resolvedDir, dirErr := filepath.EvalSymlinks(dir)
if dirErr != nil {
return "", fmt.Errorf("path resolution failed: %w", err)
}
// Reconstruct the path with resolved directory
base := filepath.Base(absPath)
resolved = filepath.Join(resolvedDir, base)
}
// Get absolute resolved path
resolvedAbs, err := filepath.Abs(resolved)
if err != nil {
return "", fmt.Errorf("failed to get absolute resolved path: %w", err)
}
// Verify within base directory (must have path separator after base to prevent prefix match issues)
baseWithSep := baseResolved + string(filepath.Separator)
if resolvedAbs != baseResolved && !strings.HasPrefix(resolvedAbs+string(filepath.Separator), baseWithSep) {
return "", fmt.Errorf("path escapes base directory: %s (resolved to %s, base is %s)", inputPath, resolvedAbs, baseResolved)
}
return resolvedAbs, nil
}
// SecureFileRead securely reads a file after cleaning the path to prevent path traversal.
func SecureFileRead(path string) ([]byte, error) {
return os.ReadFile(filepath.Clean(path))
}
// SecureFileWrite securely writes a file after cleaning the path to prevent path traversal
// SecureFileWrite securely writes a file after cleaning the path to prevent path traversal.
func SecureFileWrite(path string, data []byte, perm os.FileMode) error {
return os.WriteFile(filepath.Clean(path), data, perm)
}
// SecureOpenFile securely opens a file after cleaning the path to prevent path traversal
// SecureOpenFile securely opens a file after cleaning the path to prevent path traversal.
func SecureOpenFile(path string, flag int, perm os.FileMode) (*os.File, error) {
return os.OpenFile(filepath.Clean(path), flag, perm)
}
// SecureReadDir reads directory contents with path validation.
func (v *SecurePathValidator) SecureReadDir(dirPath string) ([]os.DirEntry, error) {
validatedPath, err := v.ValidatePath(dirPath)
if err != nil {
return nil, fmt.Errorf("directory path validation failed: %w", err)
}
return os.ReadDir(validatedPath)
}
// SecureCreateTemp creates a temporary file within the base directory.
func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string, error) {
validatedPath, err := v.ValidatePath("")
if err != nil {
return nil, "", fmt.Errorf("base directory validation failed: %w", err)
}
// Generate secure random suffix
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return nil, "", fmt.Errorf("failed to generate random bytes: %w", err)
}
randomSuffix := base64.URLEncoding.EncodeToString(randomBytes)
// Create temp file
if pattern == "" {
pattern = "tmp"
}
fileName := fmt.Sprintf("%s_%s", pattern, randomSuffix)
fullPath := filepath.Join(validatedPath, fileName)
file, err := os.Create(fullPath)
if err != nil {
return nil, "", fmt.Errorf("failed to create temp file: %w", err)
}
return file, fullPath, nil
}

View file

@ -209,6 +209,38 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
return jobs, nil
}
// DeleteJob removes a job from the database by ID.
func (db *DB) DeleteJob(id string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM jobs WHERE id = ?`
} else {
query = `DELETE FROM jobs WHERE id = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, id)
if err != nil {
return fmt.Errorf("failed to delete job: %w", err)
}
return nil
}
// DeleteJobsByPrefix removes all jobs with IDs matching the given prefix.
func (db *DB) DeleteJobsByPrefix(prefix string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM jobs WHERE id LIKE ?`
} else {
query = `DELETE FROM jobs WHERE id LIKE $1`
}
_, err := db.conn.ExecContext(context.Background(), query, prefix+"%")
if err != nil {
return fmt.Errorf("failed to delete jobs by prefix: %w", err)
}
return nil
}
// RegisterWorker registers or updates a worker in the database.
func (db *DB) RegisterWorker(worker *Worker) error {
metadataJSON, _ := json.Marshal(worker.Metadata)

View file

@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/manifest"
)
@ -17,26 +18,40 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error)
return nil, fmt.Errorf("run dir is empty")
}
// Validate and canonicalize the runDir before any operations
validator := fileutil.NewSecurePathValidator(runDir)
validatedRunDir, err := validator.ValidatePath("")
if err != nil {
return nil, fmt.Errorf("invalid run directory: %w", err)
}
var files []manifest.ArtifactFile
var total int64
now := time.Now().UTC()
err := filepath.WalkDir(runDir, func(path string, d fs.DirEntry, err error) error {
err = filepath.WalkDir(validatedRunDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if path == runDir {
if path == validatedRunDir {
return nil
}
rel, err := filepath.Rel(runDir, path)
// Security: Validate each path is still within runDir
// This catches any symlink escapes or path traversal attempts during walk
rel, err := filepath.Rel(validatedRunDir, path)
if err != nil {
return err
return fmt.Errorf("path escape detected during artifact scan: %w", err)
}
rel = filepath.ToSlash(rel)
// Check for path traversal patterns in the relative path
if strings.Contains(rel, "..") {
return fmt.Errorf("path traversal attempt detected: %s", rel)
}
// Standard exclusions (always apply)
if rel == manifestFilename {
return nil
@ -63,6 +78,7 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error)
return nil
}
if d.Type()&fs.ModeSymlink != 0 {
// Skip symlinks - they could point outside the directory
return nil
}
}

View file

@ -131,12 +131,54 @@ type AppleGPUConfig struct {
// SandboxConfig holds container sandbox settings
type SandboxConfig struct {
NetworkMode string `yaml:"network_mode"` // "none", "slirp4netns", "bridge"
ReadOnlyRoot bool `yaml:"read_only_root"`
AllowSecrets bool `yaml:"allow_secrets"`
NetworkMode string `yaml:"network_mode"` // Default: "none"
ReadOnlyRoot bool `yaml:"read_only_root"` // Default: true
AllowSecrets bool `yaml:"allow_secrets"` // Default: false
AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"]
SeccompProfile string `yaml:"seccomp_profile"`
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
MaxRuntimeHours int `yaml:"max_runtime_hours"`
// Security hardening options (NEW)
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
UserNS bool `yaml:"user_ns"` // Default: true
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
// Upload limits (NEW)
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
}
// SecurityDefaults holds default values for security configuration
var SecurityDefaults = struct {
NetworkMode string
ReadOnlyRoot bool
AllowSecrets bool
SeccompProfile string
NoNewPrivileges bool
DropAllCaps bool
UserNS bool
RunAsUID int
RunAsGID int
MaxUploadSizeBytes int64
MaxUploadRateBps int64
MaxUploadsPerMinute int
}{
NetworkMode: "none",
ReadOnlyRoot: true,
AllowSecrets: false,
SeccompProfile: "default-hardened",
NoNewPrivileges: true,
DropAllCaps: true,
UserNS: true,
RunAsUID: 1000,
RunAsGID: 1000,
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
MaxUploadsPerMinute: 10,
}
// Validate checks sandbox configuration
@ -148,9 +190,87 @@ func (s *SandboxConfig) Validate() error {
if s.MaxRuntimeHours < 0 {
return fmt.Errorf("max_runtime_hours must be positive")
}
if s.MaxUploadSizeBytes < 0 {
return fmt.Errorf("max_upload_size_bytes must be positive")
}
if s.MaxUploadRateBps < 0 {
return fmt.Errorf("max_upload_rate_bps must be positive")
}
if s.MaxUploadsPerMinute < 0 {
return fmt.Errorf("max_uploads_per_minute must be positive")
}
return nil
}
// ApplySecurityDefaults applies secure default values to empty fields.
// This implements the "secure by default" principle for HIPAA compliance.
func (s *SandboxConfig) ApplySecurityDefaults() {
// Network isolation: default to "none" (no network access)
if s.NetworkMode == "" {
s.NetworkMode = SecurityDefaults.NetworkMode
}
// Read-only root filesystem
if !s.ReadOnlyRoot {
s.ReadOnlyRoot = SecurityDefaults.ReadOnlyRoot
}
// Secrets disabled by default
if !s.AllowSecrets {
s.AllowSecrets = SecurityDefaults.AllowSecrets
}
// Seccomp profile
if s.SeccompProfile == "" {
s.SeccompProfile = SecurityDefaults.SeccompProfile
}
// No new privileges
if !s.NoNewPrivileges {
s.NoNewPrivileges = SecurityDefaults.NoNewPrivileges
}
// Drop all capabilities
if !s.DropAllCaps {
s.DropAllCaps = SecurityDefaults.DropAllCaps
}
// User namespace
if !s.UserNS {
s.UserNS = SecurityDefaults.UserNS
}
// Default non-root UID/GID
if s.RunAsUID == 0 {
s.RunAsUID = SecurityDefaults.RunAsUID
}
if s.RunAsGID == 0 {
s.RunAsGID = SecurityDefaults.RunAsGID
}
// Upload limits
if s.MaxUploadSizeBytes == 0 {
s.MaxUploadSizeBytes = SecurityDefaults.MaxUploadSizeBytes
}
if s.MaxUploadRateBps == 0 {
s.MaxUploadRateBps = SecurityDefaults.MaxUploadRateBps
}
if s.MaxUploadsPerMinute == 0 {
s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute
}
}
// Getter methods for SandboxConfig interface
func (s *SandboxConfig) GetNoNewPrivileges() bool { return s.NoNewPrivileges }
func (s *SandboxConfig) GetDropAllCaps() bool { return s.DropAllCaps }
func (s *SandboxConfig) GetAllowedCaps() []string { return s.AllowedCaps }
func (s *SandboxConfig) GetUserNS() bool { return s.UserNS }
func (s *SandboxConfig) GetRunAsUID() int { return s.RunAsUID }
func (s *SandboxConfig) GetRunAsGID() int { return s.RunAsGID }
func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
// LoadConfig loads worker configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
@ -291,6 +411,14 @@ func LoadConfig(path string) (*Config, error) {
cfg.GracefulTimeout = 5 * time.Minute
}
// Apply security defaults to sandbox configuration
cfg.Sandbox.ApplySecurityDefaults()
// Expand secrets from environment variables
if err := cfg.expandSecrets(); err != nil {
return nil, fmt.Errorf("secrets expansion failed: %w", err)
}
return &cfg, nil
}
@ -442,6 +570,125 @@ func (c *Config) Validate() error {
return nil
}
// expandSecrets replaces secret placeholders with environment variables
func (c *Config) expandSecrets() error {
// Expand Redis password from env if using ${...} syntax
if strings.Contains(c.RedisPassword, "${") {
c.RedisPassword = os.ExpandEnv(c.RedisPassword)
}
// Expand SnapshotStore credentials
if strings.Contains(c.SnapshotStore.AccessKey, "${") {
c.SnapshotStore.AccessKey = os.ExpandEnv(c.SnapshotStore.AccessKey)
}
if strings.Contains(c.SnapshotStore.SecretKey, "${") {
c.SnapshotStore.SecretKey = os.ExpandEnv(c.SnapshotStore.SecretKey)
}
if strings.Contains(c.SnapshotStore.SessionToken, "${") {
c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken)
}
// Validate no plaintext secrets remain in critical fields
if err := c.validateNoPlaintextSecrets(); err != nil {
return err
}
return nil
}
// validateNoPlaintextSecrets checks that sensitive fields use env var references
// rather than hardcoded plaintext values. This is a HIPAA compliance requirement.
func (c *Config) validateNoPlaintextSecrets() error {
// Fields that should use ${ENV_VAR} syntax instead of plaintext
sensitiveFields := []struct {
name string
value string
}{
{"redis_password", c.RedisPassword},
{"snapshot_store.access_key", c.SnapshotStore.AccessKey},
{"snapshot_store.secret_key", c.SnapshotStore.SecretKey},
{"snapshot_store.session_token", c.SnapshotStore.SessionToken},
}
for _, field := range sensitiveFields {
if field.value == "" {
continue // Empty values are fine
}
// Check if it looks like a plaintext secret (not env var reference)
if !strings.HasPrefix(field.value, "${") && looksLikeSecret(field.value) {
return fmt.Errorf(
"%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+
"use ${ENV_VAR} syntax to load from environment or secrets manager",
field.name, len(field.value), calculateEntropy(field.value),
)
}
}
return nil
}
// looksLikeSecret heuristically detects if a string looks like a secret credential
func looksLikeSecret(s string) bool {
// Minimum length for secrets
if len(s) < 16 {
return false
}
// Calculate entropy to detect high-entropy strings (likely secrets)
entropy := calculateEntropy(s)
// High entropy (>4 bits per char) combined with reasonable length suggests a secret
if entropy > 4.0 {
return true
}
// Check for common secret patterns
patterns := []string{
"AKIA", // AWS Access Key ID prefix
"ASIA", // AWS temporary credentials
"ghp_", // GitHub personal access token
"gho_", // GitHub OAuth token
"glpat-", // GitLab PAT
"sk-", // OpenAI/Stripe key prefix
"sk_live_", // Stripe live key
"sk_test_", // Stripe test key
}
for _, pattern := range patterns {
if strings.Contains(s, pattern) {
return true
}
}
return false
}
// calculateEntropy calculates Shannon entropy of a string in bits per character
func calculateEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
// Count character frequencies
freq := make(map[rune]int)
for _, r := range s {
freq[r]++
}
// Calculate entropy
var entropy float64
length := float64(len(s))
for _, count := range freq {
p := float64(count) / length
if p > 0 {
entropy -= p * math.Log2(p)
}
}
return entropy
}
// envInt reads an integer from environment variable
func envInt(name string) (int, bool) {
v := strings.TrimSpace(os.Getenv(name))

View file

@ -30,6 +30,20 @@ type ContainerConfig struct {
TrainScript string
BasePath string
AppleGPUEnabled bool
Sandbox SandboxConfig // NEW: Security configuration
}
// SandboxConfig interface to avoid import cycle
type SandboxConfig interface {
GetNoNewPrivileges() bool
GetDropAllCaps() bool
GetAllowedCaps() []string
GetUserNS() bool
GetRunAsUID() int
GetRunAsGID() int
GetSeccompProfile() string
GetReadOnlyRoot() bool
GetNetworkMode() string
}
// ContainerExecutor executes jobs in containers using podman
@ -208,6 +222,7 @@ func (e *ContainerExecutor) teardownTracking(ctx context.Context, task *queue.Ta
}
func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputDir string) map[string]string {
_ = _outputDir
volumes := make(map[string]string)
if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok {
@ -305,8 +320,20 @@ func (e *ContainerExecutor) runPodman(
e.logger.Warn("failed to open log file for podman output", "path", env.LogFile, "error", err)
}
// Build command
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, depsPath, extraArgs)
// Convert SandboxConfig to PodmanSecurityConfig
securityConfig := container.PodmanSecurityConfig{
NoNewPrivileges: e.config.Sandbox.GetNoNewPrivileges(),
DropAllCaps: e.config.Sandbox.GetDropAllCaps(),
AllowedCaps: e.config.Sandbox.GetAllowedCaps(),
UserNS: e.config.Sandbox.GetUserNS(),
RunAsUID: e.config.Sandbox.GetRunAsUID(),
RunAsGID: e.config.Sandbox.GetRunAsGID(),
SeccompProfile: e.config.Sandbox.GetSeccompProfile(),
ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(),
NetworkMode: e.config.Sandbox.GetNetworkMode(),
}
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs)
// Update manifest
if e.writer != nil {