package config import ( "fmt" "os" "path/filepath" "github.com/BurntSushi/toml" "github.com/jfraeys/fetch_ml/internal/auth" utils "github.com/jfraeys/fetch_ml/internal/config" ) // Config holds TUI configuration type Config struct { Host string `toml:"host"` User string `toml:"user"` SSHKey string `toml:"ssh_key"` Port int `toml:"port"` BasePath string `toml:"base_path"` WrapperScript string `toml:"wrapper_script"` TrainScript string `toml:"train_script"` RedisAddr string `toml:"redis_addr"` RedisPassword string `toml:"redis_password"` RedisDB int `toml:"redis_db"` KnownHosts string `toml:"known_hosts"` // Authentication Auth auth.AuthConfig `toml:"auth"` // Podman settings PodmanImage string `toml:"podman_image"` ContainerWorkspace string `toml:"container_workspace"` ContainerResults string `toml:"container_results"` GPUAccess bool `toml:"gpu_access"` } func LoadConfig(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg Config if _, err := toml.Decode(string(data), &cfg); err != nil { return nil, err } // Get smart defaults for current environment smart := utils.GetSmartDefaults() if cfg.Port == 0 { cfg.Port = utils.DefaultSSHPort } if cfg.Host == "" { cfg.Host = smart.Host() } if cfg.BasePath == "" { cfg.BasePath = smart.BasePath() } // wrapper_script is deprecated - using secure_runner.py directly via Podman if cfg.TrainScript == "" { cfg.TrainScript = utils.DefaultTrainScript } if cfg.RedisAddr == "" { cfg.RedisAddr = smart.RedisAddr() } if cfg.KnownHosts == "" { cfg.KnownHosts = smart.KnownHostsPath() } // Apply environment variable overrides with FETCH_ML_TUI_ prefix if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" { cfg.Host = host } if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" { cfg.User = user } if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" { cfg.SSHKey = sshKey } if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" { if p, err := parseInt(port); err == nil { cfg.Port = p } } if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" { cfg.BasePath = basePath } if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" { cfg.TrainScript = trainScript } if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" { cfg.RedisAddr = redisAddr } if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" { cfg.RedisPassword = redisPassword } if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" { if db, err := parseInt(redisDB); err == nil { cfg.RedisDB = db } } if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" { cfg.KnownHosts = knownHosts } return &cfg, nil } // Validate implements utils.Validator interface func (c *Config) Validate() error { if c.Port != 0 { if err := utils.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = utils.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath) } } if c.RedisAddr != "" { if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } return nil } func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") } func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") } func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") } func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") } // parseInt parses a string to integer func parseInt(s string) (int, error) { var result int _, err := fmt.Sscanf(s, "%d", &result) return result, err }