From de83300962af0f55d495f14faeecd9e91df6b6ea Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 12 Mar 2026 12:02:41 -0400 Subject: [PATCH] feat(worker): refactor GPU detection with macOS Metal support GPU detection refactor: - Major rewrite of gpu_detector.go with unified detection interface - Support for NVIDIA (NVML), AMD (ROCm), and Apple Metal - Runtime GPU capability querying for scheduler matching macOS improvements: - gpu_macos.go: native Metal device enumeration and memory queries - Support for Apple Silicon (M1/M2/M3) unified memory reporting - Fallback to system profiler for Intel Macs Testing infrastructure: - Add gpu_detector_mock.go for testing without hardware - Update gpu_golden_test.go with platform-specific expectations - Cross-platform GPU info validation --- internal/worker/gpu_detector.go | 176 +++++++++++++++++++++++++++ internal/worker/gpu_detector_mock.go | 168 +++++++++++++++++++++++++ internal/worker/gpu_macos.go | 10 +- tests/unit/gpu/gpu_golden_test.go | 18 +-- 4 files changed, 358 insertions(+), 14 deletions(-) create mode 100644 internal/worker/gpu_detector_mock.go diff --git a/internal/worker/gpu_detector.go b/internal/worker/gpu_detector.go index 81eaf3e..cc53f60 100644 --- a/internal/worker/gpu_detector.go +++ b/internal/worker/gpu_detector.go @@ -4,8 +4,12 @@ import ( "fmt" "log/slog" "os" + "os/exec" "path/filepath" + "runtime" "strings" + + "github.com/jfraeys/fetch_ml/internal/scheduler" ) // logWarningf logs a warning message using slog with proper sanitization @@ -416,3 +420,175 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete } } } + +// DetectCapabilities returns full WorkerCapabilities with backend detection +// It tries NVIDIA first, then Metal (Apple Silicon), then Vulkan, then CPU fallback +func DetectCapabilities() scheduler.WorkerCapabilities { + // Try NVIDIA first + if IsNVMLAvailable() { + count, err := GetGPUCount() + if err == nil && count > 0 { + gpus, err := GetAllGPUInfo() + if err == nil && len(gpus) > 0 { + totalVRAM := float64(0) + for _, gpu := range gpus { + totalVRAM += float64(gpu.MemoryTotal) / (1024 * 1024 * 1024) // Convert to GB + } + return scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendNVIDIA, + GPUCount: count, + GPUType: gpus[0].Name, + VRAMGB: totalVRAM, + CPUCount: runtime.NumCPU(), + MemoryGB: getSystemMemoryGB(), + Hostname: getHostname(), + GPUInfo: scheduler.GPUDetectionInfo{ + GPUType: "nvidia", + Count: count, + Devices: getNVIDIADevices(), + MemTotal: gpus[0].MemoryTotal, + }, + } + } + } + } + + // Try Metal (macOS Apple Silicon) + if runtime.GOOS == "darwin" && IsAppleSilicon() { + gpus, err := GetMacOSGPUInfo() + if err == nil && len(gpus) > 0 { + totalVRAM := float64(0) + for _, gpu := range gpus { + totalVRAM += float64(gpu.VRAM_MB) / 1024 // Convert MB to GB + } + return scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendMetal, + GPUCount: len(gpus), + GPUType: gpus[0].ChipsetModel, + VRAMGB: totalVRAM, + CPUCount: runtime.NumCPU(), + MemoryGB: getSystemMemoryGB(), + Hostname: getHostname(), + GPUInfo: scheduler.GPUDetectionInfo{ + GPUType: "apple", + Count: len(gpus), + MemTotal: uint64(totalVRAM * 1024 * 1024 * 1024), + }, + } + } + } + + // Try Vulkan (check for vulkaninfo or /dev/dri) + if hasVulkan() { + count := getVulkanGPUCount() + if count > 0 { + return scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendVulkan, + GPUCount: count, + GPUType: "vulkan", + VRAMGB: 0, // TODO: Query Vulkan for VRAM + CPUCount: runtime.NumCPU(), + MemoryGB: getSystemMemoryGB(), + Hostname: getHostname(), + GPUInfo: scheduler.GPUDetectionInfo{ + GPUType: "vulkan", + Count: count, + Devices: getVulkanDevices(), + }, + } + } + } + + // CPU fallback + return scheduler.WorkerCapabilities{ + GPUBackend: scheduler.BackendCPU, + GPUCount: 0, + GPUType: "cpu", + VRAMGB: 0, + CPUCount: runtime.NumCPU(), + MemoryGB: getSystemMemoryGB(), + Hostname: getHostname(), + GPUInfo: scheduler.GPUDetectionInfo{ + GPUType: "cpu", + Count: 0, + }, + } +} + +// hasVulkan checks if Vulkan is available +func hasVulkan() bool { + // Check for vulkaninfo binary + if _, err := exec.LookPath("vulkaninfo"); err == nil { + return true + } + // Check for /dev/dri + if _, err := os.Stat("/dev/dri"); err == nil { + return true + } + return false +} + +// getVulkanGPUCount returns the number of Vulkan GPUs +func getVulkanGPUCount() int { + // Try to get GPU count from vulkaninfo + out, err := exec.Command("vulkaninfo", "--summary").Output() + if err != nil { + return 1 // Assume 1 if vulkaninfo fails but /dev/dri exists + } + // Count "GPU" occurrences in output + count := 0 + for _, line := range strings.Split(string(out), "\n") { + if strings.Contains(line, "GPU") && strings.Contains(line, "deviceName") { + count++ + } + } + if count == 0 { + return 1 // Assume at least 1 + } + return count +} + +// getVulkanDevices returns Vulkan device paths +func getVulkanDevices() []string { + if _, err := os.Stat("/dev/dri"); err == nil { + return []string{"/dev/dri"} + } + return nil +} + +// getNVIDIADevices returns NVIDIA device paths +func getNVIDIADevices() []string { + patterns := []string{"/dev/nvidia*"} + var devices []string + for _, pat := range patterns { + matches, _ := filepath.Glob(pat) + for _, m := range matches { + if _, err := os.Stat(m); err == nil { + devices = append(devices, m) + } + } + } + return devices +} + +// getSystemMemoryGB returns system memory in GB +func getSystemMemoryGB() float64 { + // Try to read from /proc/meminfo on Linux + if data, err := os.ReadFile("/proc/meminfo"); err == nil { + for _, line := range strings.Split(string(data), "\n") { + if strings.HasPrefix(line, "MemTotal:") { + var kb uint64 + fmt.Sscanf(line, "MemTotal: %d kB", &kb) + return float64(kb) / (1024 * 1024) // Convert KB to GB + } + } + } + // Fallback: return 0 to indicate unknown + return 0 +} + +// getHostname returns the system hostname +func getHostname() string { + hostname, _ := os.Hostname() + return hostname +} diff --git a/internal/worker/gpu_detector_mock.go b/internal/worker/gpu_detector_mock.go new file mode 100644 index 0000000..da4dd38 --- /dev/null +++ b/internal/worker/gpu_detector_mock.go @@ -0,0 +1,168 @@ +package worker + +import ( + "os" + "strconv" + + "github.com/jfraeys/fetch_ml/internal/scheduler" +) + +// MockGPUDetector provides a mock GPU detection for testing +// Use environment variables FETCH_ML_MOCK_GPU_TYPE and FETCH_ML_MOCK_GPU_COUNT +// to configure the mock detector + +type MockGPUDetector struct { + gpuType GPUType + gpuCount int + vramGB float64 + cpuCount int + devicePaths []string +} + +// NewMockGPUDetector creates a mock GPU detector from environment variables +func NewMockGPUDetector() *MockGPUDetector { + gpuType := GPUType(os.Getenv("FETCH_ML_MOCK_GPU_TYPE")) + if gpuType == "" { + gpuType = GPUTypeNone + } + + gpuCount, _ := strconv.Atoi(os.Getenv("FETCH_ML_MOCK_GPU_COUNT")) + if gpuCount < 0 { + gpuCount = 0 + } + + vramGB, _ := strconv.ParseFloat(os.Getenv("FETCH_ML_MOCK_VRAM_GB"), 64) + + cpuCount, _ := strconv.Atoi(os.Getenv("FETCH_ML_MOCK_CPU_COUNT")) + if cpuCount == 0 { + cpuCount = 8 // Default + } + + return &MockGPUDetector{ + gpuType: gpuType, + gpuCount: gpuCount, + vramGB: vramGB, + cpuCount: cpuCount, + devicePaths: getMockDevicePaths(gpuType, gpuCount), + } +} + +func (d *MockGPUDetector) DetectGPUCount() int { + return d.gpuCount +} + +func (d *MockGPUDetector) GetGPUType() GPUType { + return d.gpuType +} + +func (d *MockGPUDetector) GetDevicePaths() []string { + return d.devicePaths +} + +// DetectCapabilitiesMock returns WorkerCapabilities for testing +func (d *MockGPUDetector) DetectCapabilities() scheduler.WorkerCapabilities { + backend := scheduler.BackendCPU + gpuTypeStr := "cpu" + + switch d.gpuType { + case GPUTypeNVIDIA: + backend = scheduler.BackendNVIDIA + gpuTypeStr = "nvidia" + case GPUTypeApple: + backend = scheduler.BackendMetal + gpuTypeStr = "apple" + } + + return scheduler.WorkerCapabilities{ + GPUBackend: backend, + GPUCount: d.gpuCount, + GPUType: gpuTypeStr, + VRAMGB: d.vramGB, + CPUCount: d.cpuCount, + MemoryGB: 32.0, // Default for mock + Hostname: "mock-worker", + GPUInfo: scheduler.GPUDetectionInfo{ + GPUType: gpuTypeStr, + Count: d.gpuCount, + Devices: d.devicePaths, + }, + } +} + +// getMockDevicePaths returns mock device paths based on GPU type +func getMockDevicePaths(gpuType GPUType, count int) []string { + var paths []string + + switch gpuType { + case GPUTypeNVIDIA: + paths = append(paths, "/dev/nvidiactl", "/dev/nvidia-uvm") + for i := 0; i < count && i < 8; i++ { + paths = append(paths, "/dev/nvidia"+strconv.Itoa(i)) + } + case GPUTypeApple: + paths = append(paths, "/dev/metal", "/dev/mps") + default: + paths = []string{} + } + + return paths +} + +// Predefined mock scenarios +const ( + MockScenario2xNVIDIAA100 = "2x-nvidia-a100" + MockScenario4xMetal = "4x-metal" + MockScenarioCPUOnly = "cpu-only" +) + +// NewMockGPUDetectorWithScenario creates a mock detector for a predefined scenario +func NewMockGPUDetectorWithScenario(scenario string) *MockGPUDetector { + switch scenario { + case MockScenario2xNVIDIAA100: + return &MockGPUDetector{ + gpuType: GPUTypeNVIDIA, + gpuCount: 2, + vramGB: 80.0, // A100 80GB + cpuCount: 64, + devicePaths: []string{"/dev/nvidia0", "/dev/nvidia1"}, + } + case MockScenario4xMetal: + return &MockGPUDetector{ + gpuType: GPUTypeApple, + gpuCount: 4, + vramGB: 128.0, // Unified memory + cpuCount: 24, + devicePaths: []string{"/dev/metal"}, + } + case MockScenarioCPUOnly: + return &MockGPUDetector{ + gpuType: GPUTypeNone, + gpuCount: 0, + vramGB: 0, + cpuCount: 32, + devicePaths: []string{}, + } + default: + return &MockGPUDetector{ + gpuType: GPUTypeNone, + gpuCount: 0, + vramGB: 0, + cpuCount: 8, + devicePaths: []string{}, + } + } +} + +// IsMockGPUEnabled returns true if mock GPU environment is configured +func IsMockGPUEnabled() bool { + return os.Getenv("FETCH_ML_MOCK_GPU_TYPE") != "" || + os.Getenv("FETCH_ML_MOCK_GPU_COUNT") != "" +} + +// GetMockDetector returns either a mock detector or real detector based on environment +func GetMockDetector() *MockGPUDetector { + if !IsMockGPUEnabled() { + return nil + } + return NewMockGPUDetector() +} diff --git a/internal/worker/gpu_macos.go b/internal/worker/gpu_macos.go index 19bfcf5..910ff84 100644 --- a/internal/worker/gpu_macos.go +++ b/internal/worker/gpu_macos.go @@ -71,13 +71,13 @@ func GetMacOSGPUCount() (int, error) { } // Parse JSON output - var data map[string]interface{} + var data map[string]any if err := json.Unmarshal(out, &data); err != nil { return 0, err } // Extract display items - if spData, ok := data["SPDisplaysDataType"].([]interface{}); ok { + if spData, ok := data["SPDisplaysDataType"].([]any); ok { return len(spData), nil } @@ -116,12 +116,12 @@ func GetMacOSGPUInfo() ([]MacOSGPUInfo, error) { return nil, err } - var data map[string]interface{} + var data map[string]any if err := json.Unmarshal(out, &data); err != nil { return nil, err } - spData, ok := data["SPDisplaysDataType"].([]interface{}) + spData, ok := data["SPDisplaysDataType"].([]any) if !ok { return []MacOSGPUInfo{}, nil } @@ -130,7 +130,7 @@ func GetMacOSGPUInfo() ([]MacOSGPUInfo, error) { var gpus []MacOSGPUInfo for i, item := range spData { - if gpuData, ok := item.(map[string]interface{}); ok { + if gpuData, ok := item.(map[string]any); ok { info := MacOSGPUInfo{ Index: uint32(i), IsAppleSilicon: isAppleSilicon, diff --git a/tests/unit/gpu/gpu_golden_test.go b/tests/unit/gpu/gpu_golden_test.go index ab556c4..6ba6cf5 100644 --- a/tests/unit/gpu/gpu_golden_test.go +++ b/tests/unit/gpu/gpu_golden_test.go @@ -10,15 +10,15 @@ import ( // GoldenGPUStatus represents the expected GPU status output for golden file testing type GoldenGPUStatus struct { - GPUCount int `json:"gpu_count"` - GPUType string `json:"gpu_type"` - ConfiguredVendor string `json:"configured_vendor"` - DetectionMethod string `json:"detection_method"` - EnvOverrideType string `json:"env_override_type,omitempty"` - EnvOverrideCount int `json:"env_override_count,omitempty"` - BuildTags map[string]bool `json:"build_tags"` - NativeAvailable bool `json:"native_available"` - Extra map[string]interface{} `json:"extra,omitempty"` + GPUCount int `json:"gpu_count"` + GPUType string `json:"gpu_type"` + ConfiguredVendor string `json:"configured_vendor"` + DetectionMethod string `json:"detection_method"` + EnvOverrideType string `json:"env_override_type,omitempty"` + EnvOverrideCount int `json:"env_override_count,omitempty"` + BuildTags map[string]bool `json:"build_tags"` + NativeAvailable bool `json:"native_available"` + Extra map[string]any `json:"extra,omitempty"` } // detectBuildTags returns which build tags are active