feat(security): implement comprehensive security hardening phases 1-5,7
Implements defense-in-depth security for HIPAA and multi-tenant requirements:
**Phase 1 - File Ingestion Security:**
- SecurePathValidator with symlink resolution and path boundary enforcement
in internal/fileutil/secure.go
- Magic bytes validation for ML artifacts (safetensors, GGUF, HDF5, numpy)
in internal/fileutil/filetype.go
- Dangerous extension blocking (.pt, .pkl, .exe, .sh, .zip)
- Upload limits (10GB size, 100MB/s rate, 10 uploads/min)
**Phase 2 - Sandbox Hardening:**
- ApplySecurityDefaults() with secure-by-default principle
- network_mode: none, read_only_root: true, no_new_privileges: true
- drop_all_caps: true, user_ns: true, run_as_uid/gid: 1000
- PodmanSecurityConfig and BuildSecurityArgs() in internal/container/podman.go
- BuildPodmanCommand now accepts full security configuration
- Container executor passes SandboxConfig to Podman command builder
- configs/seccomp/default-hardened.json blocks dangerous syscalls
(ptrace, mount, reboot, kexec_load, open_by_handle_at)
**Phase 3 - Secrets Management:**
- expandSecrets() for environment variable expansion using ${VAR} syntax
- validateNoPlaintextSecrets() with entropy-based detection
- Pattern matching for AWS, GitHub, GitLab, OpenAI, Stripe tokens
- Shannon entropy calculation (>4 bits/char triggers detection)
- Secrets expanded during LoadConfig() before validation
**Phase 5 - HIPAA Audit Logging:**
- Tamper-evident chain hashing with SHA-256 in internal/audit/audit.go
- Event struct extended with PrevHash, EventHash, SequenceNum
- File access event types: EventFileRead, EventFileWrite, EventFileDelete
- LogFileAccess() helper for HIPAA compliance
- VerifyChain() function for tamper detection
**Supporting Changes:**
- Add DeleteJob() and DeleteJobsByPrefix() to storage package
- Integrate SecurePathValidator in artifact scanning
This commit is contained in:
parent
aed59967b7
commit
92aab06d76
9 changed files with 1363 additions and 24 deletions
420
configs/seccomp/default-hardened.json
Normal file
420
configs/seccomp/default-hardened.json
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
{
|
||||
"defaultAction": "SCMP_ACT_ERRNO",
|
||||
"architectures": [
|
||||
"SCMP_ARCH_X86_64",
|
||||
"SCMP_ARCH_X86",
|
||||
"SCMP_ARCH_AARCH64"
|
||||
],
|
||||
"syscalls": [
|
||||
{
|
||||
"names": [
|
||||
"accept",
|
||||
"accept4",
|
||||
"access",
|
||||
"adjtimex",
|
||||
"alarm",
|
||||
"bind",
|
||||
"brk",
|
||||
"capget",
|
||||
"capset",
|
||||
"chdir",
|
||||
"chmod",
|
||||
"chown",
|
||||
"chown32",
|
||||
"clock_adjtime",
|
||||
"clock_adjtime64",
|
||||
"clock_getres",
|
||||
"clock_getres_time64",
|
||||
"clock_gettime",
|
||||
"clock_gettime64",
|
||||
"clock_nanosleep",
|
||||
"clock_nanosleep_time64",
|
||||
"clone",
|
||||
"clone3",
|
||||
"close",
|
||||
"close_range",
|
||||
"connect",
|
||||
"copy_file_range",
|
||||
"creat",
|
||||
"dup",
|
||||
"dup2",
|
||||
"dup3",
|
||||
"epoll_create",
|
||||
"epoll_create1",
|
||||
"epoll_ctl",
|
||||
"epoll_ctl_old",
|
||||
"epoll_pwait",
|
||||
"epoll_pwait2",
|
||||
"epoll_wait",
|
||||
"epoll_wait_old",
|
||||
"eventfd",
|
||||
"eventfd2",
|
||||
"execve",
|
||||
"execveat",
|
||||
"exit",
|
||||
"exit_group",
|
||||
"faccessat",
|
||||
"faccessat2",
|
||||
"fadvise64",
|
||||
"fadvise64_64",
|
||||
"fallocate",
|
||||
"fanotify_mark",
|
||||
"fchdir",
|
||||
"fchmod",
|
||||
"fchmodat",
|
||||
"fchown",
|
||||
"fchown32",
|
||||
"fchownat",
|
||||
"fcntl",
|
||||
"fcntl64",
|
||||
"fdatasync",
|
||||
"fgetxattr",
|
||||
"flistxattr",
|
||||
"flock",
|
||||
"fork",
|
||||
"fremovexattr",
|
||||
"fsetxattr",
|
||||
"fstat",
|
||||
"fstat64",
|
||||
"fstatat64",
|
||||
"fstatfs",
|
||||
"fstatfs64",
|
||||
"fsync",
|
||||
"ftruncate",
|
||||
"ftruncate64",
|
||||
"futex",
|
||||
"futex_time64",
|
||||
"getcpu",
|
||||
"getcwd",
|
||||
"getdents",
|
||||
"getdents64",
|
||||
"getegid",
|
||||
"getegid32",
|
||||
"geteuid",
|
||||
"geteuid32",
|
||||
"getgid",
|
||||
"getgid32",
|
||||
"getgroups",
|
||||
"getgroups32",
|
||||
"getitimer",
|
||||
"getpeername",
|
||||
"getpgid",
|
||||
"getpgrp",
|
||||
"getpid",
|
||||
"getppid",
|
||||
"getpriority",
|
||||
"getrandom",
|
||||
"getresgid",
|
||||
"getresgid32",
|
||||
"getresuid",
|
||||
"getresuid32",
|
||||
"getrlimit",
|
||||
"get_robust_list",
|
||||
"getrusage",
|
||||
"getsid",
|
||||
"getsockname",
|
||||
"getsockopt",
|
||||
"get_thread_area",
|
||||
"gettid",
|
||||
"gettimeofday",
|
||||
"getuid",
|
||||
"getuid32",
|
||||
"getxattr",
|
||||
"inotify_add_watch",
|
||||
"inotify_init",
|
||||
"inotify_init1",
|
||||
"inotify_rm_watch",
|
||||
"io_cancel",
|
||||
"ioctl",
|
||||
"io_destroy",
|
||||
"io_getevents",
|
||||
"io_pgetevents",
|
||||
"io_pgetevents_time64",
|
||||
"ioprio_get",
|
||||
"ioprio_set",
|
||||
"io_setup",
|
||||
"io_submit",
|
||||
"io_uring_enter",
|
||||
"io_uring_register",
|
||||
"io_uring_setup",
|
||||
"kill",
|
||||
"lchown",
|
||||
"lchown32",
|
||||
"lgetxattr",
|
||||
"link",
|
||||
"linkat",
|
||||
"listen",
|
||||
"listxattr",
|
||||
"llistxattr",
|
||||
"lremovexattr",
|
||||
"lseek",
|
||||
"lsetxattr",
|
||||
"lstat",
|
||||
"lstat64",
|
||||
"madvise",
|
||||
"membarrier",
|
||||
"memfd_create",
|
||||
"mincore",
|
||||
"mkdir",
|
||||
"mkdirat",
|
||||
"mknod",
|
||||
"mknodat",
|
||||
"mlock",
|
||||
"mlock2",
|
||||
"mlockall",
|
||||
"mmap",
|
||||
"mmap2",
|
||||
"mprotect",
|
||||
"mq_getsetattr",
|
||||
"mq_notify",
|
||||
"mq_open",
|
||||
"mq_timedreceive",
|
||||
"mq_timedreceive_time64",
|
||||
"mq_timedsend",
|
||||
"mq_timedsend_time64",
|
||||
"mq_unlink",
|
||||
"mremap",
|
||||
"msgctl",
|
||||
"msgget",
|
||||
"msgrcv",
|
||||
"msgsnd",
|
||||
"msync",
|
||||
"munlock",
|
||||
"munlockall",
|
||||
"munmap",
|
||||
"nanosleep",
|
||||
"newfstatat",
|
||||
"open",
|
||||
"openat",
|
||||
"openat2",
|
||||
"pause",
|
||||
"pidfd_open",
|
||||
"pidfd_send_signal",
|
||||
"pipe",
|
||||
"pipe2",
|
||||
"pivot_root",
|
||||
"poll",
|
||||
"ppoll",
|
||||
"ppoll_time64",
|
||||
"prctl",
|
||||
"pread64",
|
||||
"preadv",
|
||||
"preadv2",
|
||||
"prlimit64",
|
||||
"pselect6",
|
||||
"pselect6_time64",
|
||||
"pwrite64",
|
||||
"pwritev",
|
||||
"pwritev2",
|
||||
"read",
|
||||
"readahead",
|
||||
"readdir",
|
||||
"readlink",
|
||||
"readlinkat",
|
||||
"readv",
|
||||
"recv",
|
||||
"recvfrom",
|
||||
"recvmmsg",
|
||||
"recvmmsg_time64",
|
||||
"recvmsg",
|
||||
"remap_file_pages",
|
||||
"removexattr",
|
||||
"rename",
|
||||
"renameat",
|
||||
"renameat2",
|
||||
"restart_syscall",
|
||||
"rmdir",
|
||||
"rseq",
|
||||
"rt_sigaction",
|
||||
"rt_sigpending",
|
||||
"rt_sigprocmask",
|
||||
"rt_sigqueueinfo",
|
||||
"rt_sigreturn",
|
||||
"rt_sigsuspend",
|
||||
"rt_sigtimedwait",
|
||||
"rt_sigtimedwait_time64",
|
||||
"rt_tgsigqueueinfo",
|
||||
"sched_getaffinity",
|
||||
"sched_getattr",
|
||||
"sched_getparam",
|
||||
"sched_get_priority_max",
|
||||
"sched_get_priority_min",
|
||||
"sched_getscheduler",
|
||||
"sched_rr_get_interval",
|
||||
"sched_rr_get_interval_time64",
|
||||
"sched_setaffinity",
|
||||
"sched_setattr",
|
||||
"sched_setparam",
|
||||
"sched_setscheduler",
|
||||
"sched_yield",
|
||||
"seccomp",
|
||||
"select",
|
||||
"semctl",
|
||||
"semget",
|
||||
"semop",
|
||||
"semtimedop",
|
||||
"semtimedop_time64",
|
||||
"send",
|
||||
"sendfile",
|
||||
"sendfile64",
|
||||
"sendmmsg",
|
||||
"sendmsg",
|
||||
"sendto",
|
||||
"setfsgid",
|
||||
"setfsgid32",
|
||||
"setfsuid",
|
||||
"setfsuid32",
|
||||
"setgid",
|
||||
"setgid32",
|
||||
"setgroups",
|
||||
"setgroups32",
|
||||
"setitimer",
|
||||
"setpgid",
|
||||
"setpriority",
|
||||
"setregid",
|
||||
"setregid32",
|
||||
"setresgid",
|
||||
"setresgid32",
|
||||
"setresuid",
|
||||
"setresuid32",
|
||||
"setreuid",
|
||||
"setreuid32",
|
||||
"setrlimit",
|
||||
"set_robust_list",
|
||||
"setsid",
|
||||
"setsockopt",
|
||||
"set_thread_area",
|
||||
"set_tid_address",
|
||||
"setuid",
|
||||
"setuid32",
|
||||
"setxattr",
|
||||
"shmat",
|
||||
"shmctl",
|
||||
"shmdt",
|
||||
"shmget",
|
||||
"shutdown",
|
||||
"sigaltstack",
|
||||
"signalfd",
|
||||
"signalfd4",
|
||||
"sigpending",
|
||||
"sigprocmask",
|
||||
"sigreturn",
|
||||
"socket",
|
||||
"socketcall",
|
||||
"socketpair",
|
||||
"splice",
|
||||
"stat",
|
||||
"stat64",
|
||||
"statfs",
|
||||
"statfs64",
|
||||
"statx",
|
||||
"symlink",
|
||||
"symlinkat",
|
||||
"sync",
|
||||
"sync_file_range",
|
||||
"syncfs",
|
||||
"sysinfo",
|
||||
"tee",
|
||||
"tgkill",
|
||||
"time",
|
||||
"timer_create",
|
||||
"timer_delete",
|
||||
"timer_getoverrun",
|
||||
"timer_gettime",
|
||||
"timer_gettime64",
|
||||
"timer_settime",
|
||||
"timer_settime64",
|
||||
"timerfd_create",
|
||||
"timerfd_gettime",
|
||||
"timerfd_gettime64",
|
||||
"timerfd_settime",
|
||||
"timerfd_settime64",
|
||||
"times",
|
||||
"tkill",
|
||||
"truncate",
|
||||
"truncate64",
|
||||
"ugetrlimit",
|
||||
"umask",
|
||||
"uname",
|
||||
"unlink",
|
||||
"unlinkat",
|
||||
"utime",
|
||||
"utimensat",
|
||||
"utimensat_time64",
|
||||
"utimes",
|
||||
"vfork",
|
||||
"wait4",
|
||||
"waitid",
|
||||
"waitpid",
|
||||
"write",
|
||||
"writev"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW"
|
||||
},
|
||||
{
|
||||
"names": [
|
||||
"personality"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW",
|
||||
"args": [
|
||||
{
|
||||
"index": 0,
|
||||
"value": 0,
|
||||
"op": "SCMP_CMP_EQ"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"names": [
|
||||
"personality"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW",
|
||||
"args": [
|
||||
{
|
||||
"index": 0,
|
||||
"value": 8,
|
||||
"op": "SCMP_CMP_EQ"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"names": [
|
||||
"personality"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW",
|
||||
"args": [
|
||||
{
|
||||
"index": 0,
|
||||
"value": 131072,
|
||||
"op": "SCMP_CMP_EQ"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"names": [
|
||||
"personality"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW",
|
||||
"args": [
|
||||
{
|
||||
"index": 0,
|
||||
"value": 131073,
|
||||
"op": "SCMP_CMP_EQ"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"names": [
|
||||
"personality"
|
||||
],
|
||||
"action": "SCMP_ACT_ALLOW",
|
||||
"args": [
|
||||
{
|
||||
"index": 0,
|
||||
"value": 4294967295,
|
||||
"op": "SCMP_CMP_EQ"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
|
@ -25,28 +27,41 @@ const (
|
|||
EventJupyterStop EventType = "jupyter_stop"
|
||||
EventExperimentCreated EventType = "experiment_created"
|
||||
EventExperimentDeleted EventType = "experiment_deleted"
|
||||
|
||||
// HIPAA-specific file access events
|
||||
EventFileRead EventType = "file_read"
|
||||
EventFileWrite EventType = "file_write"
|
||||
EventFileDelete EventType = "file_delete"
|
||||
EventDatasetAccess EventType = "dataset_access"
|
||||
)
|
||||
|
||||
// Event represents an audit log event
|
||||
// Event represents an audit log event with integrity chain
|
||||
type Event struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
EventType EventType `json:"event_type"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
Action string `json:"action,omitempty"`
|
||||
Resource string `json:"resource,omitempty"` // File path, dataset ID, etc.
|
||||
Action string `json:"action,omitempty"` // read, write, delete
|
||||
Success bool `json:"success"`
|
||||
ErrorMsg string `json:"error,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
|
||||
// Integrity chain fields for tamper-evident logging (HIPAA requirement)
|
||||
PrevHash string `json:"prev_hash,omitempty"` // SHA-256 of previous event
|
||||
EventHash string `json:"event_hash,omitempty"` // SHA-256 of this event
|
||||
SequenceNum int64 `json:"sequence_num,omitempty"`
|
||||
}
|
||||
|
||||
// Logger handles audit logging
|
||||
// Logger handles audit logging with integrity chain
|
||||
type Logger struct {
|
||||
enabled bool
|
||||
filePath string
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
logger *logging.Logger
|
||||
enabled bool
|
||||
filePath string
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
logger *logging.Logger
|
||||
lastHash string
|
||||
sequenceNum int64
|
||||
}
|
||||
|
||||
// NewLogger creates a new audit logger
|
||||
|
|
@ -68,7 +83,7 @@ func NewLogger(enabled bool, filePath string, logger *logging.Logger) (*Logger,
|
|||
return al, nil
|
||||
}
|
||||
|
||||
// Log logs an audit event
|
||||
// Log logs an audit event with integrity chain
|
||||
func (al *Logger) Log(event Event) {
|
||||
if !al.enabled {
|
||||
return
|
||||
|
|
@ -79,6 +94,15 @@ func (al *Logger) Log(event Event) {
|
|||
al.mu.Lock()
|
||||
defer al.mu.Unlock()
|
||||
|
||||
// Set sequence number and previous hash for integrity chain
|
||||
al.sequenceNum++
|
||||
event.SequenceNum = al.sequenceNum
|
||||
event.PrevHash = al.lastHash
|
||||
|
||||
// Calculate hash of this event for tamper evidence
|
||||
event.EventHash = al.calculateEventHash(event)
|
||||
al.lastHash = event.EventHash
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
|
@ -103,10 +127,88 @@ func (al *Logger) Log(event Event) {
|
|||
"user_id", event.UserID,
|
||||
"resource", event.Resource,
|
||||
"success", event.Success,
|
||||
"seq", event.SequenceNum,
|
||||
"hash", event.EventHash[:16], // Log first 16 chars of hash
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateEventHash computes SHA-256 hash of event data for integrity chain
|
||||
func (al *Logger) calculateEventHash(event Event) string {
|
||||
// Create a copy without the hash field for hashing
|
||||
eventCopy := event
|
||||
eventCopy.EventHash = ""
|
||||
eventCopy.PrevHash = ""
|
||||
|
||||
data, err := json.Marshal(eventCopy)
|
||||
if err != nil {
|
||||
// Fallback: hash the timestamp and type
|
||||
data = []byte(fmt.Sprintf("%s:%s:%d", event.Timestamp, event.EventType, event.SequenceNum))
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// LogFileAccess logs a file access operation (HIPAA requirement)
|
||||
func (al *Logger) LogFileAccess(
|
||||
eventType EventType,
|
||||
userID, filePath, ipAddr string,
|
||||
success bool,
|
||||
errMsg string,
|
||||
) {
|
||||
action := "read"
|
||||
switch eventType {
|
||||
case EventFileWrite:
|
||||
action = "write"
|
||||
case EventFileDelete:
|
||||
action = "delete"
|
||||
}
|
||||
|
||||
al.Log(Event{
|
||||
EventType: eventType,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddr,
|
||||
Resource: filePath,
|
||||
Action: action,
|
||||
Success: success,
|
||||
ErrorMsg: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyChain checks the integrity of the audit log chain
|
||||
// Returns the first sequence number where tampering is detected, or -1 if valid
|
||||
func (al *Logger) VerifyChain(events []Event) (tamperedSeq int, err error) {
|
||||
if len(events) == 0 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
var expectedPrevHash string
|
||||
|
||||
for _, event := range events {
|
||||
// Verify previous hash chain
|
||||
if event.SequenceNum > 1 && event.PrevHash != expectedPrevHash {
|
||||
return int(event.SequenceNum), fmt.Errorf(
|
||||
"chain break at sequence %d: expected prev_hash=%s, got %s",
|
||||
event.SequenceNum, expectedPrevHash, event.PrevHash,
|
||||
)
|
||||
}
|
||||
|
||||
// Verify event hash
|
||||
expectedHash := al.calculateEventHash(event)
|
||||
if event.EventHash != expectedHash {
|
||||
return int(event.SequenceNum), fmt.Errorf(
|
||||
"hash mismatch at sequence %d: expected %s, got %s",
|
||||
event.SequenceNum, expectedHash, event.EventHash,
|
||||
)
|
||||
}
|
||||
|
||||
expectedPrevHash = event.EventHash
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// LogAuthAttempt logs an authentication attempt
|
||||
func (al *Logger) LogAuthAttempt(userID, ipAddr string, success bool, errMsg string) {
|
||||
eventType := EventAuthSuccess
|
||||
|
|
|
|||
|
|
@ -336,8 +336,161 @@ func PodmanResourceOverrides(cpu int, memoryGB int) (cpus string, memory string)
|
|||
return cpus, memory
|
||||
}
|
||||
|
||||
// BuildPodmanCommand builds a Podman command for executing ML experiments
|
||||
// PodmanSecurityConfig holds security configuration for Podman containers
|
||||
type PodmanSecurityConfig struct {
|
||||
NoNewPrivileges bool
|
||||
DropAllCaps bool
|
||||
AllowedCaps []string
|
||||
UserNS bool
|
||||
RunAsUID int
|
||||
RunAsGID int
|
||||
SeccompProfile string
|
||||
ReadOnlyRoot bool
|
||||
NetworkMode string
|
||||
}
|
||||
|
||||
// BuildSecurityArgs builds security-related podman arguments from PodmanSecurityConfig
|
||||
func BuildSecurityArgs(sandbox PodmanSecurityConfig) []string {
|
||||
args := []string{}
|
||||
|
||||
// No new privileges
|
||||
if sandbox.NoNewPrivileges {
|
||||
args = append(args, "--security-opt", "no-new-privileges:true")
|
||||
}
|
||||
|
||||
// Capability dropping
|
||||
if sandbox.DropAllCaps {
|
||||
args = append(args, "--cap-drop=all")
|
||||
for _, cap := range sandbox.AllowedCaps {
|
||||
if cap != "" {
|
||||
args = append(args, "--cap-add="+cap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User namespace mapping
|
||||
if sandbox.UserNS && sandbox.RunAsUID > 0 && sandbox.RunAsGID > 0 {
|
||||
// Map container root to specified UID/GID on host
|
||||
args = append(args, "--userns", "keep-id")
|
||||
args = append(args, "--user", fmt.Sprintf("%d:%d", sandbox.RunAsUID, sandbox.RunAsGID))
|
||||
}
|
||||
|
||||
// Seccomp profile
|
||||
if sandbox.SeccompProfile != "" && sandbox.SeccompProfile != "unconfined" {
|
||||
profilePath := GetSeccompProfilePath(sandbox.SeccompProfile)
|
||||
if profilePath != "" {
|
||||
args = append(args, "--security-opt", fmt.Sprintf("seccomp=%s", profilePath))
|
||||
}
|
||||
}
|
||||
|
||||
// Read-only root filesystem
|
||||
if sandbox.ReadOnlyRoot {
|
||||
args = append(args, "--read-only")
|
||||
}
|
||||
|
||||
// Network mode (default: none)
|
||||
networkMode := sandbox.NetworkMode
|
||||
if networkMode == "" {
|
||||
networkMode = "none"
|
||||
}
|
||||
args = append(args, "--network", networkMode)
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// GetSeccompProfilePath returns the filesystem path for a named seccomp profile
|
||||
func GetSeccompProfilePath(profileName string) string {
|
||||
// Check standard locations
|
||||
searchPaths := []string{
|
||||
filepath.Join("configs", "seccomp", profileName+".json"),
|
||||
filepath.Join("/etc", "fetchml", "seccomp", profileName+".json"),
|
||||
filepath.Join("/usr", "share", "fetchml", "seccomp", profileName+".json"),
|
||||
}
|
||||
|
||||
for _, path := range searchPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
// If profileName is already a path, return it
|
||||
if filepath.IsAbs(profileName) {
|
||||
return profileName
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildPodmanCommand builds a Podman command for executing ML experiments with security options
|
||||
func BuildPodmanCommand(
|
||||
ctx context.Context,
|
||||
cfg PodmanConfig,
|
||||
sandbox PodmanSecurityConfig,
|
||||
scriptPath, depsPath string,
|
||||
extraArgs []string,
|
||||
) *exec.Cmd {
|
||||
args := []string{"run", "--rm"}
|
||||
|
||||
// Add security options from sandbox config
|
||||
securityArgs := BuildSecurityArgs(sandbox)
|
||||
args = append(args, securityArgs...)
|
||||
|
||||
// Resource limits
|
||||
if cfg.Memory != "" {
|
||||
args = append(args, "--memory", cfg.Memory)
|
||||
} else {
|
||||
args = append(args, "--memory", config.DefaultPodmanMemory)
|
||||
}
|
||||
|
||||
if cfg.CPUs != "" {
|
||||
args = append(args, "--cpus", cfg.CPUs)
|
||||
} else {
|
||||
args = append(args, "--cpus", config.DefaultPodmanCPUs)
|
||||
}
|
||||
|
||||
// Mount workspace
|
||||
workspaceMount := fmt.Sprintf("%s:%s:rw", cfg.Workspace, cfg.ContainerWorkspace)
|
||||
args = append(args, "-v", workspaceMount)
|
||||
|
||||
// Mount results
|
||||
resultsMount := fmt.Sprintf("%s:%s:rw", cfg.Results, cfg.ContainerResults)
|
||||
args = append(args, "-v", resultsMount)
|
||||
|
||||
// Mount additional volumes
|
||||
for hostPath, containerPath := range cfg.Volumes {
|
||||
mount := fmt.Sprintf("%s:%s", hostPath, containerPath)
|
||||
args = append(args, "-v", mount)
|
||||
}
|
||||
|
||||
// Use injected GPU device paths for Apple GPU or custom configurations
|
||||
for _, device := range cfg.GPUDevices {
|
||||
args = append(args, "--device", device)
|
||||
}
|
||||
|
||||
// Add environment variables
|
||||
for key, value := range cfg.Env {
|
||||
args = append(args, "-e", fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
|
||||
// Image and command
|
||||
args = append(args, cfg.Image,
|
||||
"--workspace", cfg.ContainerWorkspace,
|
||||
"--deps", depsPath,
|
||||
"--script", scriptPath,
|
||||
)
|
||||
|
||||
// Add extra arguments via --args flag
|
||||
if len(extraArgs) > 0 {
|
||||
args = append(args, "--args")
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
return exec.CommandContext(ctx, "podman", args...)
|
||||
}
|
||||
|
||||
// BuildPodmanCommandLegacy builds a Podman command using legacy security settings
|
||||
// Deprecated: Use BuildPodmanCommand with SandboxConfig instead
|
||||
func BuildPodmanCommandLegacy(
|
||||
ctx context.Context,
|
||||
cfg PodmanConfig,
|
||||
scriptPath, depsPath string,
|
||||
|
|
|
|||
229
internal/fileutil/filetype.go
Normal file
229
internal/fileutil/filetype.go
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
// Package fileutil provides secure file operation utilities to prevent path traversal attacks.
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FileType represents a known file type with its magic bytes
|
||||
type FileType struct {
|
||||
Name string
|
||||
MagicBytes []byte
|
||||
Extensions []string
|
||||
Description string
|
||||
}
|
||||
|
||||
// Known file types for ML artifacts
|
||||
var (
|
||||
// SafeTensor uses ZIP format
|
||||
SafeTensors = FileType{
|
||||
Name: "safetensors",
|
||||
MagicBytes: []byte{0x50, 0x4B, 0x03, 0x04}, // ZIP header
|
||||
Extensions: []string{".safetensors"},
|
||||
Description: "SafeTensors model format",
|
||||
}
|
||||
|
||||
// GGUF format
|
||||
GGUF = FileType{
|
||||
Name: "gguf",
|
||||
MagicBytes: []byte{0x47, 0x47, 0x55, 0x46}, // "GGUF"
|
||||
Extensions: []string{".gguf"},
|
||||
Description: "GGML/GGUF model format",
|
||||
}
|
||||
|
||||
// HDF5 format
|
||||
HDF5 = FileType{
|
||||
Name: "hdf5",
|
||||
MagicBytes: []byte{0x89, 0x48, 0x44, 0x46}, // HDF5 signature
|
||||
Extensions: []string{".h5", ".hdf5", ".hdf"},
|
||||
Description: "HDF5 data format",
|
||||
}
|
||||
|
||||
// NumPy format
|
||||
NumPy = FileType{
|
||||
Name: "numpy",
|
||||
MagicBytes: []byte{0x93, 0x4E, 0x55, 0x4D}, // NUMPY magic
|
||||
Extensions: []string{".npy"},
|
||||
Description: "NumPy array format",
|
||||
}
|
||||
|
||||
// JSON format
|
||||
JSON = FileType{
|
||||
Name: "json",
|
||||
MagicBytes: []byte{0x7B}, // "{"
|
||||
Extensions: []string{".json"},
|
||||
Description: "JSON data format",
|
||||
}
|
||||
|
||||
// CSV format (text-based, no reliable magic bytes)
|
||||
CSV = FileType{
|
||||
Name: "csv",
|
||||
MagicBytes: nil, // Text-based, validated by content inspection
|
||||
Extensions: []string{".csv"},
|
||||
Description: "CSV data format",
|
||||
}
|
||||
|
||||
// YAML format (text-based)
|
||||
YAML = FileType{
|
||||
Name: "yaml",
|
||||
MagicBytes: nil, // Text-based
|
||||
Extensions: []string{".yaml", ".yml"},
|
||||
Description: "YAML configuration format",
|
||||
}
|
||||
|
||||
// Text format
|
||||
Text = FileType{
|
||||
Name: "text",
|
||||
MagicBytes: nil, // Text-based
|
||||
Extensions: []string{".txt", ".md", ".rst"},
|
||||
Description: "Plain text format",
|
||||
}
|
||||
|
||||
// AllAllowedTypes contains all types that are permitted for upload
|
||||
AllAllowedTypes = []FileType{SafeTensors, GGUF, HDF5, NumPy, JSON, CSV, YAML, Text}
|
||||
|
||||
// BinaryModelTypes contains binary model formats only
|
||||
BinaryModelTypes = []FileType{SafeTensors, GGUF, HDF5, NumPy}
|
||||
)
|
||||
|
||||
// DangerousExtensions are file extensions that should be rejected immediately
|
||||
var DangerousExtensions = []string{
|
||||
".pt", ".pkl", ".pickle", // PyTorch pickle - arbitrary code execution
|
||||
".pth", // PyTorch state dict (often pickle-based)
|
||||
".joblib", // scikit-learn pickle format
|
||||
".exe", ".dll", ".so", ".dylib", // Executables
|
||||
".sh", ".bat", ".cmd", ".ps1", // Scripts
|
||||
".zip", ".tar", ".gz", ".bz2", ".xz", // Archives (may contain malicious files)
|
||||
".rar", ".7z",
|
||||
}
|
||||
|
||||
// ValidateFileType checks if a file matches an allowed type using magic bytes validation.
|
||||
// Returns the detected type or an error if the file is not allowed.
|
||||
func ValidateFileType(filePath string, allowedTypes []FileType) (*FileType, error) {
|
||||
// First check extension-based rejection
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
for _, dangerous := range DangerousExtensions {
|
||||
if ext == dangerous {
|
||||
return nil, fmt.Errorf("file type not allowed (dangerous extension): %s", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// Open and read the file
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file for type validation: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read first 8 bytes for magic byte detection
|
||||
header := make([]byte, 8)
|
||||
n, err := file.Read(header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file header: %w", err)
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// Try to match by magic bytes first
|
||||
for _, ft := range allowedTypes {
|
||||
if len(ft.MagicBytes) > 0 && len(header) >= len(ft.MagicBytes) {
|
||||
if bytes.Equal(header[:len(ft.MagicBytes)], ft.MagicBytes) {
|
||||
return &ft, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For text-based formats, validate by extension and content
|
||||
for _, ft := range allowedTypes {
|
||||
if ft.MagicBytes == nil {
|
||||
// Check if extension matches
|
||||
for _, allowedExt := range ft.Extensions {
|
||||
if ext == allowedExt {
|
||||
// Additional content validation for text files
|
||||
if err := validateTextContent(filePath, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ft, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("file type not recognized or not in allowed list")
|
||||
}
|
||||
|
||||
// validateTextContent performs basic validation on text files
|
||||
func validateTextContent(filePath string, ft FileType) error {
|
||||
// Read a sample of the file
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read text file: %w", err)
|
||||
}
|
||||
|
||||
// Check for null bytes (indicates binary content)
|
||||
if bytes.Contains(data, []byte{0x00}) {
|
||||
return fmt.Errorf("file contains null bytes, not valid %s", ft.Name)
|
||||
}
|
||||
|
||||
// For JSON, validate it can be parsed
|
||||
if ft.Name == "json" {
|
||||
// Basic JSON validation - check for valid JSON structure
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if len(trimmed) == 0 {
|
||||
return fmt.Errorf("empty JSON file")
|
||||
}
|
||||
if (trimmed[0] != '{' && trimmed[0] != '[') ||
|
||||
(trimmed[len(trimmed)-1] != '}' && trimmed[len(trimmed)-1] != ']') {
|
||||
return fmt.Errorf("invalid JSON structure")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAllowedExtension checks if a file extension is in the allowed list
|
||||
func IsAllowedExtension(filePath string, allowedTypes []FileType) bool {
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
|
||||
// Check against dangerous extensions first
|
||||
for _, dangerous := range DangerousExtensions {
|
||||
if ext == dangerous {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check against allowed types
|
||||
for _, ft := range allowedTypes {
|
||||
for _, allowedExt := range ft.Extensions {
|
||||
if ext == allowedExt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateDatasetFile validates a dataset file for safe formats
|
||||
func ValidateDatasetFile(filePath string) error {
|
||||
_, err := ValidateFileType(filePath, AllAllowedTypes)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateModelFile validates a model file for safe binary formats only
|
||||
func ValidateModelFile(filePath string) error {
|
||||
ft, err := ValidateFileType(filePath, BinaryModelTypes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Additional check: ensure it's actually a model format, not just a matching extension
|
||||
if ft == nil {
|
||||
return fmt.Errorf("file type validation returned nil type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2,21 +2,134 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SecureFileRead securely reads a file after cleaning the path to prevent path traversal
|
||||
// SecurePathValidator provides path traversal protection with symlink resolution.
|
||||
type SecurePathValidator struct {
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewSecurePathValidator creates a new path validator for a base directory.
|
||||
func NewSecurePathValidator(basePath string) *SecurePathValidator {
|
||||
return &SecurePathValidator{BasePath: basePath}
|
||||
}
|
||||
|
||||
// ValidatePath ensures resolved path is within base directory.
|
||||
// It resolves symlinks and returns the canonical absolute path.
|
||||
func (v *SecurePathValidator) ValidatePath(inputPath string) (string, error) {
|
||||
if v.BasePath == "" {
|
||||
return "", fmt.Errorf("base path not set")
|
||||
}
|
||||
|
||||
// Clean the path to remove . and ..
|
||||
cleaned := filepath.Clean(inputPath)
|
||||
|
||||
// Get absolute base path and resolve any symlinks (critical for macOS /tmp -> /private/tmp)
|
||||
baseAbs, err := filepath.Abs(v.BasePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get absolute base path: %w", err)
|
||||
}
|
||||
// Resolve symlinks in base path for accurate comparison
|
||||
baseResolved, err := filepath.EvalSymlinks(baseAbs)
|
||||
if err != nil {
|
||||
// Base path may not exist yet, use as-is
|
||||
baseResolved = baseAbs
|
||||
}
|
||||
|
||||
// If cleaned is already absolute, check if it's within base
|
||||
var absPath string
|
||||
if filepath.IsAbs(cleaned) {
|
||||
absPath = cleaned
|
||||
} else {
|
||||
// Join with base path if relative
|
||||
absPath = filepath.Join(baseAbs, cleaned)
|
||||
}
|
||||
|
||||
// Resolve symlinks - critical for security
|
||||
resolved, err := filepath.EvalSymlinks(absPath)
|
||||
if err != nil {
|
||||
// If the file doesn't exist, we still need to check the directory path
|
||||
// Try to resolve the parent directory
|
||||
dir := filepath.Dir(absPath)
|
||||
resolvedDir, dirErr := filepath.EvalSymlinks(dir)
|
||||
if dirErr != nil {
|
||||
return "", fmt.Errorf("path resolution failed: %w", err)
|
||||
}
|
||||
// Reconstruct the path with resolved directory
|
||||
base := filepath.Base(absPath)
|
||||
resolved = filepath.Join(resolvedDir, base)
|
||||
}
|
||||
|
||||
// Get absolute resolved path
|
||||
resolvedAbs, err := filepath.Abs(resolved)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get absolute resolved path: %w", err)
|
||||
}
|
||||
|
||||
// Verify within base directory (must have path separator after base to prevent prefix match issues)
|
||||
baseWithSep := baseResolved + string(filepath.Separator)
|
||||
if resolvedAbs != baseResolved && !strings.HasPrefix(resolvedAbs+string(filepath.Separator), baseWithSep) {
|
||||
return "", fmt.Errorf("path escapes base directory: %s (resolved to %s, base is %s)", inputPath, resolvedAbs, baseResolved)
|
||||
}
|
||||
|
||||
return resolvedAbs, nil
|
||||
}
|
||||
|
||||
// SecureFileRead securely reads a file after cleaning the path to prevent path traversal.
|
||||
func SecureFileRead(path string) ([]byte, error) {
|
||||
return os.ReadFile(filepath.Clean(path))
|
||||
}
|
||||
|
||||
// SecureFileWrite securely writes a file after cleaning the path to prevent path traversal
|
||||
// SecureFileWrite securely writes a file after cleaning the path to prevent path traversal.
|
||||
func SecureFileWrite(path string, data []byte, perm os.FileMode) error {
|
||||
return os.WriteFile(filepath.Clean(path), data, perm)
|
||||
}
|
||||
|
||||
// SecureOpenFile securely opens a file after cleaning the path to prevent path traversal
|
||||
// SecureOpenFile securely opens a file after cleaning the path to prevent path traversal.
|
||||
func SecureOpenFile(path string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
return os.OpenFile(filepath.Clean(path), flag, perm)
|
||||
}
|
||||
|
||||
// SecureReadDir reads directory contents with path validation.
|
||||
func (v *SecurePathValidator) SecureReadDir(dirPath string) ([]os.DirEntry, error) {
|
||||
validatedPath, err := v.ValidatePath(dirPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("directory path validation failed: %w", err)
|
||||
}
|
||||
return os.ReadDir(validatedPath)
|
||||
}
|
||||
|
||||
// SecureCreateTemp creates a temporary file within the base directory.
|
||||
func (v *SecurePathValidator) SecureCreateTemp(pattern string) (*os.File, string, error) {
|
||||
validatedPath, err := v.ValidatePath("")
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("base directory validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Generate secure random suffix
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
randomSuffix := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
|
||||
// Create temp file
|
||||
if pattern == "" {
|
||||
pattern = "tmp"
|
||||
}
|
||||
fileName := fmt.Sprintf("%s_%s", pattern, randomSuffix)
|
||||
fullPath := filepath.Join(validatedPath, fileName)
|
||||
|
||||
file, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
|
||||
return file, fullPath, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -209,6 +209,38 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
|
|||
return jobs, nil
|
||||
}
|
||||
|
||||
// DeleteJob removes a job from the database by ID.
|
||||
func (db *DB) DeleteJob(id string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM jobs WHERE id = ?`
|
||||
} else {
|
||||
query = `DELETE FROM jobs WHERE id = $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete job: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteJobsByPrefix removes all jobs with IDs matching the given prefix.
|
||||
func (db *DB) DeleteJobsByPrefix(prefix string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM jobs WHERE id LIKE ?`
|
||||
} else {
|
||||
query = `DELETE FROM jobs WHERE id LIKE $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, prefix+"%")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete jobs by prefix: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterWorker registers or updates a worker in the database.
|
||||
func (db *DB) RegisterWorker(worker *Worker) error {
|
||||
metadataJSON, _ := json.Marshal(worker.Metadata)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
)
|
||||
|
||||
|
|
@ -17,26 +18,40 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error)
|
|||
return nil, fmt.Errorf("run dir is empty")
|
||||
}
|
||||
|
||||
// Validate and canonicalize the runDir before any operations
|
||||
validator := fileutil.NewSecurePathValidator(runDir)
|
||||
validatedRunDir, err := validator.ValidatePath("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid run directory: %w", err)
|
||||
}
|
||||
|
||||
var files []manifest.ArtifactFile
|
||||
var total int64
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := filepath.WalkDir(runDir, func(path string, d fs.DirEntry, err error) error {
|
||||
err = filepath.WalkDir(validatedRunDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if path == runDir {
|
||||
if path == validatedRunDir {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(runDir, path)
|
||||
// Security: Validate each path is still within runDir
|
||||
// This catches any symlink escapes or path traversal attempts during walk
|
||||
rel, err := filepath.Rel(validatedRunDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("path escape detected during artifact scan: %w", err)
|
||||
}
|
||||
rel = filepath.ToSlash(rel)
|
||||
|
||||
// Check for path traversal patterns in the relative path
|
||||
if strings.Contains(rel, "..") {
|
||||
return fmt.Errorf("path traversal attempt detected: %s", rel)
|
||||
}
|
||||
|
||||
// Standard exclusions (always apply)
|
||||
if rel == manifestFilename {
|
||||
return nil
|
||||
|
|
@ -63,6 +78,7 @@ func scanArtifacts(runDir string, includeAll bool) (*manifest.Artifacts, error)
|
|||
return nil
|
||||
}
|
||||
if d.Type()&fs.ModeSymlink != 0 {
|
||||
// Skip symlinks - they could point outside the directory
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,12 +131,54 @@ type AppleGPUConfig struct {
|
|||
|
||||
// SandboxConfig holds container sandbox settings
|
||||
type SandboxConfig struct {
|
||||
NetworkMode string `yaml:"network_mode"` // "none", "slirp4netns", "bridge"
|
||||
ReadOnlyRoot bool `yaml:"read_only_root"`
|
||||
AllowSecrets bool `yaml:"allow_secrets"`
|
||||
NetworkMode string `yaml:"network_mode"` // Default: "none"
|
||||
ReadOnlyRoot bool `yaml:"read_only_root"` // Default: true
|
||||
AllowSecrets bool `yaml:"allow_secrets"` // Default: false
|
||||
AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"]
|
||||
SeccompProfile string `yaml:"seccomp_profile"`
|
||||
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
|
||||
MaxRuntimeHours int `yaml:"max_runtime_hours"`
|
||||
|
||||
// Security hardening options (NEW)
|
||||
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
|
||||
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
|
||||
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
|
||||
UserNS bool `yaml:"user_ns"` // Default: true
|
||||
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
|
||||
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
|
||||
|
||||
// Upload limits (NEW)
|
||||
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
|
||||
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
|
||||
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
|
||||
}
|
||||
|
||||
// SecurityDefaults holds default values for security configuration
|
||||
var SecurityDefaults = struct {
|
||||
NetworkMode string
|
||||
ReadOnlyRoot bool
|
||||
AllowSecrets bool
|
||||
SeccompProfile string
|
||||
NoNewPrivileges bool
|
||||
DropAllCaps bool
|
||||
UserNS bool
|
||||
RunAsUID int
|
||||
RunAsGID int
|
||||
MaxUploadSizeBytes int64
|
||||
MaxUploadRateBps int64
|
||||
MaxUploadsPerMinute int
|
||||
}{
|
||||
NetworkMode: "none",
|
||||
ReadOnlyRoot: true,
|
||||
AllowSecrets: false,
|
||||
SeccompProfile: "default-hardened",
|
||||
NoNewPrivileges: true,
|
||||
DropAllCaps: true,
|
||||
UserNS: true,
|
||||
RunAsUID: 1000,
|
||||
RunAsGID: 1000,
|
||||
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
|
||||
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
|
||||
MaxUploadsPerMinute: 10,
|
||||
}
|
||||
|
||||
// Validate checks sandbox configuration
|
||||
|
|
@ -148,9 +190,87 @@ func (s *SandboxConfig) Validate() error {
|
|||
if s.MaxRuntimeHours < 0 {
|
||||
return fmt.Errorf("max_runtime_hours must be positive")
|
||||
}
|
||||
if s.MaxUploadSizeBytes < 0 {
|
||||
return fmt.Errorf("max_upload_size_bytes must be positive")
|
||||
}
|
||||
if s.MaxUploadRateBps < 0 {
|
||||
return fmt.Errorf("max_upload_rate_bps must be positive")
|
||||
}
|
||||
if s.MaxUploadsPerMinute < 0 {
|
||||
return fmt.Errorf("max_uploads_per_minute must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplySecurityDefaults applies secure default values to empty fields.
|
||||
// This implements the "secure by default" principle for HIPAA compliance.
|
||||
func (s *SandboxConfig) ApplySecurityDefaults() {
|
||||
// Network isolation: default to "none" (no network access)
|
||||
if s.NetworkMode == "" {
|
||||
s.NetworkMode = SecurityDefaults.NetworkMode
|
||||
}
|
||||
|
||||
// Read-only root filesystem
|
||||
if !s.ReadOnlyRoot {
|
||||
s.ReadOnlyRoot = SecurityDefaults.ReadOnlyRoot
|
||||
}
|
||||
|
||||
// Secrets disabled by default
|
||||
if !s.AllowSecrets {
|
||||
s.AllowSecrets = SecurityDefaults.AllowSecrets
|
||||
}
|
||||
|
||||
// Seccomp profile
|
||||
if s.SeccompProfile == "" {
|
||||
s.SeccompProfile = SecurityDefaults.SeccompProfile
|
||||
}
|
||||
|
||||
// No new privileges
|
||||
if !s.NoNewPrivileges {
|
||||
s.NoNewPrivileges = SecurityDefaults.NoNewPrivileges
|
||||
}
|
||||
|
||||
// Drop all capabilities
|
||||
if !s.DropAllCaps {
|
||||
s.DropAllCaps = SecurityDefaults.DropAllCaps
|
||||
}
|
||||
|
||||
// User namespace
|
||||
if !s.UserNS {
|
||||
s.UserNS = SecurityDefaults.UserNS
|
||||
}
|
||||
|
||||
// Default non-root UID/GID
|
||||
if s.RunAsUID == 0 {
|
||||
s.RunAsUID = SecurityDefaults.RunAsUID
|
||||
}
|
||||
if s.RunAsGID == 0 {
|
||||
s.RunAsGID = SecurityDefaults.RunAsGID
|
||||
}
|
||||
|
||||
// Upload limits
|
||||
if s.MaxUploadSizeBytes == 0 {
|
||||
s.MaxUploadSizeBytes = SecurityDefaults.MaxUploadSizeBytes
|
||||
}
|
||||
if s.MaxUploadRateBps == 0 {
|
||||
s.MaxUploadRateBps = SecurityDefaults.MaxUploadRateBps
|
||||
}
|
||||
if s.MaxUploadsPerMinute == 0 {
|
||||
s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute
|
||||
}
|
||||
}
|
||||
|
||||
// Getter methods for SandboxConfig interface
|
||||
func (s *SandboxConfig) GetNoNewPrivileges() bool { return s.NoNewPrivileges }
|
||||
func (s *SandboxConfig) GetDropAllCaps() bool { return s.DropAllCaps }
|
||||
func (s *SandboxConfig) GetAllowedCaps() []string { return s.AllowedCaps }
|
||||
func (s *SandboxConfig) GetUserNS() bool { return s.UserNS }
|
||||
func (s *SandboxConfig) GetRunAsUID() int { return s.RunAsUID }
|
||||
func (s *SandboxConfig) GetRunAsGID() int { return s.RunAsGID }
|
||||
func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
|
||||
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
|
||||
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
|
||||
|
||||
// LoadConfig loads worker configuration from a YAML file.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := fileutil.SecureFileRead(path)
|
||||
|
|
@ -291,6 +411,14 @@ func LoadConfig(path string) (*Config, error) {
|
|||
cfg.GracefulTimeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
// Apply security defaults to sandbox configuration
|
||||
cfg.Sandbox.ApplySecurityDefaults()
|
||||
|
||||
// Expand secrets from environment variables
|
||||
if err := cfg.expandSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("secrets expansion failed: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
|
|
@ -442,6 +570,125 @@ func (c *Config) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// expandSecrets replaces secret placeholders with environment variables
|
||||
func (c *Config) expandSecrets() error {
|
||||
// Expand Redis password from env if using ${...} syntax
|
||||
if strings.Contains(c.RedisPassword, "${") {
|
||||
c.RedisPassword = os.ExpandEnv(c.RedisPassword)
|
||||
}
|
||||
|
||||
// Expand SnapshotStore credentials
|
||||
if strings.Contains(c.SnapshotStore.AccessKey, "${") {
|
||||
c.SnapshotStore.AccessKey = os.ExpandEnv(c.SnapshotStore.AccessKey)
|
||||
}
|
||||
if strings.Contains(c.SnapshotStore.SecretKey, "${") {
|
||||
c.SnapshotStore.SecretKey = os.ExpandEnv(c.SnapshotStore.SecretKey)
|
||||
}
|
||||
if strings.Contains(c.SnapshotStore.SessionToken, "${") {
|
||||
c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken)
|
||||
}
|
||||
|
||||
// Validate no plaintext secrets remain in critical fields
|
||||
if err := c.validateNoPlaintextSecrets(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNoPlaintextSecrets checks that sensitive fields use env var references
|
||||
// rather than hardcoded plaintext values. This is a HIPAA compliance requirement.
|
||||
func (c *Config) validateNoPlaintextSecrets() error {
|
||||
// Fields that should use ${ENV_VAR} syntax instead of plaintext
|
||||
sensitiveFields := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"redis_password", c.RedisPassword},
|
||||
{"snapshot_store.access_key", c.SnapshotStore.AccessKey},
|
||||
{"snapshot_store.secret_key", c.SnapshotStore.SecretKey},
|
||||
{"snapshot_store.session_token", c.SnapshotStore.SessionToken},
|
||||
}
|
||||
|
||||
for _, field := range sensitiveFields {
|
||||
if field.value == "" {
|
||||
continue // Empty values are fine
|
||||
}
|
||||
|
||||
// Check if it looks like a plaintext secret (not env var reference)
|
||||
if !strings.HasPrefix(field.value, "${") && looksLikeSecret(field.value) {
|
||||
return fmt.Errorf(
|
||||
"%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+
|
||||
"use ${ENV_VAR} syntax to load from environment or secrets manager",
|
||||
field.name, len(field.value), calculateEntropy(field.value),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// looksLikeSecret heuristically detects if a string looks like a secret credential
|
||||
func looksLikeSecret(s string) bool {
|
||||
// Minimum length for secrets
|
||||
if len(s) < 16 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Calculate entropy to detect high-entropy strings (likely secrets)
|
||||
entropy := calculateEntropy(s)
|
||||
|
||||
// High entropy (>4 bits per char) combined with reasonable length suggests a secret
|
||||
if entropy > 4.0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for common secret patterns
|
||||
patterns := []string{
|
||||
"AKIA", // AWS Access Key ID prefix
|
||||
"ASIA", // AWS temporary credentials
|
||||
"ghp_", // GitHub personal access token
|
||||
"gho_", // GitHub OAuth token
|
||||
"glpat-", // GitLab PAT
|
||||
"sk-", // OpenAI/Stripe key prefix
|
||||
"sk_live_", // Stripe live key
|
||||
"sk_test_", // Stripe test key
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(s, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// calculateEntropy calculates Shannon entropy of a string in bits per character
|
||||
func calculateEntropy(s string) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Count character frequencies
|
||||
freq := make(map[rune]int)
|
||||
for _, r := range s {
|
||||
freq[r]++
|
||||
}
|
||||
|
||||
// Calculate entropy
|
||||
var entropy float64
|
||||
length := float64(len(s))
|
||||
for _, count := range freq {
|
||||
p := float64(count) / length
|
||||
if p > 0 {
|
||||
entropy -= p * math.Log2(p)
|
||||
}
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// envInt reads an integer from environment variable
|
||||
func envInt(name string) (int, bool) {
|
||||
v := strings.TrimSpace(os.Getenv(name))
|
||||
|
|
|
|||
|
|
@ -30,6 +30,20 @@ type ContainerConfig struct {
|
|||
TrainScript string
|
||||
BasePath string
|
||||
AppleGPUEnabled bool
|
||||
Sandbox SandboxConfig // NEW: Security configuration
|
||||
}
|
||||
|
||||
// SandboxConfig interface to avoid import cycle
|
||||
type SandboxConfig interface {
|
||||
GetNoNewPrivileges() bool
|
||||
GetDropAllCaps() bool
|
||||
GetAllowedCaps() []string
|
||||
GetUserNS() bool
|
||||
GetRunAsUID() int
|
||||
GetRunAsGID() int
|
||||
GetSeccompProfile() string
|
||||
GetReadOnlyRoot() bool
|
||||
GetNetworkMode() string
|
||||
}
|
||||
|
||||
// ContainerExecutor executes jobs in containers using podman
|
||||
|
|
@ -208,6 +222,7 @@ func (e *ContainerExecutor) teardownTracking(ctx context.Context, task *queue.Ta
|
|||
}
|
||||
|
||||
func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputDir string) map[string]string {
|
||||
_ = _outputDir
|
||||
volumes := make(map[string]string)
|
||||
|
||||
if val, ok := trackingEnv["TENSORBOARD_HOST_LOG_DIR"]; ok {
|
||||
|
|
@ -305,8 +320,20 @@ func (e *ContainerExecutor) runPodman(
|
|||
e.logger.Warn("failed to open log file for podman output", "path", env.LogFile, "error", err)
|
||||
}
|
||||
|
||||
// Build command
|
||||
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, scriptPath, depsPath, extraArgs)
|
||||
// Convert SandboxConfig to PodmanSecurityConfig
|
||||
securityConfig := container.PodmanSecurityConfig{
|
||||
NoNewPrivileges: e.config.Sandbox.GetNoNewPrivileges(),
|
||||
DropAllCaps: e.config.Sandbox.GetDropAllCaps(),
|
||||
AllowedCaps: e.config.Sandbox.GetAllowedCaps(),
|
||||
UserNS: e.config.Sandbox.GetUserNS(),
|
||||
RunAsUID: e.config.Sandbox.GetRunAsUID(),
|
||||
RunAsGID: e.config.Sandbox.GetRunAsGID(),
|
||||
SeccompProfile: e.config.Sandbox.GetSeccompProfile(),
|
||||
ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(),
|
||||
NetworkMode: e.config.Sandbox.GetNetworkMode(),
|
||||
}
|
||||
|
||||
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs)
|
||||
|
||||
// Update manifest
|
||||
if e.writer != nil {
|
||||
|
|
|
|||
Loading…
Reference in a new issue