diff --git a/CHANGELOG.md b/CHANGELOG.md index 24347b3..7da9c4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,15 +8,20 @@ - Native: modularize C++ libraries with clean layering (common, queue_index, dataset_hash) ### Added -- Tests: add e2e coverage for `wss://` upgrade through a TLS-terminating reverse proxy. -- Worker: verify `dataset_specs[].checksum` when provided and fail tasks on mismatch. -- Worker: verify `snapshot_id` using `snapshot_sha256` and fail-closed (supports local `data_dir/snapshots/` and optional S3-backed `snapshot_store`). -- Worker: stage verified `snapshot_id` into each task workspace and expose it to training code via `FETCH_ML_SNAPSHOT_DIR`. -- Worker: provenance enforcement is trustworthiness-by-default (fail-closed) with `provenance_best_effort` opt-in. -- CLI/API: add `ml validate` to fetch a validation report (commit/task) for provenance + integrity checks. -- Worker: persist discovered artifacts into `run_manifest.json` (`artifacts.discovery_time`, `artifacts.files[]`, `artifacts.total_size_bytes`) at task completion. -- Worker: best-effort environment prewarm can build a warmed Podman image keyed by `deps_manifest_sha256` and reuse it for subsequent tasks. -- Worker: export env prewarm hit/miss/built counters and total build time via the worker Prometheus metrics endpoint. -- API/Worker: `ml prune` also triggers best-effort garbage collection of warmed env images. -- API: add `/health/ok` (when health checks are enabled) and wrap HTTP handlers with Prometheus HTTP request metrics when Prometheus is enabled. -- CLI/API: add `ml logs` command to fetch and follow job logs from running or completed experiments via WebSocket. +- API/WebSocket: add dataset handlers (list, register, info, search) with DB integration +- API/WebSocket: add metrics persistence to `handleLogMetric` with `websocket_metrics` table +- Storage: add `db_metrics.go` with `RecordMetric`, `GetMetrics`, `GetMetricSummary` methods +- Tests: add payload parsing tests for WebSocket handlers + +### Changed +- Config: replace `panic()` with error returns in `smart_defaults.go` for better error handling +- Tests: move WebSocket handler tests to `tests/unit/api/ws/` + +### Fixed +- Storage: remove duplicate `db_datasets.go`, consolidate with `db_experiments.go` + +### Deprecated +- Config: `ToTUIConfig()` now returns `(*Config, error)` instead of `*Config` + +### Removed +- Storage: deleted `internal/storage/db_datasets.go` (duplicate implementation) diff --git a/cmd/tui/internal/config/cli_config.go b/cmd/tui/internal/config/cli_config.go index 402321f..8fa3a90 100644 --- a/cmd/tui/internal/config/cli_config.go +++ b/cmd/tui/internal/config/cli_config.go @@ -138,10 +138,21 @@ func parseTOML(data []byte, config *CLIConfig) { } // ToTUIConfig converts CLI config to TUI config structure -func (c *CLIConfig) ToTUIConfig() *Config { +func (c *CLIConfig) ToTUIConfig() (*Config, error) { // Get smart defaults for current environment smart := utils.GetSmartDefaults() + // Set defaults for TUI-specific fields using smart defaults + redisAddr, err := smart.RedisAddr() + if err != nil { + return nil, fmt.Errorf("failed to get default redis address: %w", err) + } + + knownHosts, err := smart.KnownHostsPath() + if err != nil { + return nil, fmt.Errorf("failed to get default known hosts path: %w", err) + } + tuiConfig := &Config{ Host: c.WorkerHost, User: c.WorkerUser, @@ -149,35 +160,16 @@ func (c *CLIConfig) ToTUIConfig() *Config { BasePath: c.WorkerBase, // Set defaults for TUI-specific fields using smart defaults - RedisAddr: smart.RedisAddr(), + RedisAddr: redisAddr, RedisDB: 0, PodmanImage: "ml-worker:latest", ContainerWorkspace: utils.DefaultContainerWorkspace, ContainerResults: utils.DefaultContainerResults, GPUDevices: nil, + KnownHosts: knownHosts, } - // Set up auth config with CLI API key - tuiConfig.Auth = auth.Config{ - Enabled: true, - APIKeys: map[auth.Username]auth.APIKeyEntry{ - "cli_user": { - Hash: auth.APIKeyHash(c.APIKey), - Admin: true, - Roles: []string{"user", "admin"}, - Permissions: map[string]bool{ - "read": true, - "write": true, - "delete": true, - }, - }, - }, - } - - // Set known hosts path - tuiConfig.KnownHosts = smart.KnownHostsPath() - - return tuiConfig + return tuiConfig, nil } // Validate validates the CLI config diff --git a/cmd/tui/internal/config/config.go b/cmd/tui/internal/config/config.go index 032a9b6..bd4bf5d 100644 --- a/cmd/tui/internal/config/config.go +++ b/cmd/tui/internal/config/config.go @@ -54,20 +54,36 @@ func LoadConfig(path string) (*Config, error) { cfg.Port = utils.DefaultSSHPort } if cfg.Host == "" { - cfg.Host = smart.Host() + host, err := smart.Host() + if err != nil { + return nil, fmt.Errorf("failed to get default host: %w", err) + } + cfg.Host = host } if cfg.BasePath == "" { - cfg.BasePath = smart.BasePath() + basePath, err := smart.BasePath() + if err != nil { + return nil, fmt.Errorf("failed to get default base path: %w", err) + } + cfg.BasePath = basePath } // wrapper_script is deprecated - using secure_runner.py directly via Podman if cfg.TrainScript == "" { cfg.TrainScript = utils.DefaultTrainScript } if cfg.RedisAddr == "" { - cfg.RedisAddr = smart.RedisAddr() + redisAddr, err := smart.RedisAddr() + if err != nil { + return nil, fmt.Errorf("failed to get default redis address: %w", err) + } + cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { - cfg.KnownHosts = smart.KnownHostsPath() + knownHosts, err := smart.KnownHostsPath() + if err != nil { + return nil, fmt.Errorf("failed to get default known hosts path: %w", err) + } + cfg.KnownHosts = knownHosts } // Apply environment variable overrides with FETCH_ML_TUI_ prefix diff --git a/cmd/tui/main.go b/cmd/tui/main.go index 4d910d2..1bf901a 100644 --- a/cmd/tui/main.go +++ b/cmd/tui/main.go @@ -95,7 +95,10 @@ func main() { os.Exit(1) } - cfg = cliConfig.ToTUIConfig() + cfg, err = cliConfig.ToTUIConfig() + if err != nil { + log.Fatalf("Failed to convert CLI config to TUI config: %v", err) + } log.Printf("Loaded TOML configuration from %s", cliConfPath) // Validate authentication configuration diff --git a/internal/config/smart_defaults.go b/internal/config/smart_defaults.go index 78838b7..5c218df 100644 --- a/internal/config/smart_defaults.go +++ b/internal/config/smart_defaults.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "os" "path/filepath" @@ -8,6 +9,9 @@ import ( "strings" ) +// ErrUnknownProfile is returned when an unrecognized environment profile is encountered +var ErrUnknownProfile = errors.New("unknown environment profile") + // EnvironmentProfile represents the deployment environment type EnvironmentProfile int @@ -59,140 +63,140 @@ func GetSmartDefaults() *SmartDefaults { } // Host returns the appropriate default host -func (s *SmartDefaults) Host() string { +func (s *SmartDefaults) Host() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "host.docker.internal" // Docker Desktop/Colima + return "host.docker.internal", nil // Docker Desktop/Colima case ProfileProduction: - return "0.0.0.0" + return "0.0.0.0", nil case ProfileLocal: - return "localhost" + return "localhost", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // BasePath returns the appropriate default base path -func (s *SmartDefaults) BasePath() string { +func (s *SmartDefaults) BasePath() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "/workspace/ml-experiments" + return "/workspace/ml-experiments", nil case ProfileProduction: - return "/var/lib/fetch_ml/experiments" + return "/var/lib/fetch_ml/experiments", nil case ProfileLocal: if home, err := os.UserHomeDir(); err == nil { - return filepath.Join(home, "ml-experiments") + return filepath.Join(home, "ml-experiments"), nil } - return "./ml-experiments" + return "./ml-experiments", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // DataDir returns the appropriate default data directory -func (s *SmartDefaults) DataDir() string { +func (s *SmartDefaults) DataDir() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "/workspace/data" + return "/workspace/data", nil case ProfileProduction: - return "/var/lib/fetch_ml/data" + return "/var/lib/fetch_ml/data", nil case ProfileLocal: if home, err := os.UserHomeDir(); err == nil { - return filepath.Join(home, "ml-data") + return filepath.Join(home, "ml-data"), nil } - return "./data" + return "./data", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // RedisAddr returns the appropriate default Redis address -func (s *SmartDefaults) RedisAddr() string { +func (s *SmartDefaults) RedisAddr() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "redis:6379" // Service name in containers + return "redis:6379", nil // Service name in containers case ProfileProduction: - return "redis:6379" + return "redis:6379", nil case ProfileLocal: - return "localhost:6379" + return "localhost:6379", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // SSHKeyPath returns the appropriate default SSH key path -func (s *SmartDefaults) SSHKeyPath() string { +func (s *SmartDefaults) SSHKeyPath() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "/workspace/.ssh/id_rsa" + return "/workspace/.ssh/id_rsa", nil case ProfileProduction: - return "/etc/fetch_ml/ssh/id_rsa" + return "/etc/fetch_ml/ssh/id_rsa", nil case ProfileLocal: if home, err := os.UserHomeDir(); err == nil { - return filepath.Join(home, ".ssh", "id_rsa") + return filepath.Join(home, ".ssh", "id_rsa"), nil } - return "~/.ssh/id_rsa" + return "~/.ssh/id_rsa", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // KnownHostsPath returns the appropriate default known_hosts path -func (s *SmartDefaults) KnownHostsPath() string { +func (s *SmartDefaults) KnownHostsPath() (string, error) { switch s.Profile { case ProfileContainer, ProfileCI: - return "/workspace/.ssh/known_hosts" + return "/workspace/.ssh/known_hosts", nil case ProfileProduction: - return "/etc/fetch_ml/ssh/known_hosts" + return "/etc/fetch_ml/ssh/known_hosts", nil case ProfileLocal: if home, err := os.UserHomeDir(); err == nil { - return filepath.Join(home, ".ssh", "known_hosts") + return filepath.Join(home, ".ssh", "known_hosts"), nil } - return "~/.ssh/known_hosts" + return "~/.ssh/known_hosts", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // LogLevel returns the appropriate default log level -func (s *SmartDefaults) LogLevel() string { +func (s *SmartDefaults) LogLevel() (string, error) { switch s.Profile { case ProfileCI: - return "debug" // More verbose for CI debugging + return "debug", nil // More verbose for CI debugging case ProfileProduction: - return "info" + return "info", nil case ProfileLocal, ProfileContainer: - return "info" + return "info", nil default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // MaxWorkers returns the appropriate default worker count -func (s *SmartDefaults) MaxWorkers() int { +func (s *SmartDefaults) MaxWorkers() (int, error) { switch s.Profile { case ProfileCI: - return 1 // Conservative for CI + return 1, nil // Conservative for CI case ProfileProduction: - return runtime.NumCPU() // Scale with CPU cores + return runtime.NumCPU(), nil // Scale with CPU cores case ProfileLocal, ProfileContainer: - return 2 // Reasonable default for local dev + return 2, nil // Reasonable default for local dev default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } // PollInterval returns the appropriate default poll interval in seconds -func (s *SmartDefaults) PollInterval() int { +func (s *SmartDefaults) PollInterval() (int, error) { switch s.Profile { case ProfileCI: - return 1 // Fast polling for quick tests + return 1, nil // Fast polling for quick tests case ProfileProduction: - return 10 // Conservative for production + return 10, nil // Conservative for production case ProfileLocal, ProfileContainer: - return 5 // Balanced default + return 5, nil // Balanced default default: - panic(fmt.Sprintf("unknown profile: %v", s.Profile)) + return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile) } } diff --git a/internal/worker/config.go b/internal/worker/config.go index f002fa8..f296bc5 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -151,16 +151,32 @@ func LoadConfig(path string) (*Config, error) { cfg.Port = config.DefaultSSHPort } if cfg.Host == "" { - cfg.Host = smart.Host() + host, err := smart.Host() + if err != nil { + return nil, fmt.Errorf("failed to get default host: %w", err) + } + cfg.Host = host } if cfg.BasePath == "" { - cfg.BasePath = smart.BasePath() + basePath, err := smart.BasePath() + if err != nil { + return nil, fmt.Errorf("failed to get default base path: %w", err) + } + cfg.BasePath = basePath } if cfg.RedisAddr == "" { - cfg.RedisAddr = smart.RedisAddr() + redisAddr, err := smart.RedisAddr() + if err != nil { + return nil, fmt.Errorf("failed to get default redis address: %w", err) + } + cfg.RedisAddr = redisAddr } if cfg.KnownHosts == "" { - cfg.KnownHosts = smart.KnownHostsPath() + knownHosts, err := smart.KnownHostsPath() + if err != nil { + return nil, fmt.Errorf("failed to get default known hosts path: %w", err) + } + cfg.KnownHosts = knownHosts } if cfg.WorkerID == "" { cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8]) @@ -169,10 +185,19 @@ func LoadConfig(path string) (*Config, error) { if cfg.MaxWorkers > 0 { cfg.Resources.MaxWorkers = cfg.MaxWorkers } else { - cfg.MaxWorkers = cfg.Resources.MaxWorkers + maxWorkers, err := smart.MaxWorkers() + if err != nil { + return nil, fmt.Errorf("failed to get default max workers: %w", err) + } + cfg.MaxWorkers = maxWorkers + cfg.Resources.MaxWorkers = maxWorkers } if cfg.PollInterval == 0 { - cfg.PollInterval = smart.PollInterval() + pollInterval, err := smart.PollInterval() + if err != nil { + return nil, fmt.Errorf("failed to get default poll interval: %w", err) + } + cfg.PollInterval = pollInterval } if cfg.DataManagerPath == "" { cfg.DataManagerPath = "./data_manager" @@ -181,7 +206,11 @@ func LoadConfig(path string) (*Config, error) { if cfg.Host == "" || !cfg.AutoFetchData { cfg.DataDir = config.DefaultLocalDataDir } else { - cfg.DataDir = smart.DataDir() + dataDir, err := smart.DataDir() + if err != nil { + return nil, fmt.Errorf("failed to get default data directory: %w", err) + } + cfg.DataDir = dataDir } } if cfg.SnapshotStore.Timeout == 0 {