- Refactor worker configuration management - Improve container executor lifecycle handling - Update runloop and worker core logic - Enhance scheduler service template generation - Remove obsolete 'scheduler' symlink/directory
1024 lines
33 KiB
Go
1024 lines
33 KiB
Go
package worker
|
||
|
||
import (
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log/slog"
|
||
"math"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||
"github.com/jfraeys/fetch_ml/internal/config"
|
||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
const (
|
||
defaultMetricsFlushInterval = 500 * time.Millisecond
|
||
datasetCacheDefaultTTL = 30 * time.Minute
|
||
)
|
||
|
||
type QueueConfig struct {
|
||
Backend string `yaml:"backend"`
|
||
SQLitePath string `yaml:"sqlite_path"`
|
||
FilesystemPath string `yaml:"filesystem_path"`
|
||
FallbackToFilesystem bool `yaml:"fallback_to_filesystem"`
|
||
}
|
||
|
||
// Config holds worker configuration.
|
||
type Config struct {
|
||
Host string `yaml:"host"`
|
||
User string `yaml:"user"`
|
||
SSHKey string `yaml:"ssh_key"`
|
||
Port int `yaml:"port"`
|
||
BasePath string `yaml:"base_path"`
|
||
Entrypoint string `yaml:"entrypoint"`
|
||
RedisURL string `yaml:"redis_url"`
|
||
RedisAddr string `yaml:"redis_addr"`
|
||
RedisPassword string `yaml:"redis_password"`
|
||
RedisDB int `yaml:"redis_db"`
|
||
Queue QueueConfig `yaml:"queue"`
|
||
KnownHosts string `yaml:"known_hosts"`
|
||
WorkerID string `yaml:"worker_id"`
|
||
MaxWorkers int `yaml:"max_workers"`
|
||
PollInterval int `yaml:"poll_interval_seconds"`
|
||
Resources config.ResourceConfig `yaml:"resources"`
|
||
LocalMode bool `yaml:"local_mode"`
|
||
|
||
// Authentication
|
||
Auth auth.Config `yaml:"auth"`
|
||
|
||
// Metrics exporter
|
||
Metrics MetricsConfig `yaml:"metrics"`
|
||
// Metrics buffering
|
||
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
|
||
|
||
// Data management
|
||
DataManagerPath string `yaml:"data_manager_path"`
|
||
AutoFetchData bool `yaml:"auto_fetch_data"`
|
||
DataDir string `yaml:"data_dir"`
|
||
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
|
||
|
||
SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"`
|
||
|
||
// Provenance enforcement
|
||
// Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort.
|
||
ProvenanceBestEffort bool `yaml:"provenance_best_effort"`
|
||
|
||
// Compliance mode: "hipaa", "standard", or empty
|
||
// When "hipaa": enforces hard requirements at startup
|
||
ComplianceMode string `yaml:"compliance_mode"`
|
||
|
||
// Opt-in prewarming of next task artifacts (snapshot/datasets/env).
|
||
PrewarmEnabled bool `yaml:"prewarm_enabled"`
|
||
|
||
// Podman execution
|
||
PodmanImage string `yaml:"podman_image"`
|
||
ContainerWorkspace string `yaml:"container_workspace"`
|
||
ContainerResults string `yaml:"container_results"`
|
||
GPUDevices []string `yaml:"gpu_devices"`
|
||
GPUVendor string `yaml:"gpu_vendor"`
|
||
GPUVendorAutoDetected bool `yaml:"-"` // Set by LoadConfig when GPUVendor is auto-detected
|
||
GPUVisibleDevices []int `yaml:"gpu_visible_devices"`
|
||
GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"`
|
||
|
||
// Apple M-series GPU configuration
|
||
AppleGPU AppleGPUConfig `yaml:"apple_gpu"`
|
||
|
||
// Task lease and retry settings
|
||
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min)
|
||
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min)
|
||
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
||
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
|
||
|
||
// Mode determines how the worker operates: "standalone" or "distributed"
|
||
Mode string `yaml:"mode"`
|
||
|
||
// Scheduler configuration for distributed mode
|
||
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||
|
||
// Plugins configuration
|
||
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
|
||
|
||
// Sandboxing configuration
|
||
Sandbox SandboxConfig `yaml:"sandbox"`
|
||
}
|
||
|
||
// MetricsConfig controls the Prometheus exporter.
|
||
type MetricsConfig struct {
|
||
Enabled bool `yaml:"enabled"`
|
||
ListenAddr string `yaml:"listen_addr"`
|
||
}
|
||
|
||
type SnapshotStoreConfig struct {
|
||
Enabled bool `yaml:"enabled"`
|
||
Endpoint string `yaml:"endpoint"`
|
||
Secure bool `yaml:"secure"`
|
||
Region string `yaml:"region"`
|
||
Bucket string `yaml:"bucket"`
|
||
Prefix string `yaml:"prefix"`
|
||
AccessKey string `yaml:"access_key"`
|
||
SecretKey string `yaml:"secret_key"`
|
||
SessionToken string `yaml:"session_token"`
|
||
Timeout time.Duration `yaml:"timeout"`
|
||
MaxRetries int `yaml:"max_retries"`
|
||
}
|
||
|
||
// SchedulerConfig holds configurable heartbeat and lease settings for distributed mode.
|
||
type SchedulerConfig struct {
|
||
Address string `yaml:"address"`
|
||
Cert string `yaml:"cert"`
|
||
Token string `yaml:"token"`
|
||
HeartbeatIntervalSecs int `yaml:"heartbeat_interval_secs"` // default: 30
|
||
TaskLeaseDurationSecs int `yaml:"task_lease_duration_secs"` // default: 90 (3x heartbeat)
|
||
}
|
||
|
||
// Validate checks that lease and heartbeat settings are valid.
|
||
// Enforces 2-10x ratio between lease duration and heartbeat interval.
|
||
func (sc *SchedulerConfig) Validate() error {
|
||
// Apply defaults if zero
|
||
if sc.HeartbeatIntervalSecs == 0 {
|
||
sc.HeartbeatIntervalSecs = 30
|
||
}
|
||
if sc.TaskLeaseDurationSecs == 0 {
|
||
sc.TaskLeaseDurationSecs = 90
|
||
}
|
||
|
||
heartbeat := time.Duration(sc.HeartbeatIntervalSecs) * time.Second
|
||
lease := time.Duration(sc.TaskLeaseDurationSecs) * time.Second
|
||
|
||
if lease <= heartbeat {
|
||
return fmt.Errorf(
|
||
"task_lease_duration_secs (%s) must be greater than heartbeat_interval_secs (%s)",
|
||
lease, heartbeat,
|
||
)
|
||
}
|
||
|
||
ratio := lease.Seconds() / heartbeat.Seconds()
|
||
if ratio < 2.0 {
|
||
return fmt.Errorf(
|
||
"task_lease_duration_secs must be at least 2× heartbeat_interval_secs "+
|
||
"(got %.1f×) — too small a margin for transient network issues",
|
||
ratio,
|
||
)
|
||
}
|
||
|
||
if ratio > 10.0 {
|
||
return fmt.Errorf(
|
||
"task_lease_duration_secs is %.1f× heartbeat_interval_secs — "+
|
||
"dead workers won't be detected for %s, consider reducing lease duration",
|
||
ratio, lease,
|
||
)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type AppleGPUConfig struct {
|
||
Enabled bool `yaml:"enabled"`
|
||
MetalDevice string `yaml:"metal_device"`
|
||
MPSRuntime string `yaml:"mps_runtime"`
|
||
}
|
||
|
||
// SandboxConfig holds container sandbox settings
|
||
type SandboxConfig struct {
|
||
NetworkMode string `yaml:"network_mode"` // Default: "none"
|
||
ReadOnlyRoot bool `yaml:"read_only_root"` // Default: true
|
||
AllowSecrets bool `yaml:"allow_secrets"` // Default: false
|
||
AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"]
|
||
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
|
||
MaxRuntimeHours int `yaml:"max_runtime_hours"`
|
||
|
||
// Security hardening options
|
||
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
|
||
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
|
||
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
|
||
UserNS bool `yaml:"user_ns"` // Default: true
|
||
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
|
||
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
|
||
|
||
// Process isolation
|
||
MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100)
|
||
MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024)
|
||
DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration
|
||
OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100)
|
||
TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID)
|
||
TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID)
|
||
|
||
// Upload limits
|
||
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
|
||
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
|
||
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
|
||
|
||
// Artifact ingestion caps
|
||
MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000
|
||
MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB
|
||
}
|
||
|
||
// SecurityDefaults holds default values for security configuration
|
||
var SecurityDefaults = struct {
|
||
NetworkMode string
|
||
ReadOnlyRoot bool
|
||
AllowSecrets bool
|
||
SeccompProfile string
|
||
NoNewPrivileges bool
|
||
DropAllCaps bool
|
||
UserNS bool
|
||
RunAsUID int
|
||
RunAsGID int
|
||
MaxProcesses int
|
||
MaxOpenFiles int
|
||
DisableSwap bool
|
||
OOMScoreAdj int
|
||
MaxUploadSizeBytes int64
|
||
MaxUploadRateBps int64
|
||
MaxUploadsPerMinute int
|
||
MaxArtifactFiles int
|
||
MaxArtifactTotalBytes int64
|
||
}{
|
||
NetworkMode: "none",
|
||
ReadOnlyRoot: true,
|
||
AllowSecrets: false,
|
||
SeccompProfile: "default-hardened",
|
||
NoNewPrivileges: true,
|
||
DropAllCaps: true,
|
||
UserNS: true,
|
||
RunAsUID: 1000,
|
||
RunAsGID: 1000,
|
||
MaxProcesses: 100, // Fork bomb protection
|
||
MaxOpenFiles: 1024, // FD exhaustion protection
|
||
DisableSwap: true, // Prevent swap exfiltration
|
||
OOMScoreAdj: 100, // Lower OOM priority
|
||
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
|
||
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
|
||
MaxUploadsPerMinute: 10,
|
||
MaxArtifactFiles: 10000,
|
||
MaxArtifactTotalBytes: 100 * 1024 * 1024 * 1024, // 100GB
|
||
}
|
||
|
||
// Validate checks sandbox configuration
|
||
func (s *SandboxConfig) Validate() error {
|
||
validNetworks := map[string]bool{"none": true, "slirp4netns": true, "bridge": true, "": true}
|
||
if !validNetworks[s.NetworkMode] {
|
||
return fmt.Errorf("invalid network_mode: %s", s.NetworkMode)
|
||
}
|
||
if s.MaxRuntimeHours < 0 {
|
||
return fmt.Errorf("max_runtime_hours must be positive")
|
||
}
|
||
if s.MaxUploadSizeBytes < 0 {
|
||
return fmt.Errorf("max_upload_size_bytes must be positive")
|
||
}
|
||
if s.MaxUploadRateBps < 0 {
|
||
return fmt.Errorf("max_upload_rate_bps must be positive")
|
||
}
|
||
if s.MaxUploadsPerMinute < 0 {
|
||
return fmt.Errorf("max_uploads_per_minute must be positive")
|
||
}
|
||
if s.MaxArtifactFiles < 0 {
|
||
return fmt.Errorf("max_artifact_files must be positive")
|
||
}
|
||
if s.MaxArtifactTotalBytes < 0 {
|
||
return fmt.Errorf("max_artifact_total_bytes must be positive")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ApplySecurityDefaults applies secure default values to empty fields.
|
||
// This implements the "secure by default" principle for HIPAA compliance.
|
||
func (s *SandboxConfig) ApplySecurityDefaults() {
|
||
// Network isolation: default to "none" (no network access)
|
||
if s.NetworkMode == "" {
|
||
s.NetworkMode = SecurityDefaults.NetworkMode
|
||
}
|
||
|
||
// Read-only root filesystem
|
||
if !s.ReadOnlyRoot {
|
||
s.ReadOnlyRoot = SecurityDefaults.ReadOnlyRoot
|
||
}
|
||
|
||
// Secrets disabled by default
|
||
if !s.AllowSecrets {
|
||
s.AllowSecrets = SecurityDefaults.AllowSecrets
|
||
}
|
||
|
||
// Seccomp profile
|
||
if s.SeccompProfile == "" {
|
||
s.SeccompProfile = SecurityDefaults.SeccompProfile
|
||
}
|
||
|
||
// No new privileges
|
||
if !s.NoNewPrivileges {
|
||
s.NoNewPrivileges = SecurityDefaults.NoNewPrivileges
|
||
}
|
||
|
||
// Drop all capabilities
|
||
if !s.DropAllCaps {
|
||
s.DropAllCaps = SecurityDefaults.DropAllCaps
|
||
}
|
||
|
||
// User namespace
|
||
if !s.UserNS {
|
||
s.UserNS = SecurityDefaults.UserNS
|
||
}
|
||
|
||
// Default non-root UID/GID
|
||
if s.RunAsUID == 0 {
|
||
s.RunAsUID = SecurityDefaults.RunAsUID
|
||
}
|
||
if s.RunAsGID == 0 {
|
||
s.RunAsGID = SecurityDefaults.RunAsGID
|
||
}
|
||
|
||
// Upload limits
|
||
if s.MaxUploadSizeBytes == 0 {
|
||
s.MaxUploadSizeBytes = SecurityDefaults.MaxUploadSizeBytes
|
||
}
|
||
if s.MaxUploadRateBps == 0 {
|
||
s.MaxUploadRateBps = SecurityDefaults.MaxUploadRateBps
|
||
}
|
||
if s.MaxUploadsPerMinute == 0 {
|
||
s.MaxUploadsPerMinute = SecurityDefaults.MaxUploadsPerMinute
|
||
}
|
||
|
||
// Artifact ingestion caps
|
||
if s.MaxArtifactFiles == 0 {
|
||
s.MaxArtifactFiles = SecurityDefaults.MaxArtifactFiles
|
||
}
|
||
if s.MaxArtifactTotalBytes == 0 {
|
||
s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes
|
||
}
|
||
|
||
// Process isolation defaults
|
||
if s.MaxProcesses == 0 {
|
||
s.MaxProcesses = SecurityDefaults.MaxProcesses
|
||
}
|
||
if s.MaxOpenFiles == 0 {
|
||
s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles
|
||
}
|
||
if !s.DisableSwap {
|
||
s.DisableSwap = SecurityDefaults.DisableSwap
|
||
}
|
||
if s.OOMScoreAdj == 0 {
|
||
s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj
|
||
}
|
||
// TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID")
|
||
// Only override if explicitly set (> 0)
|
||
if s.TaskUID < 0 {
|
||
s.TaskUID = 0
|
||
}
|
||
if s.TaskGID < 0 {
|
||
s.TaskGID = 0
|
||
}
|
||
}
|
||
|
||
// GetProcessIsolationFlags returns the effective UID/GID for a task
|
||
// If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID
|
||
func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) {
|
||
uid = s.RunAsUID
|
||
gid = s.RunAsGID
|
||
if s.TaskUID > 0 {
|
||
uid = s.TaskUID
|
||
}
|
||
if s.TaskGID > 0 {
|
||
gid = s.TaskGID
|
||
}
|
||
return uid, gid
|
||
}
|
||
|
||
// Getter methods for SandboxConfig interface
|
||
func (s *SandboxConfig) GetNoNewPrivileges() bool { return s.NoNewPrivileges }
|
||
func (s *SandboxConfig) GetDropAllCaps() bool { return s.DropAllCaps }
|
||
func (s *SandboxConfig) GetAllowedCaps() []string { return s.AllowedCaps }
|
||
func (s *SandboxConfig) GetUserNS() bool { return s.UserNS }
|
||
func (s *SandboxConfig) GetRunAsUID() int { return s.RunAsUID }
|
||
func (s *SandboxConfig) GetRunAsGID() int { return s.RunAsGID }
|
||
func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
|
||
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
|
||
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
|
||
|
||
// Process Isolation getter methods
|
||
func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses }
|
||
func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles }
|
||
func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap }
|
||
func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj }
|
||
func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID }
|
||
func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID }
|
||
|
||
// LoadConfig loads worker configuration from a YAML file.
|
||
func LoadConfig(path string) (*Config, error) {
|
||
data, err := fileutil.SecureFileRead(path)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var cfg Config
|
||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if strings.TrimSpace(cfg.RedisURL) != "" {
|
||
cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL))
|
||
cfg.RedisAddr = cfg.RedisURL
|
||
cfg.RedisPassword = ""
|
||
cfg.RedisDB = 0
|
||
}
|
||
|
||
// Get smart defaults for current environment
|
||
smart := config.GetSmartDefaults()
|
||
|
||
// Use PathRegistry for consistent path management
|
||
paths := config.FromEnv()
|
||
|
||
if cfg.Port == 0 {
|
||
cfg.Port = config.DefaultSSHPort
|
||
}
|
||
if cfg.Host == "" {
|
||
host, err := smart.Host()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get default host: %w", err)
|
||
}
|
||
cfg.Host = host
|
||
}
|
||
if cfg.BasePath == "" {
|
||
// Prefer PathRegistry over smart defaults for consistency
|
||
cfg.BasePath = paths.ExperimentsDir()
|
||
}
|
||
if cfg.RedisAddr == "" {
|
||
redisAddr, err := smart.RedisAddr()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get default redis address: %w", err)
|
||
}
|
||
cfg.RedisAddr = redisAddr
|
||
}
|
||
if cfg.KnownHosts == "" {
|
||
knownHosts, err := smart.KnownHostsPath()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
|
||
}
|
||
cfg.KnownHosts = knownHosts
|
||
}
|
||
if cfg.WorkerID == "" {
|
||
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
||
}
|
||
cfg.Resources.ApplyDefaults()
|
||
if cfg.MaxWorkers > 0 {
|
||
cfg.Resources.MaxWorkers = cfg.MaxWorkers
|
||
} else {
|
||
maxWorkers, err := smart.MaxWorkers()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get default max workers: %w", err)
|
||
}
|
||
cfg.MaxWorkers = maxWorkers
|
||
cfg.Resources.MaxWorkers = maxWorkers
|
||
}
|
||
if cfg.PollInterval == 0 {
|
||
pollInterval, err := smart.PollInterval()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get default poll interval: %w", err)
|
||
}
|
||
cfg.PollInterval = pollInterval
|
||
}
|
||
if cfg.DataManagerPath == "" {
|
||
cfg.DataManagerPath = "./data_manager"
|
||
}
|
||
if cfg.DataDir == "" {
|
||
// Use PathRegistry for consistent data directory
|
||
cfg.DataDir = paths.DataDir()
|
||
}
|
||
if cfg.SnapshotStore.Timeout == 0 {
|
||
cfg.SnapshotStore.Timeout = 10 * time.Minute
|
||
}
|
||
if cfg.SnapshotStore.MaxRetries == 0 {
|
||
cfg.SnapshotStore.MaxRetries = 3
|
||
}
|
||
if cfg.Metrics.ListenAddr == "" {
|
||
cfg.Metrics.ListenAddr = ":9100"
|
||
}
|
||
if cfg.MetricsFlushInterval == 0 {
|
||
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
|
||
}
|
||
if cfg.DatasetCacheTTL == 0 {
|
||
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
||
}
|
||
|
||
if strings.TrimSpace(cfg.Queue.Backend) == "" {
|
||
cfg.Queue.Backend = string(queue.QueueBackendRedis)
|
||
}
|
||
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) {
|
||
if strings.TrimSpace(cfg.Queue.SQLitePath) == "" {
|
||
cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db")
|
||
}
|
||
cfg.Queue.SQLitePath = storage.ExpandPath(cfg.Queue.SQLitePath)
|
||
}
|
||
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem {
|
||
if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" {
|
||
cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs")
|
||
}
|
||
cfg.Queue.FilesystemPath = storage.ExpandPath(cfg.Queue.FilesystemPath)
|
||
}
|
||
|
||
if strings.TrimSpace(cfg.GPUVendor) == "" {
|
||
cfg.GPUVendorAutoDetected = true
|
||
if cfg.AppleGPU.Enabled {
|
||
cfg.GPUVendor = string(GPUTypeApple)
|
||
} else if len(cfg.GPUDevices) > 0 ||
|
||
len(cfg.GPUVisibleDevices) > 0 ||
|
||
len(cfg.GPUVisibleDeviceIDs) > 0 {
|
||
cfg.GPUVendor = string(GPUTypeNVIDIA)
|
||
} else {
|
||
cfg.GPUVendor = string(GPUTypeNone)
|
||
}
|
||
}
|
||
|
||
// Set lease and retry defaults
|
||
if cfg.TaskLeaseDuration == 0 {
|
||
cfg.TaskLeaseDuration = 30 * time.Minute
|
||
}
|
||
if cfg.HeartbeatInterval == 0 {
|
||
cfg.HeartbeatInterval = 1 * time.Minute
|
||
}
|
||
if cfg.MaxRetries == 0 {
|
||
cfg.MaxRetries = 3
|
||
}
|
||
if cfg.GracefulTimeout == 0 {
|
||
cfg.GracefulTimeout = 5 * time.Minute
|
||
}
|
||
|
||
// Apply security defaults to sandbox configuration
|
||
cfg.Sandbox.ApplySecurityDefaults()
|
||
|
||
// Expand secrets from environment variables
|
||
if err := cfg.ExpandSecrets(); err != nil {
|
||
return nil, fmt.Errorf("secrets expansion failed: %w", err)
|
||
}
|
||
|
||
return &cfg, nil
|
||
}
|
||
|
||
// Validate implements config.Validator interface.
|
||
func (c *Config) Validate() error {
|
||
if c.Port != 0 {
|
||
if err := config.ValidatePort(c.Port); err != nil {
|
||
return fmt.Errorf("invalid SSH port: %w", err)
|
||
}
|
||
}
|
||
|
||
if c.BasePath != "" {
|
||
// Convert relative paths to absolute
|
||
c.BasePath = storage.ExpandPath(c.BasePath)
|
||
if !filepath.IsAbs(c.BasePath) {
|
||
// Resolve relative to current working directory, not DefaultBasePath
|
||
cwd, err := os.Getwd()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get current directory: %w", err)
|
||
}
|
||
c.BasePath = filepath.Join(cwd, c.BasePath)
|
||
}
|
||
}
|
||
|
||
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
|
||
if backend == "" {
|
||
backend = string(queue.QueueBackendRedis)
|
||
c.Queue.Backend = backend
|
||
}
|
||
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) {
|
||
return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS)
|
||
}
|
||
|
||
if backend == string(queue.QueueBackendSQLite) {
|
||
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
|
||
return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite)
|
||
}
|
||
c.Queue.SQLitePath = storage.ExpandPath(c.Queue.SQLitePath)
|
||
if !filepath.IsAbs(c.Queue.SQLitePath) {
|
||
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
||
}
|
||
}
|
||
if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem {
|
||
if strings.TrimSpace(c.Queue.FilesystemPath) == "" {
|
||
return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled")
|
||
}
|
||
c.Queue.FilesystemPath = storage.ExpandPath(c.Queue.FilesystemPath)
|
||
if !filepath.IsAbs(c.Queue.FilesystemPath) {
|
||
c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath)
|
||
}
|
||
}
|
||
|
||
if c.RedisAddr != "" {
|
||
addr := strings.TrimSpace(c.RedisAddr)
|
||
if strings.HasPrefix(addr, "redis://") {
|
||
u, err := url.Parse(addr)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err)
|
||
}
|
||
if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" {
|
||
return fmt.Errorf("invalid Redis configuration: invalid redis url")
|
||
}
|
||
} else {
|
||
if err := config.ValidateRedisAddr(addr); err != nil {
|
||
return fmt.Errorf("invalid Redis configuration: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
if c.MaxWorkers < 1 {
|
||
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
|
||
}
|
||
|
||
switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) {
|
||
case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd":
|
||
// ok
|
||
default:
|
||
return fmt.Errorf(
|
||
"gpu_vendor must be one of %q, %q, %q, %q",
|
||
string(GPUTypeNVIDIA),
|
||
"amd",
|
||
string(GPUTypeApple),
|
||
string(GPUTypeNone),
|
||
)
|
||
}
|
||
|
||
// Strict GPU visibility configuration:
|
||
// - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive.
|
||
// - UUID-style gpu_visible_device_ids is NVIDIA-only.
|
||
vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor))
|
||
if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 {
|
||
if vendor != string(GPUTypeNVIDIA) {
|
||
return fmt.Errorf(
|
||
"visible_device_ids is only supported when gpu_vendor is %q",
|
||
string(GPUTypeNVIDIA),
|
||
)
|
||
}
|
||
for _, id := range c.GPUVisibleDeviceIDs {
|
||
id = strings.TrimSpace(id)
|
||
if id == "" {
|
||
return fmt.Errorf("visible_device_ids contains an empty value")
|
||
}
|
||
if !strings.HasPrefix(id, "GPU-") {
|
||
return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id)
|
||
}
|
||
}
|
||
}
|
||
if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) {
|
||
if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 {
|
||
return fmt.Errorf(
|
||
"gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q",
|
||
vendor,
|
||
)
|
||
}
|
||
}
|
||
if vendor == "amd" {
|
||
if len(c.GPUVisibleDeviceIDs) > 0 {
|
||
return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor)
|
||
}
|
||
for _, idx := range c.GPUVisibleDevices {
|
||
if idx < 0 {
|
||
return fmt.Errorf("gpu_visible_devices contains negative index %d", idx)
|
||
}
|
||
}
|
||
}
|
||
|
||
if c.SnapshotStore.Enabled {
|
||
if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" {
|
||
return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true")
|
||
}
|
||
if strings.TrimSpace(c.SnapshotStore.Bucket) == "" {
|
||
return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true")
|
||
}
|
||
ak := strings.TrimSpace(c.SnapshotStore.AccessKey)
|
||
sk := strings.TrimSpace(c.SnapshotStore.SecretKey)
|
||
if (ak == "") != (sk == "") {
|
||
return fmt.Errorf(
|
||
"snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty",
|
||
)
|
||
}
|
||
if c.SnapshotStore.Timeout < 0 {
|
||
return fmt.Errorf("snapshot_store.timeout must be >= 0")
|
||
}
|
||
if c.SnapshotStore.MaxRetries < 0 {
|
||
return fmt.Errorf("snapshot_store.max_retries must be >= 0")
|
||
}
|
||
}
|
||
|
||
// HIPAA mode validation - hard requirements
|
||
if strings.ToLower(c.ComplianceMode) == "hipaa" {
|
||
if err := c.validateHIPAARequirements(); err != nil {
|
||
return fmt.Errorf("HIPAA compliance validation failed: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ExpandSecrets replaces secret placeholders with environment variables
|
||
// Exported for testing purposes
|
||
func (c *Config) ExpandSecrets() error {
|
||
// First validate that secrets use env var syntax (not plaintext)
|
||
if err := c.ValidateNoPlaintextSecrets(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// Expand Redis password from env if using ${...} syntax
|
||
if strings.Contains(c.RedisPassword, "${") {
|
||
c.RedisPassword = os.ExpandEnv(c.RedisPassword)
|
||
}
|
||
|
||
// Expand SnapshotStore credentials
|
||
if strings.Contains(c.SnapshotStore.AccessKey, "${") {
|
||
c.SnapshotStore.AccessKey = os.ExpandEnv(c.SnapshotStore.AccessKey)
|
||
}
|
||
if strings.Contains(c.SnapshotStore.SecretKey, "${") {
|
||
c.SnapshotStore.SecretKey = os.ExpandEnv(c.SnapshotStore.SecretKey)
|
||
}
|
||
if strings.Contains(c.SnapshotStore.SessionToken, "${") {
|
||
c.SnapshotStore.SessionToken = os.ExpandEnv(c.SnapshotStore.SessionToken)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ValidateNoPlaintextSecrets checks that sensitive fields use env var references
|
||
// rather than hardcoded plaintext values. This is a HIPAA compliance requirement.
|
||
// Exported for testing purposes
|
||
func (c *Config) ValidateNoPlaintextSecrets() error {
|
||
// Fields that should use ${ENV_VAR} syntax instead of plaintext
|
||
sensitiveFields := []struct {
|
||
name string
|
||
value string
|
||
}{
|
||
{"redis_password", c.RedisPassword},
|
||
{"snapshot_store.access_key", c.SnapshotStore.AccessKey},
|
||
{"snapshot_store.secret_key", c.SnapshotStore.SecretKey},
|
||
{"snapshot_store.session_token", c.SnapshotStore.SessionToken},
|
||
}
|
||
|
||
for _, field := range sensitiveFields {
|
||
if field.value == "" {
|
||
continue // Empty values are fine
|
||
}
|
||
|
||
// Check if it looks like a plaintext secret (not env var reference)
|
||
if !strings.HasPrefix(field.value, "${") && LooksLikeSecret(field.value) {
|
||
return fmt.Errorf(
|
||
"%s appears to contain a plaintext secret (length=%d, entropy=%.2f); "+
|
||
"use ${ENV_VAR} syntax to load from environment or secrets manager",
|
||
field.name, len(field.value), CalculateEntropy(field.value),
|
||
)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// validateHIPAARequirements enforces hard HIPAA compliance requirements at startup.
|
||
// These must fail loudly rather than silently fall back to insecure defaults.
|
||
func (c *Config) validateHIPAARequirements() error {
|
||
// 1. SnapshotStore must be secure
|
||
if c.SnapshotStore.Enabled && !c.SnapshotStore.Secure {
|
||
return fmt.Errorf("snapshot_store.secure must be true in HIPAA mode")
|
||
}
|
||
|
||
// 2. NetworkMode must be "none" (no network access)
|
||
if c.Sandbox.NetworkMode != "none" {
|
||
return fmt.Errorf("sandbox.network_mode must be 'none' in HIPAA mode, got %q", c.Sandbox.NetworkMode)
|
||
}
|
||
|
||
// 3. SeccompProfile must be non-empty
|
||
if c.Sandbox.SeccompProfile == "" {
|
||
return fmt.Errorf("sandbox.seccomp_profile must be non-empty in HIPAA mode")
|
||
}
|
||
|
||
// 4. NoNewPrivileges must be true
|
||
if !c.Sandbox.NoNewPrivileges {
|
||
return fmt.Errorf("sandbox.no_new_privileges must be true in HIPAA mode")
|
||
}
|
||
|
||
// 5. All credentials must be sourced from env vars, not inline YAML
|
||
if err := c.validateNoInlineCredentials(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 6. AllowedSecrets must not contain PHI field names
|
||
if err := c.Sandbox.validatePHIDenylist(); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// validateNoInlineCredentials checks that no credentials are hardcoded in config
|
||
func (c *Config) validateNoInlineCredentials() error {
|
||
// Check Redis password - must be empty or use env var syntax
|
||
if c.RedisPassword != "" && !strings.HasPrefix(c.RedisPassword, "${") {
|
||
return fmt.Errorf("redis_password must use ${ENV_VAR} syntax in HIPAA mode, not inline value")
|
||
}
|
||
|
||
// Check SSH key - must use env var syntax
|
||
if c.SSHKey != "" && !strings.HasPrefix(c.SSHKey, "${") {
|
||
return fmt.Errorf("ssh_key must use ${ENV_VAR} syntax in HIPAA mode, not inline value")
|
||
}
|
||
|
||
// Check SnapshotStore credentials
|
||
if c.SnapshotStore.AccessKey != "" && !strings.HasPrefix(c.SnapshotStore.AccessKey, "${") {
|
||
return fmt.Errorf("snapshot_store.access_key must use ${ENV_VAR} syntax in HIPAA mode")
|
||
}
|
||
if c.SnapshotStore.SecretKey != "" && !strings.HasPrefix(c.SnapshotStore.SecretKey, "${") {
|
||
return fmt.Errorf("snapshot_store.secret_key must use ${ENV_VAR} syntax in HIPAA mode")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// PHI field patterns that should not appear in AllowedSecrets
|
||
var phiDenylistPatterns = []string{
|
||
"patient", "phi", "ssn", "social_security", "mrn", "medical_record",
|
||
"dob", "birth_date", "diagnosis", "condition", "medication", "allergy",
|
||
}
|
||
|
||
// validatePHIDenylist checks that AllowedSecrets doesn't contain PHI field names
|
||
func (s *SandboxConfig) validatePHIDenylist() error {
|
||
for _, secret := range s.AllowedSecrets {
|
||
secretLower := strings.ToLower(secret)
|
||
for _, pattern := range phiDenylistPatterns {
|
||
if strings.Contains(secretLower, pattern) {
|
||
return fmt.Errorf("allowed_secrets contains potential PHI field %q (matches pattern %q); this could allow PHI exfiltration", secret, pattern)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// LooksLikeSecret heuristically detects if a string looks like a secret credential
|
||
// Exported for testing purposes
|
||
func LooksLikeSecret(s string) bool {
|
||
// Minimum length for secrets
|
||
if len(s) < 16 {
|
||
return false
|
||
}
|
||
|
||
// Calculate entropy to detect high-entropy strings (likely secrets)
|
||
entropy := CalculateEntropy(s)
|
||
|
||
// High entropy (>4 bits per char) combined with reasonable length suggests a secret
|
||
if entropy > 4.0 {
|
||
return true
|
||
}
|
||
|
||
// Check for common secret patterns
|
||
patterns := []string{
|
||
"AKIA", // AWS Access Key ID prefix
|
||
"ASIA", // AWS temporary credentials
|
||
"ghp_", // GitHub personal access token
|
||
"gho_", // GitHub OAuth token
|
||
"glpat-", // GitLab PAT
|
||
"sk-", // OpenAI/Stripe key prefix
|
||
"sk_live_", // Stripe live key
|
||
"sk_test_", // Stripe test key
|
||
}
|
||
|
||
for _, pattern := range patterns {
|
||
if strings.Contains(s, pattern) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// CalculateEntropy calculates Shannon entropy of a string in bits per character
|
||
// Exported for testing purposes
|
||
func CalculateEntropy(s string) float64 {
|
||
if len(s) == 0 {
|
||
return 0
|
||
}
|
||
|
||
// Count character frequencies
|
||
freq := make(map[rune]int)
|
||
for _, r := range s {
|
||
freq[r]++
|
||
}
|
||
|
||
// Calculate entropy
|
||
var entropy float64
|
||
length := float64(len(s))
|
||
for _, count := range freq {
|
||
p := float64(count) / length
|
||
if p > 0 {
|
||
entropy -= p * math.Log2(p)
|
||
}
|
||
}
|
||
|
||
return entropy
|
||
}
|
||
|
||
// ComputeResolvedConfigHash computes a SHA-256 hash of the resolved config.
|
||
// This must be called after os.ExpandEnv, after default application, and after Validate().
|
||
// The hash captures the actual runtime configuration, not the raw YAML file.
|
||
// This is critical for reproducibility - two different raw files that resolve
|
||
// to the same config will produce the same hash.
|
||
func (c *Config) ComputeResolvedConfigHash() (string, error) {
|
||
// Marshal config to JSON for consistent serialization
|
||
// We use a simplified struct to avoid hashing volatile fields
|
||
hashable := struct {
|
||
Host string `json:"host"`
|
||
Port int `json:"port"`
|
||
BasePath string `json:"base_path"`
|
||
MaxWorkers int `json:"max_workers"`
|
||
Resources config.ResourceConfig `json:"resources"`
|
||
GPUVendor string `json:"gpu_vendor"`
|
||
GPUVisibleDevices []int `json:"gpu_visible_devices,omitempty"`
|
||
GPUVisibleDeviceIDs []string `json:"gpu_visible_device_ids,omitempty"`
|
||
Sandbox SandboxConfig `json:"sandbox"`
|
||
ComplianceMode string `json:"compliance_mode"`
|
||
ProvenanceBestEffort bool `json:"provenance_best_effort"`
|
||
SnapshotStoreSecure bool `json:"snapshot_store_secure,omitempty"`
|
||
QueueBackend string `json:"queue_backend"`
|
||
}{
|
||
Host: c.Host,
|
||
Port: c.Port,
|
||
BasePath: c.BasePath,
|
||
MaxWorkers: c.MaxWorkers,
|
||
Resources: c.Resources,
|
||
GPUVendor: c.GPUVendor,
|
||
GPUVisibleDevices: c.GPUVisibleDevices,
|
||
GPUVisibleDeviceIDs: c.GPUVisibleDeviceIDs,
|
||
Sandbox: c.Sandbox,
|
||
ComplianceMode: c.ComplianceMode,
|
||
ProvenanceBestEffort: c.ProvenanceBestEffort,
|
||
SnapshotStoreSecure: c.SnapshotStore.Secure,
|
||
QueueBackend: c.Queue.Backend,
|
||
}
|
||
|
||
data, err := json.Marshal(hashable)
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to marshal config for hashing: %w", err)
|
||
}
|
||
|
||
// Compute SHA-256 hash
|
||
hash := sha256.Sum256(data)
|
||
return hex.EncodeToString(hash[:]), nil
|
||
}
|
||
|
||
// envInt reads an integer from environment variable
|
||
func envInt(name string) (int, bool) {
|
||
v := strings.TrimSpace(os.Getenv(name))
|
||
if v == "" {
|
||
return 0, false
|
||
}
|
||
n, err := strconv.Atoi(v)
|
||
if err != nil {
|
||
return 0, false
|
||
}
|
||
return n, true
|
||
}
|
||
|
||
// logEnvOverride logs environment variable overrides to stderr for debugging
|
||
func logEnvOverride(name string, value interface{}) {
|
||
slog.Warn("env override active", "var", name, "value", value)
|
||
}
|
||
|
||
// parseCPUFromConfig determines total CPU from environment or config
|
||
func parseCPUFromConfig(cfg *Config) int {
|
||
if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 {
|
||
logEnvOverride("FETCH_ML_TOTAL_CPU", n)
|
||
return n
|
||
}
|
||
if cfg != nil {
|
||
if cfg.Resources.PodmanCPUs != "" {
|
||
if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil {
|
||
if f < 0 {
|
||
return 0
|
||
}
|
||
return int(math.Floor(f))
|
||
}
|
||
}
|
||
}
|
||
return runtime.NumCPU()
|
||
}
|
||
|
||
// parseGPUCountFromConfig detects GPU count from config and returns detection metadata
|
||
func parseGPUCountFromConfig(cfg *Config) (int, GPUDetectionInfo) {
|
||
factory := &GPUDetectorFactory{}
|
||
result := factory.CreateDetectorWithInfo(cfg)
|
||
return result.Detector.DetectGPUCount(), result.Info
|
||
}
|
||
|
||
// parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment
|
||
func parseGPUSlotsPerGPUFromConfig() int {
|
||
if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 {
|
||
return n
|
||
}
|
||
return 1
|
||
}
|