- Enhance audit checkpoint system - Update KMS provider and tenant key management - Refine configuration constants - Improve TUI config handling
211 lines
6 KiB
Go
211 lines
6 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/BurntSushi/toml"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
utils "github.com/jfraeys/fetch_ml/internal/config"
|
|
)
|
|
|
|
// Config holds TUI configuration
|
|
type Config struct {
|
|
Experiment struct {
|
|
Name string `toml:"name"`
|
|
Entrypoint string `toml:"entrypoint"`
|
|
} `toml:"experiment"`
|
|
ProjectRoot string `toml:"project_root"`
|
|
ServerURL string `toml:"server_url"`
|
|
ContainerResults string `toml:"container_results"`
|
|
BasePath string `toml:"base_path"`
|
|
Mode string `toml:"mode"`
|
|
WrapperScript string `toml:"wrapper_script"`
|
|
Entrypoint string `toml:"train_script"`
|
|
RedisAddr string `toml:"redis_addr"`
|
|
RedisPassword string `toml:"redis_password"`
|
|
ContainerWorkspace string `toml:"container_workspace"`
|
|
SSHKey string `toml:"ssh_key"`
|
|
DBPath string `toml:"db_path"`
|
|
KnownHosts string `toml:"known_hosts"`
|
|
PodmanImage string `toml:"podman_image"`
|
|
Host string `toml:"host"`
|
|
User string `toml:"user"`
|
|
Auth auth.Config `toml:"auth"`
|
|
GPUDevices []string `toml:"gpu_devices"`
|
|
RedisDB int `toml:"redis_db"`
|
|
Port int `toml:"port"`
|
|
ForceLocal bool `toml:"force_local"`
|
|
}
|
|
|
|
// LoadConfig loads configuration from a TOML file
|
|
func LoadConfig(path string) (*Config, error) {
|
|
//nolint:gosec // G304: Config path is user-controlled but trusted
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var cfg Config
|
|
if _, err := toml.Decode(string(data), &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get smart defaults for current environment
|
|
smart := utils.GetSmartDefaults()
|
|
|
|
if cfg.Port == 0 {
|
|
cfg.Port = utils.DefaultSSHPort
|
|
}
|
|
if cfg.Host == "" {
|
|
host, err := smart.Host()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get default host: %w", err)
|
|
}
|
|
cfg.Host = host
|
|
}
|
|
if cfg.BasePath == "" {
|
|
basePath, err := smart.BasePath()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get default base path: %w", err)
|
|
}
|
|
cfg.BasePath = basePath
|
|
}
|
|
// wrapper_script is deprecated - using secure_runner.py directly via Podman
|
|
if cfg.Entrypoint == "" {
|
|
cfg.Entrypoint = utils.DefaultEntrypoint
|
|
}
|
|
if cfg.RedisAddr == "" {
|
|
redisAddr, err := smart.RedisAddr()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get default redis address: %w", err)
|
|
}
|
|
cfg.RedisAddr = redisAddr
|
|
}
|
|
if cfg.KnownHosts == "" {
|
|
knownHosts, err := smart.KnownHostsPath()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
|
|
}
|
|
cfg.KnownHosts = knownHosts
|
|
}
|
|
|
|
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
|
|
if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" {
|
|
cfg.Host = host
|
|
}
|
|
if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" {
|
|
cfg.User = user
|
|
}
|
|
if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" {
|
|
cfg.SSHKey = sshKey
|
|
}
|
|
if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" {
|
|
if p, err := parseInt(port); err == nil {
|
|
cfg.Port = p
|
|
}
|
|
}
|
|
if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" {
|
|
cfg.BasePath = basePath
|
|
}
|
|
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
|
|
cfg.Entrypoint = trainScript
|
|
}
|
|
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
|
|
cfg.RedisAddr = redisAddr
|
|
}
|
|
if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" {
|
|
cfg.RedisPassword = redisPassword
|
|
}
|
|
if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" {
|
|
if db, err := parseInt(redisDB); err == nil {
|
|
cfg.RedisDB = db
|
|
}
|
|
}
|
|
if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" {
|
|
cfg.KnownHosts = knownHosts
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
// Validate implements utils.Validator interface
|
|
func (c *Config) Validate() error {
|
|
if c.Port != 0 {
|
|
if err := utils.ValidatePort(c.Port); err != nil {
|
|
return fmt.Errorf("invalid SSH port: %w", err)
|
|
}
|
|
}
|
|
|
|
// Set default mode if not specified
|
|
if c.Mode == "" {
|
|
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
|
|
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
|
|
} else {
|
|
c.Mode = "dev" // Default to dev mode
|
|
}
|
|
}
|
|
|
|
// Set mode-appropriate default paths using project-relative paths
|
|
if c.BasePath == "" {
|
|
c.BasePath = utils.ModeBasedBasePath(c.Mode)
|
|
}
|
|
|
|
if c.BasePath != "" {
|
|
// Convert relative paths to absolute
|
|
c.BasePath = utils.ExpandPath(c.BasePath)
|
|
if !filepath.IsAbs(c.BasePath) {
|
|
c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath)
|
|
}
|
|
}
|
|
|
|
if c.RedisAddr != "" {
|
|
if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil {
|
|
return fmt.Errorf("invalid Redis configuration: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsLocalMode returns true if the TUI should operate in local-only mode
|
|
func (c *Config) IsLocalMode() bool {
|
|
if c.ForceLocal {
|
|
return true
|
|
}
|
|
// Check if tracking_uri indicates local mode (sqlite:// prefix)
|
|
if c.DBPath != "" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetDBPath returns the SQLite database path for local mode
|
|
func (c *Config) GetDBPath() string {
|
|
if c.DBPath != "" {
|
|
return c.DBPath
|
|
}
|
|
// Default location: ~/.fetchml/experiments.db
|
|
if home, err := os.UserHomeDir(); err == nil {
|
|
return filepath.Join(home, ".fetchml", "experiments.db")
|
|
}
|
|
return "fetchml.db"
|
|
}
|
|
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
|
|
|
|
// RunningPath returns the path for running experiments
|
|
func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") }
|
|
|
|
// FinishedPath returns the path for finished experiments
|
|
func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") }
|
|
|
|
// FailedPath returns the path for failed experiments
|
|
func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") }
|
|
|
|
// parseInt parses a string to integer
|
|
func parseInt(s string) (int, error) {
|
|
var result int
|
|
_, err := fmt.Sscanf(s, "%d", &result)
|
|
return result, err
|
|
}
|