168 lines
3.5 KiB
Go
168 lines
3.5 KiB
Go
package worker
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
// GPUType represents different GPU types
|
|
type GPUType string
|
|
|
|
const (
|
|
GPUTypeNVIDIA GPUType = "nvidia"
|
|
GPUTypeApple GPUType = "apple"
|
|
GPUTypeNone GPUType = "none"
|
|
)
|
|
|
|
// GPUDetector interface for detecting GPU availability
|
|
type GPUDetector interface {
|
|
DetectGPUCount() int
|
|
GetGPUType() GPUType
|
|
GetDevicePaths() []string
|
|
}
|
|
|
|
// NVIDIA GPUDetector implementation
|
|
type NVIDIADetector struct{}
|
|
|
|
func (d *NVIDIADetector) DetectGPUCount() int {
|
|
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
|
|
return n
|
|
}
|
|
// Could use nvidia-sml or other detection methods here
|
|
return 0
|
|
}
|
|
|
|
func (d *NVIDIADetector) GetGPUType() GPUType {
|
|
return GPUTypeNVIDIA
|
|
}
|
|
|
|
func (d *NVIDIADetector) GetDevicePaths() []string {
|
|
// Prefer standard NVIDIA device nodes when present.
|
|
patterns := []string{
|
|
"/dev/nvidiactl",
|
|
"/dev/nvidia-modeset",
|
|
"/dev/nvidia-uvm",
|
|
"/dev/nvidia-uvm-tools",
|
|
"/dev/nvidia*",
|
|
}
|
|
seen := make(map[string]struct{})
|
|
out := make([]string, 0, 8)
|
|
for _, pat := range patterns {
|
|
if filepath.Base(pat) == pat {
|
|
continue
|
|
}
|
|
if strings.Contains(pat, "*") {
|
|
matches, _ := filepath.Glob(pat)
|
|
for _, m := range matches {
|
|
if _, ok := seen[m]; ok {
|
|
continue
|
|
}
|
|
if _, err := os.Stat(m); err != nil {
|
|
continue
|
|
}
|
|
seen[m] = struct{}{}
|
|
out = append(out, m)
|
|
}
|
|
continue
|
|
}
|
|
if _, ok := seen[pat]; ok {
|
|
continue
|
|
}
|
|
if _, err := os.Stat(pat); err != nil {
|
|
continue
|
|
}
|
|
seen[pat] = struct{}{}
|
|
out = append(out, pat)
|
|
}
|
|
// Fallback for non-NVIDIA setups where only generic DRM device exists.
|
|
if len(out) == 0 {
|
|
if _, err := os.Stat("/dev/dri"); err == nil {
|
|
out = append(out, "/dev/dri")
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// Apple M-series GPUDetector implementation
|
|
type AppleDetector struct {
|
|
enabled bool
|
|
}
|
|
|
|
func (d *AppleDetector) DetectGPUCount() int {
|
|
if n, ok := envInt("FETCH_ML_GPU_COUNT"); ok && n >= 0 {
|
|
return n
|
|
}
|
|
if d.enabled {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *AppleDetector) GetGPUType() GPUType {
|
|
return GPUTypeApple
|
|
}
|
|
|
|
func (d *AppleDetector) GetDevicePaths() []string {
|
|
return []string{"/dev/metal", "/dev/mps"}
|
|
}
|
|
|
|
// None GPUDetector implementation
|
|
type NoneDetector struct{}
|
|
|
|
func (d *NoneDetector) DetectGPUCount() int {
|
|
return 0
|
|
}
|
|
|
|
func (d *NoneDetector) GetGPUType() GPUType {
|
|
return GPUTypeNone
|
|
}
|
|
|
|
func (d *NoneDetector) GetDevicePaths() []string {
|
|
return nil
|
|
}
|
|
|
|
// GPUDetectorFactory creates appropriate GPU detector based on config
|
|
type GPUDetectorFactory struct{}
|
|
|
|
func (f *GPUDetectorFactory) CreateDetector(cfg *Config) GPUDetector {
|
|
// Check for explicit environment override
|
|
if gpuType := os.Getenv("FETCH_ML_GPU_TYPE"); gpuType != "" {
|
|
switch gpuType {
|
|
case string(GPUTypeNVIDIA):
|
|
return &NVIDIADetector{}
|
|
case string(GPUTypeApple):
|
|
return &AppleDetector{enabled: true}
|
|
case string(GPUTypeNone):
|
|
return &NoneDetector{}
|
|
}
|
|
}
|
|
|
|
// Respect configured vendor when explicitly set.
|
|
if cfg != nil {
|
|
switch GPUType(cfg.GPUVendor) {
|
|
case GPUTypeApple:
|
|
return &AppleDetector{enabled: cfg.AppleGPU.Enabled}
|
|
case GPUTypeNone:
|
|
return &NoneDetector{}
|
|
case GPUTypeNVIDIA:
|
|
return &NVIDIADetector{}
|
|
case "amd":
|
|
// AMD uses similar device exposure patterns in this codebase.
|
|
return &NVIDIADetector{}
|
|
}
|
|
}
|
|
|
|
// Auto-detect based on config
|
|
if cfg != nil {
|
|
if cfg.AppleGPU.Enabled {
|
|
return &AppleDetector{enabled: true}
|
|
}
|
|
if len(cfg.GPUDevices) > 0 {
|
|
return &NVIDIADetector{}
|
|
}
|
|
}
|
|
|
|
// Default to no GPU
|
|
return &NoneDetector{}
|
|
}
|