371 lines
12 KiB
Go
371 lines
12 KiB
Go
package worker
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/tracking/factory"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
const (
|
|
defaultMetricsFlushInterval = 500 * time.Millisecond
|
|
datasetCacheDefaultTTL = 30 * time.Minute
|
|
)
|
|
|
|
type QueueConfig struct {
|
|
Backend string `yaml:"backend"`
|
|
SQLitePath string `yaml:"sqlite_path"`
|
|
}
|
|
|
|
// Config holds worker configuration.
|
|
type Config struct {
|
|
Host string `yaml:"host"`
|
|
User string `yaml:"user"`
|
|
SSHKey string `yaml:"ssh_key"`
|
|
Port int `yaml:"port"`
|
|
BasePath string `yaml:"base_path"`
|
|
TrainScript string `yaml:"train_script"`
|
|
RedisURL string `yaml:"redis_url"`
|
|
RedisAddr string `yaml:"redis_addr"`
|
|
RedisPassword string `yaml:"redis_password"`
|
|
RedisDB int `yaml:"redis_db"`
|
|
Queue QueueConfig `yaml:"queue"`
|
|
KnownHosts string `yaml:"known_hosts"`
|
|
WorkerID string `yaml:"worker_id"`
|
|
MaxWorkers int `yaml:"max_workers"`
|
|
PollInterval int `yaml:"poll_interval_seconds"`
|
|
Resources config.ResourceConfig `yaml:"resources"`
|
|
LocalMode bool `yaml:"local_mode"`
|
|
|
|
// Authentication
|
|
Auth auth.Config `yaml:"auth"`
|
|
|
|
// Metrics exporter
|
|
Metrics MetricsConfig `yaml:"metrics"`
|
|
// Metrics buffering
|
|
MetricsFlushInterval time.Duration `yaml:"metrics_flush_interval"`
|
|
|
|
// Data management
|
|
DataManagerPath string `yaml:"data_manager_path"`
|
|
AutoFetchData bool `yaml:"auto_fetch_data"`
|
|
DataDir string `yaml:"data_dir"`
|
|
DatasetCacheTTL time.Duration `yaml:"dataset_cache_ttl"`
|
|
|
|
SnapshotStore SnapshotStoreConfig `yaml:"snapshot_store"`
|
|
|
|
// Provenance enforcement
|
|
// Default: fail-closed (trustworthiness-by-default). Set true to opt into best-effort.
|
|
ProvenanceBestEffort bool `yaml:"provenance_best_effort"`
|
|
|
|
// Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env).
|
|
PrewarmEnabled bool `yaml:"prewarm_enabled"`
|
|
|
|
// Podman execution
|
|
PodmanImage string `yaml:"podman_image"`
|
|
ContainerWorkspace string `yaml:"container_workspace"`
|
|
ContainerResults string `yaml:"container_results"`
|
|
GPUDevices []string `yaml:"gpu_devices"`
|
|
GPUVendor string `yaml:"gpu_vendor"`
|
|
GPUVisibleDevices []int `yaml:"gpu_visible_devices"`
|
|
GPUVisibleDeviceIDs []string `yaml:"gpu_visible_device_ids"`
|
|
|
|
// Apple M-series GPU configuration
|
|
AppleGPU AppleGPUConfig `yaml:"apple_gpu"`
|
|
|
|
// Task lease and retry settings
|
|
TaskLeaseDuration time.Duration `yaml:"task_lease_duration"` // Worker lease (default: 30min)
|
|
HeartbeatInterval time.Duration `yaml:"heartbeat_interval"` // Renew lease (default: 1min)
|
|
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
|
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
|
|
|
|
// Plugins configuration
|
|
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
|
|
}
|
|
|
|
// MetricsConfig controls the Prometheus exporter.
|
|
type MetricsConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
ListenAddr string `yaml:"listen_addr"`
|
|
}
|
|
|
|
type SnapshotStoreConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Endpoint string `yaml:"endpoint"`
|
|
Secure bool `yaml:"secure"`
|
|
Region string `yaml:"region"`
|
|
Bucket string `yaml:"bucket"`
|
|
Prefix string `yaml:"prefix"`
|
|
AccessKey string `yaml:"access_key"`
|
|
SecretKey string `yaml:"secret_key"`
|
|
SessionToken string `yaml:"session_token"`
|
|
Timeout time.Duration `yaml:"timeout"`
|
|
MaxRetries int `yaml:"max_retries"`
|
|
}
|
|
|
|
// AppleGPUConfig holds configuration for Apple M-series GPU support
|
|
type AppleGPUConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
MetalDevice string `yaml:"metal_device"`
|
|
MPSRuntime string `yaml:"mps_runtime"`
|
|
}
|
|
|
|
// LoadConfig loads worker configuration from a YAML file.
|
|
func LoadConfig(path string) (*Config, error) {
|
|
data, err := fileutil.SecureFileRead(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var cfg Config
|
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if strings.TrimSpace(cfg.RedisURL) != "" {
|
|
cfg.RedisURL = os.ExpandEnv(strings.TrimSpace(cfg.RedisURL))
|
|
cfg.RedisAddr = cfg.RedisURL
|
|
cfg.RedisPassword = ""
|
|
cfg.RedisDB = 0
|
|
}
|
|
|
|
// Get smart defaults for current environment
|
|
smart := config.GetSmartDefaults()
|
|
|
|
if cfg.Port == 0 {
|
|
cfg.Port = config.DefaultSSHPort
|
|
}
|
|
if cfg.Host == "" {
|
|
cfg.Host = smart.Host()
|
|
}
|
|
if cfg.BasePath == "" {
|
|
cfg.BasePath = smart.BasePath()
|
|
}
|
|
if cfg.RedisAddr == "" {
|
|
cfg.RedisAddr = smart.RedisAddr()
|
|
}
|
|
if cfg.KnownHosts == "" {
|
|
cfg.KnownHosts = smart.KnownHostsPath()
|
|
}
|
|
if cfg.WorkerID == "" {
|
|
cfg.WorkerID = fmt.Sprintf("worker-%s", uuid.New().String()[:8])
|
|
}
|
|
cfg.Resources.ApplyDefaults()
|
|
if cfg.MaxWorkers > 0 {
|
|
cfg.Resources.MaxWorkers = cfg.MaxWorkers
|
|
} else {
|
|
cfg.MaxWorkers = cfg.Resources.MaxWorkers
|
|
}
|
|
if cfg.PollInterval == 0 {
|
|
cfg.PollInterval = smart.PollInterval()
|
|
}
|
|
if cfg.DataManagerPath == "" {
|
|
cfg.DataManagerPath = "./data_manager"
|
|
}
|
|
if cfg.DataDir == "" {
|
|
if cfg.Host == "" || !cfg.AutoFetchData {
|
|
cfg.DataDir = config.DefaultLocalDataDir
|
|
} else {
|
|
cfg.DataDir = smart.DataDir()
|
|
}
|
|
}
|
|
if cfg.SnapshotStore.Timeout == 0 {
|
|
cfg.SnapshotStore.Timeout = 10 * time.Minute
|
|
}
|
|
if cfg.SnapshotStore.MaxRetries == 0 {
|
|
cfg.SnapshotStore.MaxRetries = 3
|
|
}
|
|
if cfg.Metrics.ListenAddr == "" {
|
|
cfg.Metrics.ListenAddr = ":9100"
|
|
}
|
|
if cfg.MetricsFlushInterval == 0 {
|
|
cfg.MetricsFlushInterval = defaultMetricsFlushInterval
|
|
}
|
|
if cfg.DatasetCacheTTL == 0 {
|
|
cfg.DatasetCacheTTL = datasetCacheDefaultTTL
|
|
}
|
|
|
|
if strings.TrimSpace(cfg.Queue.Backend) == "" {
|
|
cfg.Queue.Backend = string(queue.QueueBackendRedis)
|
|
}
|
|
if strings.EqualFold(strings.TrimSpace(cfg.Queue.Backend), string(queue.QueueBackendSQLite)) {
|
|
if strings.TrimSpace(cfg.Queue.SQLitePath) == "" {
|
|
cfg.Queue.SQLitePath = filepath.Join(cfg.DataDir, "queue.db")
|
|
}
|
|
cfg.Queue.SQLitePath = config.ExpandPath(cfg.Queue.SQLitePath)
|
|
}
|
|
|
|
if strings.TrimSpace(cfg.GPUVendor) == "" {
|
|
if cfg.AppleGPU.Enabled {
|
|
cfg.GPUVendor = string(GPUTypeApple)
|
|
} else if len(cfg.GPUDevices) > 0 ||
|
|
len(cfg.GPUVisibleDevices) > 0 ||
|
|
len(cfg.GPUVisibleDeviceIDs) > 0 {
|
|
cfg.GPUVendor = string(GPUTypeNVIDIA)
|
|
} else {
|
|
cfg.GPUVendor = string(GPUTypeNone)
|
|
}
|
|
}
|
|
|
|
// Set lease and retry defaults
|
|
if cfg.TaskLeaseDuration == 0 {
|
|
cfg.TaskLeaseDuration = 30 * time.Minute
|
|
}
|
|
if cfg.HeartbeatInterval == 0 {
|
|
cfg.HeartbeatInterval = 1 * time.Minute
|
|
}
|
|
if cfg.MaxRetries == 0 {
|
|
cfg.MaxRetries = 3
|
|
}
|
|
if cfg.GracefulTimeout == 0 {
|
|
cfg.GracefulTimeout = 5 * time.Minute
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
// Validate implements config.Validator interface.
|
|
func (c *Config) Validate() error {
|
|
if c.Port != 0 {
|
|
if err := config.ValidatePort(c.Port); err != nil {
|
|
return fmt.Errorf("invalid SSH port: %w", err)
|
|
}
|
|
}
|
|
|
|
if c.BasePath != "" {
|
|
// Convert relative paths to absolute
|
|
c.BasePath = config.ExpandPath(c.BasePath)
|
|
if !filepath.IsAbs(c.BasePath) {
|
|
c.BasePath = filepath.Join(config.DefaultBasePath, c.BasePath)
|
|
}
|
|
}
|
|
|
|
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
|
|
if backend == "" {
|
|
backend = string(queue.QueueBackendRedis)
|
|
c.Queue.Backend = backend
|
|
}
|
|
if backend != string(queue.QueueBackendRedis) && backend != string(queue.QueueBackendSQLite) {
|
|
return fmt.Errorf("queue.backend must be one of %q or %q", queue.QueueBackendRedis, queue.QueueBackendSQLite)
|
|
}
|
|
|
|
if backend == string(queue.QueueBackendSQLite) {
|
|
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
|
|
return fmt.Errorf("queue.sqlite_path is required when queue.backend is %q", queue.QueueBackendSQLite)
|
|
}
|
|
c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath)
|
|
if !filepath.IsAbs(c.Queue.SQLitePath) {
|
|
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
|
}
|
|
}
|
|
|
|
if c.RedisAddr != "" {
|
|
addr := strings.TrimSpace(c.RedisAddr)
|
|
if strings.HasPrefix(addr, "redis://") {
|
|
u, err := url.Parse(addr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid Redis configuration: invalid redis url: %w", err)
|
|
}
|
|
if u.Scheme != "redis" || strings.TrimSpace(u.Host) == "" {
|
|
return fmt.Errorf("invalid Redis configuration: invalid redis url")
|
|
}
|
|
} else {
|
|
if err := config.ValidateRedisAddr(addr); err != nil {
|
|
return fmt.Errorf("invalid Redis configuration: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if c.MaxWorkers < 1 {
|
|
return fmt.Errorf("max_workers must be at least 1, got %d", c.MaxWorkers)
|
|
}
|
|
|
|
switch strings.ToLower(strings.TrimSpace(c.GPUVendor)) {
|
|
case string(GPUTypeNVIDIA), string(GPUTypeApple), string(GPUTypeNone), "amd":
|
|
// ok
|
|
default:
|
|
return fmt.Errorf(
|
|
"gpu_vendor must be one of %q, %q, %q, %q",
|
|
string(GPUTypeNVIDIA),
|
|
"amd",
|
|
string(GPUTypeApple),
|
|
string(GPUTypeNone),
|
|
)
|
|
}
|
|
|
|
// Strict GPU visibility configuration:
|
|
// - gpu_visible_devices and gpu_visible_device_ids are mutually exclusive.
|
|
// - UUID-style gpu_visible_device_ids is NVIDIA-only.
|
|
vendor := strings.ToLower(strings.TrimSpace(c.GPUVendor))
|
|
if len(c.GPUVisibleDevices) > 0 && len(c.GPUVisibleDeviceIDs) > 0 {
|
|
return fmt.Errorf("gpu_visible_devices and gpu_visible_device_ids are mutually exclusive")
|
|
}
|
|
if len(c.GPUVisibleDeviceIDs) > 0 {
|
|
if vendor != string(GPUTypeNVIDIA) {
|
|
return fmt.Errorf(
|
|
"gpu_visible_device_ids is only supported when gpu_vendor is %q",
|
|
string(GPUTypeNVIDIA),
|
|
)
|
|
}
|
|
for _, id := range c.GPUVisibleDeviceIDs {
|
|
id = strings.TrimSpace(id)
|
|
if id == "" {
|
|
return fmt.Errorf("gpu_visible_device_ids contains an empty value")
|
|
}
|
|
if !strings.HasPrefix(id, "GPU-") {
|
|
return fmt.Errorf("gpu_visible_device_ids values must start with %q, got %q", "GPU-", id)
|
|
}
|
|
}
|
|
}
|
|
if vendor == string(GPUTypeApple) || vendor == string(GPUTypeNone) {
|
|
if len(c.GPUVisibleDevices) > 0 || len(c.GPUVisibleDeviceIDs) > 0 {
|
|
return fmt.Errorf(
|
|
"gpu_visible_devices and gpu_visible_device_ids are not supported when gpu_vendor is %q",
|
|
vendor,
|
|
)
|
|
}
|
|
}
|
|
if vendor == "amd" {
|
|
if len(c.GPUVisibleDeviceIDs) > 0 {
|
|
return fmt.Errorf("gpu_visible_device_ids is not supported when gpu_vendor is %q", vendor)
|
|
}
|
|
for _, idx := range c.GPUVisibleDevices {
|
|
if idx < 0 {
|
|
return fmt.Errorf("gpu_visible_devices contains negative index %d", idx)
|
|
}
|
|
}
|
|
}
|
|
|
|
if c.SnapshotStore.Enabled {
|
|
if strings.TrimSpace(c.SnapshotStore.Endpoint) == "" {
|
|
return fmt.Errorf("snapshot_store.endpoint is required when snapshot_store.enabled is true")
|
|
}
|
|
if strings.TrimSpace(c.SnapshotStore.Bucket) == "" {
|
|
return fmt.Errorf("snapshot_store.bucket is required when snapshot_store.enabled is true")
|
|
}
|
|
ak := strings.TrimSpace(c.SnapshotStore.AccessKey)
|
|
sk := strings.TrimSpace(c.SnapshotStore.SecretKey)
|
|
if (ak == "") != (sk == "") {
|
|
return fmt.Errorf(
|
|
"snapshot_store.access_key and snapshot_store.secret_key must both be set or both be empty",
|
|
)
|
|
}
|
|
if c.SnapshotStore.Timeout < 0 {
|
|
return fmt.Errorf("snapshot_store.timeout must be >= 0")
|
|
}
|
|
if c.SnapshotStore.MaxRetries < 0 {
|
|
return fmt.Errorf("snapshot_store.max_retries must be >= 0")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|