fetch_ml/internal/worker/gpu_detector.go
Jeremie Fraeys 61660dc925
refactor: co-locate security, storage, telemetry, tracking, worker tests
Move unit tests from tests/unit/ to internal/ following Go conventions:

Security tests:
- tests/unit/security/* -> internal/security/* (audit, config_integrity, filetype, gpu_audit, hipaa_validation, manifest_filename, path_traversal, resource_quota, secrets)

Storage tests:
- tests/unit/storage/* -> internal/storage/* (db, experiment_metadata)

Telemetry tests:
- tests/unit/telemetry/* -> internal/telemetry/* (telemetry)

Tracking tests:
- tests/unit/reproducibility/* -> internal/tracking/* (config_hash, environment_capture)

Worker tests:
- tests/unit/worker/* -> internal/worker/* (artifacts, config, hash_bench, plugins/jupyter_task, plugins/vllm, prewarm_v1, run_manifest_execution, snapshot_stage, snapshot_store, worker)

Update import paths in test files to reflect new locations.
2026-03-12 16:37:03 -04:00

613 lines
16 KiB
Go

package worker
import (
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// logWarningf logs a warning message using slog with proper sanitization
func logWarningf(format string, args ...any) {
// Use structured logging to avoid log injection
// Format the message first, then log as a single string attribute
msg := fmt.Sprintf(format, args...)
slog.Warn("warning", "message", msg)
}
// GPUType represents different GPU types
type GPUType string
const (
GPUTypeNVIDIA GPUType = "nvidia"
GPUTypeAMD GPUType = "amd"
GPUTypeApple GPUType = "apple"
GPUTypeNone GPUType = "none"
)
// DetectionSource indicates how the GPU detector was selected
type DetectionSource string
const (
DetectionSourceEnvType DetectionSource = "env_override_type"
DetectionSourceEnvCount DetectionSource = "env_override_count"
DetectionSourceEnvBoth DetectionSource = "env_override_both"
DetectionSourceConfig DetectionSource = "config"
DetectionSourceAuto DetectionSource = "auto"
DetectionSourceNone DetectionSource = "none"
)
// GPUDetectionInfo provides metadata about how GPU detection was determined
type GPUDetectionInfo struct {
GPUType GPUType `json:"gpu_type"`
ConfiguredVendor string `json:"configured_vendor"`
DetectionMethod DetectionSource `json:"detection_method"`
EnvOverrideType string `json:"env_override_type,omitempty"`
EnvOverrideCount int `json:"env_override_count,omitempty"`
ConfigLayerAutoDetected bool `json:"config_layer_auto_detected,omitempty"`
}
// GPUDetector interface for detecting GPU availability
type GPUDetector interface {
DetectGPUCount() int
GetGPUType() GPUType
GetDevicePaths() []string
}
// NVIDIA GPUDetector implementation
type NVIDIADetector struct{}
func (d *NVIDIADetector) DetectGPUCount() int {
// First try NVML for accurate detection
if IsNVMLAvailable() {
count, err := GetGPUCount()
if err == nil && count > 0 {
return count
}
}
// Fall back to environment variable
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
return n
}
return 0
}
func (d *NVIDIADetector) GetGPUType() GPUType {
return GPUTypeNVIDIA
}
func (d *NVIDIADetector) GetDevicePaths() []string {
// Prefer standard NVIDIA device nodes when present.
patterns := []string{
"/dev/nvidiactl",
"/dev/nvidia-modeset",
"/dev/nvidia-uvm",
"/dev/nvidia-uvm-tools",
"/dev/nvidia*",
}
seen := make(map[string]struct{})
out := make([]string, 0, 8)
for _, pat := range patterns {
if filepath.Base(pat) == pat {
continue
}
if strings.Contains(pat, "*") {
matches, _ := filepath.Glob(pat)
for _, m := range matches {
if _, ok := seen[m]; ok {
continue
}
if _, err := os.Stat(m); err != nil {
continue
}
seen[m] = struct{}{}
out = append(out, m)
}
continue
}
if _, ok := seen[pat]; ok {
continue
}
if _, err := os.Stat(pat); err != nil {
continue
}
seen[pat] = struct{}{}
out = append(out, pat)
}
// Fallback for non-NVIDIA setups where only generic DRM device exists.
if len(out) == 0 {
if _, err := os.Stat("/dev/dri"); err == nil {
out = append(out, "/dev/dri")
}
}
return out
}
// Apple M-series GPUDetector implementation
type AppleDetector struct {
enabled bool
}
func (d *AppleDetector) DetectGPUCount() int {
// First try actual macOS GPU detection
if IsMacOS() {
count, err := GetMacOSGPUCount()
if err == nil && count > 0 {
return count
}
}
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
return n
}
if d.enabled {
return 1
}
return 0
}
func (d *AppleDetector) GetGPUType() GPUType {
return GPUTypeApple
}
func (d *AppleDetector) GetDevicePaths() []string {
return []string{"/dev/metal", "/dev/mps"}
}
// None GPUDetector implementation
type NoneDetector struct{}
func (d *NoneDetector) DetectGPUCount() int {
return 0
}
func (d *NoneDetector) GetGPUType() GPUType {
return GPUTypeNone
}
func (d *NoneDetector) GetDevicePaths() []string {
return nil
}
// GPUDetectorFactory creates appropriate GPU detector based config
type GPUDetectorFactory struct{}
// DetectionResult contains both the detector and metadata about how it was selected
type DetectionResult struct {
Detector GPUDetector
Info GPUDetectionInfo
}
// Validate checks if the detection result is valid and returns an error if not.
// This ensures users get clear error messages for unimplemented features like AMD GPU.
func (r DetectionResult) Validate() error {
if r.Detector == nil {
switch r.Info.ConfiguredVendor {
case "amd":
return errors.New(
"AMD GPU support is not yet implemented. " +
"Use NVIDIA GPUs, Apple Silicon, or CPU-only mode. " +
"For development/testing, use FETCH_ML_MOCK_GPU_TYPE=AMD",
)
default:
return fmt.Errorf("GPU detection failed for vendor %q", r.Info.ConfiguredVendor)
}
}
return nil
}
func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector {
result := f.CreateDetectorWithInfo(cfg)
return result.Detector
}
func (f *GPUDetectorFactory) CreateDetectorWithInfo(cfg *Config) DetectionResult {
// Check for explicit environment overrides
envType := os.Getenv("FETCH_ML_GPU_TYPE")
envCount, hasEnvCount := envInt("FETCH_ML_GPU_COUNT")
if envType != "" && hasEnvCount {
// Both env vars set
logEnvOverride("FETCH_ML_GPU_TYPE", envType)
logEnvOverride("FETCH_ML_GPU_COUNT", envCount)
switch envType {
case string(GPUTypeNVIDIA):
return DetectionResult{
Detector: &NVIDIADetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNVIDIA,
ConfiguredVendor: "nvidia",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
case string(GPUTypeApple):
return DetectionResult{
Detector: &AppleDetector{enabled: true},
Info: GPUDetectionInfo{
GPUType: GPUTypeApple,
ConfiguredVendor: "apple",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
case string(GPUTypeNone):
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
case "amd":
// AMD GPU support not yet implemented
// Return error so user knows this is a known limitation
return DetectionResult{
Detector: nil, // Will cause error when used
Info: GPUDetectionInfo{
GPUType: GPUTypeAMD,
ConfiguredVendor: "amd",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
default:
// Defensive: unknown env type should not silently fall through
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvBoth,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
},
}
}
}
if envType != "" {
// Only FETCH_ML_GPU_TYPE set
logEnvOverride("FETCH_ML_GPU_TYPE", envType)
switch envType {
case string(GPUTypeNVIDIA):
return DetectionResult{
Detector: &NVIDIADetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNVIDIA,
ConfiguredVendor: "nvidia",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
case string(GPUTypeApple):
return DetectionResult{
Detector: &AppleDetector{enabled: true},
Info: GPUDetectionInfo{
GPUType: GPUTypeApple,
ConfiguredVendor: "apple",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
case string(GPUTypeNone):
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
case "amd":
// AMD GPU support not yet implemented
return DetectionResult{
Detector: nil,
Info: GPUDetectionInfo{
GPUType: GPUTypeAMD,
ConfiguredVendor: "amd",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
default:
// Defensive: unknown env type should not silently fall through
logWarningf("unrecognized FETCH_ML_GPU_TYPE value %q, using no GPU", envType)
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: DetectionSourceEnvType,
EnvOverrideType: envType,
},
}
}
}
if hasEnvCount {
// Only FETCH_ML_GPU_COUNT set - need to detect vendor from config or auto
logEnvOverride("FETCH_ML_GPU_COUNT", envCount)
return f.detectFromConfigWithSource(cfg, DetectionSourceEnvCount, "", envCount)
}
// No env overrides - detect from config
return f.detectFromConfigWithSource(cfg, DetectionSourceConfig, "", -1)
}
func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source DetectionSource, envType string, envCount int) DetectionResult {
if cfg == nil {
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: source,
},
}
}
// Check for auto-detection scenarios (GPUDevices provided or AppleGPU enabled without explicit vendor)
isAutoDetect := cfg.GPUVendorAutoDetected ||
(len(cfg.GPUDevices) > 0 && cfg.GPUVendor == "") ||
(cfg.AppleGPU.Enabled && cfg.GPUVendor == "")
if isAutoDetect && source == DetectionSourceConfig {
source = DetectionSourceAuto
}
switch GPUType(cfg.GPUVendor) {
case GPUTypeApple:
return DetectionResult{
Detector: &AppleDetector{enabled: cfg.AppleGPU.Enabled},
Info: GPUDetectionInfo{
GPUType: GPUTypeApple,
ConfiguredVendor: "apple",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
case GPUTypeNone:
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
case GPUTypeNVIDIA:
return DetectionResult{
Detector: &NVIDIADetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNVIDIA,
ConfiguredVendor: "nvidia",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
case "amd":
// AMD GPU support not yet implemented - tracked in roadmap
logWarningf("AMD GPU detection requested but not yet implemented. Consider using MOCK mode or contributing to the project.")
return DetectionResult{
Detector: nil,
Info: GPUDetectionInfo{
GPUType: GPUTypeAMD,
ConfiguredVendor: "amd",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
default:
// SECURITY: Explicit default prevents silent misconfiguration
// Unknown GPU vendor is treated as no GPU - fail secure
// Note: Config.Validate() should catch invalid vendors before this point
logWarningf("unrecognized GPU vendor %q, using no GPU", cfg.GPUVendor)
return DetectionResult{
Detector: &NoneDetector{},
Info: GPUDetectionInfo{
GPUType: GPUTypeNone,
ConfiguredVendor: "none",
DetectionMethod: source,
EnvOverrideType: envType,
EnvOverrideCount: envCount,
ConfigLayerAutoDetected: cfg.GPUVendorAutoDetected,
},
}
}
}
// DetectCapabilities returns full WorkerCapabilities with backend detection
// It tries NVIDIA first, then Metal (Apple Silicon), then Vulkan, then CPU fallback
func DetectCapabilities() scheduler.WorkerCapabilities {
// Try NVIDIA first
if IsNVMLAvailable() {
count, err := GetGPUCount()
if err == nil && count > 0 {
gpus, err := GetAllGPUInfo()
if err == nil && len(gpus) > 0 {
totalVRAM := float64(0)
for _, gpu := range gpus {
totalVRAM += float64(gpu.MemoryTotal) / (1024 * 1024 * 1024) // Convert to GB
}
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: count,
GPUType: gpus[0].Name,
VRAMGB: totalVRAM,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "nvidia",
Count: count,
Devices: getNVIDIADevices(),
MemTotal: gpus[0].MemoryTotal,
},
}
}
}
}
// Try Metal (macOS Apple Silicon)
if runtime.GOOS == "darwin" && IsAppleSilicon() {
gpus, err := GetMacOSGPUInfo()
if err == nil && len(gpus) > 0 {
totalVRAM := float64(0)
for _, gpu := range gpus {
totalVRAM += float64(gpu.VRAM_MB) / 1024 // Convert MB to GB
}
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendMetal,
GPUCount: len(gpus),
GPUType: gpus[0].ChipsetModel,
VRAMGB: totalVRAM,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "apple",
Count: len(gpus),
MemTotal: uint64(totalVRAM * 1024 * 1024 * 1024),
},
}
}
}
// Try Vulkan (check for vulkaninfo or /dev/dri)
if hasVulkan() {
count := getVulkanGPUCount()
if count > 0 {
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendVulkan,
GPUCount: count,
GPUType: "vulkan",
VRAMGB: 0, // TODO: Query Vulkan for VRAM
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "vulkan",
Count: count,
Devices: getVulkanDevices(),
},
}
}
}
// CPU fallback
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendCPU,
GPUCount: 0,
GPUType: "cpu",
VRAMGB: 0,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "cpu",
Count: 0,
},
}
}
// hasVulkan checks if Vulkan is available
func hasVulkan() bool {
// Check for vulkaninfo binary
if _, err := exec.LookPath("vulkaninfo"); err == nil {
return true
}
// Check for /dev/dri
if _, err := os.Stat("/dev/dri"); err == nil {
return true
}
return false
}
// getVulkanGPUCount returns the number of Vulkan GPUs
func getVulkanGPUCount() int {
// Try to get GPU count from vulkaninfo
out, err := exec.Command("vulkaninfo", "--summary").Output()
if err != nil {
return 1 // Assume 1 if vulkaninfo fails but /dev/dri exists
}
// Count "GPU" occurrences in output
count := 0
for _, line := range strings.Split(string(out), "\n") {
if strings.Contains(line, "GPU") && strings.Contains(line, "deviceName") {
count++
}
}
if count == 0 {
return 1 // Assume at least 1
}
return count
}
// getVulkanDevices returns Vulkan device paths
func getVulkanDevices() []string {
if _, err := os.Stat("/dev/dri"); err == nil {
return []string{"/dev/dri"}
}
return nil
}
// getNVIDIADevices returns NVIDIA device paths
func getNVIDIADevices() []string {
patterns := []string{"/dev/nvidia*"}
var devices []string
for _, pat := range patterns {
matches, _ := filepath.Glob(pat)
for _, m := range matches {
if _, err := os.Stat(m); err == nil {
devices = append(devices, m)
}
}
}
return devices
}
// getSystemMemoryGB returns system memory in GB
func getSystemMemoryGB() float64 {
// Try to read from /proc/meminfo on Linux
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "MemTotal:") {
var kb uint64
fmt.Sscanf(line, "MemTotal: %d kB", &kb)
return float64(kb) / (1024 * 1024) // Convert KB to GB
}
}
}
// Fallback: return 0 to indicate unknown
return 0
}
// getHostname returns the system hostname
func getHostname() string {
hostname, _ := os.Hostname()
return hostname
}