package config import ( "fmt" "os" "path/filepath" "github.com/BurntSushi/toml" "github.com/jfraeys/fetch_ml/internal/auth" utils "github.com/jfraeys/fetch_ml/internal/config" ) // Config holds TUI configuration type Config struct { Experiment struct { Name string `toml:"name"` Entrypoint string `toml:"entrypoint"` } `toml:"experiment"` ProjectRoot string `toml:"project_root"` ServerURL string `toml:"server_url"` ContainerResults string `toml:"container_results"` BasePath string `toml:"base_path"` Mode string `toml:"mode"` WrapperScript string `toml:"wrapper_script"` Entrypoint string `toml:"train_script"` RedisAddr string `toml:"redis_addr"` RedisPassword string `toml:"redis_password"` ContainerWorkspace string `toml:"container_workspace"` SSHKey string `toml:"ssh_key"` DBPath string `toml:"db_path"` KnownHosts string `toml:"known_hosts"` PodmanImage string `toml:"podman_image"` Host string `toml:"host"` User string `toml:"user"` Auth auth.Config `toml:"auth"` GPUDevices []string `toml:"gpu_devices"` RedisDB int `toml:"redis_db"` Port int `toml:"port"` ForceLocal bool `toml:"force_local"` } // LoadConfig loads configuration from a TOML file func LoadConfig(path string) (*Config, error) { //nolint:gosec // G304: Config path is user-controlled but trusted data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg Config if _, err := toml.Decode(string(data), &cfg); err != nil { return nil, err } // Get smart defaults for current environment smart := utils.GetSmartDefaults() if cfg.Port == 0 { cfg.Port = utils.DefaultSSHPort } if cfg.Host == "" { host, err := smart.Host() if err != nil { return nil, fmt.Errorf("failed to get default host: %w", err) } cfg.Host = host } if cfg.BasePath == "" { basePath, err := smart.BasePath() if err != nil { return nil, fmt.Errorf("failed to get default base path: %w", err) } cfg.BasePath = basePath } // wrapper_script is deprecated - using secure_runner.py directly via Podman if cfg.Entrypoint == "" { cfg.Entrypoint = utils.DefaultEntrypoint } if cfg.RedisAddr == "" { redisAddr, err := smart.RedisAddr() if err != nil { return nil, fmt.Errorf("failed to get default redis address: %w", err) } cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { knownHosts, err := smart.KnownHostsPath() if err != nil { return nil, fmt.Errorf("failed to get default known hosts path: %w", err) } cfg.KnownHosts = knownHosts } // Apply environment variable overrides with FETCH_ML_TUI_ prefix if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" { cfg.Host = host } if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" { cfg.User = user } if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" { cfg.SSHKey = sshKey } if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" { if p, err := parseInt(port); err == nil { cfg.Port = p } } if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" { cfg.BasePath = basePath } if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" { cfg.Entrypoint = trainScript } if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" { cfg.RedisAddr = redisAddr } if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" { cfg.RedisPassword = redisPassword } if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" { if db, err := parseInt(redisDB); err == nil { cfg.RedisDB = db } } if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" { cfg.KnownHosts = knownHosts } return &cfg, nil } // Validate implements utils.Validator interface func (c *Config) Validate() error { if c.Port != 0 { if err := utils.ValidatePort(c.Port); err != nil { return fmt.Errorf("invalid SSH port: %w", err) } } // Set default mode if not specified if c.Mode == "" { if os.Getenv("FETCH_ML_TUI_MODE") != "" { c.Mode = os.Getenv("FETCH_ML_TUI_MODE") } else { c.Mode = "dev" // Default to dev mode } } // Set mode-appropriate default paths using project-relative paths if c.BasePath == "" { c.BasePath = utils.ModeBasedBasePath(c.Mode) } if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = utils.ExpandPath(c.BasePath) if !filepath.IsAbs(c.BasePath) { c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath) } } if c.RedisAddr != "" { if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil { return fmt.Errorf("invalid Redis configuration: %w", err) } } return nil } // IsLocalMode returns true if the TUI should operate in local-only mode func (c *Config) IsLocalMode() bool { if c.ForceLocal { return true } // Check if tracking_uri indicates local mode (sqlite:// prefix) if c.DBPath != "" { return true } return false } // GetDBPath returns the SQLite database path for local mode func (c *Config) GetDBPath() string { if c.DBPath != "" { return c.DBPath } // Default location: ~/.fetchml/experiments.db if home, err := os.UserHomeDir(); err == nil { return filepath.Join(home, ".fetchml", "experiments.db") } return "fetchml.db" } func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") } // RunningPath returns the path for running experiments func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") } // FinishedPath returns the path for finished experiments func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") } // FailedPath returns the path for failed experiments func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") } // parseInt parses a string to integer func parseInt(s string) (int, error) { var result int _, err := fmt.Sscanf(s, "%d", &result) return result, err }