fetch_ml/cmd/tui/internal/config/config.go
Jeremie Fraeys 6028779239
feat: update CLI, TUI, and security documentation
- Add safety checks to Zig build
- Add TUI with job management and narrative views
- Add WebSocket support and export services
- Add smart configuration defaults
- Update API routes with security headers
- Update SECURITY.md with comprehensive policy
- Add Makefile security scanning targets
2026-02-19 15:35:05 -05:00

186 lines
5.2 KiB
Go

package config
import (
"fmt"
"os"
"path/filepath"
"github.com/BurntSushi/toml"
"github.com/jfraeys/fetch_ml/internal/auth"
utils "github.com/jfraeys/fetch_ml/internal/config"
)
// Config holds TUI configuration
type Config struct {
Host string `toml:"host"`
User string `toml:"user"`
SSHKey string `toml:"ssh_key"`
Port int `toml:"port"`
BasePath string `toml:"base_path"`
Mode string `toml:"mode"` // "dev" or "prod"
WrapperScript string `toml:"wrapper_script"`
TrainScript string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
RedisDB int `toml:"redis_db"`
KnownHosts string `toml:"known_hosts"`
ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080)
// Authentication
Auth auth.Config `toml:"auth"`
// Podman settings
PodmanImage string `toml:"podman_image"`
ContainerWorkspace string `toml:"container_workspace"`
ContainerResults string `toml:"container_results"`
GPUDevices []string `toml:"gpu_devices"`
}
// LoadConfig loads configuration from a TOML file
func LoadConfig(path string) (*Config, error) {
//nolint:gosec // G304: Config path is user-controlled but trusted
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if _, err := toml.Decode(string(data), &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = utils.DefaultSSHPort
}
if cfg.Host == "" {
host, err := smart.Host()
if err != nil {
return nil, fmt.Errorf("failed to get default host: %w", err)
}
cfg.Host = host
}
if cfg.BasePath == "" {
basePath, err := smart.BasePath()
if err != nil {
return nil, fmt.Errorf("failed to get default base path: %w", err)
}
cfg.BasePath = basePath
}
// wrapper_script is deprecated - using secure_runner.py directly via Podman
if cfg.TrainScript == "" {
cfg.TrainScript = utils.DefaultTrainScript
}
if cfg.RedisAddr == "" {
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
cfg.RedisAddr = redisAddr
}
if cfg.KnownHosts == "" {
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
cfg.KnownHosts = knownHosts
}
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
if host := os.Getenv("FETCH_ML_TUI_HOST"); host != "" {
cfg.Host = host
}
if user := os.Getenv("FETCH_ML_TUI_USER"); user != "" {
cfg.User = user
}
if sshKey := os.Getenv("FETCH_ML_TUI_SSH_KEY"); sshKey != "" {
cfg.SSHKey = sshKey
}
if port := os.Getenv("FETCH_ML_TUI_PORT"); port != "" {
if p, err := parseInt(port); err == nil {
cfg.Port = p
}
}
if basePath := os.Getenv("FETCH_ML_TUI_BASE_PATH"); basePath != "" {
cfg.BasePath = basePath
}
if trainScript := os.Getenv("FETCH_ML_TUI_TRAIN_SCRIPT"); trainScript != "" {
cfg.TrainScript = trainScript
}
if redisAddr := os.Getenv("FETCH_ML_TUI_REDIS_ADDR"); redisAddr != "" {
cfg.RedisAddr = redisAddr
}
if redisPassword := os.Getenv("FETCH_ML_TUI_REDIS_PASSWORD"); redisPassword != "" {
cfg.RedisPassword = redisPassword
}
if redisDB := os.Getenv("FETCH_ML_TUI_REDIS_DB"); redisDB != "" {
if db, err := parseInt(redisDB); err == nil {
cfg.RedisDB = db
}
}
if knownHosts := os.Getenv("FETCH_ML_TUI_KNOWN_HOSTS"); knownHosts != "" {
cfg.KnownHosts = knownHosts
}
return &cfg, nil
}
// Validate implements utils.Validator interface
func (c *Config) Validate() error {
if c.Port != 0 {
if err := utils.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
// Set default mode if not specified
if c.Mode == "" {
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
} else {
c.Mode = "dev" // Default to dev mode
}
}
// Set mode-appropriate default paths using project-relative paths
if c.BasePath == "" {
c.BasePath = utils.ModeBasedBasePath(c.Mode)
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = utils.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(utils.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := utils.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
return nil
}
// PendingPath returns the path for pending experiments
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
// RunningPath returns the path for running experiments
func (c *Config) RunningPath() string { return filepath.Join(c.BasePath, "running") }
// FinishedPath returns the path for finished experiments
func (c *Config) FinishedPath() string { return filepath.Join(c.BasePath, "finished") }
// FailedPath returns the path for failed experiments
func (c *Config) FailedPath() string { return filepath.Join(c.BasePath, "failed") }
// parseInt parses a string to integer
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}