From 3fb6902fa106875b9d08e5ad9c0cdb16401ef913 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 26 Feb 2026 12:06:16 -0500 Subject: [PATCH] feat(worker): integrate scheduler endpoints and security hardening Update worker system for scheduler integration: - Worker server with scheduler registration - Configuration with scheduler endpoint support - Artifact handling with integrity verification - Container executor with supply chain validation - Local executor enhancements - GPU detection improvements (cross-platform) - Error handling with execution context - Factory pattern for executor instantiation - Hash integrity with native library support --- cmd/worker/worker_server.go | 113 ++++++++++++- internal/worker/artifacts.go | 14 +- internal/worker/config.go | 87 +++++++++- internal/worker/errors/execution.go | 12 +- internal/worker/executor/container.go | 16 +- internal/worker/executor/local.go | 4 +- internal/worker/factory.go | 13 ++ internal/worker/gpu_detector.go | 83 ++++++---- internal/worker/gpu_macos.go | 13 +- internal/worker/gpu_nvml_stub.go | 10 +- internal/worker/integrity/hash.go | 4 +- internal/worker/interfaces/executor.go | 6 +- internal/worker/worker.go | 215 +++++++++---------------- 13 files changed, 371 insertions(+), 219 deletions(-) diff --git a/cmd/worker/worker_server.go b/cmd/worker/worker_server.go index f9034cb..a9c6bcd 100644 --- a/cmd/worker/worker_server.go +++ b/cmd/worker/worker_server.go @@ -2,12 +2,15 @@ package main import ( + "flag" + "fmt" "log" "os" "os/signal" "strings" "syscall" + "github.com/invopop/yaml" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/worker" @@ -31,7 +34,37 @@ func resolveWorkerConfigPath(flags *auth.Flags) string { } func main() { - log.SetFlags(log.LstdFlags | log.Lshortfile) + var ( + configPath string + initConfig bool + mode string + schedulerAddr string + token string + ) + flag.StringVar(&configPath, "config", "worker.yaml", "Path to worker config file") + flag.BoolVar(&initConfig, "init", false, "Initialize a new worker config file") + flag.StringVar(&mode, "mode", "distributed", "Worker mode: standalone or distributed") + flag.StringVar(&schedulerAddr, "scheduler", "", "Scheduler address (for distributed mode)") + flag.StringVar(&token, "token", "", "Worker token (copy from scheduler -init output)") + flag.Parse() + + // Handle init mode + if initConfig { + if err := generateWorkerConfig(configPath, mode, schedulerAddr, token); err != nil { + fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err) + os.Exit(1) + } + fmt.Printf("Config generated: %s\n", configPath) + fmt.Println("\nNext steps:") + if mode == "distributed" { + fmt.Println("1. Copy the token from your scheduler's -init output") + fmt.Println("2. Edit the config to set scheduler.address and scheduler.token") + fmt.Println("3. Copy the scheduler's TLS cert to the worker") + } + os.Exit(0) + } + + // Normal worker startup... // Parse authentication flags authFlags := auth.ParseAuthFlags() @@ -95,3 +128,81 @@ func main() { log.Println("Worker shut down gracefully") } } + +// generateWorkerConfig creates a new worker config file +func generateWorkerConfig(path, mode, schedulerAddr, token string) error { + cfg := map[string]any{ + "node": map[string]any{ + "role": "worker", + "id": "", + }, + "worker": map[string]any{ + "mode": mode, + "max_workers": 3, + }, + } + + if mode == "distributed" { + cfg["scheduler"] = map[string]any{ + "address": schedulerAddr, + "cert": "/etc/fetch_ml/scheduler.crt", + "token": token, + } + } else { + cfg["queue"] = map[string]any{ + "backend": "redis", + "redis_addr": "localhost:6379", + "redis_password": "", + "redis_db": 0, + } + } + + cfg["slots"] = map[string]any{ + "service_slots": 1, + "ports": map[string]any{ + "service_range_start": 8000, + "service_range_end": 8099, + }, + } + + cfg["gpu"] = map[string]any{ + "vendor": "auto", + } + + cfg["prewarm"] = map[string]any{ + "enabled": true, + } + + cfg["log"] = map[string]any{ + "level": "info", + "format": "json", + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("marshal config: %w", err) + } + + // Add header comment + header := fmt.Sprintf(`# Worker Configuration for fetch_ml +# Generated by: worker -init +# Mode: %s +#`, mode) + + if mode == "distributed" && token == "" { + header += ` +# ⚠️ SECURITY WARNING: You must add the scheduler token to this config. +# Copy the token from the scheduler's -init output and paste it below. +# scheduler: +# token: "wkr_xxx..." +#` + } + + fullContent := header + "\n\n" + string(data) + + if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil { + return fmt.Errorf("write config file: %w", err) + } + + return nil +} diff --git a/internal/worker/artifacts.go b/internal/worker/artifacts.go index 10e7b52..d5d85ec 100644 --- a/internal/worker/artifacts.go +++ b/internal/worker/artifacts.go @@ -12,6 +12,8 @@ import ( "github.com/jfraeys/fetch_ml/internal/manifest" ) +// scanArtifacts discovers and catalogs artifact files in a run directory. +// When includeAll is false, it excludes code/, snapshot/, *.log files, and symlinks. func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) { runDir = strings.TrimSpace(runDir) if runDir == "" { @@ -55,14 +57,8 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife } // Standard exclusions (always apply) - if rel == manifestFilename { - exclusions = append(exclusions, manifest.Exclusion{ - Path: rel, - Reason: "manifest file excluded", - }) - return nil - } - if strings.HasSuffix(rel, "/"+manifestFilename) { + // Exclude manifest files - both legacy (run_manifest.json) and nonce-based (run_manifest_.json) + if strings.HasPrefix(rel, "run_manifest") && strings.HasSuffix(rel, ".json") { exclusions = append(exclusions, manifest.Exclusion{ Path: rel, Reason: "manifest file excluded", @@ -160,8 +156,6 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife }, nil } -const manifestFilename = "run_manifest.json" - // ScanArtifacts is an exported wrapper for testing/benchmarking. // When includeAll is false, excludes code/, snapshot/, *.log files, and symlinks. func ScanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) { diff --git a/internal/worker/config.go b/internal/worker/config.go index 46ebb34..f42c905 100644 --- a/internal/worker/config.go +++ b/internal/worker/config.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "log/slog" "math" "net/url" "os" @@ -80,7 +81,7 @@ type Config struct { // When "hipaa": enforces hard requirements at startup ComplianceMode string `yaml:"compliance_mode"` - // Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env). + // Opt-in prewarming of next task artifacts (snapshot/datasets/env). PrewarmEnabled bool `yaml:"prewarm_enabled"` // Podman execution @@ -102,6 +103,16 @@ type Config struct { MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3) GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min) + // Mode determines how the worker operates: "standalone" or "distributed" + Mode string `yaml:"mode"` + + // Scheduler configuration for distributed mode + Scheduler struct { + Address string `yaml:"address"` + Cert string `yaml:"cert"` + Token string `yaml:"token"` + } `yaml:"scheduler"` + // Plugins configuration Plugins map[string]factory.PluginConfig `yaml:"plugins"` @@ -145,7 +156,7 @@ type SandboxConfig struct { SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened" MaxRuntimeHours int `yaml:"max_runtime_hours"` - // Security hardening options (NEW) + // Security hardening options NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true DropAllCaps bool `yaml:"drop_all_caps"` // Default: true AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back @@ -153,12 +164,20 @@ type SandboxConfig struct { RunAsUID int `yaml:"run_as_uid"` // Default: 1000 RunAsGID int `yaml:"run_as_gid"` // Default: 1000 - // Upload limits (NEW) + // Process isolation + MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100) + MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024) + DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration + OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100) + TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID) + TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID) + + // Upload limits MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10 - // Artifact ingestion caps (NEW) + // Artifact ingestion caps MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000 MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB } @@ -174,6 +193,10 @@ var SecurityDefaults = struct { UserNS bool RunAsUID int RunAsGID int + MaxProcesses int + MaxOpenFiles int + DisableSwap bool + OOMScoreAdj int MaxUploadSizeBytes int64 MaxUploadRateBps int64 MaxUploadsPerMinute int @@ -189,6 +212,10 @@ var SecurityDefaults = struct { UserNS: true, RunAsUID: 1000, RunAsGID: 1000, + MaxProcesses: 100, // Fork bomb protection + MaxOpenFiles: 1024, // FD exhaustion protection + DisableSwap: true, // Prevent swap exfiltration + OOMScoreAdj: 100, // Lower OOM priority MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s MaxUploadsPerMinute: 10, @@ -214,6 +241,12 @@ func (s *SandboxConfig) Validate() error { if s.MaxUploadsPerMinute < 0 { return fmt.Errorf("max_uploads_per_minute must be positive") } + if s.MaxArtifactFiles < 0 { + return fmt.Errorf("max_artifact_files must be positive") + } + if s.MaxArtifactTotalBytes < 0 { + return fmt.Errorf("max_artifact_total_bytes must be positive") + } return nil } @@ -281,6 +314,42 @@ func (s *SandboxConfig) ApplySecurityDefaults() { if s.MaxArtifactTotalBytes == 0 { s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes } + + // Process isolation defaults + if s.MaxProcesses == 0 { + s.MaxProcesses = SecurityDefaults.MaxProcesses + } + if s.MaxOpenFiles == 0 { + s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles + } + if !s.DisableSwap { + s.DisableSwap = SecurityDefaults.DisableSwap + } + if s.OOMScoreAdj == 0 { + s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj + } + // TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID") + // Only override if explicitly set (> 0) + if s.TaskUID < 0 { + s.TaskUID = 0 + } + if s.TaskGID < 0 { + s.TaskGID = 0 + } +} + +// GetProcessIsolationFlags returns the effective UID/GID for a task +// If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID +func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) { + uid = s.RunAsUID + gid = s.RunAsGID + if s.TaskUID > 0 { + uid = s.TaskUID + } + if s.TaskGID > 0 { + gid = s.TaskGID + } + return uid, gid } // Getter methods for SandboxConfig interface @@ -294,6 +363,14 @@ func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile } func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot } func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode } +// Process Isolation getter methods +func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses } +func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles } +func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap } +func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj } +func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID } +func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID } + // LoadConfig loads worker configuration from a YAML file. func LoadConfig(path string) (*Config, error) { data, err := fileutil.SecureFileRead(path) @@ -864,7 +941,7 @@ func envInt(name string) (int, bool) { // logEnvOverride logs environment variable overrides to stderr for debugging func logEnvOverride(name string, value interface{}) { - fmt.Fprintf(os.Stderr, "[env] %s=%v (override active)\n", name, value) + slog.Warn("env override active", "var", name, "value", value) } // parseCPUFromConfig determines total CPU from environment or config diff --git a/internal/worker/errors/execution.go b/internal/worker/errors/execution.go index 2c0e885..aadfcda 100644 --- a/internal/worker/errors/execution.go +++ b/internal/worker/errors/execution.go @@ -10,12 +10,12 @@ import ( // It captures the task ID, execution phase, specific operation, root cause, // and additional context to make debugging easier. type ExecutionError struct { - TaskID string // The task that failed - Phase string // Current TaskState (queued, preparing, running, collecting) - Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset") - Cause error // The underlying error - Context map[string]string // Additional context (paths, IDs, etc.) - Timestamp time.Time // When the error occurred + Timestamp time.Time + Cause error + Context map[string]string + TaskID string + Phase string + Operation string } // Error implements the error interface with a formatted message. diff --git a/internal/worker/executor/container.go b/internal/worker/executor/container.go index d1761be..433e0ea 100644 --- a/internal/worker/executor/container.go +++ b/internal/worker/executor/container.go @@ -24,13 +24,13 @@ import ( // ContainerConfig holds configuration for container execution type ContainerConfig struct { + Sandbox SandboxConfig PodmanImage string ContainerResults string ContainerWorkspace string TrainScript string BasePath string AppleGPUEnabled bool - Sandbox SandboxConfig // NEW: Security configuration } // SandboxConfig interface to avoid import cycle @@ -44,6 +44,12 @@ type SandboxConfig interface { GetSeccompProfile() string GetReadOnlyRoot() bool GetNetworkMode() string + GetMaxProcesses() int + GetMaxOpenFiles() int + GetDisableSwap() bool + GetOOMScoreAdj() int + GetTaskUID() int + GetTaskGID() int } // ContainerExecutor executes jobs in containers using podman @@ -233,7 +239,7 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD } cacheRoot := filepath.Join(e.config.BasePath, ".cache") - os.MkdirAll(cacheRoot, 0755) + os.MkdirAll(cacheRoot, 0750) volumes[cacheRoot] = "/workspace/.cache:rw" defaultEnv := map[string]string{ @@ -331,6 +337,12 @@ func (e *ContainerExecutor) runPodman( SeccompProfile: e.config.Sandbox.GetSeccompProfile(), ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(), NetworkMode: e.config.Sandbox.GetNetworkMode(), + MaxProcesses: e.config.Sandbox.GetMaxProcesses(), + MaxOpenFiles: e.config.Sandbox.GetMaxOpenFiles(), + DisableSwap: e.config.Sandbox.GetDisableSwap(), + OOMScoreAdj: e.config.Sandbox.GetOOMScoreAdj(), + TaskUID: e.config.Sandbox.GetTaskUID(), + TaskGID: e.config.Sandbox.GetTaskGID(), } podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs) diff --git a/internal/worker/executor/local.go b/internal/worker/executor/local.go index a3b0c47..6507bd8 100644 --- a/internal/worker/executor/local.go +++ b/internal/worker/executor/local.go @@ -34,11 +34,11 @@ func NewLocalExecutor(logger *logging.Logger, writer interfaces.ManifestWriter) // Execute runs a job locally func (e *LocalExecutor) Execute(ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv) error { - // Generate and write script + // Generate and write script with crash safety (fsync) scriptContent := generateScript(task) scriptPath := filepath.Join(env.OutputDir, "run.sh") - if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil { + if err := fileutil.WriteFileSafe(scriptPath, []byte(scriptContent), 0600); err != nil { return &errtypes.TaskExecutionError{ TaskID: task.ID, JobName: task.JobName, diff --git a/internal/worker/factory.go b/internal/worker/factory.go index d9e2811..b2b4c5a 100644 --- a/internal/worker/factory.go +++ b/internal/worker/factory.go @@ -27,6 +27,7 @@ import ( func NewWorker(cfg *Config, _ string) (*Worker, error) { // Create queue backend backendCfg := queue.BackendConfig{ + Mode: cfg.Mode, Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))), RedisAddr: cfg.RedisAddr, RedisPassword: cfg.RedisPassword, @@ -35,6 +36,11 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { FilesystemPath: cfg.Queue.FilesystemPath, FallbackToFilesystem: cfg.Queue.FallbackToFilesystem, MetricsFlushInterval: cfg.MetricsFlushInterval, + Scheduler: queue.SchedulerConfig{ + Address: cfg.Scheduler.Address, + Cert: cfg.Scheduler.Cert, + Token: cfg.Scheduler.Token, + }, } queueClient, err := queue.NewBackend(backendCfg) @@ -171,6 +177,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) { gpuDetectionInfo: gpuDetectionInfo, } + // In distributed mode, store the scheduler connection for heartbeats + if cfg.Mode == "distributed" { + if schedBackend, ok := queueClient.(*queue.SchedulerBackend); ok { + worker.schedulerConn = schedBackend.Conn() + } + } + // Log GPU configuration if !cfg.LocalMode { gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE"))) diff --git a/internal/worker/gpu_detector.go b/internal/worker/gpu_detector.go index ad2ed6c..6e61a86 100644 --- a/internal/worker/gpu_detector.go +++ b/internal/worker/gpu_detector.go @@ -1,11 +1,18 @@ package worker import ( + "fmt" + "log/slog" "os" "path/filepath" "strings" ) +// logWarningf logs a warning message using slog +func logWarningf(format string, args ...any) { + slog.Warn(fmt.Sprintf(format, args...)) +} + // GPUType represents different GPU types type GPUType string @@ -230,6 +237,19 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult EnvOverrideCount: envCount, }, } + default: + // Defensive: unknown env type should not silently fall through + logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType) + return DetectionResult{ + Detector: &NoneDetector{}, + Info: GPUDetectionInfo{ + GPUType: GPUTypeNone, + ConfiguredVendor: "none", + DetectionMethod: DetectionSourceEnvBoth, + EnvOverrideType: envType, + EnvOverrideCount: envCount, + }, + } } } @@ -278,6 +298,18 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult EnvOverrideType: envType, }, } + default: + // Defensive: unknown env type should not silently fall through + logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType) + return DetectionResult{ + Detector: &NoneDetector{}, + Info: GPUDetectionInfo{ + GPUType: GPUTypeNone, + ConfiguredVendor: "none", + DetectionMethod: DetectionSourceEnvType, + EnvOverrideType: envType, + }, + } } } @@ -303,6 +335,14 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete } } + // Check for auto-detection scenarios (GPUDevices provided or AppleGPU enabled without explicit vendor) + isAutoDetect := cfg.GPUVendorAutoDetected || + (len(cfg.GPUDevices) > 0 && cfg.GPUVendor == "") || + (cfg.AppleGPU.Enabled && cfg.GPUVendor == "") + if isAutoDetect && source == DetectionSourceConfig { + source = DetectionSourceAuto + } + switch GPUType(cfg.GPUVendor) { case GPUTypeApple: return DetectionResult{ @@ -355,46 +395,21 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } - } - - // Auto-detect based on config settings - if cfg.AppleGPU.Enabled { + default: + // SECURITY: Explicit default prevents silent misconfiguration + // Unknown GPU vendor is treated as no GPU - fail secure + // Note: Config.Validate() should catch invalid vendors before this point + logWarningf("unrecognized GPU vendor %q, using no GPU", cfg.GPUVendor) return DetectionResult{ - Detector: &AppleDetector{enabled: true}, + Detector: &NoneDetector{}, Info: GPUDetectionInfo{ - GPUType: GPUTypeApple, - ConfiguredVendor: "apple", - DetectionMethod: DetectionSourceAuto, + GPUType: GPUTypeNone, + ConfiguredVendor: "none", + DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } } - if len(cfg.GPUDevices) > 0 { - return DetectionResult{ - Detector: &NVIDIADetector{}, - Info: GPUDetectionInfo{ - GPUType: GPUTypeNVIDIA, - ConfiguredVendor: "nvidia", - DetectionMethod: DetectionSourceAuto, - EnvOverrideType: envType, - EnvOverrideCount: envCount, - ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, - }, - } - } - - // Default to no GPU - return DetectionResult{ - Detector: &NoneDetector{}, - Info: GPUDetectionInfo{ - GPUType: GPUTypeNone, - ConfiguredVendor: "none", - DetectionMethod: source, - EnvOverrideType: envType, - EnvOverrideCount: envCount, - ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, - }, - } } diff --git a/internal/worker/gpu_macos.go b/internal/worker/gpu_macos.go index a482953..19bfcf5 100644 --- a/internal/worker/gpu_macos.go +++ b/internal/worker/gpu_macos.go @@ -19,16 +19,15 @@ import ( // MacOSGPUInfo holds information about a macOS GPU type MacOSGPUInfo struct { - Index uint32 `json:"index"` - Name string `json:"name"` - ChipsetModel string `json:"chipset_model"` - VRAM_MB uint32 `json:"vram_mb"` - IsIntegrated bool `json:"is_integrated"` - IsAppleSilicon bool `json:"is_apple_silicon"` - // Real-time metrics from powermetrics (if available) + Name string `json:"name"` + ChipsetModel string `json:"chipset_model"` + Index uint32 `json:"index"` + VRAM_MB uint32 `json:"vram_mb"` UtilizationPercent uint32 `json:"utilization_percent,omitempty"` PowerMW uint32 `json:"power_mw,omitempty"` TemperatureC uint32 `json:"temperature_c,omitempty"` + IsIntegrated bool `json:"is_integrated"` + IsAppleSilicon bool `json:"is_apple_silicon"` } // PowermetricsData holds GPU metrics from powermetrics diff --git a/internal/worker/gpu_nvml_stub.go b/internal/worker/gpu_nvml_stub.go index 47779ff..d042489 100644 --- a/internal/worker/gpu_nvml_stub.go +++ b/internal/worker/gpu_nvml_stub.go @@ -7,19 +7,19 @@ import "errors" // GPUInfo provides comprehensive GPU information type GPUInfo struct { - Index uint32 + UUID string Name string - Utilization uint32 + VBIOSVersion string MemoryUsed uint64 MemoryTotal uint64 - Temperature uint32 PowerDraw uint32 + Index uint32 ClockSM uint32 ClockMemory uint32 PCIeGen uint32 PCIeWidth uint32 - UUID string - VBIOSVersion string + Temperature uint32 + Utilization uint32 } func InitNVML() error { diff --git a/internal/worker/integrity/hash.go b/internal/worker/integrity/hash.go index 02d6da0..9bda6eb 100644 --- a/internal/worker/integrity/hash.go +++ b/internal/worker/integrity/hash.go @@ -140,9 +140,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) { } type result struct { - index int - hash string err error + hash string + index int } workCh := make(chan int, len(files)) diff --git a/internal/worker/interfaces/executor.go b/internal/worker/interfaces/executor.go index 7cbdc3d..e286730 100644 --- a/internal/worker/interfaces/executor.go +++ b/internal/worker/interfaces/executor.go @@ -13,9 +13,9 @@ type ExecutionEnv struct { JobDir string OutputDir string LogFile string - GPUDevices []string GPUEnvVar string GPUDevicesStr string + GPUDevices []string } // JobExecutor defines the contract for executing jobs @@ -26,8 +26,8 @@ type JobExecutor interface { // ExecutionResult holds the result of job execution type ExecutionResult struct { - Success bool + Error error ExitCode int Duration time.Duration - Error error + Success bool } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 945e604..382b62b 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -5,77 +5,100 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "net/http" "os" "path/filepath" "time" - "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/metrics" - "github.com/jfraeys/fetch_ml/internal/network" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/resources" + "github.com/jfraeys/fetch_ml/internal/scheduler" "github.com/jfraeys/fetch_ml/internal/worker/execution" "github.com/jfraeys/fetch_ml/internal/worker/executor" "github.com/jfraeys/fetch_ml/internal/worker/integrity" "github.com/jfraeys/fetch_ml/internal/worker/interfaces" "github.com/jfraeys/fetch_ml/internal/worker/lifecycle" + "github.com/jfraeys/fetch_ml/internal/worker/plugins" ) -// JupyterManager interface for jupyter service management -type JupyterManager interface { - StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error) - StopService(ctx context.Context, serviceID string) error - RemoveService(ctx context.Context, serviceID string, purge bool) error - RestoreWorkspace(ctx context.Context, name string) (string, error) - ListServices() []*jupyter.JupyterService - ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error) -} - -// MLServer is an alias for network.MLServer for backward compatibility. -type MLServer = network.MLServer - -// NewMLServer creates a new ML server connection. -func NewMLServer(cfg *Config) (*MLServer, error) { - return network.NewMLServer("", "", "", 0, "") -} - // Worker represents an ML task worker with composed dependencies. type Worker struct { - ID string - Config *Config - Logger *logging.Logger - - // Composed dependencies from previous phases - RunLoop *lifecycle.RunLoop - Runner *executor.JobRunner - Metrics *metrics.Metrics - metricsSrv *http.Server - Health *lifecycle.HealthMonitor - Resources *resources.Manager - - // GPU detection metadata for status output + Jupyter plugins.JupyterManager + QueueClient queue.Backend + Config *Config + Logger *logging.Logger + RunLoop *lifecycle.RunLoop + Runner *executor.JobRunner + Metrics *metrics.Metrics + metricsSrv *http.Server + Health *lifecycle.HealthMonitor + Resources *resources.Manager + ID string gpuDetectionInfo GPUDetectionInfo - - // Legacy fields for backward compatibility during migration - Jupyter JupyterManager - QueueClient queue.Backend // Stored for prewarming access + schedulerConn *scheduler.SchedulerConn // For distributed mode + ctx context.Context + cancel context.CancelFunc } // Start begins the worker's main processing loop. func (w *Worker) Start() { w.Logger.Info("worker starting", "worker_id", w.ID, - "max_concurrent", w.Config.MaxWorkers) + "max_concurrent", w.Config.MaxWorkers, + "mode", w.Config.Mode, + ) + slog.SetDefault(w.Logger.Logger) + w.ctx, w.cancel = context.WithCancel(context.Background()) w.Health.RecordHeartbeat() + + // Start heartbeat loop for distributed mode + if w.Config.Mode == "distributed" && w.schedulerConn != nil { + go w.heartbeatLoop() + } + w.RunLoop.Start() } +// heartbeatLoop sends periodic heartbeats with slot status to scheduler +func (w *Worker) heartbeatLoop() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-w.ctx.Done(): + return + case <-ticker.C: + w.Health.RecordHeartbeat() + if w.schedulerConn != nil { + slots := scheduler.SlotStatus{ + BatchTotal: w.Config.MaxWorkers, + BatchInUse: w.RunLoop.RunningCount(), + } + w.schedulerConn.Send(scheduler.Message{ + Type: scheduler.MsgHeartbeat, + Payload: mustMarshal(scheduler.HeartbeatPayload{ + WorkerID: w.ID, + Slots: slots, + }), + }) + } + } + } +} + // Stop gracefully shuts down the worker immediately. func (w *Worker) Stop() { w.Logger.Info("worker stopping", "worker_id", w.ID) + + if w.cancel != nil { + w.cancel() + } + w.RunLoop.Stop() if w.metricsSrv != nil { @@ -181,7 +204,7 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er basePath := w.Config.BasePath if basePath == "" { - basePath = "/tmp" + basePath = os.TempDir() } dataDir := w.Config.DataDir if dataDir == "" { @@ -291,7 +314,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { dataDir := w.Config.DataDir if dataDir == "" { - dataDir = "/tmp/data" + dataDir = os.TempDir() + "/data" } // Get expected checksum from metadata @@ -321,107 +344,10 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error { return nil } -// RunJupyterTask runs a Jupyter-related task. -// It handles start, stop, remove, restore, and list_packages actions. -func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) { - if w.Jupyter == nil { - return nil, fmt.Errorf("jupyter manager not configured") - } - - action := task.Metadata["jupyter_action"] - if action == "" { - action = "start" // Default action - } - - switch action { - case "start": - name := task.Metadata["jupyter_name"] - if name == "" { - name = task.Metadata["jupyter_workspace"] - } - if name == "" { - // Extract from jobName if format is "jupyter-" - if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" { - name = task.JobName[8:] - } - } - if name == "" { - return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") - } - - req := &jupyter.StartRequest{Name: name} - service, err := w.Jupyter.StartService(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to start jupyter service: %w", err) - } - - output := map[string]interface{}{ - "type": "start", - "service": service, - } - return json.Marshal(output) - - case "stop": - serviceID := task.Metadata["jupyter_service_id"] - if serviceID == "" { - return nil, fmt.Errorf("missing jupyter_service_id in task metadata") - } - if err := w.Jupyter.StopService(ctx, serviceID); err != nil { - return nil, fmt.Errorf("failed to stop jupyter service: %w", err) - } - return json.Marshal(map[string]string{"type": "stop", "status": "stopped"}) - - case "remove": - serviceID := task.Metadata["jupyter_service_id"] - if serviceID == "" { - return nil, fmt.Errorf("missing jupyter_service_id in task metadata") - } - purge := task.Metadata["jupyter_purge"] == "true" - if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil { - return nil, fmt.Errorf("failed to remove jupyter service: %w", err) - } - return json.Marshal(map[string]string{"type": "remove", "status": "removed"}) - - case "restore": - name := task.Metadata["jupyter_name"] - if name == "" { - name = task.Metadata["jupyter_workspace"] - } - if name == "" { - return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata") - } - serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name) - if err != nil { - return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err) - } - return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID}) - - case "list_packages": - serviceName := task.Metadata["jupyter_name"] - if serviceName == "" { - // Extract from jobName if format is "jupyter-packages-" - if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" { - serviceName = task.JobName[16:] - } - } - if serviceName == "" { - return nil, fmt.Errorf("missing jupyter_name in task metadata") - } - - packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName) - if err != nil { - return nil, fmt.Errorf("failed to list installed packages: %w", err) - } - - output := map[string]interface{}{ - "type": "list_packages", - "packages": packages, - } - return json.Marshal(output) - - default: - return nil, fmt.Errorf("unknown jupyter action: %s", action) - } +// GetJupyterManager returns the Jupyter manager for plugin use +// This implements the plugins.TaskRunner interface +func (w *Worker) GetJupyterManager() plugins.JupyterManager { + return w.Jupyter } // PrewarmNextOnce prewarms the next task in queue. @@ -445,7 +371,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) { // Create prewarm directory prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots") - if err := os.MkdirAll(prewarmDir, 0750); err != nil { + if err := os.MkdirAll(prewarmDir, 0o750); err != nil { return false, fmt.Errorf("failed to create prewarm directory: %w", err) } @@ -538,3 +464,8 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string) // Run the job return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv) } + +func mustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +}