refactor: replace panic with error returns and update maintenance

- Replace 9 panic() calls in smart_defaults.go with error returns
- Add ErrUnknownProfile error type for better error handling
- Update all callers (worker/config.go, tui/config.go, tui/cli_config.go, tui/main.go)
- Update CHANGELOG.md with recent WebSocket handler improvements
- Add metrics persistence, dataset handlers, and test organization notes
- Config validation passes (make configlint)
- All tests pass (go test ./tests/unit/api/ws)
This commit is contained in:
Jeremie Fraeys 2026-02-18 14:44:21 -05:00
parent 10e6416e11
commit 3775bc3ee0
No known key found for this signature in database
6 changed files with 145 additions and 96 deletions

View file

@ -8,15 +8,20 @@
- Native: modularize C++ libraries with clean layering (common, queue_index, dataset_hash)
### Added
- Tests: add e2e coverage for `wss://` upgrade through a TLS-terminating reverse proxy.
- Worker: verify `dataset_specs[].checksum` when provided and fail tasks on mismatch.
- Worker: verify `snapshot_id` using `snapshot_sha256` and fail-closed (supports local `data_dir/snapshots/<snapshot_id>` and optional S3-backed `snapshot_store`).
- Worker: stage verified `snapshot_id` into each task workspace and expose it to training code via `FETCH_ML_SNAPSHOT_DIR`.
- Worker: provenance enforcement is trustworthiness-by-default (fail-closed) with `provenance_best_effort` opt-in.
- CLI/API: add `ml validate` to fetch a validation report (commit/task) for provenance + integrity checks.
- Worker: persist discovered artifacts into `run_manifest.json` (`artifacts.discovery_time`, `artifacts.files[]`, `artifacts.total_size_bytes`) at task completion.
- Worker: best-effort environment prewarm can build a warmed Podman image keyed by `deps_manifest_sha256` and reuse it for subsequent tasks.
- Worker: export env prewarm hit/miss/built counters and total build time via the worker Prometheus metrics endpoint.
- API/Worker: `ml prune` also triggers best-effort garbage collection of warmed env images.
- API: add `/health/ok` (when health checks are enabled) and wrap HTTP handlers with Prometheus HTTP request metrics when Prometheus is enabled.
- CLI/API: add `ml logs` command to fetch and follow job logs from running or completed experiments via WebSocket.
- API/WebSocket: add dataset handlers (list, register, info, search) with DB integration
- API/WebSocket: add metrics persistence to `handleLogMetric` with `websocket_metrics` table
- Storage: add `db_metrics.go` with `RecordMetric`, `GetMetrics`, `GetMetricSummary` methods
- Tests: add payload parsing tests for WebSocket handlers
### Changed
- Config: replace `panic()` with error returns in `smart_defaults.go` for better error handling
- Tests: move WebSocket handler tests to `tests/unit/api/ws/`
### Fixed
- Storage: remove duplicate `db_datasets.go`, consolidate with `db_experiments.go`
### Deprecated
- Config: `ToTUIConfig()` now returns `(*Config, error)` instead of `*Config`
### Removed
- Storage: deleted `internal/storage/db_datasets.go` (duplicate implementation)

View file

@ -138,10 +138,21 @@ func parseTOML(data []byte, config *CLIConfig) {
}
// ToTUIConfig converts CLI config to TUI config structure
func (c *CLIConfig) ToTUIConfig() *Config {
func (c *CLIConfig) ToTUIConfig() (*Config, error) {
// Get smart defaults for current environment
smart := utils.GetSmartDefaults()
// Set defaults for TUI-specific fields using smart defaults
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
tuiConfig := &Config{
Host: c.WorkerHost,
User: c.WorkerUser,
@ -149,35 +160,16 @@ func (c *CLIConfig) ToTUIConfig() *Config {
BasePath: c.WorkerBase,
// Set defaults for TUI-specific fields using smart defaults
RedisAddr: smart.RedisAddr(),
RedisAddr: redisAddr,
RedisDB: 0,
PodmanImage: "ml-worker:latest",
ContainerWorkspace: utils.DefaultContainerWorkspace,
ContainerResults: utils.DefaultContainerResults,
GPUDevices: nil,
KnownHosts: knownHosts,
}
// Set up auth config with CLI API key
tuiConfig.Auth = auth.Config{
Enabled: true,
APIKeys: map[auth.Username]auth.APIKeyEntry{
"cli_user": {
Hash: auth.APIKeyHash(c.APIKey),
Admin: true,
Roles: []string{"user", "admin"},
Permissions: map[string]bool{
"read": true,
"write": true,
"delete": true,
},
},
},
}
// Set known hosts path
tuiConfig.KnownHosts = smart.KnownHostsPath()
return tuiConfig
return tuiConfig, nil
}
// Validate validates the CLI config

View file

@ -54,20 +54,36 @@ func LoadConfig(path string) (*Config, error) {
cfg.Port = utils.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
host, err := smart.Host()
if err != nil {
return nil, fmt.Errorf("failed to get default host: %w", err)
}
cfg.Host = host
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
basePath, err := smart.BasePath()
if err != nil {
return nil, fmt.Errorf("failed to get default base path: %w", err)
}
cfg.BasePath = basePath
}
// wrapper_script is deprecated - using secure_runner.py directly via Podman
if cfg.TrainScript == "" {
cfg.TrainScript = utils.DefaultTrainScript
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
cfg.RedisAddr = redisAddr
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
cfg.KnownHosts = knownHosts
}
// Apply environment variable overrides with FETCH_ML_TUI_ prefix

View file

@ -95,7 +95,10 @@ func main() {
os.Exit(1)
}
cfg = cliConfig.ToTUIConfig()
cfg, err = cliConfig.ToTUIConfig()
if err != nil {
log.Fatalf("Failed to convert CLI config to TUI config: %v", err)
}
log.Printf("Loaded TOML configuration from %s", cliConfPath)
// Validate authentication configuration

View file

@ -1,6 +1,7 @@
package config
import (
"errors"
"fmt"
"os"
"path/filepath"
@ -8,6 +9,9 @@ import (
"strings"
)
// ErrUnknownProfile is returned when an unrecognized environment profile is encountered
var ErrUnknownProfile = errors.New("unknown environment profile")
// EnvironmentProfile represents the deployment environment
type EnvironmentProfile int
@ -59,140 +63,140 @@ func GetSmartDefaults() *SmartDefaults {
}
// Host returns the appropriate default host
func (s *SmartDefaults) Host() string {
func (s *SmartDefaults) Host() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "host.docker.internal" // Docker Desktop/Colima
return "host.docker.internal", nil // Docker Desktop/Colima
case ProfileProduction:
return "0.0.0.0"
return "0.0.0.0", nil
case ProfileLocal:
return "localhost"
return "localhost", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// BasePath returns the appropriate default base path
func (s *SmartDefaults) BasePath() string {
func (s *SmartDefaults) BasePath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/ml-experiments"
return "/workspace/ml-experiments", nil
case ProfileProduction:
return "/var/lib/fetch_ml/experiments"
return "/var/lib/fetch_ml/experiments", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-experiments")
return filepath.Join(home, "ml-experiments"), nil
}
return "./ml-experiments"
return "./ml-experiments", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// DataDir returns the appropriate default data directory
func (s *SmartDefaults) DataDir() string {
func (s *SmartDefaults) DataDir() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/data"
return "/workspace/data", nil
case ProfileProduction:
return "/var/lib/fetch_ml/data"
return "/var/lib/fetch_ml/data", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, "ml-data")
return filepath.Join(home, "ml-data"), nil
}
return "./data"
return "./data", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// RedisAddr returns the appropriate default Redis address
func (s *SmartDefaults) RedisAddr() string {
func (s *SmartDefaults) RedisAddr() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "redis:6379" // Service name in containers
return "redis:6379", nil // Service name in containers
case ProfileProduction:
return "redis:6379"
return "redis:6379", nil
case ProfileLocal:
return "localhost:6379"
return "localhost:6379", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// SSHKeyPath returns the appropriate default SSH key path
func (s *SmartDefaults) SSHKeyPath() string {
func (s *SmartDefaults) SSHKeyPath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/id_rsa"
return "/workspace/.ssh/id_rsa", nil
case ProfileProduction:
return "/etc/fetch_ml/ssh/id_rsa"
return "/etc/fetch_ml/ssh/id_rsa", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "id_rsa")
return filepath.Join(home, ".ssh", "id_rsa"), nil
}
return "~/.ssh/id_rsa"
return "~/.ssh/id_rsa", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// KnownHostsPath returns the appropriate default known_hosts path
func (s *SmartDefaults) KnownHostsPath() string {
func (s *SmartDefaults) KnownHostsPath() (string, error) {
switch s.Profile {
case ProfileContainer, ProfileCI:
return "/workspace/.ssh/known_hosts"
return "/workspace/.ssh/known_hosts", nil
case ProfileProduction:
return "/etc/fetch_ml/ssh/known_hosts"
return "/etc/fetch_ml/ssh/known_hosts", nil
case ProfileLocal:
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".ssh", "known_hosts")
return filepath.Join(home, ".ssh", "known_hosts"), nil
}
return "~/.ssh/known_hosts"
return "~/.ssh/known_hosts", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// LogLevel returns the appropriate default log level
func (s *SmartDefaults) LogLevel() string {
func (s *SmartDefaults) LogLevel() (string, error) {
switch s.Profile {
case ProfileCI:
return "debug" // More verbose for CI debugging
return "debug", nil // More verbose for CI debugging
case ProfileProduction:
return "info"
return "info", nil
case ProfileLocal, ProfileContainer:
return "info"
return "info", nil
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// MaxWorkers returns the appropriate default worker count
func (s *SmartDefaults) MaxWorkers() int {
func (s *SmartDefaults) MaxWorkers() (int, error) {
switch s.Profile {
case ProfileCI:
return 1 // Conservative for CI
return 1, nil // Conservative for CI
case ProfileProduction:
return runtime.NumCPU() // Scale with CPU cores
return runtime.NumCPU(), nil // Scale with CPU cores
case ProfileLocal, ProfileContainer:
return 2 // Reasonable default for local dev
return 2, nil // Reasonable default for local dev
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}
// PollInterval returns the appropriate default poll interval in seconds
func (s *SmartDefaults) PollInterval() int {
func (s *SmartDefaults) PollInterval() (int, error) {
switch s.Profile {
case ProfileCI:
return 1 // Fast polling for quick tests
return 1, nil // Fast polling for quick tests
case ProfileProduction:
return 10 // Conservative for production
return 10, nil // Conservative for production
case ProfileLocal, ProfileContainer:
return 5 // Balanced default
return 5, nil // Balanced default
default:
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
}
}

View file

@ -151,16 +151,32 @@ func LoadConfig(path string) (*Config, error) {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
host, err := smart.Host()
if err != nil {
return nil, fmt.Errorf("failed to get default host: %w", err)
}
cfg.Host = host
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
basePath, err := smart.BasePath()
if err != nil {
return nil, fmt.Errorf("failed to get default base path: %w", err)
}
cfg.BasePath = basePath
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
redisAddr, err := smart.RedisAddr()
if err != nil {
return nil, fmt.Errorf("failed to get default redis address: %w", err)
}
cfg.RedisAddr = redisAddr
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
knownHosts, err := smart.KnownHostsPath()
if err != nil {
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
}
cfg.KnownHosts = knownHosts
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
@ -169,10 +185,19 @@ func LoadConfig(path string) (*Config, error) {
if cfg.MaxWorkers > 0 {
cfg.Resources.MaxWorkers = cfg.MaxWorkers
} else {
cfg.MaxWorkers = cfg.Resources.MaxWorkers
maxWorkers, err := smart.MaxWorkers()
if err != nil {
return nil, fmt.Errorf("failed to get default max workers: %w", err)
}
cfg.MaxWorkers = maxWorkers
cfg.Resources.MaxWorkers = maxWorkers
}
if cfg.PollInterval == 0 {
cfg.PollInterval = smart.PollInterval()
pollInterval, err := smart.PollInterval()
if err != nil {
return nil, fmt.Errorf("failed to get default poll interval: %w", err)
}
cfg.PollInterval = pollInterval
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
@ -181,7 +206,11 @@ func LoadConfig(path string) (*Config, error) {
if cfg.Host == "" || !cfg.AutoFetchData {
cfg.DataDir = config.DefaultLocalDataDir
} else {
cfg.DataDir = smart.DataDir()
dataDir, err := smart.DataDir()
if err != nil {
return nil, fmt.Errorf("failed to get default data directory: %w", err)
}
cfg.DataDir = dataDir
}
}
if cfg.SnapshotStore.Timeout == 0 {