fetch_ml/internal/worker/config.go
Jeremie Fraeys 3fb6902fa1
feat(worker): integrate scheduler endpoints and security hardening
Update worker system for scheduler integration:
- Worker server with scheduler registration
- Configuration with scheduler endpoint support
- Artifact handling with integrity verification
- Container executor with supply chain validation
- Local executor enhancements
- GPU detection improvements (cross-platform)
- Error handling with execution context
- Factory pattern for executor instantiation
- Hash integrity with native library support
2026-02-26 12:06:16 -05:00

979 lines
32 KiB
Go

package worker
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"math"
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
"gopkg.in/yaml.v3"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
datasetCacheDefaultTTL = 30 * time.Minute
)
type QueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
FilesystemPath string `yaml:"filesystem_path"`
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
}
// Config holds worker configuration.
type Config struct {
Host string `yaml:"host"`
User string `yaml:"user"`
SSHKey string `yaml:"ssh_key"`
Port int `yaml:"port"`
BasePath string `yaml:"base_path"`
TrainScript string `yaml:"train_script"`
RedisURL string `yaml:"redis_url"`
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
Queue QueueConfig `yaml:"queue"`
KnownHosts string `yaml:"known_hosts"`
WorkerID string `yaml:"worker_id"`
MaxWorkers int `yaml:"max_workers"`
PollInterval int `yaml:"poll_interval_seconds"`
Resources config.ResourceConfig `yaml:"resources"`
LocalMode bool `yaml:"local_mode"`
// Authentication
Auth auth.Config `yaml:"auth"`
// Metrics exporter
Metrics MetricsConfig `yaml:"metrics"`
// Metrics buffering
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
// Data management
DataManagerPath string `yaml:"data_manager_path"`
AutoFetchData bool `yaml:"auto_fetch_data"`
DataDir string `yaml:"data_dir"`
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"`
// Provenance enforcement
// Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort.
ProvenanceBestEffort bool `yaml:"provenance_best_effort"`
// Compliance mode: "hipaa", "standard", or empty
// When "hipaa": enforces hard requirements at startup
ComplianceMode string `yaml:"compliance_mode"`
// Opt-in prewarming of next task artifacts (snapshot/datasets/env).
PrewarmEnabled bool `yaml:"prewarm_enabled"`
// Podman execution
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUDevices []string `yaml:"gpu_devices"`
GPUVendor string `yaml:"gpu_vendor"`
GPUVendorAutoDetected bool `yaml:"-"` // Set by LoadConfig when GPUVendor is auto-detected
GPUVisibleDevices []int `yaml:"gpu_visible_devices"`
GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"`
// Apple M-series GPU configuration
AppleGPU AppleGPUConfig `yaml:"apple_gpu"`
// Task lease and retry settings
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min)
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min)
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
// Mode determines how the worker operates: "standalone" or "distributed"
Mode string `yaml:"mode"`
// Scheduler configuration for distributed mode
Scheduler struct {
Address string `yaml:"address"`
Cert string `yaml:"cert"`
Token string `yaml:"token"`
} `yaml:"scheduler"`
// Plugins configuration
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
// Sandboxing configuration
Sandbox SandboxConfig `yaml:"sandbox"`
}
// MetricsConfig controls the Prometheus exporter.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_addr"`
}
type SnapshotStoreConfig struct {
Enabled bool `yaml:"enabled"`
Endpoint string `yaml:"endpoint"`
Secure bool `yaml:"secure"`
Region string `yaml:"region"`
Bucket string `yaml:"bucket"`
Prefix string `yaml:"prefix"`
AccessKey string `yaml:"access_key"`
SecretKey string `yaml:"secret_key"`
SessionToken string `yaml:"session_token"`
Timeout time.Duration `yaml:"timeout"`
MaxRetries int `yaml:"max_retries"`
}
// AppleGPUConfig holds configuration for Apple M-series GPU support
type AppleGPUConfig struct {
Enabled bool `yaml:"enabled"`
MetalDevice string `yaml:"metal_device"`
MPSRuntime string `yaml:"mps_runtime"`
}
// SandboxConfig holds container sandbox settings
type SandboxConfig struct {
NetworkMode string `yaml:"network_mode"` // Default: "none"
ReadOnlyRoot bool `yaml:"read_only_root"` // Default: true
AllowSecrets bool `yaml:"allow_secrets"` // Default: false
AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"]
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
MaxRuntimeHours int `yaml:"max_runtime_hours"`
// Security hardening options
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
UserNS bool `yaml:"user_ns"` // Default: true
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
// Process isolation
MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100)
MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024)
DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration
OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100)
TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID)
TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID)
// Upload limits
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
// Artifact ingestion caps
MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000
MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB
}
// SecurityDefaults holds default values for security configuration
var SecurityDefaults = struct {
NetworkMode string
ReadOnlyRoot bool
AllowSecrets bool
SeccompProfile string
NoNewPrivileges bool
DropAllCaps bool
UserNS bool
RunAsUID int
RunAsGID int
MaxProcesses int
MaxOpenFiles int
DisableSwap bool
OOMScoreAdj int
MaxUploadSizeBytes int64
MaxUploadRateBps int64
MaxUploadsPerMinute int
MaxArtifactFiles int
MaxArtifactTotalBytes int64
}{
NetworkMode: "none",
ReadOnlyRoot: true,
AllowSecrets: false,
SeccompProfile: "default-hardened",
NoNewPrivileges: true,
DropAllCaps: true,
UserNS: true,
RunAsUID: 1000,
RunAsGID: 1000,
MaxProcesses: 100, // Fork bomb protection
MaxOpenFiles: 1024, // FD exhaustion protection
DisableSwap: true, // Prevent swap exfiltration
OOMScoreAdj: 100, // Lower OOM priority
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
MaxUploadsPerMinute: 10,
MaxArtifactFiles: 10000,
MaxArtifactTotalBytes: 100 * 1024 * 1024 * 1024, // 100GB
}
// Validate checks sandbox configuration
func (s *SandboxConfig) Validate() error {
validNetworks := map[string]bool{"none": true, "slirp4netns": true, "bridge": true, "": true}
if !validNetworks[s.NetworkMode] {
return fmt.Errorf("invalid network_mode: %s", s.NetworkMode)
}
if s.MaxRuntimeHours < 0 {
return fmt.Errorf("max_runtime_hours must be positive")
}
if s.MaxUploadSizeBytes < 0 {
return fmt.Errorf("max_upload_size_bytes must be positive")
}
if s.MaxUploadRateBps < 0 {
return fmt.Errorf("max_upload_rate_bps must be positive")
}
if s.MaxUploadsPerMinute < 0 {
return fmt.Errorf("max_uploads_per_minute must be positive")
}
if s.MaxArtifactFiles < 0 {
return fmt.Errorf("max_artifact_files must be positive")
}
if s.MaxArtifactTotalBytes < 0 {
return fmt.Errorf("max_artifact_total_bytes must be positive")
}
return nil
}
// ApplySecurityDefaults applies secure default values to empty fields.
// This implements the "secure by default" principle for HIPAA compliance.
func (s *SandboxConfig) ApplySecurityDefaults() {
// Network isolation: default to "none" (no network access)
if s.NetworkMode == "" {
s.NetworkMode = SecurityDefaults.NetworkMode
}
// Read-only root filesystem
if !s.ReadOnlyRoot {
s.ReadOnlyRoot = SecurityDefaults.ReadOnlyRoot
}
// Secrets disabled by default
if !s.AllowSecrets {
s.AllowSecrets = SecurityDefaults.AllowSecrets
}
// Seccomp profile
if s.SeccompProfile == "" {
s.SeccompProfile = SecurityDefaults.SeccompProfile
}
// No new privileges
if !s.NoNewPrivileges {
s.NoNewPrivileges = SecurityDefaults.NoNewPrivileges
}
// Drop all capabilities
if !s.DropAllCaps {
s.DropAllCaps = SecurityDefaults.DropAllCaps
}
// User namespace
if !s.UserNS {
s.UserNS = SecurityDefaults.UserNS
}
// Default non-root UID/GID
if s.RunAsUID == 0 {
s.RunAsUID = SecurityDefaults.RunAsUID
}
if s.RunAsGID == 0 {
s.RunAsGID = SecurityDefaults.RunAsGID
}
// Upload limits
if s.MaxUploadSizeBytes == 0 {
s.MaxUploadSizeBytes = SecurityDefaults.MaxUploadSizeBytes
}
if s.MaxUploadRateBps == 0 {
s.MaxUploadRateBps = SecurityDefaults.MaxUploadRateBps
}
if s.MaxUploadsPerMinute == 0 {
s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute
}
// Artifact ingestion caps
if s.MaxArtifactFiles == 0 {
s.MaxArtifactFiles = SecurityDefaults.MaxArtifactFiles
}
if s.MaxArtifactTotalBytes == 0 {
s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes
}
// Process isolation defaults
if s.MaxProcesses == 0 {
s.MaxProcesses = SecurityDefaults.MaxProcesses
}
if s.MaxOpenFiles == 0 {
s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles
}
if !s.DisableSwap {
s.DisableSwap = SecurityDefaults.DisableSwap
}
if s.OOMScoreAdj == 0 {
s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj
}
// TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID")
// Only override if explicitly set (> 0)
if s.TaskUID < 0 {
s.TaskUID = 0
}
if s.TaskGID < 0 {
s.TaskGID = 0
}
}
// GetProcessIsolationFlags returns the effective UID/GID for a task
// If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID
func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) {
uid = s.RunAsUID
gid = s.RunAsGID
if s.TaskUID > 0 {
uid = s.TaskUID
}
if s.TaskGID > 0 {
gid = s.TaskGID
}
return uid, gid
}
// Getter methods for SandboxConfig interface
func (s *SandboxConfig) GetNoNewPrivileges() bool { return s.NoNewPrivileges }
func (s *SandboxConfig) GetDropAllCaps() bool { return s.DropAllCaps }
func (s *SandboxConfig) GetAllowedCaps() []string { return s.AllowedCaps }
func (s *SandboxConfig) GetUserNS() bool { return s.UserNS }
func (s *SandboxConfig) GetRunAsUID() int { return s.RunAsUID }
func (s *SandboxConfig) GetRunAsGID() int { return s.RunAsGID }
func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
// Process Isolation getter methods
func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses }
func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles }
func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap }
func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj }
func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID }
func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID }
// LoadConfig loads worker configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
if strings.TrimSpace(cfg.RedisURL) != "" {
cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL))
cfg.RedisAddr = cfg.RedisURL
cfg.RedisPassword = ""
cfg.RedisDB = 0
}
// Get smart defaults for current environment
smart := config.GetSmartDefaults()
// Use PathRegistry for consistent path management
paths := config.FromEnv()
if cfg.Port == 0 {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
host, err := smart.Host()
if err != nil {
return nil, fmt.Errorf("failed to get default host: %w", err)
}
cfg.Host = host
}
if cfg.BasePath == "" {
// Prefer PathRegistry over smart defaults for consistency
cfg.BasePath = paths.ExperimentsDir()
}
if cfg.RedisAddr == "" {
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
cfg.RedisAddr = redisAddr
}
if cfg.KnownHosts == "" {
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
cfg.KnownHosts = knownHosts
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
}
cfg.Resources.ApplyDefaults()
if cfg.MaxWorkers > 0 {
cfg.Resources.MaxWorkers = cfg.MaxWorkers
} else {
maxWorkers, err := smart.MaxWorkers()
if err != nil {
return nil, fmt.Errorf("failed to get default max workers: %w", err)
}
cfg.MaxWorkers = maxWorkers
cfg.Resources.MaxWorkers = maxWorkers
}
if cfg.PollInterval == 0 {
pollInterval, err := smart.PollInterval()
if err != nil {
return nil, fmt.Errorf("failed to get default poll interval: %w", err)
}
cfg.PollInterval = pollInterval
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
}
if cfg.DataDir == "" {
// Use PathRegistry for consistent data directory
cfg.DataDir = paths.DataDir()
}
if cfg.SnapshotStore.Timeout == 0 {
cfg.SnapshotStore.Timeout = 10 * time.Minute
}
if cfg.SnapshotStore.MaxRetries == 0 {
cfg.SnapshotStore.MaxRetries = 3
}
if cfg.Metrics.ListenAddr == "" {
cfg.Metrics.ListenAddr = ":9100"
}
if cfg.MetricsFlushInterval == 0 {
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
if strings.TrimSpace(cfg.Queue.Backend) == "" {
cfg.Queue.Backend = string(queue.QueueBackendRedis)
}
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) {
if strings.TrimSpace(cfg.Queue.SQLitePath) == "" {
cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db")
}
cfg.Queue.SQLitePath = storage.ExpandPath(cfg.Queue.SQLitePath)
}
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem {
if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" {
cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs")
}
cfg.Queue.FilesystemPath = storage.ExpandPath(cfg.Queue.FilesystemPath)
}
if strings.TrimSpace(cfg.GPUVendor) == "" {
cfg.GPUVendorAutoDetected = true
if cfg.AppleGPU.Enabled {
cfg.GPUVendor = string(GPUTypeApple)
} else if len(cfg.GPUDevices) > 0 ||
len(cfg.GPUVisibleDevices) > 0 ||
len(cfg.GPUVisibleDeviceIDs) > 0 {
cfg.GPUVendor = string(GPUTypeNVIDIA)
} else {
cfg.GPUVendor = string(GPUTypeNone)
}
}
// Set lease and retry defaults
if cfg.TaskLeaseDuration == 0 {
cfg.TaskLeaseDuration = 30 * time.Minute
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 1 * time.Minute
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.GracefulTimeout == 0 {
cfg.GracefulTimeout = 5 * time.Minute
}
// Apply security defaults to sandbox configuration
cfg.Sandbox.ApplySecurityDefaults()
// Expand secrets from environment variables
if err := cfg.ExpandSecrets(); err != nil {
return nil, fmt.Errorf("secrets expansion failed: %w", err)
}
return &cfg, nil
}
// Validate implements config.Validator interface.
func (c *Config) Validate() error {
if c.Port != 0 {
if err := config.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = storage.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
// Resolve relative to current working directory, not DefaultBasePath
cwd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current directory: %w", err)
}
c.BasePath = filepath.Join(cwd, c.BasePath)
}
}
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
if backend == "" {
backend = string(queue.QueueBackendRedis)
c.Queue.Backend = backend
}
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) {
return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS)
}
if backend == string(queue.QueueBackendSQLite) {
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite)
}
c.Queue.SQLitePath = storage.ExpandPath(c.Queue.SQLitePath)
if !filepath.IsAbs(c.Queue.SQLitePath) {
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
}
}
if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem {
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled")
}
c.Queue.FilesystemPath = storage.ExpandPath(c.Queue.FilesystemPath)
if !filepath.IsAbs(c.Queue.FilesystemPath) {
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
}
}
if c.RedisAddr != "" {
addr := strings.TrimSpace(c.RedisAddr)
if strings.HasPrefix(addr, "redis://") {
u, err := url.Parse(addr)
if err != nil {
return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err)
}
if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("invalid Redis configuration: invalid redis url")
}
} else {
if err := config.ValidateRedisAddr(addr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
}
if c.MaxWorkers < 1 {
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
}
switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) {
case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd":
// ok
default:
return fmt.Errorf(
"gpu_vendor must be one of %q, %q, %q, %q",
string(GPUTypeNVIDIA),
"amd",
string(GPUTypeApple),
string(GPUTypeNone),
)
}
// Strict GPU visibility configuration:
// - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive.
// - UUID-style gpu_visible_device_ids is NVIDIA-only.
vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor))
if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 {
if vendor != string(GPUTypeNVIDIA) {
return fmt.Errorf(
"visible_device_ids is only supported when gpu_vendor is %q",
string(GPUTypeNVIDIA),
)
}
for _, id := range c.GPUVisibleDeviceIDs {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("visible_device_ids contains an empty value")
}
if !strings.HasPrefix(id, "GPU-") {
return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id)
}
}
}
if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) {
if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 {
return fmt.Errorf(
"gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q",
vendor,
)
}
}
if vendor == "amd" {
if len(c.GPUVisibleDeviceIDs) > 0 {
return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor)
}
for _, idx := range c.GPUVisibleDevices {
if idx < 0 {
return fmt.Errorf("gpu_visible_devices contains negative index %d", idx)
}
}
}
if c.SnapshotStore.Enabled {
if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" {
return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true")
}
if strings.TrimSpace(c.SnapshotStore.Bucket) == "" {
return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true")
}
ak := strings.TrimSpace(c.SnapshotStore.AccessKey)
sk := strings.TrimSpace(c.SnapshotStore.SecretKey)
if (ak == "") != (sk == "") {
return fmt.Errorf(
"snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty",
)
}
if c.SnapshotStore.Timeout < 0 {
return fmt.Errorf("snapshot_store.timeout must be >= 0")
}
if c.SnapshotStore.MaxRetries < 0 {
return fmt.Errorf("snapshot_store.max_retries must be >= 0")
}
}
// HIPAA mode validation - hard requirements
if strings.ToLower(c.ComplianceMode) == "hipaa" {
if err := c.validateHIPAARequirements(); err != nil {
return fmt.Errorf("HIPAA compliance validation failed: %w", err)
}
}
return nil
}
// ExpandSecrets replaces secret placeholders with environment variables
// Exported for testing purposes
func (c *Config) ExpandSecrets() error {
// First validate that secrets use env var syntax (not plaintext)
if err := c.ValidateNoPlaintextSecrets(); err != nil {
return err
}
// Expand Redis password from env if using ${...} syntax
if strings.Contains(c.RedisPassword, "${") {
c.RedisPassword = os.ExpandEnv(c.RedisPassword)
}
// Expand SnapshotStore credentials
if strings.Contains(c.SnapshotStore.AccessKey, "${") {
c.SnapshotStore.AccessKey = os.ExpandEnv(c.SnapshotStore.AccessKey)
}
if strings.Contains(c.SnapshotStore.SecretKey, "${") {
c.SnapshotStore.SecretKey = os.ExpandEnv(c.SnapshotStore.SecretKey)
}
if strings.Contains(c.SnapshotStore.SessionToken, "${") {
c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken)
}
return nil
}
// ValidateNoPlaintextSecrets checks that sensitive fields use env var references
// rather than hardcoded plaintext values. This is a HIPAA compliance requirement.
// Exported for testing purposes
func (c *Config) ValidateNoPlaintextSecrets() error {
// Fields that should use ${ENV_VAR} syntax instead of plaintext
sensitiveFields := []struct {
name string
value string
}{
{"redis_password", c.RedisPassword},
{"snapshot_store.access_key", c.SnapshotStore.AccessKey},
{"snapshot_store.secret_key", c.SnapshotStore.SecretKey},
{"snapshot_store.session_token", c.SnapshotStore.SessionToken},
}
for _, field := range sensitiveFields {
if field.value == "" {
continue // Empty values are fine
}
// Check if it looks like a plaintext secret (not env var reference)
if !strings.HasPrefix(field.value, "${") && LooksLikeSecret(field.value) {
return fmt.Errorf(
"%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+
"use ${ENV_VAR} syntax to load from environment or secrets manager",
field.name, len(field.value), CalculateEntropy(field.value),
)
}
}
return nil
}
// validateHIPAARequirements enforces hard HIPAA compliance requirements at startup.
// These must fail loudly rather than silently fall back to insecure defaults.
func (c *Config) validateHIPAARequirements() error {
// 1. SnapshotStore must be secure
if c.SnapshotStore.Enabled && !c.SnapshotStore.Secure {
return fmt.Errorf("snapshot_store.secure must be true in HIPAA mode")
}
// 2. NetworkMode must be "none" (no network access)
if c.Sandbox.NetworkMode != "none" {
return fmt.Errorf("sandbox.network_mode must be 'none' in HIPAA mode, got %q", c.Sandbox.NetworkMode)
}
// 3. SeccompProfile must be non-empty
if c.Sandbox.SeccompProfile == "" {
return fmt.Errorf("sandbox.seccomp_profile must be non-empty in HIPAA mode")
}
// 4. NoNewPrivileges must be true
if !c.Sandbox.NoNewPrivileges {
return fmt.Errorf("sandbox.no_new_privileges must be true in HIPAA mode")
}
// 5. All credentials must be sourced from env vars, not inline YAML
if err := c.validateNoInlineCredentials(); err != nil {
return err
}
// 6. AllowedSecrets must not contain PHI field names
if err := c.Sandbox.validatePHIDenylist(); err != nil {
return err
}
return nil
}
// validateNoInlineCredentials checks that no credentials are hardcoded in config
func (c *Config) validateNoInlineCredentials() error {
// Check Redis password - must be empty or use env var syntax
if c.RedisPassword != "" && !strings.HasPrefix(c.RedisPassword, "${") {
return fmt.Errorf("redis_password must use ${ENV_VAR} syntax in HIPAA mode, not inline value")
}
// Check SSH key - must use env var syntax
if c.SSHKey != "" && !strings.HasPrefix(c.SSHKey, "${") {
return fmt.Errorf("ssh_key must use ${ENV_VAR} syntax in HIPAA mode, not inline value")
}
// Check SnapshotStore credentials
if c.SnapshotStore.AccessKey != "" && !strings.HasPrefix(c.SnapshotStore.AccessKey, "${") {
return fmt.Errorf("snapshot_store.access_key must use ${ENV_VAR} syntax in HIPAA mode")
}
if c.SnapshotStore.SecretKey != "" && !strings.HasPrefix(c.SnapshotStore.SecretKey, "${") {
return fmt.Errorf("snapshot_store.secret_key must use ${ENV_VAR} syntax in HIPAA mode")
}
return nil
}
// PHI field patterns that should not appear in AllowedSecrets
var phiDenylistPatterns = []string{
"patient", "phi", "ssn", "social_security", "mrn", "medical_record",
"dob", "birth_date", "diagnosis", "condition", "medication", "allergy",
}
// validatePHIDenylist checks that AllowedSecrets doesn't contain PHI field names
func (s *SandboxConfig) validatePHIDenylist() error {
for _, secret := range s.AllowedSecrets {
secretLower := strings.ToLower(secret)
for _, pattern := range phiDenylistPatterns {
if strings.Contains(secretLower, pattern) {
return fmt.Errorf("allowed_secrets contains potential PHI field %q (matches pattern %q); this could allow PHI exfiltration", secret, pattern)
}
}
}
return nil
}
// LooksLikeSecret heuristically detects if a string looks like a secret credential
// Exported for testing purposes
func LooksLikeSecret(s string) bool {
// Minimum length for secrets
if len(s) < 16 {
return false
}
// Calculate entropy to detect high-entropy strings (likely secrets)
entropy := CalculateEntropy(s)
// High entropy (>4 bits per char) combined with reasonable length suggests a secret
if entropy > 4.0 {
return true
}
// Check for common secret patterns
patterns := []string{
"AKIA", // AWS Access Key ID prefix
"ASIA", // AWS temporary credentials
"ghp_", // GitHub personal access token
"gho_", // GitHub OAuth token
"glpat-", // GitLab PAT
"sk-", // OpenAI/Stripe key prefix
"sk_live_", // Stripe live key
"sk_test_", // Stripe test key
}
for _, pattern := range patterns {
if strings.Contains(s, pattern) {
return true
}
}
return false
}
// CalculateEntropy calculates Shannon entropy of a string in bits per character
// Exported for testing purposes
func CalculateEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
// Count character frequencies
freq := make(map[rune]int)
for _, r := range s {
freq[r]++
}
// Calculate entropy
var entropy float64
length := float64(len(s))
for _, count := range freq {
p := float64(count) / length
if p > 0 {
entropy -= p * math.Log2(p)
}
}
return entropy
}
// ComputeResolvedConfigHash computes a SHA-256 hash of the resolved config.
// This must be called after os.ExpandEnv, after default application, and after Validate().
// The hash captures the actual runtime configuration, not the raw YAML file.
// This is critical for reproducibility - two different raw files that resolve
// to the same config will produce the same hash.
func (c *Config) ComputeResolvedConfigHash() (string, error) {
// Marshal config to JSON for consistent serialization
// We use a simplified struct to avoid hashing volatile fields
hashable := struct {
Host string `json:"host"`
Port int `json:"port"`
BasePath string `json:"base_path"`
MaxWorkers int `json:"max_workers"`
Resources config.ResourceConfig `json:"resources"`
GPUVendor string `json:"gpu_vendor"`
GPUVisibleDevices []int `json:"gpu_visible_devices,omitempty"`
GPUVisibleDeviceIDs []string `json:"gpu_visible_device_ids,omitempty"`
Sandbox SandboxConfig `json:"sandbox"`
ComplianceMode string `json:"compliance_mode"`
ProvenanceBestEffort bool `json:"provenance_best_effort"`
SnapshotStoreSecure bool `json:"snapshot_store_secure,omitempty"`
QueueBackend string `json:"queue_backend"`
}{
Host: c.Host,
Port: c.Port,
BasePath: c.BasePath,
MaxWorkers: c.MaxWorkers,
Resources: c.Resources,
GPUVendor: c.GPUVendor,
GPUVisibleDevices: c.GPUVisibleDevices,
GPUVisibleDeviceIDs: c.GPUVisibleDeviceIDs,
Sandbox: c.Sandbox,
ComplianceMode: c.ComplianceMode,
ProvenanceBestEffort: c.ProvenanceBestEffort,
SnapshotStoreSecure: c.SnapshotStore.Secure,
QueueBackend: c.Queue.Backend,
}
data, err := json.Marshal(hashable)
if err != nil {
return "", fmt.Errorf("failed to marshal config for hashing: %w", err)
}
// Compute SHA-256 hash
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// envInt reads an integer from environment variable
func envInt(name string) (int, bool) {
v := strings.TrimSpace(os.Getenv(name))
if v == "" {
return 0, false
}
n, err := strconv.Atoi(v)
if err != nil {
return 0, false
}
return n, true
}
// logEnvOverride logs environment variable overrides to stderr for debugging
func logEnvOverride(name string, value interface{}) {
slog.Warn("env override active", "var", name, "value", value)
}
// parseCPUFromConfig determines total CPU from environment or config
func parseCPUFromConfig(cfg *Config) int {
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
logEnvOverride("FETCH_ML_TOTAL_CPU", n)
return n
}
if cfg != nil {
if cfg.Resources.PodmanCPUs != "" {
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
if f < 0 {
return 0
}
return int(math.Floor(f))
}
}
}
return runtime.NumCPU()
}
// parseGPUCountFromConfig detects GPU count from config and returns detection metadata
func parseGPUCountFromConfig(cfg *Config) (int, GPUDetectionInfo) {
factory := &GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(cfg)
return result.Detector.DetectGPUCount(), result.Info
}
// parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment
func parseGPUSlotsPerGPUFromConfig() int {
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
return n
}
return 1
}