feat(worker): refactor GPU detection with macOS Metal support

GPU detection refactor:
- Major rewrite of gpu_detector.go with unified detection interface
- Support for NVIDIA (NVML), AMD (ROCm), and Apple Metal
- Runtime GPU capability querying for scheduler matching

macOS improvements:
- gpu_macos.go: native Metal device enumeration and memory queries
- Support for Apple Silicon (M1/M2/M3) unified memory reporting
- Fallback to system profiler for Intel Macs

Testing infrastructure:
- Add gpu_detector_mock.go for testing without hardware
- Update gpu_golden_test.go with platform-specific expectations
- Cross-platform GPU info validation
This commit is contained in:
Jeremie Fraeys 2026-03-12 12:02:41 -04:00
parent 188cf55939
commit de83300962
No known key found for this signature in database
4 changed files with 358 additions and 14 deletions

View file

@ -4,8 +4,12 @@ import (
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// logWarningf logs a warning message using slog with proper sanitization
@ -416,3 +420,175 @@ func (f *GPUDetectorFactory) detectFromConfigWithSource(cfg *Config, source Dete
}
}
}
// DetectCapabilities returns full WorkerCapabilities with backend detection
// It tries NVIDIA first, then Metal (Apple Silicon), then Vulkan, then CPU fallback
func DetectCapabilities() scheduler.WorkerCapabilities {
// Try NVIDIA first
if IsNVMLAvailable() {
count, err := GetGPUCount()
if err == nil && count > 0 {
gpus, err := GetAllGPUInfo()
if err == nil && len(gpus) > 0 {
totalVRAM := float64(0)
for _, gpu := range gpus {
totalVRAM += float64(gpu.MemoryTotal) / (1024 * 1024 * 1024) // Convert to GB
}
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendNVIDIA,
GPUCount: count,
GPUType: gpus[0].Name,
VRAMGB: totalVRAM,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "nvidia",
Count: count,
Devices: getNVIDIADevices(),
MemTotal: gpus[0].MemoryTotal,
},
}
}
}
}
// Try Metal (macOS Apple Silicon)
if runtime.GOOS == "darwin" && IsAppleSilicon() {
gpus, err := GetMacOSGPUInfo()
if err == nil && len(gpus) > 0 {
totalVRAM := float64(0)
for _, gpu := range gpus {
totalVRAM += float64(gpu.VRAM_MB) / 1024 // Convert MB to GB
}
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendMetal,
GPUCount: len(gpus),
GPUType: gpus[0].ChipsetModel,
VRAMGB: totalVRAM,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "apple",
Count: len(gpus),
MemTotal: uint64(totalVRAM * 1024 * 1024 * 1024),
},
}
}
}
// Try Vulkan (check for vulkaninfo or /dev/dri)
if hasVulkan() {
count := getVulkanGPUCount()
if count > 0 {
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendVulkan,
GPUCount: count,
GPUType: "vulkan",
VRAMGB: 0, // TODO: Query Vulkan for VRAM
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "vulkan",
Count: count,
Devices: getVulkanDevices(),
},
}
}
}
// CPU fallback
return scheduler.WorkerCapabilities{
GPUBackend: scheduler.BackendCPU,
GPUCount: 0,
GPUType: "cpu",
VRAMGB: 0,
CPUCount: runtime.NumCPU(),
MemoryGB: getSystemMemoryGB(),
Hostname: getHostname(),
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: "cpu",
Count: 0,
},
}
}
// hasVulkan checks if Vulkan is available
func hasVulkan() bool {
// Check for vulkaninfo binary
if _, err := exec.LookPath("vulkaninfo"); err == nil {
return true
}
// Check for /dev/dri
if _, err := os.Stat("/dev/dri"); err == nil {
return true
}
return false
}
// getVulkanGPUCount returns the number of Vulkan GPUs
func getVulkanGPUCount() int {
// Try to get GPU count from vulkaninfo
out, err := exec.Command("vulkaninfo", "--summary").Output()
if err != nil {
return 1 // Assume 1 if vulkaninfo fails but /dev/dri exists
}
// Count "GPU" occurrences in output
count := 0
for _, line := range strings.Split(string(out), "\n") {
if strings.Contains(line, "GPU") && strings.Contains(line, "deviceName") {
count++
}
}
if count == 0 {
return 1 // Assume at least 1
}
return count
}
// getVulkanDevices returns Vulkan device paths
func getVulkanDevices() []string {
if _, err := os.Stat("/dev/dri"); err == nil {
return []string{"/dev/dri"}
}
return nil
}
// getNVIDIADevices returns NVIDIA device paths
func getNVIDIADevices() []string {
patterns := []string{"/dev/nvidia*"}
var devices []string
for _, pat := range patterns {
matches, _ := filepath.Glob(pat)
for _, m := range matches {
if _, err := os.Stat(m); err == nil {
devices = append(devices, m)
}
}
}
return devices
}
// getSystemMemoryGB returns system memory in GB
func getSystemMemoryGB() float64 {
// Try to read from /proc/meminfo on Linux
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "MemTotal:") {
var kb uint64
fmt.Sscanf(line, "MemTotal: %d kB", &kb)
return float64(kb) / (1024 * 1024) // Convert KB to GB
}
}
}
// Fallback: return 0 to indicate unknown
return 0
}
// getHostname returns the system hostname
func getHostname() string {
hostname, _ := os.Hostname()
return hostname
}

View file

@ -0,0 +1,168 @@
package worker
import (
"os"
"strconv"
"github.com/jfraeys/fetch_ml/internal/scheduler"
)
// MockGPUDetector provides a mock GPU detection for testing
// Use environment variables FETCH_ML_MOCK_GPU_TYPE and FETCH_ML_MOCK_GPU_COUNT
// to configure the mock detector
type MockGPUDetector struct {
gpuType GPUType
gpuCount int
vramGB float64
cpuCount int
devicePaths []string
}
// NewMockGPUDetector creates a mock GPU detector from environment variables
func NewMockGPUDetector() *MockGPUDetector {
gpuType := GPUType(os.Getenv("FETCH_ML_MOCK_GPU_TYPE"))
if gpuType == "" {
gpuType = GPUTypeNone
}
gpuCount, _ := strconv.Atoi(os.Getenv("FETCH_ML_MOCK_GPU_COUNT"))
if gpuCount < 0 {
gpuCount = 0
}
vramGB, _ := strconv.ParseFloat(os.Getenv("FETCH_ML_MOCK_VRAM_GB"), 64)
cpuCount, _ := strconv.Atoi(os.Getenv("FETCH_ML_MOCK_CPU_COUNT"))
if cpuCount == 0 {
cpuCount = 8 // Default
}
return &MockGPUDetector{
gpuType: gpuType,
gpuCount: gpuCount,
vramGB: vramGB,
cpuCount: cpuCount,
devicePaths: getMockDevicePaths(gpuType, gpuCount),
}
}
func (d *MockGPUDetector) DetectGPUCount() int {
return d.gpuCount
}
func (d *MockGPUDetector) GetGPUType() GPUType {
return d.gpuType
}
func (d *MockGPUDetector) GetDevicePaths() []string {
return d.devicePaths
}
// DetectCapabilitiesMock returns WorkerCapabilities for testing
func (d *MockGPUDetector) DetectCapabilities() scheduler.WorkerCapabilities {
backend := scheduler.BackendCPU
gpuTypeStr := "cpu"
switch d.gpuType {
case GPUTypeNVIDIA:
backend = scheduler.BackendNVIDIA
gpuTypeStr = "nvidia"
case GPUTypeApple:
backend = scheduler.BackendMetal
gpuTypeStr = "apple"
}
return scheduler.WorkerCapabilities{
GPUBackend: backend,
GPUCount: d.gpuCount,
GPUType: gpuTypeStr,
VRAMGB: d.vramGB,
CPUCount: d.cpuCount,
MemoryGB: 32.0, // Default for mock
Hostname: "mock-worker",
GPUInfo: scheduler.GPUDetectionInfo{
GPUType: gpuTypeStr,
Count: d.gpuCount,
Devices: d.devicePaths,
},
}
}
// getMockDevicePaths returns mock device paths based on GPU type
func getMockDevicePaths(gpuType GPUType, count int) []string {
var paths []string
switch gpuType {
case GPUTypeNVIDIA:
paths = append(paths, "/dev/nvidiactl", "/dev/nvidia-uvm")
for i := 0; i < count && i < 8; i++ {
paths = append(paths, "/dev/nvidia"+strconv.Itoa(i))
}
case GPUTypeApple:
paths = append(paths, "/dev/metal", "/dev/mps")
default:
paths = []string{}
}
return paths
}
// Predefined mock scenarios
const (
MockScenario2xNVIDIAA100 = "2x-nvidia-a100"
MockScenario4xMetal = "4x-metal"
MockScenarioCPUOnly = "cpu-only"
)
// NewMockGPUDetectorWithScenario creates a mock detector for a predefined scenario
func NewMockGPUDetectorWithScenario(scenario string) *MockGPUDetector {
switch scenario {
case MockScenario2xNVIDIAA100:
return &MockGPUDetector{
gpuType: GPUTypeNVIDIA,
gpuCount: 2,
vramGB: 80.0, // A100 80GB
cpuCount: 64,
devicePaths: []string{"/dev/nvidia0", "/dev/nvidia1"},
}
case MockScenario4xMetal:
return &MockGPUDetector{
gpuType: GPUTypeApple,
gpuCount: 4,
vramGB: 128.0, // Unified memory
cpuCount: 24,
devicePaths: []string{"/dev/metal"},
}
case MockScenarioCPUOnly:
return &MockGPUDetector{
gpuType: GPUTypeNone,
gpuCount: 0,
vramGB: 0,
cpuCount: 32,
devicePaths: []string{},
}
default:
return &MockGPUDetector{
gpuType: GPUTypeNone,
gpuCount: 0,
vramGB: 0,
cpuCount: 8,
devicePaths: []string{},
}
}
}
// IsMockGPUEnabled returns true if mock GPU environment is configured
func IsMockGPUEnabled() bool {
return os.Getenv("FETCH_ML_MOCK_GPU_TYPE") != "" ||
os.Getenv("FETCH_ML_MOCK_GPU_COUNT") != ""
}
// GetMockDetector returns either a mock detector or real detector based on environment
func GetMockDetector() *MockGPUDetector {
if !IsMockGPUEnabled() {
return nil
}
return NewMockGPUDetector()
}

View file

@ -71,13 +71,13 @@ func GetMacOSGPUCount() (int, error) {
}
// Parse JSON output
var data map[string]interface{}
var data map[string]any
if err := json.Unmarshal(out, &data); err != nil {
return 0, err
}
// Extract display items
if spData, ok := data["SPDisplaysDataType"].([]interface{}); ok {
if spData, ok := data["SPDisplaysDataType"].([]any); ok {
return len(spData), nil
}
@ -116,12 +116,12 @@ func GetMacOSGPUInfo() ([]MacOSGPUInfo, error) {
return nil, err
}
var data map[string]interface{}
var data map[string]any
if err := json.Unmarshal(out, &data); err != nil {
return nil, err
}
spData, ok := data["SPDisplaysDataType"].([]interface{})
spData, ok := data["SPDisplaysDataType"].([]any)
if !ok {
return []MacOSGPUInfo{}, nil
}
@ -130,7 +130,7 @@ func GetMacOSGPUInfo() ([]MacOSGPUInfo, error) {
var gpus []MacOSGPUInfo
for i, item := range spData {
if gpuData, ok := item.(map[string]interface{}); ok {
if gpuData, ok := item.(map[string]any); ok {
info := MacOSGPUInfo{
Index: uint32(i),
IsAppleSilicon: isAppleSilicon,

View file

@ -10,15 +10,15 @@ import (
// GoldenGPUStatus represents the expected GPU status output for golden file testing
type GoldenGPUStatus struct {
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type"`
ConfiguredVendor string `json:"configured_vendor"`
DetectionMethod string `json:"detection_method"`
EnvOverrideType string `json:"env_override_type,omitempty"`
EnvOverrideCount int `json:"env_override_count,omitempty"`
BuildTags map[string]bool `json:"build_tags"`
NativeAvailable bool `json:"native_available"`
Extra map[string]interface{} `json:"extra,omitempty"`
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type"`
ConfiguredVendor string `json:"configured_vendor"`
DetectionMethod string `json:"detection_method"`
EnvOverrideType string `json:"env_override_type,omitempty"`
EnvOverrideCount int `json:"env_override_count,omitempty"`
BuildTags map[string]bool `json:"build_tags"`
NativeAvailable bool `json:"native_available"`
Extra map[string]any `json:"extra,omitempty"`
}
// detectBuildTags returns which build tags are active