Rename and enhance existing tests to align with coverage map: - TestGPUDetectorAMDVendorAlias -> TestAMDAliasManifestRecord - TestScanArtifacts_SkipsKnownPathsAndLogs -> TestScanExclusionsRecorded - Add env var expansion verification to TestHIPAAValidation_InlineCredentials - Record exclusions in manifest.Artifacts for audit trail
249 lines
7.5 KiB
Go
249 lines
7.5 KiB
Go
package worker_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
// TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work
|
|
func TestGPUDetectorEnvOverrides(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
gpuType string
|
|
gpuCount string
|
|
wantType worker.GPUType
|
|
wantCount int
|
|
wantMethod worker.DetectionSource
|
|
wantConfigured string
|
|
}{
|
|
{
|
|
name: "env type only - nvidia",
|
|
gpuType: "nvidia",
|
|
wantType: worker.GPUTypeNVIDIA,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "nvidia",
|
|
},
|
|
{
|
|
name: "env type only - apple",
|
|
gpuType: "apple",
|
|
wantType: worker.GPUTypeApple,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "apple",
|
|
},
|
|
{
|
|
name: "env type only - none",
|
|
gpuType: "none",
|
|
wantType: worker.GPUTypeNone,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "none",
|
|
},
|
|
{
|
|
name: "both env vars set",
|
|
gpuType: "nvidia",
|
|
gpuCount: "4",
|
|
wantType: worker.GPUTypeNVIDIA,
|
|
wantMethod: worker.DetectionSourceEnvBoth,
|
|
wantConfigured: "nvidia",
|
|
},
|
|
{
|
|
name: "env type amd - shows amd configured vendor",
|
|
gpuType: "amd",
|
|
wantType: worker.GPUTypeAMD,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "amd",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Set env vars
|
|
if tt.gpuType != "" {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType)
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
}
|
|
if tt.gpuCount != "" {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount)
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(nil)
|
|
|
|
if result.Info.GPUType != tt.wantType {
|
|
t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType)
|
|
}
|
|
if result.Info.DetectionMethod != tt.wantMethod {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod)
|
|
}
|
|
if result.Info.ConfiguredVendor != tt.wantConfigured {
|
|
t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAMDAliasManifestRecord validates AMD config shows proper aliasing and records to manifest
|
|
func TestAMDAliasManifestRecord(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
GPUVendor: "amd",
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(cfg)
|
|
|
|
// AMD uses NVIDIA detector implementation
|
|
if result.Info.ConfiguredVendor != "amd" {
|
|
t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor)
|
|
}
|
|
if result.Info.GPUType != worker.GPUTypeNVIDIA {
|
|
t.Errorf("GPUType = %v, want %v (NVIDIA implementation for AMD alias)", result.Info.GPUType, worker.GPUTypeNVIDIA)
|
|
}
|
|
|
|
// R.3: Record GPU detection info to manifest
|
|
m := manifest.NewRunManifest("run-test-amd", "task-amd", "job-amd", time.Now())
|
|
m.Environment = &manifest.ExecutionEnvironment{
|
|
GPUDetectionMethod: string(result.Info.DetectionMethod),
|
|
GPUVendor: result.Info.ConfiguredVendor,
|
|
GPUCount: 1, // AMD detection returns count from NVIDIA implementation
|
|
}
|
|
|
|
// Write and reload manifest to verify persistence
|
|
dir := t.TempDir()
|
|
if err := m.WriteToDir(dir); err != nil {
|
|
t.Fatalf("WriteToDir failed: %v", err)
|
|
}
|
|
|
|
loaded, err := manifest.LoadFromDir(dir)
|
|
if err != nil {
|
|
t.Fatalf("LoadFromDir failed: %v", err)
|
|
}
|
|
|
|
// Verify GPU info was persisted
|
|
if loaded.Environment == nil {
|
|
t.Fatal("expected Environment to be written to manifest")
|
|
}
|
|
if loaded.Environment.GPUVendor != result.Info.ConfiguredVendor {
|
|
t.Errorf("GPUVendor mismatch: got %q, want %q", loaded.Environment.GPUVendor, result.Info.ConfiguredVendor)
|
|
}
|
|
if loaded.Environment.GPUDetectionMethod != string(result.Info.DetectionMethod) {
|
|
t.Errorf("GPUDetectionMethod mismatch: got %q, want %q", loaded.Environment.GPUDetectionMethod, result.Info.DetectionMethod)
|
|
}
|
|
|
|
// Verify manifest file includes nonce for security
|
|
manifestFile := filepath.Join(dir, "run_manifest.json")
|
|
if _, err := os.Stat(manifestFile); os.IsNotExist(err) {
|
|
t.Error("expected run_manifest.json to exist")
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect
|
|
func TestGPUDetectorEnvCountOverride(t *testing.T) {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", "8")
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
|
|
cfg := &worker.Config{
|
|
GPUVendor: "nvidia",
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(cfg)
|
|
|
|
if result.Info.DetectionMethod != worker.DetectionSourceEnvCount {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount)
|
|
}
|
|
if result.Info.EnvOverrideCount != 8 {
|
|
t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount)
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorDetectionSources validates all detection source types
|
|
func TestGPUDetectorDetectionSources(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
envType string
|
|
envCount string
|
|
config *worker.Config
|
|
wantSource worker.DetectionSource
|
|
}{
|
|
{
|
|
name: "env type takes precedence over config",
|
|
envType: "apple",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceEnvType,
|
|
},
|
|
{
|
|
name: "env count triggers env_count source",
|
|
envCount: "2",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceEnvCount,
|
|
},
|
|
{
|
|
name: "config source when no env",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceConfig,
|
|
},
|
|
{
|
|
name: "auto source for GPUDevices",
|
|
config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}},
|
|
wantSource: worker.DetectionSourceAuto,
|
|
},
|
|
{
|
|
name: "auto source for AppleGPU",
|
|
config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}},
|
|
wantSource: worker.DetectionSourceAuto,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.envType != "" {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", tt.envType)
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
}
|
|
if tt.envCount != "" {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount)
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(tt.config)
|
|
|
|
if result.Info.DetectionMethod != tt.wantSource {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated
|
|
func TestGPUDetectorInfoFields(t *testing.T) {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", "nvidia")
|
|
os.Setenv("FETCH_ML_GPU_COUNT", "4")
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(nil)
|
|
|
|
// Validate all expected fields
|
|
if result.Info.GPUType == "" {
|
|
t.Error("GPUType field is empty")
|
|
}
|
|
if result.Info.ConfiguredVendor == "" {
|
|
t.Error("ConfiguredVendor field is empty")
|
|
}
|
|
if result.Info.DetectionMethod == "" {
|
|
t.Error("DetectionMethod field is empty")
|
|
}
|
|
if result.Info.EnvOverrideType != "nvidia" {
|
|
t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType)
|
|
}
|
|
if result.Info.EnvOverrideCount != 4 {
|
|
t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount)
|
|
}
|
|
}
|