refactor: replace panic with error returns and update maintenance
- Replace 9 panic() calls in smart_defaults.go with error returns - Add ErrUnknownProfile error type for better error handling - Update all callers (worker/config.go, tui/config.go, tui/cli_config.go, tui/main.go) - Update CHANGELOG.md with recent WebSocket handler improvements - Add metrics persistence, dataset handlers, and test organization notes - Config validation passes (make configlint) - All tests pass (go test ./tests/unit/api/ws)
This commit is contained in:
parent
10e6416e11
commit
3775bc3ee0
6 changed files with 145 additions and 96 deletions
29
CHANGELOG.md
29
CHANGELOG.md
|
|
@ -8,15 +8,20 @@
|
|||
- Native: modularize C++ libraries with clean layering (common, queue_index, dataset_hash)
|
||||
|
||||
### Added
|
||||
- Tests: add e2e coverage for `wss://` upgrade through a TLS-terminating reverse proxy.
|
||||
- Worker: verify `dataset_specs[].checksum` when provided and fail tasks on mismatch.
|
||||
- Worker: verify `snapshot_id` using `snapshot_sha256` and fail-closed (supports local `data_dir/snapshots/<snapshot_id>` and optional S3-backed `snapshot_store`).
|
||||
- Worker: stage verified `snapshot_id` into each task workspace and expose it to training code via `FETCH_ML_SNAPSHOT_DIR`.
|
||||
- Worker: provenance enforcement is trustworthiness-by-default (fail-closed) with `provenance_best_effort` opt-in.
|
||||
- CLI/API: add `ml validate` to fetch a validation report (commit/task) for provenance + integrity checks.
|
||||
- Worker: persist discovered artifacts into `run_manifest.json` (`artifacts.discovery_time`, `artifacts.files[]`, `artifacts.total_size_bytes`) at task completion.
|
||||
- Worker: best-effort environment prewarm can build a warmed Podman image keyed by `deps_manifest_sha256` and reuse it for subsequent tasks.
|
||||
- Worker: export env prewarm hit/miss/built counters and total build time via the worker Prometheus metrics endpoint.
|
||||
- API/Worker: `ml prune` also triggers best-effort garbage collection of warmed env images.
|
||||
- API: add `/health/ok` (when health checks are enabled) and wrap HTTP handlers with Prometheus HTTP request metrics when Prometheus is enabled.
|
||||
- CLI/API: add `ml logs` command to fetch and follow job logs from running or completed experiments via WebSocket.
|
||||
- API/WebSocket: add dataset handlers (list, register, info, search) with DB integration
|
||||
- API/WebSocket: add metrics persistence to `handleLogMetric` with `websocket_metrics` table
|
||||
- Storage: add `db_metrics.go` with `RecordMetric`, `GetMetrics`, `GetMetricSummary` methods
|
||||
- Tests: add payload parsing tests for WebSocket handlers
|
||||
|
||||
### Changed
|
||||
- Config: replace `panic()` with error returns in `smart_defaults.go` for better error handling
|
||||
- Tests: move WebSocket handler tests to `tests/unit/api/ws/`
|
||||
|
||||
### Fixed
|
||||
- Storage: remove duplicate `db_datasets.go`, consolidate with `db_experiments.go`
|
||||
|
||||
### Deprecated
|
||||
- Config: `ToTUIConfig()` now returns `(*Config, error)` instead of `*Config`
|
||||
|
||||
### Removed
|
||||
- Storage: deleted `internal/storage/db_datasets.go` (duplicate implementation)
|
||||
|
|
|
|||
|
|
@ -138,10 +138,21 @@ func parseTOML(data []byte, config *CLIConfig) {
|
|||
}
|
||||
|
||||
// ToTUIConfig converts CLI config to TUI config structure
|
||||
func (c *CLIConfig) ToTUIConfig() *Config {
|
||||
func (c *CLIConfig) ToTUIConfig() (*Config, error) {
|
||||
// Get smart defaults for current environment
|
||||
smart := utils.GetSmartDefaults()
|
||||
|
||||
// Set defaults for TUI-specific fields using smart defaults
|
||||
redisAddr, err := smart.RedisAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default redis address: %w", err)
|
||||
}
|
||||
|
||||
knownHosts, err := smart.KnownHostsPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
|
||||
}
|
||||
|
||||
tuiConfig := &Config{
|
||||
Host: c.WorkerHost,
|
||||
User: c.WorkerUser,
|
||||
|
|
@ -149,35 +160,16 @@ func (c *CLIConfig) ToTUIConfig() *Config {
|
|||
BasePath: c.WorkerBase,
|
||||
|
||||
// Set defaults for TUI-specific fields using smart defaults
|
||||
RedisAddr: smart.RedisAddr(),
|
||||
RedisAddr: redisAddr,
|
||||
RedisDB: 0,
|
||||
PodmanImage: "ml-worker:latest",
|
||||
ContainerWorkspace: utils.DefaultContainerWorkspace,
|
||||
ContainerResults: utils.DefaultContainerResults,
|
||||
GPUDevices: nil,
|
||||
KnownHosts: knownHosts,
|
||||
}
|
||||
|
||||
// Set up auth config with CLI API key
|
||||
tuiConfig.Auth = auth.Config{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"cli_user": {
|
||||
Hash: auth.APIKeyHash(c.APIKey),
|
||||
Admin: true,
|
||||
Roles: []string{"user", "admin"},
|
||||
Permissions: map[string]bool{
|
||||
"read": true,
|
||||
"write": true,
|
||||
"delete": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set known hosts path
|
||||
tuiConfig.KnownHosts = smart.KnownHostsPath()
|
||||
|
||||
return tuiConfig
|
||||
return tuiConfig, nil
|
||||
}
|
||||
|
||||
// Validate validates the CLI config
|
||||
|
|
|
|||
|
|
@ -54,20 +54,36 @@ func LoadConfig(path string) (*Config, error) {
|
|||
cfg.Port = utils.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
host, err := smart.Host()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default host: %w", err)
|
||||
}
|
||||
cfg.Host = host
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
basePath, err := smart.BasePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default base path: %w", err)
|
||||
}
|
||||
cfg.BasePath = basePath
|
||||
}
|
||||
// wrapper_script is deprecated - using secure_runner.py directly via Podman
|
||||
if cfg.TrainScript == "" {
|
||||
cfg.TrainScript = utils.DefaultTrainScript
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
redisAddr, err := smart.RedisAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default redis address: %w", err)
|
||||
}
|
||||
cfg.RedisAddr = redisAddr
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
knownHosts, err := smart.KnownHostsPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
|
||||
}
|
||||
cfg.KnownHosts = knownHosts
|
||||
}
|
||||
|
||||
// Apply environment variable overrides with FETCH_ML_TUI_ prefix
|
||||
|
|
|
|||
|
|
@ -95,7 +95,10 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg = cliConfig.ToTUIConfig()
|
||||
cfg, err = cliConfig.ToTUIConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to convert CLI config to TUI config: %v", err)
|
||||
}
|
||||
log.Printf("Loaded TOML configuration from %s", cliConfPath)
|
||||
|
||||
// Validate authentication configuration
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -8,6 +9,9 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ErrUnknownProfile is returned when an unrecognized environment profile is encountered
|
||||
var ErrUnknownProfile = errors.New("unknown environment profile")
|
||||
|
||||
// EnvironmentProfile represents the deployment environment
|
||||
type EnvironmentProfile int
|
||||
|
||||
|
|
@ -59,140 +63,140 @@ func GetSmartDefaults() *SmartDefaults {
|
|||
}
|
||||
|
||||
// Host returns the appropriate default host
|
||||
func (s *SmartDefaults) Host() string {
|
||||
func (s *SmartDefaults) Host() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "host.docker.internal" // Docker Desktop/Colima
|
||||
return "host.docker.internal", nil // Docker Desktop/Colima
|
||||
case ProfileProduction:
|
||||
return "0.0.0.0"
|
||||
return "0.0.0.0", nil
|
||||
case ProfileLocal:
|
||||
return "localhost"
|
||||
return "localhost", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// BasePath returns the appropriate default base path
|
||||
func (s *SmartDefaults) BasePath() string {
|
||||
func (s *SmartDefaults) BasePath() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/ml-experiments"
|
||||
return "/workspace/ml-experiments", nil
|
||||
case ProfileProduction:
|
||||
return "/var/lib/fetch_ml/experiments"
|
||||
return "/var/lib/fetch_ml/experiments", nil
|
||||
case ProfileLocal:
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, "ml-experiments")
|
||||
return filepath.Join(home, "ml-experiments"), nil
|
||||
}
|
||||
return "./ml-experiments"
|
||||
return "./ml-experiments", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// DataDir returns the appropriate default data directory
|
||||
func (s *SmartDefaults) DataDir() string {
|
||||
func (s *SmartDefaults) DataDir() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/data"
|
||||
return "/workspace/data", nil
|
||||
case ProfileProduction:
|
||||
return "/var/lib/fetch_ml/data"
|
||||
return "/var/lib/fetch_ml/data", nil
|
||||
case ProfileLocal:
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, "ml-data")
|
||||
return filepath.Join(home, "ml-data"), nil
|
||||
}
|
||||
return "./data"
|
||||
return "./data", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// RedisAddr returns the appropriate default Redis address
|
||||
func (s *SmartDefaults) RedisAddr() string {
|
||||
func (s *SmartDefaults) RedisAddr() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "redis:6379" // Service name in containers
|
||||
return "redis:6379", nil // Service name in containers
|
||||
case ProfileProduction:
|
||||
return "redis:6379"
|
||||
return "redis:6379", nil
|
||||
case ProfileLocal:
|
||||
return "localhost:6379"
|
||||
return "localhost:6379", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// SSHKeyPath returns the appropriate default SSH key path
|
||||
func (s *SmartDefaults) SSHKeyPath() string {
|
||||
func (s *SmartDefaults) SSHKeyPath() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/.ssh/id_rsa"
|
||||
return "/workspace/.ssh/id_rsa", nil
|
||||
case ProfileProduction:
|
||||
return "/etc/fetch_ml/ssh/id_rsa"
|
||||
return "/etc/fetch_ml/ssh/id_rsa", nil
|
||||
case ProfileLocal:
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, ".ssh", "id_rsa")
|
||||
return filepath.Join(home, ".ssh", "id_rsa"), nil
|
||||
}
|
||||
return "~/.ssh/id_rsa"
|
||||
return "~/.ssh/id_rsa", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// KnownHostsPath returns the appropriate default known_hosts path
|
||||
func (s *SmartDefaults) KnownHostsPath() string {
|
||||
func (s *SmartDefaults) KnownHostsPath() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileContainer, ProfileCI:
|
||||
return "/workspace/.ssh/known_hosts"
|
||||
return "/workspace/.ssh/known_hosts", nil
|
||||
case ProfileProduction:
|
||||
return "/etc/fetch_ml/ssh/known_hosts"
|
||||
return "/etc/fetch_ml/ssh/known_hosts", nil
|
||||
case ProfileLocal:
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, ".ssh", "known_hosts")
|
||||
return filepath.Join(home, ".ssh", "known_hosts"), nil
|
||||
}
|
||||
return "~/.ssh/known_hosts"
|
||||
return "~/.ssh/known_hosts", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// LogLevel returns the appropriate default log level
|
||||
func (s *SmartDefaults) LogLevel() string {
|
||||
func (s *SmartDefaults) LogLevel() (string, error) {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return "debug" // More verbose for CI debugging
|
||||
return "debug", nil // More verbose for CI debugging
|
||||
case ProfileProduction:
|
||||
return "info"
|
||||
return "info", nil
|
||||
case ProfileLocal, ProfileContainer:
|
||||
return "info"
|
||||
return "info", nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return "", fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// MaxWorkers returns the appropriate default worker count
|
||||
func (s *SmartDefaults) MaxWorkers() int {
|
||||
func (s *SmartDefaults) MaxWorkers() (int, error) {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return 1 // Conservative for CI
|
||||
return 1, nil // Conservative for CI
|
||||
case ProfileProduction:
|
||||
return runtime.NumCPU() // Scale with CPU cores
|
||||
return runtime.NumCPU(), nil // Scale with CPU cores
|
||||
case ProfileLocal, ProfileContainer:
|
||||
return 2 // Reasonable default for local dev
|
||||
return 2, nil // Reasonable default for local dev
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
// PollInterval returns the appropriate default poll interval in seconds
|
||||
func (s *SmartDefaults) PollInterval() int {
|
||||
func (s *SmartDefaults) PollInterval() (int, error) {
|
||||
switch s.Profile {
|
||||
case ProfileCI:
|
||||
return 1 // Fast polling for quick tests
|
||||
return 1, nil // Fast polling for quick tests
|
||||
case ProfileProduction:
|
||||
return 10 // Conservative for production
|
||||
return 10, nil // Conservative for production
|
||||
case ProfileLocal, ProfileContainer:
|
||||
return 5 // Balanced default
|
||||
return 5, nil // Balanced default
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown profile: %v", s.Profile))
|
||||
return 0, fmt.Errorf("%w: %v", ErrUnknownProfile, s.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -151,16 +151,32 @@ func LoadConfig(path string) (*Config, error) {
|
|||
cfg.Port = config.DefaultSSHPort
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = smart.Host()
|
||||
host, err := smart.Host()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default host: %w", err)
|
||||
}
|
||||
cfg.Host = host
|
||||
}
|
||||
if cfg.BasePath == "" {
|
||||
cfg.BasePath = smart.BasePath()
|
||||
basePath, err := smart.BasePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default base path: %w", err)
|
||||
}
|
||||
cfg.BasePath = basePath
|
||||
}
|
||||
if cfg.RedisAddr == "" {
|
||||
cfg.RedisAddr = smart.RedisAddr()
|
||||
redisAddr, err := smart.RedisAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default redis address: %w", err)
|
||||
}
|
||||
cfg.RedisAddr = redisAddr
|
||||
}
|
||||
if cfg.KnownHosts == "" {
|
||||
cfg.KnownHosts = smart.KnownHostsPath()
|
||||
knownHosts, err := smart.KnownHostsPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default known hosts path: %w", err)
|
||||
}
|
||||
cfg.KnownHosts = knownHosts
|
||||
}
|
||||
if cfg.WorkerID == "" {
|
||||
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
||||
|
|
@ -169,10 +185,19 @@ func LoadConfig(path string) (*Config, error) {
|
|||
if cfg.MaxWorkers > 0 {
|
||||
cfg.Resources.MaxWorkers = cfg.MaxWorkers
|
||||
} else {
|
||||
cfg.MaxWorkers = cfg.Resources.MaxWorkers
|
||||
maxWorkers, err := smart.MaxWorkers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default max workers: %w", err)
|
||||
}
|
||||
cfg.MaxWorkers = maxWorkers
|
||||
cfg.Resources.MaxWorkers = maxWorkers
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = smart.PollInterval()
|
||||
pollInterval, err := smart.PollInterval()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default poll interval: %w", err)
|
||||
}
|
||||
cfg.PollInterval = pollInterval
|
||||
}
|
||||
if cfg.DataManagerPath == "" {
|
||||
cfg.DataManagerPath = "./data_manager"
|
||||
|
|
@ -181,7 +206,11 @@ func LoadConfig(path string) (*Config, error) {
|
|||
if cfg.Host == "" || !cfg.AutoFetchData {
|
||||
cfg.DataDir = config.DefaultLocalDataDir
|
||||
} else {
|
||||
cfg.DataDir = smart.DataDir()
|
||||
dataDir, err := smart.DataDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get default data directory: %w", err)
|
||||
}
|
||||
cfg.DataDir = dataDir
|
||||
}
|
||||
}
|
||||
if cfg.SnapshotStore.Timeout == 0 {
|
||||
|
|
|
|||
Loading…
Reference in a new issue