package main import ( "fmt" "os" "path/filepath" "time" "github.com/google/uuid" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "gopkg.in/yaml.v3" ) const ( defaultMetricsFlushInterval = 500 * time.Millisecond datasetCacheDefaultTTL = 30 * time.Minute ) // Config holds worker configuration type Config struct { Host string `yaml:"host"` User string `yaml:"user"` SSHKey string `yaml:"ssh_key"` Port int `yaml:"port"` BasePath string `yaml:"base_path"` TrainScript string `yaml:"train_script"` RedisAddr string `yaml:"redis_addr"` RedisPassword string `yaml:"redis_password"` RedisDB int `yaml:"redis_db"` KnownHosts string `yaml:"known_hosts"` WorkerID string `yaml:"worker_id"` MaxWorkers int `yaml:"max_workers"` PollInterval int `yaml:"poll_interval_seconds"` // Authentication Auth auth.AuthConfig `yaml:"auth"` // Metrics exporter Metrics MetricsConfig `yaml:"metrics"` // Metrics buffering MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"` // Data management DataManagerPath string `yaml:"data_manager_path"` AutoFetchData bool `yaml:"auto_fetch_data"` DataDir string `yaml:"data_dir"` DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"` // Podman execution PodmanImage string `yaml:"podman_image"` ContainerWorkspace string `yaml:"container_workspace"` ContainerResults string `yaml:"container_results"` GPUAccess bool `yaml:"gpu_access"` // Task lease and retry settings TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min) HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min) MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min) } // MetricsConfig controls the Prometheus exporter. type MetricsConfig struct { Enabled bool `yaml:"enabled"` ListenAddr string `yaml:"listen_addr"` } func LoadConfig(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err } // Get smart defaults for current environment smart := config.GetSmartDefaults() if cfg.Port == 0 { cfg.Port = config.DefaultSSHPort } if cfg.Host == "" { cfg.Host = smart.Host() } if cfg.BasePath == "" { cfg.BasePath = smart.BasePath() } if cfg.RedisAddr == "" { cfg.RedisAddr = smart.RedisAddr() } if cfg.KnownHosts == "" { cfg.KnownHosts = smart.KnownHostsPath() } if cfg.WorkerID == "" { cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) } if cfg.MaxWorkers == 0 { cfg.MaxWorkers = smart.MaxWorkers() } if cfg.PollInterval == 0 { cfg.PollInterval = smart.PollInterval() } if cfg.DataManagerPath == "" { cfg.DataManagerPath = "./data_manager" } if cfg.DataDir == "" { if cfg.Host == "" || !cfg.AutoFetchData { cfg.DataDir = config.DefaultLocalDataDir } else { cfg.DataDir = smart.DataDir() } } if cfg.Metrics.ListenAddr == "" { cfg.Metrics.ListenAddr = ":9100" } if cfg.MetricsFlushInterval == 0 { cfg.MetricsFlushInterval = defaultMetricsFlushInterval } if cfg.DatasetCacheTTL == 0 { cfg.DatasetCacheTTL = datasetCacheDefaultTTL } // Set lease and retry defaults if cfg.TaskLeaseDuration == 0 { cfg.TaskLeaseDuration = 30 * time.Minute } if cfg.HeartbeatInterval == 0 { cfg.HeartbeatInterval = 1 * time.Minute } if cfg.MaxRetries == 0 { cfg.MaxRetries = 3 } if cfg.GracefulTimeout == 0 { cfg.GracefulTimeout = 5 * time.Minute } return &cfg, nil } // Validate implements config.Validator interface func (c *Config) Validate() error { if c.Port != 0 { if err := config.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = config.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath) } } if c.RedisAddr != "" { if err := config.ValidateRedisAddr(c.RedisAddr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } if c.MaxWorkers < 1 { return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers) } return nil } // Task struct and Redis constants moved to internal/queue