fetch_ml/cmd/worker/worker_config.go
Jeremie Fraeys 803677be57 feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
2025-12-04 16:53:53 -05:00

173 lines
4.5 KiB
Go

package main
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"gopkg.in/yaml.v3"
)
const (
defaultMetricsFlushInterval = 500 * time.Millisecond
datasetCacheDefaultTTL = 30 * time.Minute
)
// Config holds worker configuration
type Config struct {
Host string `yaml:"host"`
User string `yaml:"user"`
SSHKey string `yaml:"ssh_key"`
Port int `yaml:"port"`
BasePath string `yaml:"base_path"`
TrainScript string `yaml:"train_script"`
RedisAddr string `yaml:"redis_addr"`
RedisPassword string `yaml:"redis_password"`
RedisDB int `yaml:"redis_db"`
KnownHosts string `yaml:"known_hosts"`
WorkerID string `yaml:"worker_id"`
MaxWorkers int `yaml:"max_workers"`
PollInterval int `yaml:"poll_interval_seconds"`
// Authentication
Auth auth.AuthConfig `yaml:"auth"`
// Metrics exporter
Metrics MetricsConfig `yaml:"metrics"`
// Metrics buffering
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
// Data management
DataManagerPath string `yaml:"data_manager_path"`
AutoFetchData bool `yaml:"auto_fetch_data"`
DataDir string `yaml:"data_dir"`
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
// Podman execution
PodmanImage string `yaml:"podman_image"`
ContainerWorkspace string `yaml:"container_workspace"`
ContainerResults string `yaml:"container_results"`
GPUAccess bool `yaml:"gpu_access"`
// Task lease and retry settings
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // How long worker holds lease (default: 30min)
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // How often to renew lease (default: 1min)
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Graceful shutdown timeout (default: 5min)
}
// MetricsConfig controls the Prometheus exporter.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
ListenAddr string `yaml:"listen_addr"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Get smart defaults for current environment
smart := config.GetSmartDefaults()
if cfg.Port == 0 {
cfg.Port = config.DefaultSSHPort
}
if cfg.Host == "" {
cfg.Host = smart.Host()
}
if cfg.BasePath == "" {
cfg.BasePath = smart.BasePath()
}
if cfg.RedisAddr == "" {
cfg.RedisAddr = smart.RedisAddr()
}
if cfg.KnownHosts == "" {
cfg.KnownHosts = smart.KnownHostsPath()
}
if cfg.WorkerID == "" {
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
}
if cfg.MaxWorkers == 0 {
cfg.MaxWorkers = smart.MaxWorkers()
}
if cfg.PollInterval == 0 {
cfg.PollInterval = smart.PollInterval()
}
if cfg.DataManagerPath == "" {
cfg.DataManagerPath = "./data_manager"
}
if cfg.DataDir == "" {
if cfg.Host == "" || !cfg.AutoFetchData {
cfg.DataDir = config.DefaultLocalDataDir
} else {
cfg.DataDir = smart.DataDir()
}
}
if cfg.Metrics.ListenAddr == "" {
cfg.Metrics.ListenAddr = ":9100"
}
if cfg.MetricsFlushInterval == 0 {
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
}
if cfg.DatasetCacheTTL == 0 {
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
}
// Set lease and retry defaults
if cfg.TaskLeaseDuration == 0 {
cfg.TaskLeaseDuration = 30 * time.Minute
}
if cfg.HeartbeatInterval == 0 {
cfg.HeartbeatInterval = 1 * time.Minute
}
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.GracefulTimeout == 0 {
cfg.GracefulTimeout = 5 * time.Minute
}
return &cfg, nil
}
// Validate implements config.Validator interface
func (c *Config) Validate() error {
if c.Port != 0 {
if err := config.ValidatePort(c.Port); err != nil {
return fmt.Errorf("invalid SSH port: %w", err)
}
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = config.ExpandPath(c.BasePath)
if !filepath.IsAbs(c.BasePath) {
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
}
}
if c.RedisAddr != "" {
if err := config.ValidateRedisAddr(c.RedisAddr); err != nil {
return fmt.Errorf("invalid Redis configuration: %w", err)
}
}
if c.MaxWorkers < 1 {
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
}
return nil
}
// Task struct and Redis constants moved to internal/queue