fetch_ml/cmd/tui/internal/config/cli_config.go
Jeremie Fraeys 3775bc3ee0
refactor: replace panic with error returns and update maintenance
- Replace 9 panic() calls in smart_defaults.go with error returns
- Add ErrUnknownProfile error type for better error handling
- Update all callers (worker/config.go, tui/config.go, tui/cli_config.go, tui/main.go)
- Update CHANGELOG.md with recent WebSocket handler improvements
- Add metrics persistence, dataset handlers, and test organization notes
- Config validation passes (make configlint)
- All tests pass (go test ./tests/unit/api/ws)
2026-02-18 14:44:21 -05:00

363 lines
9.4 KiB
Go

// Package config provides TUI configuration management
package config
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/auth"
utils "github.com/jfraeys/fetch_ml/internal/config"
)
// CLIConfig represents the TOML config structure used by the CLI
type CLIConfig struct {
WorkerHost string `toml:"worker_host"`
WorkerUser string `toml:"worker_user"`
WorkerBase string `toml:"worker_base"`
WorkerPort int `toml:"worker_port"`
APIKey string `toml:"api_key"`
// User context (filled after authentication)
CurrentUser *UserContext `toml:"-"`
}
// UserContext represents the authenticated user information
type UserContext struct {
Name string `json:"name"`
Admin bool `json:"admin"`
Roles []string `json:"roles"`
Permissions map[string]bool `json:"permissions"`
}
// LoadCLIConfig loads the CLI's TOML configuration from the provided path.
// If path is empty, ~/.ml/config.toml is used. The resolved path is returned.
// Environment variables with FETCH_ML_CLI_ prefix override config file values.
func LoadCLIConfig(configPath string) (*CLIConfig, string, error) {
if configPath == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, "", fmt.Errorf("failed to get home directory: %w", err)
}
configPath = filepath.Join(home, ".ml", "config.toml")
} else {
configPath = utils.ExpandPath(configPath)
if !filepath.IsAbs(configPath) {
if abs, err := filepath.Abs(configPath); err == nil {
configPath = abs
}
}
}
// Check if TOML config exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return nil, configPath, fmt.Errorf("CLI config not found at %s (run 'ml init' first)", configPath)
} else if err != nil {
return nil, configPath, fmt.Errorf("cannot access CLI config %s: %w", configPath, err)
}
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
}
//nolint:gosec // G304: Config path is user-controlled but trusted
data, err := os.ReadFile(configPath)
if err != nil {
return nil, configPath, fmt.Errorf("failed to read CLI config: %w", err)
}
config := &CLIConfig{}
parseTOML(data, config)
if err := config.Validate(); err != nil {
return nil, configPath, err
}
// Apply environment variable overrides with FETCH_ML_CLI_ prefix
if host := os.Getenv("FETCH_ML_CLI_HOST"); host != "" {
config.WorkerHost = host
}
if user := os.Getenv("FETCH_ML_CLI_USER"); user != "" {
config.WorkerUser = user
}
if base := os.Getenv("FETCH_ML_CLI_BASE"); base != "" {
config.WorkerBase = base
}
if port := os.Getenv("FETCH_ML_CLI_PORT"); port != "" {
if p, err := parseInt(port); err == nil {
config.WorkerPort = p
}
}
if apiKey := os.Getenv("FETCH_ML_CLI_API_KEY"); apiKey != "" {
config.APIKey = apiKey
}
return config, configPath, nil
}
// parseTOML is a simple TOML parser for the CLI config format
func parseTOML(data []byte, config *CLIConfig) {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove quotes if present
if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
value = value[1 : len(value)-1]
}
switch key {
case "worker_host":
config.WorkerHost = value
case "worker_user":
config.WorkerUser = value
case "worker_base":
config.WorkerBase = value
case "worker_port":
if p, err := parseInt(value); err == nil {
config.WorkerPort = p
}
case "api_key":
config.APIKey = value
}
}
}
// ToTUIConfig converts CLI config to TUI config structure
func (c *CLIConfig) ToTUIConfig() (*Config, error) {
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
// Set defaults for TUI-specific fields using smart defaults
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
tuiConfig := &Config{
Host: c.WorkerHost,
User: c.WorkerUser,
Port: c.WorkerPort,
BasePath: c.WorkerBase,
// Set defaults for TUI-specific fields using smart defaults
RedisAddr: redisAddr,
RedisDB: 0,
PodmanImage: "ml-worker:latest",
ContainerWorkspace: utils.DefaultContainerWorkspace,
ContainerResults: utils.DefaultContainerResults,
GPUDevices: nil,
KnownHosts: knownHosts,
}
return tuiConfig, nil
}
// Validate validates the CLI config
func (c *CLIConfig) Validate() error {
var errors []string
if c.WorkerHost == "" {
errors = append(errors, "worker_host is required")
} else if len(strings.TrimSpace(c.WorkerHost)) == 0 {
errors = append(errors, "worker_host cannot be empty or whitespace")
}
if c.WorkerUser == "" {
errors = append(errors, "worker_user is required")
} else if len(strings.TrimSpace(c.WorkerUser)) == 0 {
errors = append(errors, "worker_user cannot be empty or whitespace")
}
if c.WorkerBase == "" {
errors = append(errors, "worker_base is required")
} else {
// Expand and validate path
c.WorkerBase = utils.ExpandPath(c.WorkerBase)
if !filepath.IsAbs(c.WorkerBase) {
errors = append(errors, "worker_base must be an absolute path")
}
}
if c.WorkerPort == 0 {
errors = append(errors, "worker_port is required")
} else if err := utils.ValidatePort(c.WorkerPort); err != nil {
errors = append(errors, fmt.Sprintf("invalid worker_port: %v", err))
}
if c.APIKey == "" {
errors = append(errors, "api_key is required")
} else if len(c.APIKey) < 16 {
errors = append(errors, "api_key must be at least 16 characters")
}
if len(errors) > 0 {
return fmt.Errorf("validation failed: %s", strings.Join(errors, "; "))
}
return nil
}
// AuthenticateWithServer validates the API key and sets user context
func (c *CLIConfig) AuthenticateWithServer() error {
if c.APIKey == "" {
return fmt.Errorf("no API key configured")
}
// Create temporary auth config for validation
authConfig := &auth.Config{
Enabled: true,
APIKeys: map[auth.Username]auth.APIKeyEntry{
"temp": {
Hash: auth.APIKeyHash(auth.HashAPIKey(c.APIKey)),
Admin: false,
},
},
}
// Validate API key and get user info
user, err := authConfig.ValidateAPIKey(c.APIKey)
if err != nil {
return fmt.Errorf("API key validation failed: %w", err)
}
// Set user context
c.CurrentUser = &UserContext{
Name: user.Name,
Admin: user.Admin,
Roles: user.Roles,
Permissions: user.Permissions,
}
return nil
}
// CheckPermission checks if the current user has a specific permission
func (c *CLIConfig) CheckPermission(permission string) bool {
if c.CurrentUser == nil {
return false
}
// Admin users have all permissions
if c.CurrentUser.Admin {
return true
}
// Check explicit permission
if c.CurrentUser.Permissions[permission] {
return true
}
// Check wildcard permission
if c.CurrentUser.Permissions["*"] {
return true
}
return false
}
// CanViewJob checks if user can view a specific job
func (c *CLIConfig) CanViewJob(jobUserID string) bool {
if c.CurrentUser == nil {
return false
}
// Admin can view all jobs
if c.CurrentUser.Admin {
return true
}
// Users can view their own jobs
return jobUserID == c.CurrentUser.Name
}
// CanModifyJob checks if user can modify a specific job
func (c *CLIConfig) CanModifyJob(jobUserID string) bool {
if c.CurrentUser == nil {
return false
}
// Need jobs:update permission
if !c.CheckPermission("jobs:update") {
return false
}
// Admin can modify all jobs
if c.CurrentUser.Admin {
return true
}
// Users can only modify their own jobs
return jobUserID == c.CurrentUser.Name
}
// Exists checks if a CLI configuration file exists
func Exists(configPath string) bool {
if configPath == "" {
home, err := os.UserHomeDir()
if err != nil {
return false
}
configPath = filepath.Join(home, ".ml", "config.toml")
}
_, err := os.Stat(configPath)
return !os.IsNotExist(err)
}
// GenerateDefaultConfig creates a default TOML configuration file
func GenerateDefaultConfig(configPath string) error {
// Create directory if it doesn't exist
if err := os.MkdirAll(filepath.Dir(configPath), 0750); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Generate default configuration
defaultContent := `# Fetch ML CLI Configuration
# This file contains connection settings for the ML platform
# Worker connection settings
worker_host = "localhost" # Hostname or IP of the worker
worker_user = "your_username" # SSH username for the worker
worker_base = "~/ml_jobs" # Base directory for ML jobs on worker
worker_port = 22 # SSH port (default: 22)
# Authentication
api_key = "your_api_key_here" # Your API key (get from admin)
# Environment variable overrides:
# FETCH_ML_CLI_HOST, FETCH_ML_CLI_USER, FETCH_ML_CLI_BASE,
# FETCH_ML_CLI_PORT, FETCH_ML_CLI_API_KEY
`
// Write configuration file
if err := os.WriteFile(configPath, []byte(defaultContent), 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
// Set proper permissions
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
}
return nil
}