fetch_ml/tests/unit/gpu/gpu_detector_test.go
Jeremie Fraeys 3b194ff2e8
Some checks failed
Build CLI with Embedded SQLite / build (arm64, aarch64-linux) (push) Waiting to run
Build CLI with Embedded SQLite / build (x86_64, x86_64-linux) (push) Waiting to run
Build CLI with Embedded SQLite / build-macos (arm64) (push) Waiting to run
Build CLI with Embedded SQLite / build-macos (x86_64) (push) Waiting to run
Security Scan / Security Analysis (push) Waiting to run
Security Scan / Native Library Security (push) Waiting to run
Checkout test / test (push) Successful in 6s
CI/CD Pipeline / Test (push) Failing after 1s
CI/CD Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI/CD Pipeline / Build (push) Has been skipped
CI/CD Pipeline / Test Scripts (push) Has been skipped
CI/CD Pipeline / Test Native Libraries (push) Has been skipped
CI/CD Pipeline / GPU Golden Test Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 39s
CI/CD Pipeline / Docker Build (push) Has been skipped
feat: GPU detection transparency and artifact scanner improvements
- Surface GPUDetectionInfo from parseGPUCountFromConfig for detection metadata
- Document FETCH_ML_TOTAL_CPU and FETCH_ML_GPU_SLOTS_PER_GPU env vars
- Add debug logging for all env var overrides to stderr
- Track config-layer auto-detection in GPUDetectionInfo.ConfigLayerAutoDetected
- Add --include-all flag to artifact scanner (includeAll parameter)
- Add AMD production mode enforcement (error in non-local mode)
- Add GPU detector unit tests for env overrides and AMD aliasing
2026-02-23 12:29:34 -05:00

210 lines
6.1 KiB
Go

package worker_test
import (
"os"
"testing"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work
func TestGPUDetectorEnvOverrides(t *testing.T) {
tests := []struct {
name string
gpuType string
gpuCount string
wantType worker.GPUType
wantCount int
wantMethod worker.DetectionSource
wantConfigured string
}{
{
name: "env type only - nvidia",
gpuType: "nvidia",
wantType: worker.GPUTypeNVIDIA,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "nvidia",
},
{
name: "env type only - apple",
gpuType: "apple",
wantType: worker.GPUTypeApple,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "apple",
},
{
name: "env type only - none",
gpuType: "none",
wantType: worker.GPUTypeNone,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "none",
},
{
name: "both env vars set",
gpuType: "nvidia",
gpuCount: "4",
wantType: worker.GPUTypeNVIDIA,
wantMethod: worker.DetectionSourceEnvBoth,
wantConfigured: "nvidia",
},
{
name: "env type amd - shows amd configured vendor",
gpuType: "amd",
wantType: worker.GPUTypeAMD,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "amd",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set env vars
if tt.gpuType != "" {
os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType)
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
}
if tt.gpuCount != "" {
os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount)
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(nil)
if result.Info.GPUType != tt.wantType {
t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType)
}
if result.Info.DetectionMethod != tt.wantMethod {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod)
}
if result.Info.ConfiguredVendor != tt.wantConfigured {
t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured)
}
})
}
}
// TestGPUDetectorAMDVendorAlias validates AMD config shows proper aliasing
func TestGPUDetectorAMDVendorAlias(t *testing.T) {
cfg := &worker.Config{
GPUVendor: "amd",
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(cfg)
// AMD uses NVIDIA detector implementation
if result.Info.ConfiguredVendor != "amd" {
t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor)
}
if result.Info.GPUType != worker.GPUTypeNVIDIA {
t.Errorf("GPUType = %v, want %v (NVIDIA implementation for AMD alias)", result.Info.GPUType, worker.GPUTypeNVIDIA)
}
}
// TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect
func TestGPUDetectorEnvCountOverride(t *testing.T) {
os.Setenv("FETCH_ML_GPU_COUNT", "8")
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
cfg := &worker.Config{
GPUVendor: "nvidia",
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(cfg)
if result.Info.DetectionMethod != worker.DetectionSourceEnvCount {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount)
}
if result.Info.EnvOverrideCount != 8 {
t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount)
}
}
// TestGPUDetectorDetectionSources validates all detection source types
func TestGPUDetectorDetectionSources(t *testing.T) {
tests := []struct {
name string
envType string
envCount string
config *worker.Config
wantSource worker.DetectionSource
}{
{
name: "env type takes precedence over config",
envType: "apple",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceEnvType,
},
{
name: "env count triggers env_count source",
envCount: "2",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceEnvCount,
},
{
name: "config source when no env",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceConfig,
},
{
name: "auto source for GPUDevices",
config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}},
wantSource: worker.DetectionSourceAuto,
},
{
name: "auto source for AppleGPU",
config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}},
wantSource: worker.DetectionSourceAuto,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envType != "" {
os.Setenv("FETCH_ML_GPU_TYPE", tt.envType)
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
}
if tt.envCount != "" {
os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount)
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(tt.config)
if result.Info.DetectionMethod != tt.wantSource {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource)
}
})
}
}
// TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated
func TestGPUDetectorInfoFields(t *testing.T) {
os.Setenv("FETCH_ML_GPU_TYPE", "nvidia")
os.Setenv("FETCH_ML_GPU_COUNT", "4")
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(nil)
// Validate all expected fields
if result.Info.GPUType == "" {
t.Error("GPUType field is empty")
}
if result.Info.ConfiguredVendor == "" {
t.Error("ConfiguredVendor field is empty")
}
if result.Info.DetectionMethod == "" {
t.Error("DetectionMethod field is empty")
}
if result.Info.EnvOverrideType != "nvidia" {
t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType)
}
if result.Info.EnvOverrideCount != 4 {
t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount)
}
}