feat(worker): integrate scheduler endpoints and security hardening
Update worker system for scheduler integration: - Worker server with scheduler registration - Configuration with scheduler endpoint support - Artifact handling with integrity verification - Container executor with supply chain validation - Local executor enhancements - GPU detection improvements (cross-platform) - Error handling with execution context - Factory pattern for executor instantiation - Hash integrity with native library support
This commit is contained in:
parent
ef11d88a75
commit
3fb6902fa1
13 changed files with 371 additions and 219 deletions
|
|
@ -2,12 +2,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/invopop/yaml"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
|
|
@ -31,7 +34,37 @@ func resolveWorkerConfigPath(flags *auth.Flags) string {
|
|||
}
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
var (
|
||||
configPath string
|
||||
initConfig bool
|
||||
mode string
|
||||
schedulerAddr string
|
||||
token string
|
||||
)
|
||||
flag.StringVar(&configPath, "config", "worker.yaml", "Path to worker config file")
|
||||
flag.BoolVar(&initConfig, "init", false, "Initialize a new worker config file")
|
||||
flag.StringVar(&mode, "mode", "distributed", "Worker mode: standalone or distributed")
|
||||
flag.StringVar(&schedulerAddr, "scheduler", "", "Scheduler address (for distributed mode)")
|
||||
flag.StringVar(&token, "token", "", "Worker token (copy from scheduler -init output)")
|
||||
flag.Parse()
|
||||
|
||||
// Handle init mode
|
||||
if initConfig {
|
||||
if err := generateWorkerConfig(configPath, mode, schedulerAddr, token); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Config generated: %s\n", configPath)
|
||||
fmt.Println("\nNext steps:")
|
||||
if mode == "distributed" {
|
||||
fmt.Println("1. Copy the token from your scheduler's -init output")
|
||||
fmt.Println("2. Edit the config to set scheduler.address and scheduler.token")
|
||||
fmt.Println("3. Copy the scheduler's TLS cert to the worker")
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Normal worker startup...
|
||||
|
||||
// Parse authentication flags
|
||||
authFlags := auth.ParseAuthFlags()
|
||||
|
|
@ -95,3 +128,81 @@ func main() {
|
|||
log.Println("Worker shut down gracefully")
|
||||
}
|
||||
}
|
||||
|
||||
// generateWorkerConfig creates a new worker config file
|
||||
func generateWorkerConfig(path, mode, schedulerAddr, token string) error {
|
||||
cfg := map[string]any{
|
||||
"node": map[string]any{
|
||||
"role": "worker",
|
||||
"id": "",
|
||||
},
|
||||
"worker": map[string]any{
|
||||
"mode": mode,
|
||||
"max_workers": 3,
|
||||
},
|
||||
}
|
||||
|
||||
if mode == "distributed" {
|
||||
cfg["scheduler"] = map[string]any{
|
||||
"address": schedulerAddr,
|
||||
"cert": "/etc/fetch_ml/scheduler.crt",
|
||||
"token": token,
|
||||
}
|
||||
} else {
|
||||
cfg["queue"] = map[string]any{
|
||||
"backend": "redis",
|
||||
"redis_addr": "localhost:6379",
|
||||
"redis_password": "",
|
||||
"redis_db": 0,
|
||||
}
|
||||
}
|
||||
|
||||
cfg["slots"] = map[string]any{
|
||||
"service_slots": 1,
|
||||
"ports": map[string]any{
|
||||
"service_range_start": 8000,
|
||||
"service_range_end": 8099,
|
||||
},
|
||||
}
|
||||
|
||||
cfg["gpu"] = map[string]any{
|
||||
"vendor": "auto",
|
||||
}
|
||||
|
||||
cfg["prewarm"] = map[string]any{
|
||||
"enabled": true,
|
||||
}
|
||||
|
||||
cfg["log"] = map[string]any{
|
||||
"level": "info",
|
||||
"format": "json",
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Add header comment
|
||||
header := fmt.Sprintf(`# Worker Configuration for fetch_ml
|
||||
# Generated by: worker -init
|
||||
# Mode: %s
|
||||
#`, mode)
|
||||
|
||||
if mode == "distributed" && token == "" {
|
||||
header += `
|
||||
# ⚠️ SECURITY WARNING: You must add the scheduler token to this config.
|
||||
# Copy the token from the scheduler's -init output and paste it below.
|
||||
# scheduler:
|
||||
# token: "wkr_xxx..."
|
||||
#`
|
||||
}
|
||||
|
||||
fullContent := header + "\n\n" + string(data)
|
||||
|
||||
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
|
||||
return fmt.Errorf("write config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
)
|
||||
|
||||
// scanArtifacts discovers and catalogs artifact files in a run directory.
|
||||
// When includeAll is false, it excludes code/, snapshot/, *.log files, and symlinks.
|
||||
func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) {
|
||||
runDir = strings.TrimSpace(runDir)
|
||||
if runDir == "" {
|
||||
|
|
@ -55,14 +57,8 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife
|
|||
}
|
||||
|
||||
// Standard exclusions (always apply)
|
||||
if rel == manifestFilename {
|
||||
exclusions = append(exclusions, manifest.Exclusion{
|
||||
Path: rel,
|
||||
Reason: "manifest file excluded",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
if strings.HasSuffix(rel, "/"+manifestFilename) {
|
||||
// Exclude manifest files - both legacy (run_manifest.json) and nonce-based (run_manifest_<nonce>.json)
|
||||
if strings.HasPrefix(rel, "run_manifest") && strings.HasSuffix(rel, ".json") {
|
||||
exclusions = append(exclusions, manifest.Exclusion{
|
||||
Path: rel,
|
||||
Reason: "manifest file excluded",
|
||||
|
|
@ -160,8 +156,6 @@ func scanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manife
|
|||
}, nil
|
||||
}
|
||||
|
||||
const manifestFilename = "run_manifest.json"
|
||||
|
||||
// ScanArtifacts is an exported wrapper for testing/benchmarking.
|
||||
// When includeAll is false, excludes code/, snapshot/, *.log files, and symlinks.
|
||||
func ScanArtifacts(runDir string, includeAll bool, caps *SandboxConfig) (*manifest.Artifacts, error) {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/url"
|
||||
"os"
|
||||
|
|
@ -80,7 +81,7 @@ type Config struct {
|
|||
// When "hipaa": enforces hard requirements at startup
|
||||
ComplianceMode string `yaml:"compliance_mode"`
|
||||
|
||||
// Phase 1: opt-in prewarming of next task artifacts (snapshot/datasets/env).
|
||||
// Opt-in prewarming of next task artifacts (snapshot/datasets/env).
|
||||
PrewarmEnabled bool `yaml:"prewarm_enabled"`
|
||||
|
||||
// Podman execution
|
||||
|
|
@ -102,6 +103,16 @@ type Config struct {
|
|||
MaxRetries int `yaml:"max_retries"` // Maximum retry attempts (default: 3)
|
||||
GracefulTimeout time.Duration `yaml:"graceful_timeout"` // Shutdown timeout (default: 5min)
|
||||
|
||||
// Mode determines how the worker operates: "standalone" or "distributed"
|
||||
Mode string `yaml:"mode"`
|
||||
|
||||
// Scheduler configuration for distributed mode
|
||||
Scheduler struct {
|
||||
Address string `yaml:"address"`
|
||||
Cert string `yaml:"cert"`
|
||||
Token string `yaml:"token"`
|
||||
} `yaml:"scheduler"`
|
||||
|
||||
// Plugins configuration
|
||||
Plugins map[string]factory.PluginConfig `yaml:"plugins"`
|
||||
|
||||
|
|
@ -145,7 +156,7 @@ type SandboxConfig struct {
|
|||
SeccompProfile string `yaml:"seccomp_profile"` // Default: "default-hardened"
|
||||
MaxRuntimeHours int `yaml:"max_runtime_hours"`
|
||||
|
||||
// Security hardening options (NEW)
|
||||
// Security hardening options
|
||||
NoNewPrivileges bool `yaml:"no_new_privileges"` // Default: true
|
||||
DropAllCaps bool `yaml:"drop_all_caps"` // Default: true
|
||||
AllowedCaps []string `yaml:"allowed_caps"` // Capabilities to add back
|
||||
|
|
@ -153,12 +164,20 @@ type SandboxConfig struct {
|
|||
RunAsUID int `yaml:"run_as_uid"` // Default: 1000
|
||||
RunAsGID int `yaml:"run_as_gid"` // Default: 1000
|
||||
|
||||
// Upload limits (NEW)
|
||||
// Process isolation
|
||||
MaxProcesses int `yaml:"max_processes"` // Fork bomb protection (default: 100)
|
||||
MaxOpenFiles int `yaml:"max_open_files"` // FD exhaustion protection (default: 1024)
|
||||
DisableSwap bool `yaml:"disable_swap"` // Prevent swap exfiltration
|
||||
OOMScoreAdj int `yaml:"oom_score_adj"` // OOM killer priority (default: 100)
|
||||
TaskUID int `yaml:"task_uid"` // Per-task UID (0 = use RunAsUID)
|
||||
TaskGID int `yaml:"task_gid"` // Per-task GID (0 = use RunAsGID)
|
||||
|
||||
// Upload limits
|
||||
MaxUploadSizeBytes int64 `yaml:"max_upload_size_bytes"` // Default: 10GB
|
||||
MaxUploadRateBps int64 `yaml:"max_upload_rate_bps"` // Default: 100MB/s
|
||||
MaxUploadsPerMinute int `yaml:"max_uploads_per_minute"` // Default: 10
|
||||
|
||||
// Artifact ingestion caps (NEW)
|
||||
// Artifact ingestion caps
|
||||
MaxArtifactFiles int `yaml:"max_artifact_files"` // Default: 10000
|
||||
MaxArtifactTotalBytes int64 `yaml:"max_artifact_total_bytes"` // Default: 100GB
|
||||
}
|
||||
|
|
@ -174,6 +193,10 @@ var SecurityDefaults = struct {
|
|||
UserNS bool
|
||||
RunAsUID int
|
||||
RunAsGID int
|
||||
MaxProcesses int
|
||||
MaxOpenFiles int
|
||||
DisableSwap bool
|
||||
OOMScoreAdj int
|
||||
MaxUploadSizeBytes int64
|
||||
MaxUploadRateBps int64
|
||||
MaxUploadsPerMinute int
|
||||
|
|
@ -189,6 +212,10 @@ var SecurityDefaults = struct {
|
|||
UserNS: true,
|
||||
RunAsUID: 1000,
|
||||
RunAsGID: 1000,
|
||||
MaxProcesses: 100, // Fork bomb protection
|
||||
MaxOpenFiles: 1024, // FD exhaustion protection
|
||||
DisableSwap: true, // Prevent swap exfiltration
|
||||
OOMScoreAdj: 100, // Lower OOM priority
|
||||
MaxUploadSizeBytes: 10 * 1024 * 1024 * 1024, // 10GB
|
||||
MaxUploadRateBps: 100 * 1024 * 1024, // 100MB/s
|
||||
MaxUploadsPerMinute: 10,
|
||||
|
|
@ -214,6 +241,12 @@ func (s *SandboxConfig) Validate() error {
|
|||
if s.MaxUploadsPerMinute < 0 {
|
||||
return fmt.Errorf("max_uploads_per_minute must be positive")
|
||||
}
|
||||
if s.MaxArtifactFiles < 0 {
|
||||
return fmt.Errorf("max_artifact_files must be positive")
|
||||
}
|
||||
if s.MaxArtifactTotalBytes < 0 {
|
||||
return fmt.Errorf("max_artifact_total_bytes must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -281,6 +314,42 @@ func (s *SandboxConfig) ApplySecurityDefaults() {
|
|||
if s.MaxArtifactTotalBytes == 0 {
|
||||
s.MaxArtifactTotalBytes = SecurityDefaults.MaxArtifactTotalBytes
|
||||
}
|
||||
|
||||
// Process isolation defaults
|
||||
if s.MaxProcesses == 0 {
|
||||
s.MaxProcesses = SecurityDefaults.MaxProcesses
|
||||
}
|
||||
if s.MaxOpenFiles == 0 {
|
||||
s.MaxOpenFiles = SecurityDefaults.MaxOpenFiles
|
||||
}
|
||||
if !s.DisableSwap {
|
||||
s.DisableSwap = SecurityDefaults.DisableSwap
|
||||
}
|
||||
if s.OOMScoreAdj == 0 {
|
||||
s.OOMScoreAdj = SecurityDefaults.OOMScoreAdj
|
||||
}
|
||||
// TaskUID/TaskGID default to 0 (meaning "use RunAsUID/RunAsGID")
|
||||
// Only override if explicitly set (> 0)
|
||||
if s.TaskUID < 0 {
|
||||
s.TaskUID = 0
|
||||
}
|
||||
if s.TaskGID < 0 {
|
||||
s.TaskGID = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetProcessIsolationFlags returns the effective UID/GID for a task
|
||||
// If TaskUID/TaskGID are set (>0), use those; otherwise use RunAsUID/RunAsGID
|
||||
func (s *SandboxConfig) GetProcessIsolationFlags() (uid, gid int) {
|
||||
uid = s.RunAsUID
|
||||
gid = s.RunAsGID
|
||||
if s.TaskUID > 0 {
|
||||
uid = s.TaskUID
|
||||
}
|
||||
if s.TaskGID > 0 {
|
||||
gid = s.TaskGID
|
||||
}
|
||||
return uid, gid
|
||||
}
|
||||
|
||||
// Getter methods for SandboxConfig interface
|
||||
|
|
@ -294,6 +363,14 @@ func (s *SandboxConfig) GetSeccompProfile() string { return s.SeccompProfile }
|
|||
func (s *SandboxConfig) GetReadOnlyRoot() bool { return s.ReadOnlyRoot }
|
||||
func (s *SandboxConfig) GetNetworkMode() string { return s.NetworkMode }
|
||||
|
||||
// Process Isolation getter methods
|
||||
func (s *SandboxConfig) GetMaxProcesses() int { return s.MaxProcesses }
|
||||
func (s *SandboxConfig) GetMaxOpenFiles() int { return s.MaxOpenFiles }
|
||||
func (s *SandboxConfig) GetDisableSwap() bool { return s.DisableSwap }
|
||||
func (s *SandboxConfig) GetOOMScoreAdj() int { return s.OOMScoreAdj }
|
||||
func (s *SandboxConfig) GetTaskUID() int { return s.TaskUID }
|
||||
func (s *SandboxConfig) GetTaskGID() int { return s.TaskGID }
|
||||
|
||||
// LoadConfig loads worker configuration from a YAML file.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := fileutil.SecureFileRead(path)
|
||||
|
|
@ -864,7 +941,7 @@ func envInt(name string) (int, bool) {
|
|||
|
||||
// logEnvOverride logs environment variable overrides to stderr for debugging
|
||||
func logEnvOverride(name string, value interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "[env] %s=%v (override active)\n", name, value)
|
||||
slog.Warn("env override active", "var", name, "value", value)
|
||||
}
|
||||
|
||||
// parseCPUFromConfig determines total CPU from environment or config
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ import (
|
|||
// It captures the task ID, execution phase, specific operation, root cause,
|
||||
// and additional context to make debugging easier.
|
||||
type ExecutionError struct {
|
||||
TaskID string // The task that failed
|
||||
Phase string // Current TaskState (queued, preparing, running, collecting)
|
||||
Operation string // Specific operation that failed (e.g., "create_workspace", "fetch_dataset")
|
||||
Cause error // The underlying error
|
||||
Context map[string]string // Additional context (paths, IDs, etc.)
|
||||
Timestamp time.Time // When the error occurred
|
||||
Timestamp time.Time
|
||||
Cause error
|
||||
Context map[string]string
|
||||
TaskID string
|
||||
Phase string
|
||||
Operation string
|
||||
}
|
||||
|
||||
// Error implements the error interface with a formatted message.
|
||||
|
|
|
|||
|
|
@ -24,13 +24,13 @@ import (
|
|||
|
||||
// ContainerConfig holds configuration for container execution
|
||||
type ContainerConfig struct {
|
||||
Sandbox SandboxConfig
|
||||
PodmanImage string
|
||||
ContainerResults string
|
||||
ContainerWorkspace string
|
||||
TrainScript string
|
||||
BasePath string
|
||||
AppleGPUEnabled bool
|
||||
Sandbox SandboxConfig // NEW: Security configuration
|
||||
}
|
||||
|
||||
// SandboxConfig interface to avoid import cycle
|
||||
|
|
@ -44,6 +44,12 @@ type SandboxConfig interface {
|
|||
GetSeccompProfile() string
|
||||
GetReadOnlyRoot() bool
|
||||
GetNetworkMode() string
|
||||
GetMaxProcesses() int
|
||||
GetMaxOpenFiles() int
|
||||
GetDisableSwap() bool
|
||||
GetOOMScoreAdj() int
|
||||
GetTaskUID() int
|
||||
GetTaskGID() int
|
||||
}
|
||||
|
||||
// ContainerExecutor executes jobs in containers using podman
|
||||
|
|
@ -233,7 +239,7 @@ func (e *ContainerExecutor) setupVolumes(trackingEnv map[string]string, _outputD
|
|||
}
|
||||
|
||||
cacheRoot := filepath.Join(e.config.BasePath, ".cache")
|
||||
os.MkdirAll(cacheRoot, 0755)
|
||||
os.MkdirAll(cacheRoot, 0750)
|
||||
volumes[cacheRoot] = "/workspace/.cache:rw"
|
||||
|
||||
defaultEnv := map[string]string{
|
||||
|
|
@ -331,6 +337,12 @@ func (e *ContainerExecutor) runPodman(
|
|||
SeccompProfile: e.config.Sandbox.GetSeccompProfile(),
|
||||
ReadOnlyRoot: e.config.Sandbox.GetReadOnlyRoot(),
|
||||
NetworkMode: e.config.Sandbox.GetNetworkMode(),
|
||||
MaxProcesses: e.config.Sandbox.GetMaxProcesses(),
|
||||
MaxOpenFiles: e.config.Sandbox.GetMaxOpenFiles(),
|
||||
DisableSwap: e.config.Sandbox.GetDisableSwap(),
|
||||
OOMScoreAdj: e.config.Sandbox.GetOOMScoreAdj(),
|
||||
TaskUID: e.config.Sandbox.GetTaskUID(),
|
||||
TaskGID: e.config.Sandbox.GetTaskGID(),
|
||||
}
|
||||
|
||||
podmanCmd := container.BuildPodmanCommand(ctx, podmanCfg, securityConfig, scriptPath, depsPath, extraArgs)
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ func NewLocalExecutor(logger *logging.Logger, writer interfaces.ManifestWriter)
|
|||
|
||||
// Execute runs a job locally
|
||||
func (e *LocalExecutor) Execute(ctx context.Context, task *queue.Task, env interfaces.ExecutionEnv) error {
|
||||
// Generate and write script
|
||||
// Generate and write script with crash safety (fsync)
|
||||
scriptContent := generateScript(task)
|
||||
scriptPath := filepath.Join(env.OutputDir, "run.sh")
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0600); err != nil {
|
||||
if err := fileutil.WriteFileSafe(scriptPath, []byte(scriptContent), 0600); err != nil {
|
||||
return &errtypes.TaskExecutionError{
|
||||
TaskID: task.ID,
|
||||
JobName: task.JobName,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
||||
// Create queue backend
|
||||
backendCfg := queue.BackendConfig{
|
||||
Mode: cfg.Mode,
|
||||
Backend: queue.QueueBackend(strings.ToLower(strings.TrimSpace(cfg.Queue.Backend))),
|
||||
RedisAddr: cfg.RedisAddr,
|
||||
RedisPassword: cfg.RedisPassword,
|
||||
|
|
@ -35,6 +36,11 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
FilesystemPath: cfg.Queue.FilesystemPath,
|
||||
FallbackToFilesystem: cfg.Queue.FallbackToFilesystem,
|
||||
MetricsFlushInterval: cfg.MetricsFlushInterval,
|
||||
Scheduler: queue.SchedulerConfig{
|
||||
Address: cfg.Scheduler.Address,
|
||||
Cert: cfg.Scheduler.Cert,
|
||||
Token: cfg.Scheduler.Token,
|
||||
},
|
||||
}
|
||||
|
||||
queueClient, err := queue.NewBackend(backendCfg)
|
||||
|
|
@ -171,6 +177,13 @@ func NewWorker(cfg *Config, _ string) (*Worker, error) {
|
|||
gpuDetectionInfo: gpuDetectionInfo,
|
||||
}
|
||||
|
||||
// In distributed mode, store the scheduler connection for heartbeats
|
||||
if cfg.Mode == "distributed" {
|
||||
if schedBackend, ok := queueClient.(*queue.SchedulerBackend); ok {
|
||||
worker.schedulerConn = schedBackend.Conn()
|
||||
}
|
||||
}
|
||||
|
||||
// Log GPU configuration
|
||||
if !cfg.LocalMode {
|
||||
gpuType := strings.ToLower(strings.TrimSpace(os.Getenv("FETCH_ML_GPU_TYPE")))
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// logWarningf logs a warning message using slog
|
||||
func logWarningf(format string, args ...any) {
|
||||
slog.Warn(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// GPUType represents different GPU types
|
||||
type GPUType string
|
||||
|
||||
|
|
@ -230,6 +237,19 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult
|
|||
EnvOverrideCount: envCount,
|
||||
},
|
||||
}
|
||||
default:
|
||||
// Defensive: unknown env type should not silently fall through
|
||||
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
|
||||
return DetectionResult{
|
||||
Detector: &NoneDetector{},
|
||||
Info: GPUDetectionInfo{
|
||||
GPUType: GPUTypeNone,
|
||||
ConfiguredVendor: "none",
|
||||
DetectionMethod: DetectionSourceEnvBoth,
|
||||
EnvOverrideType: envType,
|
||||
EnvOverrideCount: envCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -278,6 +298,18 @@ func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult
|
|||
EnvOverrideType: envType,
|
||||
},
|
||||
}
|
||||
default:
|
||||
// Defensive: unknown env type should not silently fall through
|
||||
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
|
||||
return DetectionResult{
|
||||
Detector: &NoneDetector{},
|
||||
Info: GPUDetectionInfo{
|
||||
GPUType: GPUTypeNone,
|
||||
ConfiguredVendor: "none",
|
||||
DetectionMethod: DetectionSourceEnvType,
|
||||
EnvOverrideType: envType,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -303,6 +335,14 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete
|
|||
}
|
||||
}
|
||||
|
||||
// Check for auto-detection scenarios (GPUDevices provided or AppleGPU enabled without explicit vendor)
|
||||
isAutoDetect := cfg.GPUVendorAutoDetected ||
|
||||
(len(cfg.GPUDevices) > 0 && cfg.GPUVendor == "") ||
|
||||
(cfg.AppleGPU.Enabled && cfg.GPUVendor == "")
|
||||
if isAutoDetect && source == DetectionSourceConfig {
|
||||
source = DetectionSourceAuto
|
||||
}
|
||||
|
||||
switch GPUType(cfg.GPUVendor) {
|
||||
case GPUTypeApple:
|
||||
return DetectionResult{
|
||||
|
|
@ -355,46 +395,21 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete
|
|||
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect based on config settings
|
||||
if cfg.AppleGPU.Enabled {
|
||||
default:
|
||||
// SECURITY: Explicit default prevents silent misconfiguration
|
||||
// Unknown GPU vendor is treated as no GPU - fail secure
|
||||
// Note: Config.Validate() should catch invalid vendors before this point
|
||||
logWarningf("unrecognized GPU vendor %q, using no GPU", cfg.GPUVendor)
|
||||
return DetectionResult{
|
||||
Detector: &AppleDetector{enabled: true},
|
||||
Detector: &NoneDetector{},
|
||||
Info: GPUDetectionInfo{
|
||||
GPUType: GPUTypeApple,
|
||||
ConfiguredVendor: "apple",
|
||||
DetectionMethod: DetectionSourceAuto,
|
||||
GPUType: GPUTypeNone,
|
||||
ConfiguredVendor: "none",
|
||||
DetectionMethod: source,
|
||||
EnvOverrideType: envType,
|
||||
EnvOverrideCount: envCount,
|
||||
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
|
||||
},
|
||||
}
|
||||
}
|
||||
if len(cfg.GPUDevices) > 0 {
|
||||
return DetectionResult{
|
||||
Detector: &NVIDIADetector{},
|
||||
Info: GPUDetectionInfo{
|
||||
GPUType: GPUTypeNVIDIA,
|
||||
ConfiguredVendor: "nvidia",
|
||||
DetectionMethod: DetectionSourceAuto,
|
||||
EnvOverrideType: envType,
|
||||
EnvOverrideCount: envCount,
|
||||
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Default to no GPU
|
||||
return DetectionResult{
|
||||
Detector: &NoneDetector{},
|
||||
Info: GPUDetectionInfo{
|
||||
GPUType: GPUTypeNone,
|
||||
ConfiguredVendor: "none",
|
||||
DetectionMethod: source,
|
||||
EnvOverrideType: envType,
|
||||
EnvOverrideCount: envCount,
|
||||
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,16 +19,15 @@ import (
|
|||
|
||||
// MacOSGPUInfo holds information about a macOS GPU
|
||||
type MacOSGPUInfo struct {
|
||||
Index uint32 `json:"index"`
|
||||
Name string `json:"name"`
|
||||
ChipsetModel string `json:"chipset_model"`
|
||||
VRAM_MB uint32 `json:"vram_mb"`
|
||||
IsIntegrated bool `json:"is_integrated"`
|
||||
IsAppleSilicon bool `json:"is_apple_silicon"`
|
||||
// Real-time metrics from powermetrics (if available)
|
||||
Name string `json:"name"`
|
||||
ChipsetModel string `json:"chipset_model"`
|
||||
Index uint32 `json:"index"`
|
||||
VRAM_MB uint32 `json:"vram_mb"`
|
||||
UtilizationPercent uint32 `json:"utilization_percent,omitempty"`
|
||||
PowerMW uint32 `json:"power_mw,omitempty"`
|
||||
TemperatureC uint32 `json:"temperature_c,omitempty"`
|
||||
IsIntegrated bool `json:"is_integrated"`
|
||||
IsAppleSilicon bool `json:"is_apple_silicon"`
|
||||
}
|
||||
|
||||
// PowermetricsData holds GPU metrics from powermetrics
|
||||
|
|
|
|||
|
|
@ -7,19 +7,19 @@ import "errors"
|
|||
|
||||
// GPUInfo provides comprehensive GPU information
|
||||
type GPUInfo struct {
|
||||
Index uint32
|
||||
UUID string
|
||||
Name string
|
||||
Utilization uint32
|
||||
VBIOSVersion string
|
||||
MemoryUsed uint64
|
||||
MemoryTotal uint64
|
||||
Temperature uint32
|
||||
PowerDraw uint32
|
||||
Index uint32
|
||||
ClockSM uint32
|
||||
ClockMemory uint32
|
||||
PCIeGen uint32
|
||||
PCIeWidth uint32
|
||||
UUID string
|
||||
VBIOSVersion string
|
||||
Temperature uint32
|
||||
Utilization uint32
|
||||
}
|
||||
|
||||
func InitNVML() error {
|
||||
|
|
|
|||
|
|
@ -140,9 +140,9 @@ func DirOverallSHA256HexParallel(root string) (string, error) {
|
|||
}
|
||||
|
||||
type result struct {
|
||||
index int
|
||||
hash string
|
||||
err error
|
||||
hash string
|
||||
index int
|
||||
}
|
||||
|
||||
workCh := make(chan int, len(files))
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ type ExecutionEnv struct {
|
|||
JobDir string
|
||||
OutputDir string
|
||||
LogFile string
|
||||
GPUDevices []string
|
||||
GPUEnvVar string
|
||||
GPUDevicesStr string
|
||||
GPUDevices []string
|
||||
}
|
||||
|
||||
// JobExecutor defines the contract for executing jobs
|
||||
|
|
@ -26,8 +26,8 @@ type JobExecutor interface {
|
|||
|
||||
// ExecutionResult holds the result of job execution
|
||||
type ExecutionResult struct {
|
||||
Success bool
|
||||
Error error
|
||||
ExitCode int
|
||||
Duration time.Duration
|
||||
Error error
|
||||
Success bool
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,77 +5,100 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/metrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/resources"
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/execution"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/executor"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/integrity"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/interfaces"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/lifecycle"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker/plugins"
|
||||
)
|
||||
|
||||
// JupyterManager interface for jupyter service management
|
||||
type JupyterManager interface {
|
||||
StartService(ctx context.Context, req *jupyter.StartRequest) (*jupyter.JupyterService, error)
|
||||
StopService(ctx context.Context, serviceID string) error
|
||||
RemoveService(ctx context.Context, serviceID string, purge bool) error
|
||||
RestoreWorkspace(ctx context.Context, name string) (string, error)
|
||||
ListServices() []*jupyter.JupyterService
|
||||
ListInstalledPackages(ctx context.Context, serviceName string) ([]jupyter.InstalledPackage, error)
|
||||
}
|
||||
|
||||
// MLServer is an alias for network.MLServer for backward compatibility.
|
||||
type MLServer = network.MLServer
|
||||
|
||||
// NewMLServer creates a new ML server connection.
|
||||
func NewMLServer(cfg *Config) (*MLServer, error) {
|
||||
return network.NewMLServer("", "", "", 0, "")
|
||||
}
|
||||
|
||||
// Worker represents an ML task worker with composed dependencies.
|
||||
type Worker struct {
|
||||
ID string
|
||||
Config *Config
|
||||
Logger *logging.Logger
|
||||
|
||||
// Composed dependencies from previous phases
|
||||
RunLoop *lifecycle.RunLoop
|
||||
Runner *executor.JobRunner
|
||||
Metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
Health *lifecycle.HealthMonitor
|
||||
Resources *resources.Manager
|
||||
|
||||
// GPU detection metadata for status output
|
||||
Jupyter plugins.JupyterManager
|
||||
QueueClient queue.Backend
|
||||
Config *Config
|
||||
Logger *logging.Logger
|
||||
RunLoop *lifecycle.RunLoop
|
||||
Runner *executor.JobRunner
|
||||
Metrics *metrics.Metrics
|
||||
metricsSrv *http.Server
|
||||
Health *lifecycle.HealthMonitor
|
||||
Resources *resources.Manager
|
||||
ID string
|
||||
gpuDetectionInfo GPUDetectionInfo
|
||||
|
||||
// Legacy fields for backward compatibility during migration
|
||||
Jupyter JupyterManager
|
||||
QueueClient queue.Backend // Stored for prewarming access
|
||||
schedulerConn *scheduler.SchedulerConn // For distributed mode
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Start begins the worker's main processing loop.
|
||||
func (w *Worker) Start() {
|
||||
w.Logger.Info("worker starting",
|
||||
"worker_id", w.ID,
|
||||
"max_concurrent", w.Config.MaxWorkers)
|
||||
"max_concurrent", w.Config.MaxWorkers,
|
||||
"mode", w.Config.Mode,
|
||||
)
|
||||
slog.SetDefault(w.Logger.Logger)
|
||||
|
||||
w.ctx, w.cancel = context.WithCancel(context.Background())
|
||||
w.Health.RecordHeartbeat()
|
||||
|
||||
// Start heartbeat loop for distributed mode
|
||||
if w.Config.Mode == "distributed" && w.schedulerConn != nil {
|
||||
go w.heartbeatLoop()
|
||||
}
|
||||
|
||||
w.RunLoop.Start()
|
||||
}
|
||||
|
||||
// heartbeatLoop sends periodic heartbeats with slot status to scheduler
|
||||
func (w *Worker) heartbeatLoop() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.Health.RecordHeartbeat()
|
||||
if w.schedulerConn != nil {
|
||||
slots := scheduler.SlotStatus{
|
||||
BatchTotal: w.Config.MaxWorkers,
|
||||
BatchInUse: w.RunLoop.RunningCount(),
|
||||
}
|
||||
w.schedulerConn.Send(scheduler.Message{
|
||||
Type: scheduler.MsgHeartbeat,
|
||||
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
||||
WorkerID: w.ID,
|
||||
Slots: slots,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the worker immediately.
|
||||
func (w *Worker) Stop() {
|
||||
w.Logger.Info("worker stopping", "worker_id", w.ID)
|
||||
|
||||
if w.cancel != nil {
|
||||
w.cancel()
|
||||
}
|
||||
|
||||
w.RunLoop.Stop()
|
||||
|
||||
if w.metricsSrv != nil {
|
||||
|
|
@ -181,7 +204,7 @@ func (w *Worker) EnforceTaskProvenance(ctx context.Context, task *queue.Task) er
|
|||
|
||||
basePath := w.Config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/tmp"
|
||||
basePath = os.TempDir()
|
||||
}
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
|
|
@ -291,7 +314,7 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|||
|
||||
dataDir := w.Config.DataDir
|
||||
if dataDir == "" {
|
||||
dataDir = "/tmp/data"
|
||||
dataDir = os.TempDir() + "/data"
|
||||
}
|
||||
|
||||
// Get expected checksum from metadata
|
||||
|
|
@ -321,107 +344,10 @@ func (w *Worker) VerifySnapshot(ctx context.Context, task *queue.Task) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RunJupyterTask runs a Jupyter-related task.
|
||||
// It handles start, stop, remove, restore, and list_packages actions.
|
||||
func (w *Worker) RunJupyterTask(ctx context.Context, task *queue.Task) ([]byte, error) {
|
||||
if w.Jupyter == nil {
|
||||
return nil, fmt.Errorf("jupyter manager not configured")
|
||||
}
|
||||
|
||||
action := task.Metadata["jupyter_action"]
|
||||
if action == "" {
|
||||
action = "start" // Default action
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "start":
|
||||
name := task.Metadata["jupyter_name"]
|
||||
if name == "" {
|
||||
name = task.Metadata["jupyter_workspace"]
|
||||
}
|
||||
if name == "" {
|
||||
// Extract from jobName if format is "jupyter-<name>"
|
||||
if len(task.JobName) > 8 && task.JobName[:8] == "jupyter-" {
|
||||
name = task.JobName[8:]
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||
}
|
||||
|
||||
req := &jupyter.StartRequest{Name: name}
|
||||
service, err := w.Jupyter.StartService(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start jupyter service: %w", err)
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"type": "start",
|
||||
"service": service,
|
||||
}
|
||||
return json.Marshal(output)
|
||||
|
||||
case "stop":
|
||||
serviceID := task.Metadata["jupyter_service_id"]
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
if err := w.Jupyter.StopService(ctx, serviceID); err != nil {
|
||||
return nil, fmt.Errorf("failed to stop jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "stop", "status": "stopped"})
|
||||
|
||||
case "remove":
|
||||
serviceID := task.Metadata["jupyter_service_id"]
|
||||
if serviceID == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_service_id in task metadata")
|
||||
}
|
||||
purge := task.Metadata["jupyter_purge"] == "true"
|
||||
if err := w.Jupyter.RemoveService(ctx, serviceID, purge); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove jupyter service: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "remove", "status": "removed"})
|
||||
|
||||
case "restore":
|
||||
name := task.Metadata["jupyter_name"]
|
||||
if name == "" {
|
||||
name = task.Metadata["jupyter_workspace"]
|
||||
}
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name or jupyter_workspace in task metadata")
|
||||
}
|
||||
serviceID, err := w.Jupyter.RestoreWorkspace(ctx, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to restore jupyter workspace: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]string{"type": "restore", "service_id": serviceID})
|
||||
|
||||
case "list_packages":
|
||||
serviceName := task.Metadata["jupyter_name"]
|
||||
if serviceName == "" {
|
||||
// Extract from jobName if format is "jupyter-packages-<name>"
|
||||
if len(task.JobName) > 16 && task.JobName[:16] == "jupyter-packages-" {
|
||||
serviceName = task.JobName[16:]
|
||||
}
|
||||
}
|
||||
if serviceName == "" {
|
||||
return nil, fmt.Errorf("missing jupyter_name in task metadata")
|
||||
}
|
||||
|
||||
packages, err := w.Jupyter.ListInstalledPackages(ctx, serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list installed packages: %w", err)
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"type": "list_packages",
|
||||
"packages": packages,
|
||||
}
|
||||
return json.Marshal(output)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown jupyter action: %s", action)
|
||||
}
|
||||
// GetJupyterManager returns the Jupyter manager for plugin use
|
||||
// This implements the plugins.TaskRunner interface
|
||||
func (w *Worker) GetJupyterManager() plugins.JupyterManager {
|
||||
return w.Jupyter
|
||||
}
|
||||
|
||||
// PrewarmNextOnce prewarms the next task in queue.
|
||||
|
|
@ -445,7 +371,7 @@ func (w *Worker) PrewarmNextOnce(ctx context.Context) (bool, error) {
|
|||
|
||||
// Create prewarm directory
|
||||
prewarmDir := filepath.Join(basePath, ".prewarm", "snapshots")
|
||||
if err := os.MkdirAll(prewarmDir, 0750); err != nil {
|
||||
if err := os.MkdirAll(prewarmDir, 0o750); err != nil {
|
||||
return false, fmt.Errorf("failed to create prewarm directory: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -538,3 +464,8 @@ func (w *Worker) RunJob(ctx context.Context, task *queue.Task, outputDir string)
|
|||
// Run the job
|
||||
return w.Runner.Run(ctx, task, basePath, mode, w.Config.LocalMode, gpuEnv)
|
||||
}
|
||||
|
||||
func mustMarshal(v any) []byte {
|
||||
b, _ := json.Marshal(v)
|
||||
return b
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue