package config import ( "fmt" "os" "path/filepath" "github.com/BurntSushi/toml" "github.com/jfraeys/fetch_ml/internal/auth" utils "github.com/jfraeys/fetch_ml/internal/config" ) // Config holds TUI configuration type Config struct { Host string `toml:"host"` User string `toml:"user"` SSHKey string `toml:"ssh_key"` Port int `toml:"port"` BasePath string `toml:"base_path"` WrapperScript string `toml:"wrapper_script"` TrainScript string `toml:"train_script"` RedisAddr string `toml:"redis_addr"` RedisPassword string `toml:"redis_password"` RedisDB int `toml:"redis_db"` KnownHosts string `toml:"known_hosts"` // Authentication Auth auth.Config `toml:"auth"` // Podman settings PodmanImage string `toml:"podman_image"` ContainerWorkspace string `toml:"container_workspace"` ContainerResults string `toml:"container_results"` GPUDevices []string `toml:"gpu_devices"` } // LoadConfig loads configuration from a TOML file func LoadConfig(path string) (*Config, error) { //nolint:gosec // G304: Config path is user-controlled but trusted data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg Config if _, err := toml.Decode(string(data), &cfg); err != nil { return nil, err } // Get smart defaults for current environment smart := utils.GetSmartDefaults() if cfg.Port == 0 { cfg.Port = utils.DefaultSSHPort } if cfg.Host == "" { host, err := smart.Host() if err != nil { return nil, fmt.Errorf("failed to get default host: %w", err) } cfg.Host = host } if cfg.BasePath == "" { basePath, err := smart.BasePath() if err != nil { return nil, fmt.Errorf("failed to get default base path: %w", err) } cfg.BasePath = basePath } // wrapper_script is deprecated - using secure_runner.py directly via Podman if cfg.TrainScript == "" { cfg.TrainScript = utils.DefaultTrainScript } if cfg.RedisAddr == "" { redisAddr, err := smart.RedisAddr() if err != nil { return nil, fmt.Errorf("failed to get default redis address: %w", err) } cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { knownHosts, err := smart.KnownHostsPath() if err != nil { return nil, fmt.Errorf("failed to get default known hosts path: %w", err) } cfg.KnownHosts = knownHosts } // Apply environment variable overrides with FETCH_ML_TUI_ prefix if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" { cfg.Host = host } if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" { cfg.User = user } if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" { cfg.SSHKey = sshKey } if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" { if p, err := parseInt(port); err == nil { cfg.Port = p } } if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" { cfg.BasePath = basePath } if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" { cfg.TrainScript = trainScript } if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" { cfg.RedisAddr = redisAddr } if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" { cfg.RedisPassword = redisPassword } if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" { if db, err := parseInt(redisDB); err == nil { cfg.RedisDB = db } } if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" { cfg.KnownHosts = knownHosts } return &cfg, nil } // Validate implements utils.Validator interface func (c *Config) Validate() error { if c.Port != 0 { if err := utils.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = utils.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath) } } if c.RedisAddr != "" { if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } return nil } // PendingPath returns the path for pending experiments func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") } // RunningPath returns the path for running experiments func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") } // FinishedPath returns the path for finished experiments func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") } // FailedPath returns the path for failed experiments func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") } // parseInt parses a string to integer func parseInt(s string) (int, error) { var result int _, err := fmt.Sscanf(s, "%d", &result) return result, err }