package worker_test import ( "os" "testing" "github.com/jfraeys/fetch_ml/internal/worker" ) // TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work func TestGPUDetectorEnvOverrides(t *testing.T) { tests := []struct { name string gpuType string gpuCount string wantType worker.GPUType wantCount int wantMethod worker.DetectionSource wantConfigured string }{ { name: "env type only - nvidia", gpuType: "nvidia", wantType: worker.GPUTypeNVIDIA, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "nvidia", }, { name: "env type only - apple", gpuType: "apple", wantType: worker.GPUTypeApple, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "apple", }, { name: "env type only - none", gpuType: "none", wantType: worker.GPUTypeNone, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "none", }, { name: "both env vars set", gpuType: "nvidia", gpuCount: "4", wantType: worker.GPUTypeNVIDIA, wantMethod: worker.DetectionSourceEnvBoth, wantConfigured: "nvidia", }, { name: "env type amd - shows amd configured vendor", gpuType: "amd", wantType: worker.GPUTypeAMD, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "amd", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Set env vars if tt.gpuType != "" { os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType) defer os.Unsetenv("FETCH_ML_GPU_TYPE") } if tt.gpuCount != "" { os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount) defer os.Unsetenv("FETCH_ML_GPU_COUNT") } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(nil) if result.Info.GPUType != tt.wantType { t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType) } if result.Info.DetectionMethod != tt.wantMethod { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod) } if result.Info.ConfiguredVendor != tt.wantConfigured { t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured) } }) } } // TestGPUDetectorAMDVendorAlias validates AMD config shows proper aliasing func TestGPUDetectorAMDVendorAlias(t *testing.T) { cfg := &worker.Config{ GPUVendor: "amd", } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(cfg) // AMD uses NVIDIA detector implementation if result.Info.ConfiguredVendor != "amd" { t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor) } if result.Info.GPUType != worker.GPUTypeNVIDIA { t.Errorf("GPUType = %v, want %v (NVIDIA implementation for AMD alias)", result.Info.GPUType, worker.GPUTypeNVIDIA) } } // TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect func TestGPUDetectorEnvCountOverride(t *testing.T) { os.Setenv("FETCH_ML_GPU_COUNT", "8") defer os.Unsetenv("FETCH_ML_GPU_COUNT") cfg := &worker.Config{ GPUVendor: "nvidia", } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(cfg) if result.Info.DetectionMethod != worker.DetectionSourceEnvCount { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount) } if result.Info.EnvOverrideCount != 8 { t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount) } } // TestGPUDetectorDetectionSources validates all detection source types func TestGPUDetectorDetectionSources(t *testing.T) { tests := []struct { name string envType string envCount string config *worker.Config wantSource worker.DetectionSource }{ { name: "env type takes precedence over config", envType: "apple", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceEnvType, }, { name: "env count triggers env_count source", envCount: "2", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceEnvCount, }, { name: "config source when no env", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceConfig, }, { name: "auto source for GPUDevices", config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}}, wantSource: worker.DetectionSourceAuto, }, { name: "auto source for AppleGPU", config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}}, wantSource: worker.DetectionSourceAuto, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.envType != "" { os.Setenv("FETCH_ML_GPU_TYPE", tt.envType) defer os.Unsetenv("FETCH_ML_GPU_TYPE") } if tt.envCount != "" { os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount) defer os.Unsetenv("FETCH_ML_GPU_COUNT") } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(tt.config) if result.Info.DetectionMethod != tt.wantSource { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource) } }) } } // TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated func TestGPUDetectorInfoFields(t *testing.T) { os.Setenv("FETCH_ML_GPU_TYPE", "nvidia") os.Setenv("FETCH_ML_GPU_COUNT", "4") defer os.Unsetenv("FETCH_ML_GPU_TYPE") defer os.Unsetenv("FETCH_ML_GPU_COUNT") factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(nil) // Validate all expected fields if result.Info.GPUType == "" { t.Error("GPUType field is empty") } if result.Info.ConfiguredVendor == "" { t.Error("ConfiguredVendor field is empty") } if result.Info.DetectionMethod == "" { t.Error("DetectionMethod field is empty") } if result.Info.EnvOverrideType != "nvidia" { t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType) } if result.Info.EnvOverrideCount != 4 { t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount) } }