fetch_ml/tests/unit/gpu/gpu_detector_test.go
Jeremie Fraeys f71352202e
test(phase-1-2): naming alignment and partial test completion
Rename and enhance existing tests to align with coverage map:
- TestGPUDetectorAMDVendorAlias -> TestAMDAliasManifestRecord
- TestScanArtifacts_SkipsKnownPathsAndLogs -> TestScanExclusionsRecorded
- Add env var expansion verification to TestHIPAAValidation_InlineCredentials
- Record exclusions in manifest.Artifacts for audit trail
2026-02-23 20:25:07 -05:00

249 lines
7.5 KiB
Go

package worker_test
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/worker"
)
// TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work
func TestGPUDetectorEnvOverrides(t *testing.T) {
tests := []struct {
name string
gpuType string
gpuCount string
wantType worker.GPUType
wantCount int
wantMethod worker.DetectionSource
wantConfigured string
}{
{
name: "env type only - nvidia",
gpuType: "nvidia",
wantType: worker.GPUTypeNVIDIA,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "nvidia",
},
{
name: "env type only - apple",
gpuType: "apple",
wantType: worker.GPUTypeApple,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "apple",
},
{
name: "env type only - none",
gpuType: "none",
wantType: worker.GPUTypeNone,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "none",
},
{
name: "both env vars set",
gpuType: "nvidia",
gpuCount: "4",
wantType: worker.GPUTypeNVIDIA,
wantMethod: worker.DetectionSourceEnvBoth,
wantConfigured: "nvidia",
},
{
name: "env type amd - shows amd configured vendor",
gpuType: "amd",
wantType: worker.GPUTypeAMD,
wantMethod: worker.DetectionSourceEnvType,
wantConfigured: "amd",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set env vars
if tt.gpuType != "" {
os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType)
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
}
if tt.gpuCount != "" {
os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount)
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(nil)
if result.Info.GPUType != tt.wantType {
t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType)
}
if result.Info.DetectionMethod != tt.wantMethod {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod)
}
if result.Info.ConfiguredVendor != tt.wantConfigured {
t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured)
}
})
}
}
// TestAMDAliasManifestRecord validates AMD config shows proper aliasing and records to manifest
func TestAMDAliasManifestRecord(t *testing.T) {
cfg := &worker.Config{
GPUVendor: "amd",
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(cfg)
// AMD uses NVIDIA detector implementation
if result.Info.ConfiguredVendor != "amd" {
t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor)
}
if result.Info.GPUType != worker.GPUTypeNVIDIA {
t.Errorf("GPUType = %v, want %v (NVIDIA implementation for AMD alias)", result.Info.GPUType, worker.GPUTypeNVIDIA)
}
// R.3: Record GPU detection info to manifest
m := manifest.NewRunManifest("run-test-amd", "task-amd", "job-amd", time.Now())
m.Environment = &manifest.ExecutionEnvironment{
GPUDetectionMethod: string(result.Info.DetectionMethod),
GPUVendor: result.Info.ConfiguredVendor,
GPUCount: 1, // AMD detection returns count from NVIDIA implementation
}
// Write and reload manifest to verify persistence
dir := t.TempDir()
if err := m.WriteToDir(dir); err != nil {
t.Fatalf("WriteToDir failed: %v", err)
}
loaded, err := manifest.LoadFromDir(dir)
if err != nil {
t.Fatalf("LoadFromDir failed: %v", err)
}
// Verify GPU info was persisted
if loaded.Environment == nil {
t.Fatal("expected Environment to be written to manifest")
}
if loaded.Environment.GPUVendor != result.Info.ConfiguredVendor {
t.Errorf("GPUVendor mismatch: got %q, want %q", loaded.Environment.GPUVendor, result.Info.ConfiguredVendor)
}
if loaded.Environment.GPUDetectionMethod != string(result.Info.DetectionMethod) {
t.Errorf("GPUDetectionMethod mismatch: got %q, want %q", loaded.Environment.GPUDetectionMethod, result.Info.DetectionMethod)
}
// Verify manifest file includes nonce for security
manifestFile := filepath.Join(dir, "run_manifest.json")
if _, err := os.Stat(manifestFile); os.IsNotExist(err) {
t.Error("expected run_manifest.json to exist")
}
}
// TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect
func TestGPUDetectorEnvCountOverride(t *testing.T) {
os.Setenv("FETCH_ML_GPU_COUNT", "8")
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
cfg := &worker.Config{
GPUVendor: "nvidia",
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(cfg)
if result.Info.DetectionMethod != worker.DetectionSourceEnvCount {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount)
}
if result.Info.EnvOverrideCount != 8 {
t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount)
}
}
// TestGPUDetectorDetectionSources validates all detection source types
func TestGPUDetectorDetectionSources(t *testing.T) {
tests := []struct {
name string
envType string
envCount string
config *worker.Config
wantSource worker.DetectionSource
}{
{
name: "env type takes precedence over config",
envType: "apple",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceEnvType,
},
{
name: "env count triggers env_count source",
envCount: "2",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceEnvCount,
},
{
name: "config source when no env",
config: &worker.Config{GPUVendor: "nvidia"},
wantSource: worker.DetectionSourceConfig,
},
{
name: "auto source for GPUDevices",
config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}},
wantSource: worker.DetectionSourceAuto,
},
{
name: "auto source for AppleGPU",
config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}},
wantSource: worker.DetectionSourceAuto,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envType != "" {
os.Setenv("FETCH_ML_GPU_TYPE", tt.envType)
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
}
if tt.envCount != "" {
os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount)
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
}
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(tt.config)
if result.Info.DetectionMethod != tt.wantSource {
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource)
}
})
}
}
// TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated
func TestGPUDetectorInfoFields(t *testing.T) {
os.Setenv("FETCH_ML_GPU_TYPE", "nvidia")
os.Setenv("FETCH_ML_GPU_COUNT", "4")
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
factory := &worker.GPUDetectorFactory{}
result := factory.CreateDetectorWithInfo(nil)
// Validate all expected fields
if result.Info.GPUType == "" {
t.Error("GPUType field is empty")
}
if result.Info.ConfiguredVendor == "" {
t.Error("ConfiguredVendor field is empty")
}
if result.Info.DetectionMethod == "" {
t.Error("DetectionMethod field is empty")
}
if result.Info.EnvOverrideType != "nvidia" {
t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType)
}
if result.Info.EnvOverrideCount != 4 {
t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount)
}
}