feat(worker): integrate scheduler endpoints and security hardening

Update worker system for scheduler integration:
- Worker server with scheduler registration
- Configuration with scheduler endpoint support
- Artifact handling with integrity verification
- Container executor with supply chain validation
- Local executor enhancements
- GPU detection improvements (cross-platform)
- Error handling with execution context
- Factory pattern for executor instantiation
- Hash integrity with native library support
This commit is contained in:
Jeremie Fraeys 2026-02-26 12:06:16 -05:00
parent ef11d88a75
commit 3fb6902fa1
No known key found for this signature in database
13 changed files with 371 additions and 219 deletions

View file

@ -2,12 +2,15 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"github.com/invopop/yaml"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/worker"
@ -31,7 +34,37 @@ func resolveWorkerConfigPath(flags *auth.Flags) string {
}
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
var (
configPath string
initConfig bool
mode string
schedulerAddr string
token string
)
flag.StringVar(&configPath, "config", "worker.yaml", "Path to worker config file")
flag.BoolVar(&initConfig, "init", false, "Initialize a new worker config file")
flag.StringVar(&mode, "mode", "distributed", "Worker mode: standalone or distributed")
flag.StringVar(&schedulerAddr, "scheduler", "", "Scheduler address (for distributed mode)")
flag.StringVar(&token, "token", "", "Worker token (copy from scheduler -init output)")
flag.Parse()
// Handle init mode
if initConfig {
if err := generateWorkerConfig(configPath, mode, schedulerAddr, token); err != nil {
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
os.Exit(1)
}
fmt.Printf("Config generated: %s\n", configPath)
fmt.Println("\nNext steps:")
if mode == "distributed" {
fmt.Println("1. Copy the token from your scheduler's -init output")
fmt.Println("2. Edit the config to set scheduler.address and scheduler.token")
fmt.Println("3. Copy the scheduler's TLS cert to the worker")
}
os.Exit(0)
}
// Normal worker startup...
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
@ -95,3 +128,81 @@ func main() {
log.Println("Worker shut down gracefully")
}
}
// generateWorkerConfig creates a new worker config file
func generateWorkerConfig(path, mode, schedulerAddr, token string) error {
cfg := map[string]any{
"node": map[string]any{
"role": "worker",
"id": "",
},
"worker": map[string]any{
"mode": mode,
"max_workers": 3,
},
}
if mode == "distributed" {
cfg["scheduler"] = map[string]any{
"address": schedulerAddr,
"cert": "/etc/fetch_ml/scheduler.crt",
"token": token,
}
} else {
cfg["queue"] = map[string]any{
"backend": "redis",
"redis_addr": "localhost:6379",
"redis_password": "",
"redis_db": 0,
}
}
cfg["slots"] = map[string]any{
"service_slots": 1,
"ports": map[string]any{
"service_range_start": 8000,
"service_range_end": 8099,
},
}
cfg["gpu"] = map[string]any{
"vendor": "auto",
}
cfg["prewarm"] = map[string]any{
"enabled": true,
}
cfg["log"] = map[string]any{
"level": "info",
"format": "json",
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
// Add header comment
header := fmt.Sprintf(`# Worker Configuration for fetch_ml
# Generated by: worker -init
# Mode: %s
#`, mode)
if mode == "distributed" && token == "" {
header += `
# SECURITY WARNING: You must add the scheduler token to this config.
# Copy the token from the scheduler's -init output and paste it below.
# scheduler:
# token: "wkr_xxx..."
#`
}
fullContent := header + "\n\n" + string(data)
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
return fmt.Errorf("write config file: %w", err)
}
return nil
}

View file

@ -12,6 +12,8 @@ import (
"github.com/jfraeys/fetch_ml/internal/manifest"
)
// scanArtifacts discovers and catalogs artifact files in a run directory.
// When includeAll is false, it excludes code/, snapshot/, *.log files, and symlinks.
func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) {
runDir = strings.TrimSpace(runDir)
if runDir == "" {
@ -55,14 +57,8 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife
}
// Standard exclusions (always apply)
if rel == manifestFilename {
exclusions = append(exclusions, manifest.Exclusion{
Path: rel,
Reason: "manifest file excluded",
})
return nil
}
if strings.HasSuffix(rel, "/"+manifestFilename) {
// Exclude manifest files - both legacy (run_manifest.json) and nonce-based (run_manifest_<nonce>.json)
if strings.HasPrefix(rel, "run_manifest") && strings.HasSuffix(rel, ".json") {
exclusions = append(exclusions, manifest.Exclusion{
Path: rel,
Reason: "manifest file excluded",
@ -160,8 +156,6 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife
}, nil
}
const manifestFilename = "run_manifest.json"
// ScanArtifacts is an exported wrapper for testing/benchmarking.
// When includeAll is false, excludes code/, snapshot/, *.log files, and symlinks.
func ScanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) {

View file

@ -5,6 +5,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"math"
"net/url"
"os"
@ -80,7 +81,7 @@ type Config struct {
// When "hipaa": enforces hard requirements at startup
ComplianceMode string `yaml:"compliance_mode"`
// Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env).
// Opt-in prewarming of next task artifacts (snapshot/datasets/env).
PrewarmEnabled bool `yaml:"prewarm_enabled"`
// Podman execution
@ -102,6 +103,16 @@ type Config struct {
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
// Mode determines how the worker operates: "standalone" or "distributed"
Mode string `yaml:"mode"`
// Scheduler configuration for distributed mode
Scheduler struct {
Address string `yaml:"address"`
Cert string `yaml:"cert"`
Token string `yaml:"token"`
} `yaml:"scheduler"`
// Plugins configuration
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
@ -145,7 +156,7 @@ type SandboxConfig struct {
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
MaxRuntimeHours int `yaml:"max_runtime_hours"`
// Security hardening options (NEW)
// Security hardening options
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
@ -153,12 +164,20 @@ type SandboxConfig struct {
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
// Upload limits (NEW)
// Process isolation
MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100)
MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024)
DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration
OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100)
TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID)
TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID)
// Upload limits
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
// Artifact ingestion caps (NEW)
// Artifact ingestion caps
MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000
MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB
}
@ -174,6 +193,10 @@ var SecurityDefaults = struct {
UserNS bool
RunAsUID int
RunAsGID int
MaxProcesses int
MaxOpenFiles int
DisableSwap bool
OOMScoreAdj int
MaxUploadSizeBytes int64
MaxUploadRateBps int64
MaxUploadsPerMinute int
@ -189,6 +212,10 @@ var SecurityDefaults = struct {
UserNS: true,
RunAsUID: 1000,
RunAsGID: 1000,
MaxProcesses: 100, // Fork bomb protection
MaxOpenFiles: 1024, // FD exhaustion protection
DisableSwap: true, // Prevent swap exfiltration
OOMScoreAdj: 100, // Lower OOM priority
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
MaxUploadsPerMinute: 10,
@ -214,6 +241,12 @@ func (s *SandboxConfig) Validate() error {
if s.MaxUploadsPerMinute < 0 {
return fmt.Errorf("max_uploads_per_minute must be positive")
}
if s.MaxArtifactFiles < 0 {
return fmt.Errorf("max_artifact_files must be positive")
}
if s.MaxArtifactTotalBytes < 0 {
return fmt.Errorf("max_artifact_total_bytes must be positive")
}
return nil
}
@ -281,6 +314,42 @@ func (s *SandboxConfig) ApplySecurityDefaults() {
if s.MaxArtifactTotalBytes == 0 {
s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes
}
// Process isolation defaults
if s.MaxProcesses == 0 {
s.MaxProcesses = SecurityDefaults.MaxProcesses
}
if s.MaxOpenFiles == 0 {
s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles
}
if !s.DisableSwap {
s.DisableSwap = SecurityDefaults.DisableSwap
}
if s.OOMScoreAdj == 0 {
s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj
}
// TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID")
// Only override if explicitly set (> 0)
if s.TaskUID < 0 {
s.TaskUID = 0
}
if s.TaskGID < 0 {
s.TaskGID = 0
}
}
// GetProcessIsolationFlags returns the effective UID/GID for a task
// If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID
func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) {
uid = s.RunAsUID
gid = s.RunAsGID
if s.TaskUID > 0 {
uid = s.TaskUID
}
if s.TaskGID > 0 {
gid = s.TaskGID
}
return uid, gid
}
// Getter methods for SandboxConfig interface
@ -294,6 +363,14 @@ func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
// Process Isolation getter methods
func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses }
func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles }
func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap }
func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj }
func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID }
func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID }
// LoadConfig loads worker configuration from a YAML file.
func LoadConfig(path string) (*Config, error) {
data, err := fileutil.SecureFileRead(path)
@ -864,7 +941,7 @@ func envInt(name string) (int, bool) {
// logEnvOverride logs environment variable overrides to stderr for debugging
func logEnvOverride(name string, value interface{}) {
fmt.Fprintf(os.Stderr, "[env] %s=%v (override active)\n", name, value)
slog.Warn("env override active", "var", name, "value", value)
}
// parseCPUFromConfig determines total CPU from environment or config

View file

@ -10,12 +10,12 @@ import (
// It captures the task ID, execution phase, specific operation, root cause,
// and additional context to make debugging easier.
type ExecutionError struct {
TaskID string // The task that failed
Phase string // Current TaskState (queued, preparing, running, collecting)
Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset")
Cause error // The underlying error
Context map[string]string // Additional context (paths, IDs, etc.)
Timestamp time.Time // When the error occurred
Timestamp time.Time
Cause error
Context map[string]string
TaskID string
Phase string
Operation string
}
// Error implements the error interface with a formatted message.

View file

@ -24,13 +24,13 @@ import (
// ContainerConfig holds configuration for container execution
type ContainerConfig struct {
Sandbox SandboxConfig
PodmanImage string
ContainerResults string
ContainerWorkspace string
TrainScript string
BasePath string
AppleGPUEnabled bool
Sandbox SandboxConfig // NEW: Security configuration
}
// SandboxConfig interface to avoid import cycle
@ -44,6 +44,12 @@ type SandboxConfig interface {
GetSeccompProfile() string
GetReadOnlyRoot() bool
GetNetworkMode() string
GetMaxProcesses() int
GetMaxOpenFiles() int
GetDisableSwap() bool
GetOOMScoreAdj() int
GetTaskUID() int
GetTaskGID() int
}
// ContainerExecutor executes jobs in containers using podman
@ -233,7 +239,7 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD
}
cacheRoot := filepath.Join(e.config.BasePath, ".cache")
os.MkdirAll(cacheRoot, 0755)
os.MkdirAll(cacheRoot, 0750)
volumes[cacheRoot] = "/workspace/.cache:rw"
defaultEnv := map[string]string{
@ -331,6 +337,12 @@ func (e *ContainerExecutor) runPodman(
SeccompProfile: e.config.Sandbox.GetSeccompProfile(),
ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(),
NetworkMode: e.config.Sandbox.GetNetworkMode(),
MaxProcesses: e.config.Sandbox.GetMaxProcesses(),
MaxOpenFiles: e.config.Sandbox.GetMaxOpenFiles(),
DisableSwap: e.config.Sandbox.GetDisableSwap(),
OOMScoreAdj: e.config.Sandbox.GetOOMScoreAdj(),
TaskUID: e.config.Sandbox.GetTaskUID(),
TaskGID: e.config.Sandbox.GetTaskGID(),
}
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs)

View file

@ -34,11 +34,11 @@ func NewLocalExecutor(logger *logging.Logger, writer interfaces.ManifestWriter)
// Execute runs a job locally
func (e *LocalExecutor) Execute(ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv) error {
// Generate and write script
// Generate and write script with crash safety (fsync)
scriptContent := generateScript(task)
scriptPath := filepath.Join(env.OutputDir, "run.sh")
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil {
if err := fileutil.WriteFileSafe(scriptPath, []byte(scriptContent), 0600); err != nil {
return &errtypes.TaskExecutionError{
TaskID: task.ID,
JobName: task.JobName,

View file

@ -27,6 +27,7 @@ import (
func NewWorker(cfg *Config, _ string) (*Worker, error) {
// Create queue backend
backendCfg := queue.BackendConfig{
Mode: cfg.Mode,
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
RedisAddr: cfg.RedisAddr,
RedisPassword: cfg.RedisPassword,
@ -35,6 +36,11 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
FilesystemPath: cfg.Queue.FilesystemPath,
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
MetricsFlushInterval: cfg.MetricsFlushInterval,
Scheduler: queue.SchedulerConfig{
Address: cfg.Scheduler.Address,
Cert: cfg.Scheduler.Cert,
Token: cfg.Scheduler.Token,
},
}
queueClient, err := queue.NewBackend(backendCfg)
@ -171,6 +177,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
gpuDetectionInfo: gpuDetectionInfo,
}
// In distributed mode, store the scheduler connection for heartbeats
if cfg.Mode == "distributed" {
if schedBackend, ok := queueClient.(*queue.SchedulerBackend); ok {
worker.schedulerConn = schedBackend.Conn()
}
}
// Log GPU configuration
if !cfg.LocalMode {
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))

View file

@ -1,11 +1,18 @@
package worker
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
)
// logWarningf logs a warning message using slog
func logWarningf(format string, args ...any) {
slog.Warn(fmt.Sprintf(format, args...))
}
// GPUType represents different GPU types
type GPUType string
@ -230,6 +237,19 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult
EnvOverrideCount: envCount,
},
}
default:
// Defensive: unknown env type should not silently fall through
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
}
}
@ -278,6 +298,18 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult
EnvOverrideType: envType,
},
}
default:
// Defensive: unknown env type should not silently fall through
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
}
}
@ -303,6 +335,14 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete
}
}
// Check for auto-detection scenarios (GPUDevices provided or AppleGPU enabled without explicit vendor)
isAutoDetect := cfg.GPUVendorAutoDetected ||
(len(cfg.GPUDevices) > 0 && cfg.GPUVendor == "") ||
(cfg.AppleGPU.Enabled && cfg.GPUVendor == "")
if isAutoDetect && source == DetectionSourceConfig {
source = DetectionSourceAuto
}
switch GPUType(cfg.GPUVendor) {
case GPUTypeApple:
return DetectionResult{
@ -355,46 +395,21 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
}
// Auto-detect based on config settings
if cfg.AppleGPU.Enabled {
default:
// SECURITY: Explicit default prevents silent misconfiguration
// Unknown GPU vendor is treated as no GPU - fail secure
// Note: Config.Validate() should catch invalid vendors before this point
logWarningf("unrecognized GPU vendor %q, using no GPU", cfg.GPUVendor)
return DetectionResult{
Detector: &AppleDetector{enabled: true},
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeApple,
ConfiguredVendor: "apple",
DetectionMethod: DetectionSourceAuto,
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
}
if len(cfg.GPUDevices) > 0 {
return DetectionResult{
Detector: &NVIDIADetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNVIDIA,
ConfiguredVendor: "nvidia",
DetectionMethod: DetectionSourceAuto,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
}
// Default to no GPU
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
}

View file

@ -19,16 +19,15 @@ import (
// MacOSGPUInfo holds information about a macOS GPU
type MacOSGPUInfo struct {
Index uint32 `json:"index"`
Name string `json:"name"`
ChipsetModel string `json:"chipset_model"`
VRAM_MB uint32 `json:"vram_mb"`
IsIntegrated bool `json:"is_integrated"`
IsAppleSilicon bool `json:"is_apple_silicon"`
// Real-time metrics from powermetrics (if available)
Name string `json:"name"`
ChipsetModel string `json:"chipset_model"`
Index uint32 `json:"index"`
VRAM_MB uint32 `json:"vram_mb"`
UtilizationPercent uint32 `json:"utilization_percent,omitempty"`
PowerMW uint32 `json:"power_mw,omitempty"`
TemperatureC uint32 `json:"temperature_c,omitempty"`
IsIntegrated bool `json:"is_integrated"`
IsAppleSilicon bool `json:"is_apple_silicon"`
}
// PowermetricsData holds GPU metrics from powermetrics

View file

@ -7,19 +7,19 @@ import "errors"
// GPUInfo provides comprehensive GPU information
type GPUInfo struct {
Index uint32
UUID string
Name string
Utilization uint32
VBIOSVersion string
MemoryUsed uint64
MemoryTotal uint64
Temperature uint32
PowerDraw uint32
Index uint32
ClockSM uint32
ClockMemory uint32
PCIeGen uint32
PCIeWidth uint32
UUID string
VBIOSVersion string
Temperature uint32
Utilization uint32
}
func InitNVML() error {

View file

@ -140,9 +140,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
}
type result struct {
index int
hash string
err error
hash string
index int
}
workCh := make(chan int, len(files))

View file

@ -13,9 +13,9 @@ type ExecutionEnv struct {
JobDir string
OutputDir string
LogFile string
GPUDevices []string
GPUEnvVar string
GPUDevicesStr string
GPUDevices []string
}
// JobExecutor defines the contract for executing jobs
@ -26,8 +26,8 @@ type JobExecutor interface {
// ExecutionResult holds the result of job execution
type ExecutionResult struct {
Success bool
Error error
ExitCode int
Duration time.Duration
Error error
Success bool
}

View file

@ -5,77 +5,100 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"os"
"path/filepath"
"time"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/metrics"
"github.com/jfraeys/fetch_ml/internal/network"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/resources"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/jfraeys/fetch_ml/internal/worker/execution"
"github.com/jfraeys/fetch_ml/internal/worker/executor"
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
)
// JupyterManager interface for jupyter service management
type JupyterManager interface {
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
StopService(ctx context.Context, serviceID string) error
RemoveService(ctx context.Context, serviceID string, purge bool) error
RestoreWorkspace(ctx context.Context, name string) (string, error)
ListServices() []*jupyter.JupyterService
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
}
// MLServer is an alias for network.MLServer for backward compatibility.
type MLServer = network.MLServer
// NewMLServer creates a new ML server connection.
func NewMLServer(cfg *Config) (*MLServer, error) {
return network.NewMLServer("", "", "", 0, "")
}
// Worker represents an ML task worker with composed dependencies.
type Worker struct {
ID string
Config *Config
Logger *logging.Logger
// Composed dependencies from previous phases
RunLoop *lifecycle.RunLoop
Runner *executor.JobRunner
Metrics *metrics.Metrics
metricsSrv *http.Server
Health *lifecycle.HealthMonitor
Resources *resources.Manager
// GPU detection metadata for status output
Jupyter plugins.JupyterManager
QueueClient queue.Backend
Config *Config
Logger *logging.Logger
RunLoop *lifecycle.RunLoop
Runner *executor.JobRunner
Metrics *metrics.Metrics
metricsSrv *http.Server
Health *lifecycle.HealthMonitor
Resources *resources.Manager
ID string
gpuDetectionInfo GPUDetectionInfo
// Legacy fields for backward compatibility during migration
Jupyter JupyterManager
QueueClient queue.Backend // Stored for prewarming access
schedulerConn *scheduler.SchedulerConn // For distributed mode
ctx context.Context
cancel context.CancelFunc
}
// Start begins the worker's main processing loop.
func (w *Worker) Start() {
w.Logger.Info("worker starting",
"worker_id", w.ID,
"max_concurrent", w.Config.MaxWorkers)
"max_concurrent", w.Config.MaxWorkers,
"mode", w.Config.Mode,
)
slog.SetDefault(w.Logger.Logger)
w.ctx, w.cancel = context.WithCancel(context.Background())
w.Health.RecordHeartbeat()
// Start heartbeat loop for distributed mode
if w.Config.Mode == "distributed" && w.schedulerConn != nil {
go w.heartbeatLoop()
}
w.RunLoop.Start()
}
// heartbeatLoop sends periodic heartbeats with slot status to scheduler
func (w *Worker) heartbeatLoop() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
w.Health.RecordHeartbeat()
if w.schedulerConn != nil {
slots := scheduler.SlotStatus{
BatchTotal: w.Config.MaxWorkers,
BatchInUse: w.RunLoop.RunningCount(),
}
w.schedulerConn.Send(scheduler.Message{
Type: scheduler.MsgHeartbeat,
Payload: mustMarshal(scheduler.HeartbeatPayload{
WorkerID: w.ID,
Slots: slots,
}),
})
}
}
}
}
// Stop gracefully shuts down the worker immediately.
func (w *Worker) Stop() {
w.Logger.Info("worker stopping", "worker_id", w.ID)
if w.cancel != nil {
w.cancel()
}
w.RunLoop.Stop()
if w.metricsSrv != nil {
@ -181,7 +204,7 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
basePath := w.Config.BasePath
if basePath == "" {
basePath = "/tmp"
basePath = os.TempDir()
}
dataDir := w.Config.DataDir
if dataDir == "" {
@ -291,7 +314,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
dataDir := w.Config.DataDir
if dataDir == "" {
dataDir = "/tmp/data"
dataDir = os.TempDir() + "/data"
}
// Get expected checksum from metadata
@ -321,107 +344,10 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
return nil
}
// RunJupyterTask runs a Jupyter-related task.
// It handles start, stop, remove, restore, and list_packages actions.
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
if w.Jupyter == nil {
return nil, fmt.Errorf("jupyter manager not configured")
}
action := task.Metadata["jupyter_action"]
if action == "" {
action = "start" // Default action
}
switch action {
case "start":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
// Extract from jobName if format is "jupyter-<name>"
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
name = task.JobName[8:]
}
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
req := &jupyter.StartRequest{Name: name}
service, err := w.Jupyter.StartService(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
}
output := map[string]interface{}{
"type": "start",
"service": service,
}
return json.Marshal(output)
case "stop":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
case "remove":
serviceID := task.Metadata["jupyter_service_id"]
if serviceID == "" {
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
}
purge := task.Metadata["jupyter_purge"] == "true"
if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
}
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
case "restore":
name := task.Metadata["jupyter_name"]
if name == "" {
name = task.Metadata["jupyter_workspace"]
}
if name == "" {
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
}
serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
if err != nil {
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
}
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
case "list_packages":
serviceName := task.Metadata["jupyter_name"]
if serviceName == "" {
// Extract from jobName if format is "jupyter-packages-<name>"
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
serviceName = task.JobName[16:]
}
}
if serviceName == "" {
return nil, fmt.Errorf("missing jupyter_name in task metadata")
}
packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
if err != nil {
return nil, fmt.Errorf("failed to list installed packages: %w", err)
}
output := map[string]interface{}{
"type": "list_packages",
"packages": packages,
}
return json.Marshal(output)
default:
return nil, fmt.Errorf("unknown jupyter action: %s", action)
}
// GetJupyterManager returns the Jupyter manager for plugin use
// This implements the plugins.TaskRunner interface
func (w *Worker) GetJupyterManager() plugins.JupyterManager {
return w.Jupyter
}
// PrewarmNextOnce prewarms the next task in queue.
@ -445,7 +371,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
// Create prewarm directory
prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots")
if err := os.MkdirAll(prewarmDir, 0750); err != nil {
if err := os.MkdirAll(prewarmDir, 0o750); err != nil {
return false, fmt.Errorf("failed to create prewarm directory: %w", err)
}
@ -538,3 +464,8 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
// Run the job
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
}
func mustMarshal(v any) []byte {
b, _ := json.Marshal(v)
return b
}