package worker_test import ( "os" "path/filepath" "testing" "time" "github.com/jfraeys/fetch_ml/internal/manifest" "github.com/jfraeys/fetch_ml/internal/worker" ) // TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work func TestGPUDetectorEnvOverrides(t *testing.T) { tests := []struct { name string gpuType string gpuCount string wantType worker.GPUType wantCount int wantMethod worker.DetectionSource wantConfigured string }{ { name: "env type only - nvidia", gpuType: "nvidia", wantType: worker.GPUTypeNVIDIA, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "nvidia", }, { name: "env type only - apple", gpuType: "apple", wantType: worker.GPUTypeApple, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "apple", }, { name: "env type only - none", gpuType: "none", wantType: worker.GPUTypeNone, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "none", }, { name: "both env vars set", gpuType: "nvidia", gpuCount: "4", wantType: worker.GPUTypeNVIDIA, wantMethod: worker.DetectionSourceEnvBoth, wantConfigured: "nvidia", }, { name: "env type amd - shows amd configured vendor", gpuType: "amd", wantType: worker.GPUTypeAMD, wantMethod: worker.DetectionSourceEnvType, wantConfigured: "amd", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Set env vars if tt.gpuType != "" { os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType) defer os.Unsetenv("FETCH_ML_GPU_TYPE") } if tt.gpuCount != "" { os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount) defer os.Unsetenv("FETCH_ML_GPU_COUNT") } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(nil) if result.Info.GPUType != tt.wantType { t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType) } if result.Info.DetectionMethod != tt.wantMethod { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod) } if result.Info.ConfiguredVendor != tt.wantConfigured { t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured) } }) } } // TestAMDAliasManifestRecord validates AMD config shows proper aliasing and records to manifest func TestAMDAliasManifestRecord(t *testing.T) { cfg := &worker.Config{ GPUVendor: "amd", } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(cfg) // AMD uses NVIDIA detector implementation if result.Info.ConfiguredVendor != "amd" { t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor) } if result.Info.GPUType != worker.GPUTypeNVIDIA { t.Errorf("GPUType = %v, want %v (NVIDIA implementation for AMD alias)", result.Info.GPUType, worker.GPUTypeNVIDIA) } // R.3: Record GPU detection info to manifest m := manifest.NewRunManifest("run-test-amd", "task-amd", "job-amd", time.Now()) m.Environment = &manifest.ExecutionEnvironment{ GPUDetectionMethod: string(result.Info.DetectionMethod), GPUVendor: result.Info.ConfiguredVendor, GPUCount: 1, // AMD detection returns count from NVIDIA implementation } // Write and reload manifest to verify persistence dir := t.TempDir() if err := m.WriteToDir(dir); err != nil { t.Fatalf("WriteToDir failed: %v", err) } loaded, err := manifest.LoadFromDir(dir) if err != nil { t.Fatalf("LoadFromDir failed: %v", err) } // Verify GPU info was persisted if loaded.Environment == nil { t.Fatal("expected Environment to be written to manifest") } if loaded.Environment.GPUVendor != result.Info.ConfiguredVendor { t.Errorf("GPUVendor mismatch: got %q, want %q", loaded.Environment.GPUVendor, result.Info.ConfiguredVendor) } if loaded.Environment.GPUDetectionMethod != string(result.Info.DetectionMethod) { t.Errorf("GPUDetectionMethod mismatch: got %q, want %q", loaded.Environment.GPUDetectionMethod, result.Info.DetectionMethod) } // Verify manifest file includes nonce for security manifestFile := filepath.Join(dir, "run_manifest.json") if _, err := os.Stat(manifestFile); os.IsNotExist(err) { t.Error("expected run_manifest.json to exist") } } // TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect func TestGPUDetectorEnvCountOverride(t *testing.T) { os.Setenv("FETCH_ML_GPU_COUNT", "8") defer os.Unsetenv("FETCH_ML_GPU_COUNT") cfg := &worker.Config{ GPUVendor: "nvidia", } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(cfg) if result.Info.DetectionMethod != worker.DetectionSourceEnvCount { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount) } if result.Info.EnvOverrideCount != 8 { t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount) } } // TestGPUDetectorDetectionSources validates all detection source types func TestGPUDetectorDetectionSources(t *testing.T) { tests := []struct { name string envType string envCount string config *worker.Config wantSource worker.DetectionSource }{ { name: "env type takes precedence over config", envType: "apple", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceEnvType, }, { name: "env count triggers env_count source", envCount: "2", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceEnvCount, }, { name: "config source when no env", config: &worker.Config{GPUVendor: "nvidia"}, wantSource: worker.DetectionSourceConfig, }, { name: "auto source for GPUDevices", config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}}, wantSource: worker.DetectionSourceAuto, }, { name: "auto source for AppleGPU", config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}}, wantSource: worker.DetectionSourceAuto, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.envType != "" { os.Setenv("FETCH_ML_GPU_TYPE", tt.envType) defer os.Unsetenv("FETCH_ML_GPU_TYPE") } if tt.envCount != "" { os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount) defer os.Unsetenv("FETCH_ML_GPU_COUNT") } factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(tt.config) if result.Info.DetectionMethod != tt.wantSource { t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource) } }) } } // TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated func TestGPUDetectorInfoFields(t *testing.T) { os.Setenv("FETCH_ML_GPU_TYPE", "nvidia") os.Setenv("FETCH_ML_GPU_COUNT", "4") defer os.Unsetenv("FETCH_ML_GPU_TYPE") defer os.Unsetenv("FETCH_ML_GPU_COUNT") factory := &worker.GPUDetectorFactory{} result := factory.CreateDetectorWithInfo(nil) // Validate all expected fields if result.Info.GPUType == "" { t.Error("GPUType field is empty") } if result.Info.ConfiguredVendor == "" { t.Error("ConfiguredVendor field is empty") } if result.Info.DetectionMethod == "" { t.Error("DetectionMethod field is empty") } if result.Info.EnvOverrideType != "nvidia" { t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType) } if result.Info.EnvOverrideCount != 4 { t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount) } }