Move unit tests from tests/unit/ to internal/ following Go conventions: - tests/unit/queue/* -> internal/queue/* (dedup, filesystem_fallback, queue_permissions, queue_spec, queue, sqlite_queue tests) - tests/unit/gpu/* -> internal/resources/* (gpu_detector, gpu_golden tests) - tests/unit/resources/* -> internal/resources/* (manager_test.go) Update import paths in test files to reflect new locations. Note: GPU tests consolidated into resources package since GPU detection is part of resource management. Manager tests show significant new test coverage (166 lines).
249 lines
7.4 KiB
Go
249 lines
7.4 KiB
Go
package resources_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/manifest"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
// TestGPUDetectorEnvOverrides validates both FETCH_ML_GPU_TYPE and FETCH_ML_GPU_COUNT work
|
|
func TestGPUDetectorEnvOverrides(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
gpuType string
|
|
gpuCount string
|
|
wantType worker.GPUType
|
|
wantCount int
|
|
wantMethod worker.DetectionSource
|
|
wantConfigured string
|
|
}{
|
|
{
|
|
name: "env type only - nvidia",
|
|
gpuType: "nvidia",
|
|
wantType: worker.GPUTypeNVIDIA,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "nvidia",
|
|
},
|
|
{
|
|
name: "env type only - apple",
|
|
gpuType: "apple",
|
|
wantType: worker.GPUTypeApple,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "apple",
|
|
},
|
|
{
|
|
name: "env type only - none",
|
|
gpuType: "none",
|
|
wantType: worker.GPUTypeNone,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "none",
|
|
},
|
|
{
|
|
name: "both env vars set",
|
|
gpuType: "nvidia",
|
|
gpuCount: "4",
|
|
wantType: worker.GPUTypeNVIDIA,
|
|
wantMethod: worker.DetectionSourceEnvBoth,
|
|
wantConfigured: "nvidia",
|
|
},
|
|
{
|
|
name: "env type amd - shows amd configured vendor",
|
|
gpuType: "amd",
|
|
wantType: worker.GPUTypeAMD,
|
|
wantMethod: worker.DetectionSourceEnvType,
|
|
wantConfigured: "amd",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Set env vars
|
|
if tt.gpuType != "" {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", tt.gpuType)
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
}
|
|
if tt.gpuCount != "" {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", tt.gpuCount)
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(nil)
|
|
|
|
if result.Info.GPUType != tt.wantType {
|
|
t.Errorf("GPUType = %v, want %v", result.Info.GPUType, tt.wantType)
|
|
}
|
|
if result.Info.DetectionMethod != tt.wantMethod {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantMethod)
|
|
}
|
|
if result.Info.ConfiguredVendor != tt.wantConfigured {
|
|
t.Errorf("ConfiguredVendor = %v, want %v", result.Info.ConfiguredVendor, tt.wantConfigured)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAMDAliasManifestRecord validates AMD config shows proper aliasing and records to manifest
|
|
func TestAMDAliasManifestRecord(t *testing.T) {
|
|
cfg := &worker.Config{
|
|
GPUVendor: "amd",
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(cfg)
|
|
|
|
// AMD uses NVIDIA detector implementation
|
|
if result.Info.ConfiguredVendor != "amd" {
|
|
t.Errorf("ConfiguredVendor = %v, want 'amd'", result.Info.ConfiguredVendor)
|
|
}
|
|
if result.Info.GPUType != "amd" {
|
|
t.Errorf("GPUType = %v, want amd", result.Info.GPUType)
|
|
}
|
|
|
|
// R.3: Record GPU detection info to manifest
|
|
m := manifest.NewRunManifest("run-test-amd", "task-amd", "job-amd", time.Now())
|
|
m.Environment = &manifest.ExecutionEnvironment{
|
|
GPUDetectionMethod: string(result.Info.DetectionMethod),
|
|
GPUVendor: result.Info.ConfiguredVendor,
|
|
GPUCount: 1, // AMD detection returns count from NVIDIA implementation
|
|
}
|
|
|
|
// Write and reload manifest to verify persistence
|
|
dir := t.TempDir()
|
|
if err := m.WriteToDir(dir); err != nil {
|
|
t.Fatalf("WriteToDir failed: %v", err)
|
|
}
|
|
|
|
loaded, err := manifest.LoadFromDir(dir)
|
|
if err != nil {
|
|
t.Fatalf("LoadFromDir failed: %v", err)
|
|
}
|
|
|
|
// Verify GPU info was persisted
|
|
if loaded.Environment == nil {
|
|
t.Fatal("expected Environment to be written to manifest")
|
|
}
|
|
if loaded.Environment.GPUVendor != result.Info.ConfiguredVendor {
|
|
t.Errorf("GPUVendor mismatch: got %q, want %q", loaded.Environment.GPUVendor, result.Info.ConfiguredVendor)
|
|
}
|
|
if loaded.Environment.GPUDetectionMethod != string(result.Info.DetectionMethod) {
|
|
t.Errorf("GPUDetectionMethod mismatch: got %q, want %q", loaded.Environment.GPUDetectionMethod, result.Info.DetectionMethod)
|
|
}
|
|
|
|
// Verify manifest file includes nonce for security
|
|
manifestFile := filepath.Join(dir, "run_manifest.json")
|
|
if _, err := os.Stat(manifestFile); os.IsNotExist(err) {
|
|
t.Error("expected run_manifest.json to exist")
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorEnvCountOverride validates FETCH_ML_GPU_COUNT with auto-detect
|
|
func TestGPUDetectorEnvCountOverride(t *testing.T) {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", "8")
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
|
|
cfg := &worker.Config{
|
|
GPUVendor: "nvidia",
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(cfg)
|
|
|
|
if result.Info.DetectionMethod != worker.DetectionSourceEnvCount {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, worker.DetectionSourceEnvCount)
|
|
}
|
|
if result.Info.EnvOverrideCount != 8 {
|
|
t.Errorf("EnvOverrideCount = %v, want 8", result.Info.EnvOverrideCount)
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorDetectionSources validates all detection source types
|
|
func TestGPUDetectorDetectionSources(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
envType string
|
|
envCount string
|
|
config *worker.Config
|
|
wantSource worker.DetectionSource
|
|
}{
|
|
{
|
|
name: "env type takes precedence over config",
|
|
envType: "apple",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceEnvType,
|
|
},
|
|
{
|
|
name: "env count triggers env_count source",
|
|
envCount: "2",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceEnvCount,
|
|
},
|
|
{
|
|
name: "config source when no env",
|
|
config: &worker.Config{GPUVendor: "nvidia"},
|
|
wantSource: worker.DetectionSourceConfig,
|
|
},
|
|
{
|
|
name: "auto source for GPUDevices",
|
|
config: &worker.Config{GPUDevices: []string{"/dev/nvidia0"}},
|
|
wantSource: worker.DetectionSourceAuto,
|
|
},
|
|
{
|
|
name: "auto source for AppleGPU",
|
|
config: &worker.Config{AppleGPU: worker.AppleGPUConfig{Enabled: true}},
|
|
wantSource: worker.DetectionSourceAuto,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.envType != "" {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", tt.envType)
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
}
|
|
if tt.envCount != "" {
|
|
os.Setenv("FETCH_ML_GPU_COUNT", tt.envCount)
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
}
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(tt.config)
|
|
|
|
if result.Info.DetectionMethod != tt.wantSource {
|
|
t.Errorf("DetectionMethod = %v, want %v", result.Info.DetectionMethod, tt.wantSource)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGPUDetectorInfoFields validates all GPUDetectionInfo fields are populated
|
|
func TestGPUDetectorInfoFields(t *testing.T) {
|
|
os.Setenv("FETCH_ML_GPU_TYPE", "nvidia")
|
|
os.Setenv("FETCH_ML_GPU_COUNT", "4")
|
|
defer os.Unsetenv("FETCH_ML_GPU_TYPE")
|
|
defer os.Unsetenv("FETCH_ML_GPU_COUNT")
|
|
|
|
factory := &worker.GPUDetectorFactory{}
|
|
result := factory.CreateDetectorWithInfo(nil)
|
|
|
|
// Validate all expected fields
|
|
if result.Info.GPUType == "" {
|
|
t.Error("GPUType field is empty")
|
|
}
|
|
if result.Info.ConfiguredVendor == "" {
|
|
t.Error("ConfiguredVendor field is empty")
|
|
}
|
|
if result.Info.DetectionMethod == "" {
|
|
t.Error("DetectionMethod field is empty")
|
|
}
|
|
if result.Info.EnvOverrideType != "nvidia" {
|
|
t.Errorf("EnvOverrideType = %v, want 'nvidia'", result.Info.EnvOverrideType)
|
|
}
|
|
if result.Info.EnvOverrideCount != 4 {
|
|
t.Errorf("EnvOverrideCount = %v, want 4", result.Info.EnvOverrideCount)
|
|
}
|
|
}
|