fetch_ml/cmd/tui/internal/config/config.go
Jeremie Fraeys 66f262d788
security: improve audit, crypto, and config handling
- Enhance audit checkpoint system
- Update KMS provider and tenant key management
- Refine configuration constants
- Improve TUI config handling
2026-03-04 13:23:42 -05:00

211 lines
6 KiB
Go

package config
import (
"fmt"
"os"
"path/filepath"
"github.com/BurntSushi/toml"
"github.com/jfraeys/fetch_ml/internal/auth"
utils "github.com/jfraeys/fetch_ml/internal/config"
)
// Config holds TUI configuration
type Config struct {
Experiment struct {
Name string `toml:"name"`
Entrypoint string `toml:"entrypoint"`
} `toml:"experiment"`
ProjectRoot string `toml:"project_root"`
ServerURL string `toml:"server_url"`
ContainerResults string `toml:"container_results"`
BasePath string `toml:"base_path"`
Mode string `toml:"mode"`
WrapperScript string `toml:"wrapper_script"`
Entrypoint string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
ContainerWorkspace string `toml:"container_workspace"`
SSHKey string `toml:"ssh_key"`
DBPath string `toml:"db_path"`
KnownHosts string `toml:"known_hosts"`
PodmanImage string `toml:"podman_image"`
Host string `toml:"host"`
User string `toml:"user"`
Auth auth.Config `toml:"auth"`
GPUDevices []string `toml:"gpu_devices"`
RedisDB int `toml:"redis_db"`
Port int `toml:"port"`
ForceLocal bool `toml:"force_local"`
}
// LoadConfig loads configuration from a TOML file
func LoadConfig(path string) (*Config, error) {
//nolint:gosec // G304: Config path is user-controlled but trusted
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if _, err := toml.Decode(string(data), &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = utils.DefaultSSHPort
}
if cfg.Host == "" {
host, err := smart.Host()
if err != nil {
return nil, fmt.Errorf("failed to get default host: %w", err)
}
cfg.Host = host
}
if cfg.BasePath == "" {
basePath, err := smart.BasePath()
if err != nil {
return nil, fmt.Errorf("failed to get default base path: %w", err)
}
cfg.BasePath = basePath
}
// wrapper_script is deprecated - using secure_runner.py directly via Podman
if cfg.Entrypoint == "" {
cfg.Entrypoint = utils.DefaultEntrypoint
}
if cfg.RedisAddr == "" {
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
cfg.RedisAddr = redisAddr
}
if cfg.KnownHosts == "" {
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
cfg.KnownHosts = knownHosts
}
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" {
cfg.Host = host
}
if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" {
cfg.User = user
}
if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" {
cfg.SSHKey = sshKey
}
if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" {
if p, err := parseInt(port); err == nil {
cfg.Port = p
}
}
if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" {
cfg.BasePath = basePath
}
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
cfg.Entrypoint = trainScript
}
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
cfg.RedisAddr = redisAddr
}
if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" {
cfg.RedisPassword = redisPassword
}
if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" {
if db, err := parseInt(redisDB); err == nil {
cfg.RedisDB = db
}
}
if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" {
cfg.KnownHosts = knownHosts
}
return &cfg, nil
}
// Validate implements utils.Validator interface
func (c *Config) Validate() error {
if c.Port != 0 {
if err := utils.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
// Set default mode if not specified
if c.Mode == "" {
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
} else {
c.Mode = "dev" // Default to dev mode
}
}
// Set mode-appropriate default paths using project-relative paths
if c.BasePath == "" {
c.BasePath = utils.ModeBasedBasePath(c.Mode)
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = utils.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
return nil
}
// IsLocalMode returns true if the TUI should operate in local-only mode
func (c *Config) IsLocalMode() bool {
if c.ForceLocal {
return true
}
// Check if tracking_uri indicates local mode (sqlite:// prefix)
if c.DBPath != "" {
return true
}
return false
}
// GetDBPath returns the SQLite database path for local mode
func (c *Config) GetDBPath() string {
if c.DBPath != "" {
return c.DBPath
}
// Default location: ~/.fetchml/experiments.db
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".fetchml", "experiments.db")
}
return "fetchml.db"
}
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
// RunningPath returns the path for running experiments
func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") }
// FinishedPath returns the path for finished experiments
func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") }
// FailedPath returns the path for failed experiments
func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") }
// parseInt parses a string to integer
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}