package worker import ( "fmt" "math" "net/url" "os" "path/filepath" "runtime" "strconv" "strings" "time" "github.com/google/uuid" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/tracking/factory" "gopkg.in/yaml.v3" ) const ( defaultMetricsFlushInterval = 500 * time.Millisecond datasetCacheDefaultTTL = 30 * time.Minute ) type QueueConfig struct { Backend string `yaml:"backend"` SQLitePath string `yaml:"sqlite_path"` FilesystemPath string `yaml:"filesystem_path"` FallbackToFilesystem bool `yaml:"fallback_to_filesystem"` } // Config holds worker configuration. type Config struct { Host string `yaml:"host"` User string `yaml:"user"` SSHKey string `yaml:"ssh_key"` Port int `yaml:"port"` BasePath string `yaml:"base_path"` TrainScript string `yaml:"train_script"` RedisURL string `yaml:"redis_url"` RedisAddr string `yaml:"redis_addr"` RedisPassword string `yaml:"redis_password"` RedisDB int `yaml:"redis_db"` Queue QueueConfig `yaml:"queue"` KnownHosts string `yaml:"known_hosts"` WorkerID string `yaml:"worker_id"` MaxWorkers int `yaml:"max_workers"` PollInterval int `yaml:"poll_interval_seconds"` Resources config.ResourceConfig `yaml:"resources"` LocalMode bool `yaml:"local_mode"` // Authentication Auth auth.Config `yaml:"auth"` // Metrics exporter Metrics MetricsConfig `yaml:"metrics"` // Metrics buffering MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` // Data management DataManagerPath string `yaml:"data_manager_path"` AutoFetchData bool `yaml:"auto_fetch_data"` DataDir string `yaml:"data_dir"` DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"` // Provenance enforcement // Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort. ProvenanceBestEffort bool `yaml:"provenance_best_effort"` // Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env). PrewarmEnabled bool `yaml:"prewarm_enabled"` // Podman execution PodmanImage string `yaml:"podman_image"` ContainerWorkspace string `yaml:"container_workspace"` ContainerResults string `yaml:"container_results"` GPUDevices []string `yaml:"gpu_devices"` GPUVendor string `yaml:"gpu_vendor"` GPUVisibleDevices []int `yaml:"gpu_visible_devices"` GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"` // Apple M-series GPU configuration AppleGPU AppleGPUConfig `yaml:"apple_gpu"` // Task lease and retry settings TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min) HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min) MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min) // Plugins configuration Plugins map[string]factory.PluginConfig `yaml:"plugins"` // Sandboxing configuration Sandbox SandboxConfig `yaml:"sandbox"` } // MetricsConfig controls the Prometheus exporter. type MetricsConfig struct { Enabled bool `yaml:"enabled"` ListenAddr string `yaml:"listen_addr"` } type SnapshotStoreConfig struct { Enabled bool `yaml:"enabled"` Endpoint string `yaml:"endpoint"` Secure bool `yaml:"secure"` Region string `yaml:"region"` Bucket string `yaml:"bucket"` Prefix string `yaml:"prefix"` AccessKey string `yaml:"access_key"` SecretKey string `yaml:"secret_key"` SessionToken string `yaml:"session_token"` Timeout time.Duration `yaml:"timeout"` MaxRetries int `yaml:"max_retries"` } // AppleGPUConfig holds configuration for Apple M-series GPU support type AppleGPUConfig struct { Enabled bool `yaml:"enabled"` MetalDevice string `yaml:"metal_device"` MPSRuntime string `yaml:"mps_runtime"` } // SandboxConfig holds container sandbox settings type SandboxConfig struct { NetworkMode string `yaml:"network_mode"` // "none", "slirp4netns", "bridge" ReadOnlyRoot bool `yaml:"read_only_root"` AllowSecrets bool `yaml:"allow_secrets"` AllowedSecrets []string `yaml:"allowed_secrets"` // e.g., ["HF_TOKEN", "WANDB_API_KEY"] SeccompProfile string `yaml:"seccomp_profile"` MaxRuntimeHours int `yaml:"max_runtime_hours"` } // Validate checks sandbox configuration func (s *SandboxConfig) Validate() error { validNetworks := map[string]bool{"none": true, "slirp4netns": true, "bridge": true, "": true} if !validNetworks[s.NetworkMode] { return fmt.Errorf("invalid network_mode: %s", s.NetworkMode) } if s.MaxRuntimeHours < 0 { return fmt.Errorf("max_runtime_hours must be positive") } return nil } // LoadConfig loads worker configuration from a YAML file. func LoadConfig(path string) (*Config, error) { data, err := fileutil.SecureFileRead(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } if strings.TrimSpace(cfg.RedisURL) != "" { cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL)) cfg.RedisAddr = cfg.RedisURL cfg.RedisPassword = "" cfg.RedisDB = 0 } // Get smart defaults for current environment smart := config.GetSmartDefaults() // Use PathRegistry for consistent path management paths := config.FromEnv() if cfg.Port == 0 { cfg.Port = config.DefaultSSHPort } if cfg.Host == "" { host, err := smart.Host() if err != nil { return nil, fmt.Errorf("failed to get default host: %w", err) } cfg.Host = host } if cfg.BasePath == "" { // Prefer PathRegistry over smart defaults for consistency cfg.BasePath = paths.ExperimentsDir() } if cfg.RedisAddr == "" { redisAddr, err := smart.RedisAddr() if err != nil { return nil, fmt.Errorf("failed to get default redis address: %w", err) } cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { knownHosts, err := smart.KnownHostsPath() if err != nil { return nil, fmt.Errorf("failed to get default known hosts path: %w", err) } cfg.KnownHosts = knownHosts } if cfg.WorkerID == "" { cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) } cfg.Resources.ApplyDefaults() if cfg.MaxWorkers > 0 { cfg.Resources.MaxWorkers = cfg.MaxWorkers } else { maxWorkers, err := smart.MaxWorkers() if err != nil { return nil, fmt.Errorf("failed to get default max workers: %w", err) } cfg.MaxWorkers = maxWorkers cfg.Resources.MaxWorkers = maxWorkers } if cfg.PollInterval == 0 { pollInterval, err := smart.PollInterval() if err != nil { return nil, fmt.Errorf("failed to get default poll interval: %w", err) } cfg.PollInterval = pollInterval } if cfg.DataManagerPath == "" { cfg.DataManagerPath = "./data_manager" } if cfg.DataDir == "" { // Use PathRegistry for consistent data directory cfg.DataDir = paths.DataDir() } if cfg.SnapshotStore.Timeout == 0 { cfg.SnapshotStore.Timeout = 10 * time.Minute } if cfg.SnapshotStore.MaxRetries == 0 { cfg.SnapshotStore.MaxRetries = 3 } if cfg.Metrics.ListenAddr == "" { cfg.Metrics.ListenAddr = ":9100" } if cfg.MetricsFlushInterval == 0 { cfg.MetricsFlushInterval = defaultMetricsFlushInterval } if cfg.DatasetCacheTTL == 0 { cfg.DatasetCacheTTL = datasetCacheDefaultTTL } if strings.TrimSpace(cfg.Queue.Backend) == "" { cfg.Queue.Backend = string(queue.QueueBackendRedis) } if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) { if strings.TrimSpace(cfg.Queue.SQLitePath) == "" { cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db") } cfg.Queue.SQLitePath = storage.ExpandPath(cfg.Queue.SQLitePath) } if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendFS)) || cfg.Queue.FallbackToFilesystem { if strings.TrimSpace(cfg.Queue.FilesystemPath) == "" { cfg.Queue.FilesystemPath = filepath.Join(cfg.DataDir, "queue-fs") } cfg.Queue.FilesystemPath = storage.ExpandPath(cfg.Queue.FilesystemPath) } if strings.TrimSpace(cfg.GPUVendor) == "" { if cfg.AppleGPU.Enabled { cfg.GPUVendor = string(GPUTypeApple) } else if len(cfg.GPUDevices) > 0 || len(cfg.GPUVisibleDevices) > 0 || len(cfg.GPUVisibleDeviceIDs) > 0 { cfg.GPUVendor = string(GPUTypeNVIDIA) } else { cfg.GPUVendor = string(GPUTypeNone) } } // Set lease and retry defaults if cfg.TaskLeaseDuration == 0 { cfg.TaskLeaseDuration = 30 * time.Minute } if cfg.HeartbeatInterval == 0 { cfg.HeartbeatInterval = 1 * time.Minute } if cfg.MaxRetries == 0 { cfg.MaxRetries = 3 } if cfg.GracefulTimeout == 0 { cfg.GracefulTimeout = 5 * time.Minute } return &cfg, nil } // Validate implements config.Validator interface. func (c *Config) Validate() error { if c.Port != 0 { if err := config.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = storage.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { // Resolve relative to current working directory, not DefaultBasePath cwd, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get current directory: %w", err) } c.BasePath = filepath.Join(cwd, c.BasePath) } } backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend)) if backend == "" { backend = string(queue.QueueBackendRedis) c.Queue.Backend = backend } if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) && backend != string(queue.QueueBackendFS) { return fmt.Errorf("queue.backend must be one of %q, %q, or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite, queue.QueueBackendFS) } if backend == string(queue.QueueBackendSQLite) { if strings.TrimSpace(c.Queue.SQLitePath) == "" { return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite) } c.Queue.SQLitePath = storage.ExpandPath(c.Queue.SQLitePath) if !filepath.IsAbs(c.Queue.SQLitePath) { c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) } } if backend == string(queue.QueueBackendFS) || c.Queue.FallbackToFilesystem { if strings.TrimSpace(c.Queue.FilesystemPath) == "" { return fmt.Errorf("queue.filesystem_path is required when filesystem queue is enabled") } c.Queue.FilesystemPath = storage.ExpandPath(c.Queue.FilesystemPath) if !filepath.IsAbs(c.Queue.FilesystemPath) { c.Queue.FilesystemPath = filepath.Join(config.DefaultLocalDataDir, c.Queue.FilesystemPath) } } if c.RedisAddr != "" { addr := strings.TrimSpace(c.RedisAddr) if strings.HasPrefix(addr, "redis://") { u, err := url.Parse(addr) if err != nil { return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err) } if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" { return fmt.Errorf("invalid Redis configuration: invalid redis url") } } else { if err := config.ValidateRedisAddr(addr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } } if c.MaxWorkers < 1 { return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) } switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) { case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd": // ok default: return fmt.Errorf( "gpu_vendor must be one of %q, %q, %q, %q", string(GPUTypeNVIDIA), "amd", string(GPUTypeApple), string(GPUTypeNone), ) } // Strict GPU visibility configuration: // - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive. // - UUID-style gpu_visible_device_ids is NVIDIA-only. vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor)) if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 { if vendor != string(GPUTypeNVIDIA) { return fmt.Errorf( "visible_device_ids is only supported when gpu_vendor is %q", string(GPUTypeNVIDIA), ) } for _, id := range c.GPUVisibleDeviceIDs { id = strings.TrimSpace(id) if id == "" { return fmt.Errorf("visible_device_ids contains an empty value") } if !strings.HasPrefix(id, "GPU-") { return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id) } } } if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) { if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 { return fmt.Errorf( "gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q", vendor, ) } } if vendor == "amd" { if len(c.GPUVisibleDeviceIDs) > 0 { return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor) } for _, idx := range c.GPUVisibleDevices { if idx < 0 { return fmt.Errorf("gpu_visible_devices contains negative index %d", idx) } } } if c.SnapshotStore.Enabled { if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" { return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true") } if strings.TrimSpace(c.SnapshotStore.Bucket) == "" { return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true") } ak := strings.TrimSpace(c.SnapshotStore.AccessKey) sk := strings.TrimSpace(c.SnapshotStore.SecretKey) if (ak == "") != (sk == "") { return fmt.Errorf( "snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty", ) } if c.SnapshotStore.Timeout < 0 { return fmt.Errorf("snapshot_store.timeout must be >= 0") } if c.SnapshotStore.MaxRetries < 0 { return fmt.Errorf("snapshot_store.max_retries must be >= 0") } } return nil } // envInt reads an integer from environment variable func envInt(name string) (int, bool) { v := strings.TrimSpace(os.Getenv(name)) if v == "" { return 0, false } n, err := strconv.Atoi(v) if err != nil { return 0, false } return n, true } // parseCPUFromConfig determines total CPU from environment or config func parseCPUFromConfig(cfg *Config) int { if n, ok := envInt("FETCH_ML_TOTAL_CPU"); ok && n >= 0 { return n } if cfg != nil { if cfg.Resources.PodmanCPUs != "" { if f, err := strconv.ParseFloat(strings.TrimSpace(cfg.Resources.PodmanCPUs), 64); err == nil { if f < 0 { return 0 } return int(math.Floor(f)) } } } return runtime.NumCPU() } // parseGPUCountFromConfig detects GPU count from config func parseGPUCountFromConfig(cfg *Config) int { factory := &GPUDetectorFactory{} detector := factory.CreateDetector(cfg) return detector.DetectGPUCount() } // parseGPUSlotsPerGPUFromConfig reads GPU slots per GPU from environment func parseGPUSlotsPerGPUFromConfig() int { if n, ok := envInt("FETCH_ML_GPU_SLOTS_PER_GPU"); ok && n > 0 { return n } return 1 }