package worker import ( "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "log/slog" "math" "net/url" "os" "path/filepath" "runtime" "strconv" "strings" "time" "github.com/google/uuid" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/tracking/factory" "gopkg.in/yaml.v3" ) const ( defaultMetricsFlushInterval = 500 * time.Millisecond datasetCacheDefaultTTL = 30 * time.Minute ) type QueueConfig struct { Backend string `yaml:"backend"` SQLitePath string `yaml:"sqlite_path"` FilesystemPath string `yaml:"filesystem_path"` FallbackToFilesystem bool `yaml:"fallback_to_filesystem"` } // Config holds worker configuration. type Config struct { Host string `yaml:"host"` User string `yaml:"user"` SSHKey string `yaml:"ssh_key"` Port int `yaml:"port"` BasePath string `yaml:"base_path"` Entrypoint string `yaml:"entrypoint"` RedisURL string `yaml:"redis_url"` RedisAddr string `yaml:"redis_addr"` RedisPassword string `yaml:"redis_password"` RedisDB int `yaml:"redis_db"` Queue QueueConfig `yaml:"queue"` KnownHosts string `yaml:"known_hosts"` WorkerID string `yaml:"worker_id"` MaxWorkers int `yaml:"max_workers"` PollInterval int `yaml:"poll_interval_seconds"` Resources config.ResourceConfig `yaml:"resources"` LocalMode bool `yaml:"local_mode"` // Authentication Auth auth.Config `yaml:"auth"` // Metrics exporter Metrics MetricsConfig `yaml:"metrics"` // Metrics buffering MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` // Data management DataManagerPath string `yaml:"data_manager_path"` AutoFetchData bool `yaml:"auto_fetch_data"` DataDir string `yaml:"data_dir"` DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"` // Provenance enforcement // Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort. ProvenanceBestEffort bool `yaml:"provenance_best_effort"` // Compliance mode: "hipaa", "standard", or empty // When "hipaa": enforces hard requirements at startup ComplianceMode string `yaml:"compliance_mode"` // Opt-in prewarming of next task artifacts (snapshot/datasets/env). PrewarmEnabled bool `yaml:"prewarm_enabled"` // Podman execution PodmanImage string `yaml:"podman_image"` ContainerWorkspace string `yaml:"container_workspace"` ContainerResults string `yaml:"container_results"` GPUDevices []string `yaml:"gpu_devices"` GPUVendor string `yaml:"gpu_vendor"` GPUVendorAutoDetected bool `yaml:"-"` // Set by LoadConfig when GPUVendor is auto-detected GPUVisibleDevices []int `yaml:"gpu_visible_devices"` GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"` // Apple M-series GPU configuration AppleGPU AppleGPUConfig `yaml:"apple_gpu"` // Task lease and retry settings TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min) HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min) MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min) // Mode determines how the worker operates: "standalone" or "distributed" Mode string `yaml:"mode"` // Scheduler configuration for distributed mode Scheduler SchedulerConfig `yaml:"scheduler"` // Plugins configuration Plugins map[string]factory.PluginConfig `yaml:"plugins"` // Sandboxing configuration Sandbox SandboxConfig `yaml:"sandbox"` } // MetricsConfig controls the Prometheus exporter. type MetricsConfig struct { Enabled bool `yaml:"enabled"` ListenAddr string `yaml:"listen_addr"` } type SnapshotStoreConfig struct { Enabled bool `yaml:"enabled"` Endpoint string `yaml:"endpoint"` Secure bool `yaml:"secure"` Region string `yaml:"region"` Bucket string `yaml:"bucket"` Prefix string `yaml:"prefix"` AccessKey string `yaml:"access_key"` SecretKey string `yaml:"secret_key"` SessionToken string `yaml:"session_token"` Timeout time.Duration `yaml:"timeout"` MaxRetries int `yaml:"max_retries"` } // SchedulerConfig holds configurable heartbeat and lease settings for distributed mode. type SchedulerConfig struct { Address string `yaml:"address"` Cert string `yaml:"cert"` Token string `yaml:"token"` HeartbeatIntervalSecs int `yaml:"heartbeat_interval_secs"` // default: 30 TaskLeaseDurationSecs int `yaml:"task_lease_duration_secs"` // default: 90 (3x heartbeat) } // Validate checks that lease and heartbeat settings are valid. // Enforces 2-10x ratio between lease duration and heartbeat interval. func (sc *SchedulerConfig) Validate() error { // Apply defaults if zero if sc.HeartbeatIntervalSecs == 0 { sc.HeartbeatIntervalSecs = 30 } if sc.TaskLeaseDurationSecs == 0 { sc.TaskLeaseDurationSecs = 90 } heartbeat := time.Duration(sc.HeartbeatIntervalSecs) * time.Second lease := time.Duration(sc.TaskLeaseDurationSecs) * time.Second if lease <= heartbeat { return fmt.Errorf( "task_lease_duration_secs (%s) must be greater than heartbeat_interval_secs (%s)", lease, heartbeat, ) } ratio := lease.Seconds() / heartbeat.Seconds() if ratio < 2.0 { return fmt.Errorf( "task_lease_duration_secs must be at least 2× heartbeat_interval_secs "+ "(got %.1f×) — too small a margin for transient network issues", ratio, ) } if ratio > 10.0 { return fmt.Errorf( "task_lease_duration_secs is %.1f× heartbeat_interval_secs — "+ "dead workers won't be detected for %s, consider reducing lease duration", ratio, lease, ) } return nil } type AppleGPUConfig struct { Enabled bool `yaml:"enabled"` MetalDevice string `yaml:"metal_device"` MPSRuntime string `yaml:"mps_runtime"` } // SandboxConfig holds container sandbox settings type SandboxConfig struct { NetworkMode string `yaml:"network_mode"` // Default: "none" ReadOnlyRoot bool `yaml:"read_only_root"` // Default: true AllowSecrets bool `yaml:"allow_secrets"` // Default: false AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"] SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened" MaxRuntimeHours int `yaml:"max_runtime_hours"` // Security hardening options NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true DropAllCaps bool `yaml:"drop_all_caps"` // Default: true AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back UserNS bool `yaml:"user_ns"` // Default: true RunAsUID int `yaml:"run_as_uid"` // Default: 1000 RunAsGID int `yaml:"run_as_gid"` // Default: 1000 // Process isolation MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100) MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024) DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100) TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID) TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID) // Upload limits MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10 // Artifact ingestion caps MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000 MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB } // SecurityDefaults holds default values for security configuration var SecurityDefaults = struct { NetworkMode string ReadOnlyRoot bool AllowSecrets bool SeccompProfile string NoNewPrivileges bool DropAllCaps bool UserNS bool RunAsUID int RunAsGID int MaxProcesses int MaxOpenFiles int DisableSwap bool OOMScoreAdj int MaxUploadSizeBytes int64 MaxUploadRateBps int64 MaxUploadsPerMinute int MaxArtifactFiles int MaxArtifactTotalBytes int64 }{ NetworkMode: "none", ReadOnlyRoot: true, AllowSecrets: false, SeccompProfile: "default-hardened", NoNewPrivileges: true, DropAllCaps: true, UserNS: true, RunAsUID: 1000, RunAsGID: 1000, MaxProcesses: 100, // Fork bomb protection MaxOpenFiles: 1024, // FD exhaustion protection DisableSwap: true, // Prevent swap exfiltration OOMScoreAdj: 100, // Lower OOM priority MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s MaxUploadsPerMinute: 10, MaxArtifactFiles: 10000, MaxArtifactTotalBytes: 100 * 1024 * 1024 * 1024, // 100GB } // Validate checks sandbox configuration func (s *SandboxConfig) Validate() error { validNetworks := map[string]bool{"none": true, "slirp4netns": true, "bridge": true, "": true} if !validNetworks[s.NetworkMode] { return fmt.Errorf("invalid network_mode: %s", s.NetworkMode) } if s.MaxRuntimeHours < 0 { return fmt.Errorf("max_runtime_hours must be positive") } if s.MaxUploadSizeBytes < 0 { return fmt.Errorf("max_upload_size_bytes must be positive") } if s.MaxUploadRateBps < 0 { return fmt.Errorf("max_upload_rate_bps must be positive") } if s.MaxUploadsPerMinute < 0 { return fmt.Errorf("max_uploads_per_minute must be positive") } if s.MaxArtifactFiles < 0 { return fmt.Errorf("max_artifact_files must be positive") } if s.MaxArtifactTotalBytes < 0 { return fmt.Errorf("max_artifact_total_bytes must be positive") } return nil } // ApplySecurityDefaults applies secure default values to empty fields. // This implements the "secure by default" principle for HIPAA compliance. func (s *SandboxConfig) ApplySecurityDefaults() { // Network isolation: default to "none" (no network access) if s.NetworkMode == "" { s.NetworkMode = SecurityDefaults.NetworkMode } // Read-only root filesystem if !s.ReadOnlyRoot { s.ReadOnlyRoot = SecurityDefaults.ReadOnlyRoot } // Secrets disabled by default if !s.AllowSecrets { s.AllowSecrets = SecurityDefaults.AllowSecrets } // Seccomp profile if s.SeccompProfile == "" { s.SeccompProfile = SecurityDefaults.SeccompProfile } // No new privileges if !s.NoNewPrivileges { s.NoNewPrivileges = SecurityDefaults.NoNewPrivileges } // Drop all capabilities if !s.DropAllCaps { s.DropAllCaps = SecurityDefaults.DropAllCaps } // User namespace if !s.UserNS { s.UserNS = SecurityDefaults.UserNS } // Default non-root UID/GID if s.RunAsUID == 0 { s.RunAsUID = SecurityDefaults.RunAsUID } if s.RunAsGID == 0 { s.RunAsGID = SecurityDefaults.RunAsGID } // Upload limits if s.MaxUploadSizeBytes == 0 { s.MaxUploadSizeBytes = SecurityDefaults.MaxUploadSizeBytes } if s.MaxUploadRateBps == 0 { s.MaxUploadRateBps = SecurityDefaults.MaxUploadRateBps } if s.MaxUploadsPerMinute == 0 { s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute } // Artifact ingestion caps if s.MaxArtifactFiles == 0 { s.MaxArtifactFiles = SecurityDefaults.MaxArtifactFiles } if s.MaxArtifactTotalBytes == 0 { s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes } // Process isolation defaults if s.MaxProcesses == 0 { s.MaxProcesses = SecurityDefaults.MaxProcesses } if s.MaxOpenFiles == 0 { s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles } if !s.DisableSwap { s.DisableSwap = SecurityDefaults.DisableSwap } if s.OOMScoreAdj == 0 { s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj } // TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID") // Only override if explicitly set (> 0) if s.TaskUID < 0 { s.TaskUID = 0 } if s.TaskGID < 0 { s.TaskGID = 0 } } // GetProcessIsolationFlags returns the effective UID/GID for a task // If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) { uid = s.RunAsUID gid = s.RunAsGID if s.TaskUID > 0 { uid = s.TaskUID } if s.TaskGID > 0 { gid = s.TaskGID } return uid, gid } // Getter methods for SandboxConfig interface func (s *SandboxConfig) GetNoNewPrivileges() bool { return s.NoNewPrivileges } func (s *SandboxConfig) GetDropAllCaps() bool { return s.DropAllCaps } func (s *SandboxConfig) GetAllowedCaps() []string { return s.AllowedCaps } func (s *SandboxConfig) GetUserNS() bool { return s.UserNS } func (s *SandboxConfig) GetRunAsUID() int { return s.RunAsUID } func (s *SandboxConfig) GetRunAsGID() int { return s.RunAsGID } func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile } func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot } func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode } // Process Isolation getter methods func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses } func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles } func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap } func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj } func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID } func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID } // LoadConfig loads worker configuration from a YAML file. func LoadConfig(path string) (*Config, error) { data, err := fileutil.SecureFileRead(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } if strings.TrimSpace(cfg.RedisURL) != "" { cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL)) cfg.RedisAddr = cfg.RedisURL cfg.RedisPassword = "" cfg.RedisDB = 0 } // Get smart defaults for current environment smart := config.GetSmartDefaults() // Use PathRegistry for consistent path management paths := config.FromEnv() if cfg.Port == 0 { cfg.Port = config.DefaultSSHPort } if cfg.Host == "" { host, err := smart.Host() if err != nil { return nil, fmt.Errorf("failed to get default host: %w", err) } cfg.Host = host } if cfg.BasePath == "" { // Prefer PathRegistry over smart defaults for consistency cfg.BasePath = paths.ExperimentsDir() } if cfg.RedisAddr == "" { redisAddr, err := smart.RedisAddr() if err != nil { return nil, fmt.Errorf("failed to get default redis address: %w", err) } cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { knownHosts, err := smart.KnownHostsPath() if err != nil { return nil, fmt.Errorf("failed to get default known hosts path: %w", err) } cfg.KnownHosts = knownHosts } if cfg.WorkerID == "" { cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) } cfg.Resources.ApplyDefaults() if cfg.MaxWorkers > 0 { cfg.Resources.MaxWorkers = cfg.MaxWorkers } else { maxWorkers, err := smart.MaxWorkers() if err != nil { return nil, fmt.Errorf("failed to get default max workers: %w", err) } cfg.MaxWorkers = maxWorkers cfg.Resources.MaxWorkers = maxWorkers } if cfg.PollInterval == 0 { pollInterval, err := smart.PollInterval() if err != nil { return nil, fmt.Errorf("failed to get default poll interval: %w", err) } cfg.PollInterval = pollInterval } if cfg.DataManagerPath == "" { cfg.DataManagerPath = "./data_manager" } if cfg.DataDir == "" { // Use PathRegistry for consistent data directory cfg.DataDir = paths.DataDir() } if cfg.SnapshotStore.Timeout == 0 { cfg.SnapshotStore.Timeout = 10 * time.Minute } if cfg.SnapshotStore.MaxRetries == 0 { cfg.SnapshotStore.MaxRetries = 3 } if cfg.Metrics.ListenAddr == "" { cfg.Metrics.ListenAddr = ":9100" } if cfg.MetricsFlushInterval == 0 { cfg.MetricsFlushInterval = defaultMetricsFlushInterval } if cfg.DatasetCacheTTL == 0 { cfg.DatasetCacheTTL = datasetCacheDefaultTTL } if strings.TrimSpace(cfg.Queue.Backend) == "" { cfg.Queue.Backend = string(queue.QueueBackendRedis) } if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) { if strings.TrimSpace(cfg.Queue.SQLitePath) == "" { cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db") } cfg.Queue.SQLitePath = storage.ExpandPath(cfg.Queue.SQLitePath) } if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem { if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" { cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs") } cfg.Queue.FilesystemPath = storage.ExpandPath(cfg.Queue.FilesystemPath) } if strings.TrimSpace(cfg.GPUVendor) == "" { cfg.GPUVendorAutoDetected = true if cfg.AppleGPU.Enabled { cfg.GPUVendor = string(GPUTypeApple) } else if len(cfg.GPUDevices) > 0 || len(cfg.GPUVisibleDevices) > 0 || len(cfg.GPUVisibleDeviceIDs) > 0 { cfg.GPUVendor = string(GPUTypeNVIDIA) } else { cfg.GPUVendor = string(GPUTypeNone) } } // Set lease and retry defaults if cfg.TaskLeaseDuration == 0 { cfg.TaskLeaseDuration = 30 * time.Minute } if cfg.HeartbeatInterval == 0 { cfg.HeartbeatInterval = 1 * time.Minute } if cfg.MaxRetries == 0 { cfg.MaxRetries = 3 } if cfg.GracefulTimeout == 0 { cfg.GracefulTimeout = 5 * time.Minute } // Apply security defaults to sandbox configuration cfg.Sandbox.ApplySecurityDefaults() // Expand secrets from environment variables if err := cfg.ExpandSecrets(); err != nil { return nil, fmt.Errorf("secrets expansion failed: %w", err) } return &cfg, nil } // Validate implements config.Validator interface. func (c *Config) Validate() error { if c.Port != 0 { if err := config.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = storage.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { // Resolve relative to current working directory, not DefaultBasePath cwd, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get current directory: %w", err) } c.BasePath = filepath.Join(cwd, c.BasePath) } } backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend)) if backend == "" { backend = string(queue.QueueBackendRedis) c.Queue.Backend = backend } if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) { return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS) } if backend == string(queue.QueueBackendSQLite) { if strings.TrimSpace(c.Queue.SQLitePath) == "" { return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite) } c.Queue.SQLitePath = storage.ExpandPath(c.Queue.SQLitePath) if !filepath.IsAbs(c.Queue.SQLitePath) { c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) } } if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem { if strings.TrimSpace(c.Queue.FilesystemPath) == "" { return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled") } c.Queue.FilesystemPath = storage.ExpandPath(c.Queue.FilesystemPath) if !filepath.IsAbs(c.Queue.FilesystemPath) { c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath) } } if c.RedisAddr != "" { addr := strings.TrimSpace(c.RedisAddr) if strings.HasPrefix(addr, "redis://") { u, err := url.Parse(addr) if err != nil { return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err) } if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" { return fmt.Errorf("invalid Redis configuration: invalid redis url") } } else { if err := config.ValidateRedisAddr(addr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } } if c.MaxWorkers < 1 { return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) } switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) { case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd": // ok default: return fmt.Errorf( "gpu_vendor must be one of %q, %q, %q, %q", string(GPUTypeNVIDIA), "amd", string(GPUTypeApple), string(GPUTypeNone), ) } // Strict GPU visibility configuration: // - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive. // - UUID-style gpu_visible_device_ids is NVIDIA-only. vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor)) if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 { if vendor != string(GPUTypeNVIDIA) { return fmt.Errorf( "visible_device_ids is only supported when gpu_vendor is %q", string(GPUTypeNVIDIA), ) } for _, id := range c.GPUVisibleDeviceIDs { id = strings.TrimSpace(id) if id == "" { return fmt.Errorf("visible_device_ids contains an empty value") } if !strings.HasPrefix(id, "GPU-") { return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id) } } } if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) { if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 { return fmt.Errorf( "gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q", vendor, ) } } if vendor == "amd" { if len(c.GPUVisibleDeviceIDs) > 0 { return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor) } for _, idx := range c.GPUVisibleDevices { if idx < 0 { return fmt.Errorf("gpu_visible_devices contains negative index %d", idx) } } } if c.SnapshotStore.Enabled { if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" { return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true") } if strings.TrimSpace(c.SnapshotStore.Bucket) == "" { return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true") } ak := strings.TrimSpace(c.SnapshotStore.AccessKey) sk := strings.TrimSpace(c.SnapshotStore.SecretKey) if (ak == "") != (sk == "") { return fmt.Errorf( "snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty", ) } if c.SnapshotStore.Timeout < 0 { return fmt.Errorf("snapshot_store.timeout must be >= 0") } if c.SnapshotStore.MaxRetries < 0 { return fmt.Errorf("snapshot_store.max_retries must be >= 0") } } // HIPAA mode validation - hard requirements if strings.ToLower(c.ComplianceMode) == "hipaa" { if err := c.validateHIPAARequirements(); err != nil { return fmt.Errorf("HIPAA compliance validation failed: %w", err) } } return nil } // ExpandSecrets replaces secret placeholders with environment variables // Exported for testing purposes func (c *Config) ExpandSecrets() error { // First validate that secrets use env var syntax (not plaintext) if err := c.ValidateNoPlaintextSecrets(); err != nil { return err } // Expand Redis password from env if using ${...} syntax if strings.Contains(c.RedisPassword, "${") { c.RedisPassword = os.ExpandEnv(c.RedisPassword) } // Expand SnapshotStore credentials if strings.Contains(c.SnapshotStore.AccessKey, "${") { c.SnapshotStore.AccessKey = os.ExpandEnv(c.SnapshotStore.AccessKey) } if strings.Contains(c.SnapshotStore.SecretKey, "${") { c.SnapshotStore.SecretKey = os.ExpandEnv(c.SnapshotStore.SecretKey) } if strings.Contains(c.SnapshotStore.SessionToken, "${") { c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken) } return nil } // ValidateNoPlaintextSecrets checks that sensitive fields use env var references // rather than hardcoded plaintext values. This is a HIPAA compliance requirement. // Exported for testing purposes func (c *Config) ValidateNoPlaintextSecrets() error { // Fields that should use ${ENV_VAR} syntax instead of plaintext sensitiveFields := []struct { name string value string }{ {"redis_password", c.RedisPassword}, {"snapshot_store.access_key", c.SnapshotStore.AccessKey}, {"snapshot_store.secret_key", c.SnapshotStore.SecretKey}, {"snapshot_store.session_token", c.SnapshotStore.SessionToken}, } for _, field := range sensitiveFields { if field.value == "" { continue // Empty values are fine } // Check if it looks like a plaintext secret (not env var reference) if !strings.HasPrefix(field.value, "${") && LooksLikeSecret(field.value) { return fmt.Errorf( "%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+ "use ${ENV_VAR} syntax to load from environment or secrets manager", field.name, len(field.value), CalculateEntropy(field.value), ) } } return nil } // validateHIPAARequirements enforces hard HIPAA compliance requirements at startup. // These must fail loudly rather than silently fall back to insecure defaults. func (c *Config) validateHIPAARequirements() error { // 1. SnapshotStore must be secure if c.SnapshotStore.Enabled && !c.SnapshotStore.Secure { return fmt.Errorf("snapshot_store.secure must be true in HIPAA mode") } // 2. NetworkMode must be "none" (no network access) if c.Sandbox.NetworkMode != "none" { return fmt.Errorf("sandbox.network_mode must be 'none' in HIPAA mode, got %q", c.Sandbox.NetworkMode) } // 3. SeccompProfile must be non-empty if c.Sandbox.SeccompProfile == "" { return fmt.Errorf("sandbox.seccomp_profile must be non-empty in HIPAA mode") } // 4. NoNewPrivileges must be true if !c.Sandbox.NoNewPrivileges { return fmt.Errorf("sandbox.no_new_privileges must be true in HIPAA mode") } // 5. All credentials must be sourced from env vars, not inline YAML if err := c.validateNoInlineCredentials(); err != nil { return err } // 6. AllowedSecrets must not contain PHI field names if err := c.Sandbox.validatePHIDenylist(); err != nil { return err } return nil } // validateNoInlineCredentials checks that no credentials are hardcoded in config func (c *Config) validateNoInlineCredentials() error { // Check Redis password - must be empty or use env var syntax if c.RedisPassword != "" && !strings.HasPrefix(c.RedisPassword, "${") { return fmt.Errorf("redis_password must use ${ENV_VAR} syntax in HIPAA mode, not inline value") } // Check SSH key - must use env var syntax if c.SSHKey != "" && !strings.HasPrefix(c.SSHKey, "${") { return fmt.Errorf("ssh_key must use ${ENV_VAR} syntax in HIPAA mode, not inline value") } // Check SnapshotStore credentials if c.SnapshotStore.AccessKey != "" && !strings.HasPrefix(c.SnapshotStore.AccessKey, "${") { return fmt.Errorf("snapshot_store.access_key must use ${ENV_VAR} syntax in HIPAA mode") } if c.SnapshotStore.SecretKey != "" && !strings.HasPrefix(c.SnapshotStore.SecretKey, "${") { return fmt.Errorf("snapshot_store.secret_key must use ${ENV_VAR} syntax in HIPAA mode") } return nil } // PHI field patterns that should not appear in AllowedSecrets var phiDenylistPatterns = []string{ "patient", "phi", "ssn", "social_security", "mrn", "medical_record", "dob", "birth_date", "diagnosis", "condition", "medication", "allergy", } // validatePHIDenylist checks that AllowedSecrets doesn't contain PHI field names func (s *SandboxConfig) validatePHIDenylist() error { for _, secret := range s.AllowedSecrets { secretLower := strings.ToLower(secret) for _, pattern := range phiDenylistPatterns { if strings.Contains(secretLower, pattern) { return fmt.Errorf("allowed_secrets contains potential PHI field %q (matches pattern %q); this could allow PHI exfiltration", secret, pattern) } } } return nil } // LooksLikeSecret heuristically detects if a string looks like a secret credential // Exported for testing purposes func LooksLikeSecret(s string) bool { // Minimum length for secrets if len(s) < 16 { return false } // Calculate entropy to detect high-entropy strings (likely secrets) entropy := CalculateEntropy(s) // High entropy (>4 bits per char) combined with reasonable length suggests a secret if entropy > 4.0 { return true } // Check for common secret patterns patterns := []string{ "AKIA", // AWS Access Key ID prefix "ASIA", // AWS temporary credentials "ghp_", // GitHub personal access token "gho_", // GitHub OAuth token "glpat-", // GitLab PAT "sk-", // OpenAI/Stripe key prefix "sk_live_", // Stripe live key "sk_test_", // Stripe test key } for _, pattern := range patterns { if strings.Contains(s, pattern) { return true } } return false } // CalculateEntropy calculates Shannon entropy of a string in bits per character // Exported for testing purposes func CalculateEntropy(s string) float64 { if len(s) == 0 { return 0 } // Count character frequencies freq := make(map[rune]int) for _, r := range s { freq[r]++ } // Calculate entropy var entropy float64 length := float64(len(s)) for _, count := range freq { p := float64(count) / length if p > 0 { entropy -= p * math.Log2(p) } } return entropy } // ComputeResolvedConfigHash computes a SHA-256 hash of the resolved config. // This must be called after os.ExpandEnv, after default application, and after Validate(). // The hash captures the actual runtime configuration, not the raw YAML file. // This is critical for reproducibility - two different raw files that resolve // to the same config will produce the same hash. func (c *Config) ComputeResolvedConfigHash() (string, error) { // Marshal config to JSON for consistent serialization // We use a simplified struct to avoid hashing volatile fields hashable := struct { Host string `json:"host"` Port int `json:"port"` BasePath string `json:"base_path"` MaxWorkers int `json:"max_workers"` Resources config.ResourceConfig `json:"resources"` GPUVendor string `json:"gpu_vendor"` GPUVisibleDevices []int `json:"gpu_visible_devices,omitempty"` GPUVisibleDeviceIDs []string `json:"gpu_visible_device_ids,omitempty"` Sandbox SandboxConfig `json:"sandbox"` ComplianceMode string `json:"compliance_mode"` ProvenanceBestEffort bool `json:"provenance_best_effort"` SnapshotStoreSecure bool `json:"snapshot_store_secure,omitempty"` QueueBackend string `json:"queue_backend"` }{ Host: c.Host, Port: c.Port, BasePath: c.BasePath, MaxWorkers: c.MaxWorkers, Resources: c.Resources, GPUVendor: c.GPUVendor, GPUVisibleDevices: c.GPUVisibleDevices, GPUVisibleDeviceIDs: c.GPUVisibleDeviceIDs, Sandbox: c.Sandbox, ComplianceMode: c.ComplianceMode, ProvenanceBestEffort: c.ProvenanceBestEffort, SnapshotStoreSecure: c.SnapshotStore.Secure, QueueBackend: c.Queue.Backend, } data, err := json.Marshal(hashable) if err != nil { return "", fmt.Errorf("failed to marshal config for hashing: %w", err) } // Compute SHA-256 hash hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]), nil } // envInt reads an integer from environment variable func envInt(name string) (int, bool) { v := strings.TrimSpace(os.Getenv(name)) if v == "" { return 0, false } n, err := strconv.Atoi(v) if err != nil { return 0, false } return n, true } // logEnvOverride logs environment variable overrides to stderr for debugging func logEnvOverride(name string, value interface{}) { slog.Warn("env override active", "var", name, "value", value) } // parseCPUFromConfig determines total CPU from environment or config func parseCPUFromConfig(cfg *Config) int { if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { logEnvOverride("FETCH_ML_TOTAL_CPU", n) return n } if cfg != nil { if cfg.Resources.PodmanCPUs != "" { if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { if f < 0 { return 0 } return int(math.Floor(f)) } } } return runtime.NumCPU() } // parseGPUCountFromConfig detects GPU count from config and returns detection metadata func parseGPUCountFromConfig(cfg *Config) (int, GPUDetectionInfo) { factory := &GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(cfg) return result.Detector.DetectGPUCount(), result.Info } // parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment func parseGPUSlotsPerGPUFromConfig() int { if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { return n } return 1 }