package worker import ( "errors" "fmt" "log/slog" "os" "os/exec" "path/filepath" "runtime" "strings" "github.com/jfraeys/fetch_ml/internal/scheduler" ) // logWarningf logs a warning message using slog with proper sanitization func logWarningf(format string, args ...any) { // Use structured logging to avoid log injection // Format the message first, then log as a single string attribute msg := fmt.Sprintf(format, args...) slog.Warn("warning", "message", msg) } // GPUType represents different GPU types type GPUType string const ( GPUTypeNVIDIA GPUType = "nvidia" GPUTypeAMD GPUType = "amd" GPUTypeApple GPUType = "apple" GPUTypeNone GPUType = "none" ) // DetectionSource indicates how the GPU detector was selected type DetectionSource string const ( DetectionSourceEnvType DetectionSource = "env_override_type" DetectionSourceEnvCount DetectionSource = "env_override_count" DetectionSourceEnvBoth DetectionSource = "env_override_both" DetectionSourceConfig DetectionSource = "config" DetectionSourceAuto DetectionSource = "auto" DetectionSourceNone DetectionSource = "none" ) // GPUDetectionInfo provides metadata about how GPU detection was determined type GPUDetectionInfo struct { GPUType GPUType `json:"gpu_type"` ConfiguredVendor string `json:"configured_vendor"` DetectionMethod DetectionSource `json:"detection_method"` EnvOverrideType string `json:"env_override_type,omitempty"` EnvOverrideCount int `json:"env_override_count,omitempty"` ConfigLayerAutoDetected bool `json:"config_layer_auto_detected,omitempty"` } // GPUDetector interface for detecting GPU availability type GPUDetector interface { DetectGPUCount() int GetGPUType() GPUType GetDevicePaths() []string } // NVIDIA GPUDetector implementation type NVIDIADetector struct{} func (d *NVIDIADetector) DetectGPUCount() int { // First try NVML for accurate detection if IsNVMLAvailable() { count, err := GetGPUCount() if err == nil && count > 0 { return count } } // Fall back to environment variable if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 { return n } return 0 } func (d *NVIDIADetector) GetGPUType() GPUType { return GPUTypeNVIDIA } func (d *NVIDIADetector) GetDevicePaths() []string { // Prefer standard NVIDIA device nodes when present. patterns := []string{ "/dev/nvidiactl", "/dev/nvidia-modeset", "/dev/nvidia-uvm", "/dev/nvidia-uvm-tools", "/dev/nvidia*", } seen := make(map[string]struct{}) out := make([]string, 0, 8) for _, pat := range patterns { if filepath.Base(pat) == pat { continue } if strings.Contains(pat, "*") { matches, _ := filepath.Glob(pat) for _, m := range matches { if _, ok := seen[m]; ok { continue } if _, err := os.Stat(m); err != nil { continue } seen[m] = struct{}{} out = append(out, m) } continue } if _, ok := seen[pat]; ok { continue } if _, err := os.Stat(pat); err != nil { continue } seen[pat] = struct{}{} out = append(out, pat) } // Fallback for non-NVIDIA setups where only generic DRM device exists. if len(out) == 0 { if _, err := os.Stat("/dev/dri"); err == nil { out = append(out, "/dev/dri") } } return out } // Apple M-series GPUDetector implementation type AppleDetector struct { enabled bool } func (d *AppleDetector) DetectGPUCount() int { // First try actual macOS GPU detection if IsMacOS() { count, err := GetMacOSGPUCount() if err == nil && count > 0 { return count } } if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 { return n } if d.enabled { return 1 } return 0 } func (d *AppleDetector) GetGPUType() GPUType { return GPUTypeApple } func (d *AppleDetector) GetDevicePaths() []string { return []string{"/dev/metal", "/dev/mps"} } // None GPUDetector implementation type NoneDetector struct{} func (d *NoneDetector) DetectGPUCount() int { return 0 } func (d *NoneDetector) GetGPUType() GPUType { return GPUTypeNone } func (d *NoneDetector) GetDevicePaths() []string { return nil } // GPUDetectorFactory creates appropriate GPU detector based config type GPUDetectorFactory struct{} // DetectionResult contains both the detector and metadata about how it was selected type DetectionResult struct { Detector GPUDetector Info GPUDetectionInfo } // Validate checks if the detection result is valid and returns an error if not. // This ensures users get clear error messages for unimplemented features like AMD GPU. func (r DetectionResult) Validate() error { if r.Detector == nil { switch r.Info.ConfiguredVendor { case "amd": return errors.New( "AMD GPU support is not yet implemented. " + "Use NVIDIA GPUs, Apple Silicon, or CPU-only mode. " + "For development/testing, use FETCH_ML_MOCK_GPU_TYPE=AMD", ) default: return fmt.Errorf("GPU detection failed for vendor %q", r.Info.ConfiguredVendor) } } return nil } func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector { result := f.CreateDetectorWithInfo(cfg) return result.Detector } func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult { // Check for explicit environment overrides envType := os.Getenv("FETCH_ML_GPU_TYPE") envCount, hasEnvCount := envInt("FETCH_ML_GPU_COUNT") if envType != "" && hasEnvCount { // Both env vars set logEnvOverride("FETCH_ML_GPU_TYPE", envType) logEnvOverride("FETCH_ML_GPU_COUNT", envCount) switch envType { case string(GPUTypeNVIDIA): return DetectionResult{ Detector: &NVIDIADetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNVIDIA, ConfiguredVendor: "nvidia", DetectionMethod: DetectionSourceEnvBoth, EnvOverrideType: envType, EnvOverrideCount: envCount, }, } case string(GPUTypeApple): return DetectionResult{ Detector: &AppleDetector{enabled: true}, Info: GPUDetectionInfo{ GPUType: GPUTypeApple, ConfiguredVendor: "apple", DetectionMethod: DetectionSourceEnvBoth, EnvOverrideType: envType, EnvOverrideCount: envCount, }, } case string(GPUTypeNone): return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: DetectionSourceEnvBoth, EnvOverrideType: envType, EnvOverrideCount: envCount, }, } case "amd": // AMD GPU support not yet implemented // Return error so user knows this is a known limitation return DetectionResult{ Detector: nil, // Will cause error when used Info: GPUDetectionInfo{ GPUType: GPUTypeAMD, ConfiguredVendor: "amd", DetectionMethod: DetectionSourceEnvBoth, EnvOverrideType: envType, EnvOverrideCount: envCount, }, } default: // Defensive: unknown env type should not silently fall through logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType) return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: DetectionSourceEnvBoth, EnvOverrideType: envType, EnvOverrideCount: envCount, }, } } } if envType != "" { // Only FETCH_ML_GPU_TYPE set logEnvOverride("FETCH_ML_GPU_TYPE", envType) switch envType { case string(GPUTypeNVIDIA): return DetectionResult{ Detector: &NVIDIADetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNVIDIA, ConfiguredVendor: "nvidia", DetectionMethod: DetectionSourceEnvType, EnvOverrideType: envType, }, } case string(GPUTypeApple): return DetectionResult{ Detector: &AppleDetector{enabled: true}, Info: GPUDetectionInfo{ GPUType: GPUTypeApple, ConfiguredVendor: "apple", DetectionMethod: DetectionSourceEnvType, EnvOverrideType: envType, }, } case string(GPUTypeNone): return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: DetectionSourceEnvType, EnvOverrideType: envType, }, } case "amd": // AMD GPU support not yet implemented return DetectionResult{ Detector: nil, Info: GPUDetectionInfo{ GPUType: GPUTypeAMD, ConfiguredVendor: "amd", DetectionMethod: DetectionSourceEnvType, EnvOverrideType: envType, }, } default: // Defensive: unknown env type should not silently fall through logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType) return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: DetectionSourceEnvType, EnvOverrideType: envType, }, } } } if hasEnvCount { // Only FETCH_ML_GPU_COUNT set - need to detect vendor from config or auto logEnvOverride("FETCH_ML_GPU_COUNT", envCount) return f.detectFromConfigWithSource(cfg, DetectionSourceEnvCount, "", envCount) } // No env overrides - detect from config return f.detectFromConfigWithSource(cfg, DetectionSourceConfig, "", -1) } func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source DetectionSource, envType string, envCount int) DetectionResult { if cfg == nil { return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: source, }, } } // Check for auto-detection scenarios (GPUDevices provided or AppleGPU enabled without explicit vendor) isAutoDetect := cfg.GPUVendorAutoDetected || (len(cfg.GPUDevices) > 0 && cfg.GPUVendor == "") || (cfg.AppleGPU.Enabled && cfg.GPUVendor == "") if isAutoDetect && source == DetectionSourceConfig { source = DetectionSourceAuto } switch GPUType(cfg.GPUVendor) { case GPUTypeApple: return DetectionResult{ Detector: &AppleDetector{enabled: cfg.AppleGPU.Enabled}, Info: GPUDetectionInfo{ GPUType: GPUTypeApple, ConfiguredVendor: "apple", DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } case GPUTypeNone: return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } case GPUTypeNVIDIA: return DetectionResult{ Detector: &NVIDIADetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNVIDIA, ConfiguredVendor: "nvidia", DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } case "amd": // AMD GPU support not yet implemented - tracked in roadmap logWarningf("AMD GPU detection requested but not yet implemented. Consider using MOCK mode or contributing to the project.") return DetectionResult{ Detector: nil, Info: GPUDetectionInfo{ GPUType: GPUTypeAMD, ConfiguredVendor: "amd", DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } default: // SECURITY: Explicit default prevents silent misconfiguration // Unknown GPU vendor is treated as no GPU - fail secure // Note: Config.Validate() should catch invalid vendors before this point logWarningf("unrecognized GPU vendor %q, using no GPU", cfg.GPUVendor) return DetectionResult{ Detector: &NoneDetector{}, Info: GPUDetectionInfo{ GPUType: GPUTypeNone, ConfiguredVendor: "none", DetectionMethod: source, EnvOverrideType: envType, EnvOverrideCount: envCount, ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected, }, } } } // DetectCapabilities returns full WorkerCapabilities with backend detection // It tries NVIDIA first, then Metal (Apple Silicon), then Vulkan, then CPU fallback func DetectCapabilities() scheduler.WorkerCapabilities { // Try NVIDIA first if IsNVMLAvailable() { count, err := GetGPUCount() if err == nil && count > 0 { gpus, err := GetAllGPUInfo() if err == nil && len(gpus) > 0 { totalVRAM := float64(0) for _, gpu := range gpus { totalVRAM += float64(gpu.MemoryTotal) / (1024 * 1024 * 1024) // Convert to GB } return scheduler.WorkerCapabilities{ GPUBackend: scheduler.BackendNVIDIA, GPUCount: count, GPUType: gpus[0].Name, VRAMGB: totalVRAM, CPUCount: runtime.NumCPU(), MemoryGB: getSystemMemoryGB(), Hostname: getHostname(), GPUInfo: scheduler.GPUDetectionInfo{ GPUType: "nvidia", Count: count, Devices: getNVIDIADevices(), MemTotal: gpus[0].MemoryTotal, }, } } } } // Try Metal (macOS Apple Silicon) if runtime.GOOS == "darwin" && IsAppleSilicon() { gpus, err := GetMacOSGPUInfo() if err == nil && len(gpus) > 0 { totalVRAM := float64(0) for _, gpu := range gpus { totalVRAM += float64(gpu.VRAM_MB) / 1024 // Convert MB to GB } return scheduler.WorkerCapabilities{ GPUBackend: scheduler.BackendMetal, GPUCount: len(gpus), GPUType: gpus[0].ChipsetModel, VRAMGB: totalVRAM, CPUCount: runtime.NumCPU(), MemoryGB: getSystemMemoryGB(), Hostname: getHostname(), GPUInfo: scheduler.GPUDetectionInfo{ GPUType: "apple", Count: len(gpus), MemTotal: uint64(totalVRAM * 1024 * 1024 * 1024), }, } } } // Try Vulkan (check for vulkaninfo or /dev/dri) if hasVulkan() { count := getVulkanGPUCount() if count > 0 { return scheduler.WorkerCapabilities{ GPUBackend: scheduler.BackendVulkan, GPUCount: count, GPUType: "vulkan", VRAMGB: 0, // TODO: Query Vulkan for VRAM CPUCount: runtime.NumCPU(), MemoryGB: getSystemMemoryGB(), Hostname: getHostname(), GPUInfo: scheduler.GPUDetectionInfo{ GPUType: "vulkan", Count: count, Devices: getVulkanDevices(), }, } } } // CPU fallback return scheduler.WorkerCapabilities{ GPUBackend: scheduler.BackendCPU, GPUCount: 0, GPUType: "cpu", VRAMGB: 0, CPUCount: runtime.NumCPU(), MemoryGB: getSystemMemoryGB(), Hostname: getHostname(), GPUInfo: scheduler.GPUDetectionInfo{ GPUType: "cpu", Count: 0, }, } } // hasVulkan checks if Vulkan is available func hasVulkan() bool { // Check for vulkaninfo binary if _, err := exec.LookPath("vulkaninfo"); err == nil { return true } // Check for /dev/dri if _, err := os.Stat("/dev/dri"); err == nil { return true } return false } // getVulkanGPUCount returns the number of Vulkan GPUs func getVulkanGPUCount() int { // Try to get GPU count from vulkaninfo out, err := exec.Command("vulkaninfo", "--summary").Output() if err != nil { return 1 // Assume 1 if vulkaninfo fails but /dev/dri exists } // Count "GPU" occurrences in output count := 0 for _, line := range strings.Split(string(out), "\n") { if strings.Contains(line, "GPU") && strings.Contains(line, "deviceName") { count++ } } if count == 0 { return 1 // Assume at least 1 } return count } // getVulkanDevices returns Vulkan device paths func getVulkanDevices() []string { if _, err := os.Stat("/dev/dri"); err == nil { return []string{"/dev/dri"} } return nil } // getNVIDIADevices returns NVIDIA device paths func getNVIDIADevices() []string { patterns := []string{"/dev/nvidia*"} var devices []string for _, pat := range patterns { matches, _ := filepath.Glob(pat) for _, m := range matches { if _, err := os.Stat(m); err == nil { devices = append(devices, m) } } } return devices } // getSystemMemoryGB returns system memory in GB func getSystemMemoryGB() float64 { // Try to read from /proc/meminfo on Linux if data, err := os.ReadFile("/proc/meminfo"); err == nil { for _, line := range strings.Split(string(data), "\n") { if strings.HasPrefix(line, "MemTotal:") { var kb uint64 fmt.Sscanf(line, "MemTotal: %d kB", &kb) return float64(kb) / (1024 * 1024) // Convert KB to GB } } } // Fallback: return 0 to indicate unknown return 0 } // getHostname returns the system hostname func getHostname() string { hostname, _ := os.Hostname() return hostname }