fetch_ml/internal/config/smart_defaults.go
Jeremie Fraeys 3775bc3ee0
refactor: replace panic with error returns and update maintenance
- Replace 9 panic() calls in smart_defaults.go with error returns
- Add ErrUnknownProfile error type for better error handling
- Update all callers (worker/config.go, tui/config.go, tui/cli_config.go, tui/main.go)
- Update CHANGELOG.md with recent WebSocket handler improvements
- Add metrics persistence, dataset handlers, and test organization notes
- Config validation passes (make configlint)
- All tests pass (go test ./tests/unit/api/ws)
2026-02-18 14:44:21 -05:00

246 lines
6.6 KiB
Go

package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
)
// ErrUnknownProfile is returned when an unrecognized environment profile is encountered
var ErrUnknownProfile = errors.New("unknown environment profile")
// EnvironmentProfile represents the deployment environment
type EnvironmentProfile int
// Environment profiles for configuration defaults
const (
ProfileLocal EnvironmentProfile = iota
ProfileContainer
ProfileCI
ProfileProduction
)
// DetectEnvironment determines the current environment profile
func DetectEnvironment() EnvironmentProfile {
// CI detection
if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("GITLAB_CI") != "" {
return ProfileCI
}
// Container detection
if _, err := os.Stat("/.dockerenv"); err == nil {
return ProfileContainer
}
if os.Getenv("KUBERNETES_SERVICE_HOST") != "" {
return ProfileContainer
}
if os.Getenv("CONTAINER") != "" {
return ProfileContainer
}
// Production detection (customizable)
if os.Getenv("FETCH_ML_ENV") == "prod" || os.Getenv("ENV") == "prod" {
return ProfileProduction
}
// Default to local development
return ProfileLocal
}
// SmartDefaults provides environment-aware default values
type SmartDefaults struct {
Profile EnvironmentProfile
}
// GetSmartDefaults returns defaults for the current environment
func GetSmartDefaults() *SmartDefaults {
return &SmartDefaults{
Profile: DetectEnvironment(),
}
}
// Host returns the appropriate default host
func (s *SmartDefaults) Host() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "host.docker.internal", nil // Docker Desktop/Colima
case ProfileProduction:
return "0.0.0.0", nil
case ProfileLocal:
return "localhost", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// BasePath returns the appropriate default base path
func (s *SmartDefaults) BasePath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/ml-experiments", nil
case ProfileProduction:
return "/var/lib/fetch_ml/experiments", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-experiments"), nil
}
return "./ml-experiments", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// DataDir returns the appropriate default data directory
func (s *SmartDefaults) DataDir() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/data", nil
case ProfileProduction:
return "/var/lib/fetch_ml/data", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-data"), nil
}
return "./data", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// RedisAddr returns the appropriate default Redis address
func (s *SmartDefaults) RedisAddr() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "redis:6379", nil // Service name in containers
case ProfileProduction:
return "redis:6379", nil
case ProfileLocal:
return "localhost:6379", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// SSHKeyPath returns the appropriate default SSH key path
func (s *SmartDefaults) SSHKeyPath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/id_rsa", nil
case ProfileProduction:
return "/etc/fetch_ml/ssh/id_rsa", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "id_rsa"), nil
}
return "~/.ssh/id_rsa", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// KnownHostsPath returns the appropriate default known_hosts path
func (s *SmartDefaults) KnownHostsPath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/known_hosts", nil
case ProfileProduction:
return "/etc/fetch_ml/ssh/known_hosts", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "known_hosts"), nil
}
return "~/.ssh/known_hosts", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// LogLevel returns the appropriate default log level
func (s *SmartDefaults) LogLevel() (string, error) {
switch s.Profile {
case ProfileCI:
return "debug", nil // More verbose for CI debugging
case ProfileProduction:
return "info", nil
case ProfileLocal, ProfileContainer:
return "info", nil
default:
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// MaxWorkers returns the appropriate default worker count
func (s *SmartDefaults) MaxWorkers() (int, error) {
switch s.Profile {
case ProfileCI:
return 1, nil // Conservative for CI
case ProfileProduction:
return runtime.NumCPU(), nil // Scale with CPU cores
case ProfileLocal, ProfileContainer:
return 2, nil // Reasonable default for local dev
default:
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// PollInterval returns the appropriate default poll interval in seconds
func (s *SmartDefaults) PollInterval() (int, error) {
switch s.Profile {
case ProfileCI:
return 1, nil // Fast polling for quick tests
case ProfileProduction:
return 10, nil // Conservative for production
case ProfileLocal, ProfileContainer:
return 5, nil // Balanced default
default:
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// IsInContainer returns true if running in a container environment
func (s *SmartDefaults) IsInContainer() bool {
return s.Profile == ProfileContainer || s.Profile == ProfileCI
}
// IsProduction returns true if this is a production environment
func (s *SmartDefaults) IsProduction() bool {
return s.Profile == ProfileProduction
}
// IsCI returns true if this is a CI environment
func (s *SmartDefaults) IsCI() bool {
return s.Profile == ProfileCI
}
// ExpandPath expands ~ and environment variables in paths
func (s *SmartDefaults) ExpandPath(path string) string {
if strings.HasPrefix(path, "~/") {
if home, err := os.UserHomeDir(); err == nil {
path = filepath.Join(home, path[2:])
}
}
// Expand environment variables
path = os.ExpandEnv(path)
return path
}
// GetEnvironmentDescription returns a human-readable description
func (s *SmartDefaults) GetEnvironmentDescription() string {
switch s.Profile {
case ProfileLocal:
return "Local Development"
case ProfileContainer:
return "Container Environment"
case ProfileCI:
return "CI/CD Environment"
case ProfileProduction:
return "Production Environment"
default:
return "Unknown Environment"
}
}