diff --git a/Makefile b/Makefile index 7997442..f9f6efa 100644 --- a/Makefile +++ b/Makefile @@ -177,14 +177,14 @@ verify-build: test: test-infra-up @arch=$$(uname -m | sed 's/x86_64/amd64/'); os=$$(uname -s | tr '[:upper:]' '[:lower:]'); \ [ -f bin/cli/ml-$$os-$$arch ] || $(MAKE) build-cli - @go test -v ./tests/unit/... ./tests/integration/... ./tests/e2e/... 2>&1 | grep -v "redis: connection pool" | tee /tmp/test-all.txt || true + @go test -v ./internal/... ./tests/integration/... ./tests/e2e/... 2>&1 | grep -v "redis: connection pool" | tee /tmp/test-all.txt || true @echo "\n=== Test Summary ===" $(call test_summary,/tmp/test-all.txt) @$(MAKE) test-infra-down @cd cli && zig build test test-unit: - @go test -v -short ./tests/unit/... 2>&1 | tee /tmp/test-unit.txt || true + @go test -v -short ./internal/... 2>&1 | tee /tmp/test-unit.txt || true @echo "\n=== Unit Test Summary ===" $(call test_summary,/tmp/test-unit.txt) @cd cli && zig build test @@ -205,9 +205,10 @@ test-e2e: test-infra-up test-coverage: @mkdir -p coverage - go test -coverprofile=coverage/coverage.out ./... + go test -coverprofile=coverage/coverage.out -coverpkg=./internal/...,./cmd/... ./internal/... ./tests/integration/... ./tests/e2e/... go tool cover -html=coverage/coverage.out -o coverage/coverage.html @echo "$(OK) Coverage report: coverage/coverage.html" + @go tool cover -func=coverage/coverage.out | tail -1 consistency-test: build native-build @echo "Running cross-implementation consistency tests..." @@ -353,7 +354,7 @@ lint-custom: @echo "$(OK) Custom lint complete" verify-audit: - @go test ./tests/unit/audit/... -run TestChainVerifier -v + @go test ./internal/audit/... -run TestChainVerifier -v @echo "$(OK) Audit chain verification passed" verify-audit-chain: diff --git a/tests/unit/crypto/kms/cache_test.go b/tests/unit/crypto/kms/cache_test.go deleted file mode 100644 index 3ab22cd..0000000 --- a/tests/unit/crypto/kms/cache_test.go +++ /dev/null @@ -1,255 +0,0 @@ -package kms_test - -import ( - "bytes" - "testing" - "time" - - "github.com/jfraeys/fetch_ml/internal/crypto/kms" - kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" -) - -// TestDEKCache_PutAndGet tests basic cache put and get operations. -func TestDEKCache_PutAndGet(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - tenantID := "tenant-1" - artifactID := "artifact-1" - dek := []byte("test-dek-data-12345678901234567890123456789012") - - // Put DEK in cache - if err := cache.Put(tenantID, artifactID, "kms-key-1", dek); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Get DEK from cache (KMS available) - retrieved, ok := cache.Get(tenantID, artifactID, "kms-key-1", false) - if !ok { - t.Fatal("Get returned false, expected true") - } - - if !bytes.Equal(retrieved, dek) { - t.Error("Retrieved DEK doesn't match original") - } -} - -// TestDEKCache_GetNonexistent tests getting a non-existent entry. -func TestDEKCache_GetNonexistent(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - _, ok := cache.Get("nonexistent", "nonexistent", "kms-key-1", false) - if ok { - t.Error("Get for non-existent key should return false") - } -} - -// TestDEKCache_TTLExpiry tests that entries expire after TTL. -func TestDEKCache_TTLExpiry(t *testing.T) { - // Use very short TTL for testing - config := kmsconfig.CacheConfig{ - TTL: 50 * time.Millisecond, - MaxEntries: 100, - GraceWindow: 100 * time.Millisecond, - } - cache := kms.NewDEKCache(config) - defer cache.Clear() - - tenantID := "tenant-1" - artifactID := "artifact-1" - dek := []byte("test-dek-data-12345678901234567890123456789012") - - // Put DEK - cache.Put(tenantID, artifactID, "kms-key-1", dek) - - // Should be available immediately - _, ok := cache.Get(tenantID, artifactID, "kms-key-1", false) - if !ok { - t.Error("DEK should be available immediately after put") - } - - // Wait for TTL to expire - time.Sleep(60 * time.Millisecond) - - // Should not be available after TTL (KMS available) - _, ok = cache.Get(tenantID, artifactID, "kms-key-1", false) - if ok { - t.Error("DEK should not be available after TTL expires") - } - - // Should be available in grace window (KMS unavailable) - _, ok = cache.Get(tenantID, artifactID, "kms-key-1", true) - if !ok { - t.Error("DEK should be available in grace window when KMS is unavailable") - } - - // Wait for grace window to expire - time.Sleep(150 * time.Millisecond) - - // Should not be available after grace window - _, ok = cache.Get(tenantID, artifactID, "kms-key-1", true) - if ok { - t.Error("DEK should not be available after grace window expires") - } -} - -// TestDEKCache_LRUeviction tests LRU eviction when cache is full. -func TestDEKCache_LRUeviction(t *testing.T) { - config := kmsconfig.CacheConfig{ - TTL: 1 * time.Hour, // Long TTL so eviction is due to size - MaxEntries: 3, - GraceWindow: 1 * time.Hour, - } - cache := kms.NewDEKCache(config) - defer cache.Clear() - - // Add 3 entries (at capacity) - for i := 0; i < 3; i++ { - dek := []byte("dek-data-12345678901234567890123456789012-" + string(rune('0'+i))) - cache.Put("tenant-1", string(rune('a'+i)), "kms-key-1", dek) - } - - // Access first entry to make it recently used - cache.Get("tenant-1", "a", "kms-key-1", false) - - // Add 4th entry (should evict 'b' as it's the oldest unaccessed) - dek4 := []byte("dek-data-12345678901234567890123456789012-4") - cache.Put("tenant-1", "d", "kms-key-1", dek4) - - // 'a' should still exist (was accessed) - _, ok := cache.Get("tenant-1", "a", "kms-key-1", false) - if !ok { - t.Error("Entry 'a' should still exist after eviction") - } - - // 'b' should be evicted - _, ok = cache.Get("tenant-1", "b", "kms-key-1", false) - if ok { - t.Error("Entry 'b' should have been evicted") - } - - // 'c' and 'd' should exist - _, ok = cache.Get("tenant-1", "c", "kms-key-1", false) - if !ok { - t.Error("Entry 'c' should still exist") - } - _, ok = cache.Get("tenant-1", "d", "kms-key-1", false) - if !ok { - t.Error("Entry 'd' should exist") - } -} - -// TestDEKCache_Flush tests flushing entries for a specific tenant. -func TestDEKCache_Flush(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - // Add entries for two tenants - cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) - cache.Put("tenant-1", "artifact-2", "kms-key-1", []byte("dek-2")) - cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-3")) - - // Flush tenant-1 - cache.Flush("tenant-1") - - // tenant-1 entries should be gone - _, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false) - if ok { - t.Error("tenant-1 artifact-1 should be flushed") - } - _, ok = cache.Get("tenant-1", "artifact-2", "kms-key-1", false) - if ok { - t.Error("tenant-1 artifact-2 should be flushed") - } - - // tenant-2 entry should still exist - _, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false) - if !ok { - t.Error("tenant-2 artifact-1 should still exist") - } -} - -// TestDEKCache_Clear tests clearing all entries. -func TestDEKCache_Clear(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - - // Add entries - cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) - cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-2")) - - // Clear - cache.Clear() - - // All entries should be gone - _, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false) - if ok { - t.Error("All entries should be cleared") - } - _, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false) - if ok { - t.Error("All entries should be cleared") - } -} - -// TestDEKCache_Stats tests cache statistics. -func TestDEKCache_Stats(t *testing.T) { - config := kmsconfig.DefaultCacheConfig() - cache := kms.NewDEKCache(config) - defer cache.Clear() - - stats := cache.Stats() - - if stats.Size != 0 { - t.Errorf("Initial size should be 0, got %d", stats.Size) - } - if stats.MaxSize != config.MaxEntries { - t.Errorf("MaxSize should be %d, got %d", config.MaxEntries, stats.MaxSize) - } - if stats.TTL != config.TTL { - t.Errorf("TTL should be %v, got %v", config.TTL, stats.TTL) - } - if stats.GraceWindow != config.GraceWindow { - t.Errorf("GraceWindow should be %v, got %v", config.GraceWindow, stats.GraceWindow) - } - - // Add entry - cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) - - stats = cache.Stats() - if stats.Size != 1 { - t.Errorf("Size should be 1 after put, got %d", stats.Size) - } -} - -// TestDEKCache_EmptyDEK tests that empty DEK is rejected. -func TestDEKCache_EmptyDEK(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - err := cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte{}) - if err == nil { - t.Error("Should reject empty DEK") - } -} - -// TestDEKCache_Isolation tests that DEKs are isolated between tenants. -func TestDEKCache_Isolation(t *testing.T) { - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - // Same artifact ID, different tenants - cache.Put("tenant-1", "shared-artifact", "kms-key-1", []byte("dek-for-tenant-1")) - cache.Put("tenant-2", "shared-artifact", "kms-key-2", []byte("dek-for-tenant-2")) - - // Each tenant should get their own DEK - d1, ok := cache.Get("tenant-1", "shared-artifact", "kms-key-1", false) - if !ok || string(d1) != "dek-for-tenant-1" { - t.Error("tenant-1 should get their own DEK") - } - - d2, ok := cache.Get("tenant-2", "shared-artifact", "kms-key-2", false) - if !ok || string(d2) != "dek-for-tenant-2" { - t.Error("tenant-2 should get their own DEK") - } -} diff --git a/tests/unit/crypto/kms/protocol_test.go b/tests/unit/crypto/kms/protocol_test.go deleted file mode 100644 index 7e1a956..0000000 --- a/tests/unit/crypto/kms/protocol_test.go +++ /dev/null @@ -1,335 +0,0 @@ -package kms_test - -import ( - "bytes" - "context" - "encoding/binary" - "testing" - "time" - - "github.com/jfraeys/fetch_ml/internal/api" - "github.com/jfraeys/fetch_ml/internal/crypto" - "github.com/jfraeys/fetch_ml/internal/crypto/kms" - kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" -) - -func TestProtocolSerialization(t *testing.T) { - // Test success packet - successPacket := api.NewSuccessPacket("Operation completed successfully") - data, err := successPacket.Serialize() - if err != nil { - t.Fatalf("Failed to serialize success packet: %v", err) - } - - // Verify packet type - if len(data) < 1 || data[0] != api.PacketTypeSuccess { - t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0]) - } - - // Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp) - if len(data) < 9 { - t.Errorf("Expected at least 9 bytes, got %d", len(data)) - } - - // Test error packet - uses string error code from errors package - errorPacket := api.NewErrorPacket("AUTHENTICATION_FAILED", "Auth failed", "Invalid API key") - data, err = errorPacket.Serialize() - if err != nil { - t.Fatalf("Failed to serialize error packet: %v", err) - } - - if len(data) < 1 || data[0] != api.PacketTypeError { - t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0]) - } - - // Test progress packet - progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...") - data, err = progressPacket.Serialize() - if err != nil { - t.Fatalf("Failed to serialize progress packet: %v", err) - } - - if len(data) < 1 || data[0] != api.PacketTypeProgress { - t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0]) - } - - // Test status packet - statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`) - data, err = statusPacket.Serialize() - if err != nil { - t.Fatalf("Failed to serialize status packet: %v", err) - } - - if len(data) < 1 || data[0] != api.PacketTypeStatus { - t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0]) - } -} - -func TestByteCodeFromErrorCode(t *testing.T) { - tests := map[string]byte{ - "UNKNOWN_ERROR": api.ErrorCodeUnknownError, - "AUTHENTICATION_FAILED": api.ErrorCodeAuthenticationFailed, - "JOB_NOT_FOUND": api.ErrorCodeJobNotFound, - "SERVER_OVERLOADED": api.ErrorCodeServerOverloaded, - "INVALID_REQUEST": api.ErrorCodeInvalidRequest, - "BAD_REQUEST": api.ErrorCodeInvalidRequest, - "PERMISSION_DENIED": api.ErrorCodePermissionDenied, - "FORBIDDEN": api.ErrorCodePermissionDenied, - } - - for code, expectedByte := range tests { - actual := api.ByteCodeFromErrorCode(code) - if actual != expectedByte { - t.Errorf("Expected byte %d for code '%s', got %d", expectedByte, code, actual) - } - } -} - -func TestLogLevelMapping(t *testing.T) { - tests := map[byte]string{ - api.LogLevelDebug: "DEBUG", - api.LogLevelInfo: "INFO", - api.LogLevelWarn: "WARN", - api.LogLevelError: "ERROR", - } - - for level, expected := range tests { - actual := api.GetLogLevelName(level) - if actual != expected { - t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual) - } - } -} - -func TestTimestampConsistency(t *testing.T) { - before := time.Now().Unix() - - packet := api.NewSuccessPacket("Test message") - data, err := packet.Serialize() - if err != nil { - t.Fatalf("Failed to serialize: %v", err) - } - - after := time.Now().Unix() - - // Extract timestamp (bytes 1-8, big-endian) - if len(data) < 9 { - t.Fatalf("Packet too short: %d bytes", len(data)) - } - - timestamp := binary.BigEndian.Uint64(data[1:9]) - - if timestamp < uint64(before) || timestamp > uint64(after) { - t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after) - } -} - -// TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol. -func TestKMSProtocol_EncryptDecrypt(t *testing.T) { - // Create memory provider for testing - provider := kms.NewMemoryProvider() - defer provider.Close() - - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - config := kmsconfig.Config{ - Provider: kms.ProviderTypeMemory, - Cache: kmsconfig.DefaultCacheConfig(), - } - - tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) - - // Provision tenant - hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant") - if err != nil { - t.Fatalf("ProvisionTenant failed: %v", err) - } - - // Test data - simulate artifact data - plaintext := []byte("sensitive model weights and training data") - - // Encrypt - encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext) - if err != nil { - t.Fatalf("EncryptArtifact failed: %v", err) - } - - // Verify encrypted structure - if encrypted.Ciphertext == "" { - t.Error("Ciphertext should not be empty") - } - if encrypted.DEK == nil { - t.Error("DEK should not be nil") - } - if encrypted.KMSKeyID != hierarchy.KMSKeyID { - t.Error("KMSKeyID should match") - } - if encrypted.Algorithm != "AES-256-GCM" { - t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm) - } - - // Decrypt - decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) - if err != nil { - t.Fatalf("DecryptArtifact failed: %v", err) - } - - // Verify round-trip - if !bytes.Equal(decrypted, plaintext) { - t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext) - } -} - -// TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data. -func TestKMSProtocol_MultiTenantIsolation(t *testing.T) { - provider := kms.NewMemoryProvider() - defer provider.Close() - - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - config := kmsconfig.Config{ - Provider: kms.ProviderTypeMemory, - Cache: kmsconfig.DefaultCacheConfig(), - } - - tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) - - // Provision two tenants - tenant1, err := tkm.ProvisionTenant("tenant-1") - if err != nil { - t.Fatalf("Failed to provision tenant-1: %v", err) - } - - tenant2, err := tkm.ProvisionTenant("tenant-2") - if err != nil { - t.Fatalf("Failed to provision tenant-2: %v", err) - } - - // Encrypt data for tenant-1 - plaintext := []byte("tenant-1 secret data") - encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext) - if err != nil { - t.Fatalf("Encrypt failed: %v", err) - } - - // Attempt to decrypt with tenant-2's key - should fail - _, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID) - if err == nil { - t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)") - } - - // Tenant-1 should still be able to decrypt - decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID) - if err != nil { - t.Fatalf("Tenant-1 decrypt failed: %v", err) - } - - if !bytes.Equal(decrypted, plaintext) { - t.Error("Tenant-1 should decrypt their own data correctly") - } -} - -// TestKMSProtocol_CacheHit verifies cached DEKs work correctly. -func TestKMSProtocol_CacheHit(t *testing.T) { - provider := kms.NewMemoryProvider() - defer provider.Close() - - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - config := kmsconfig.Config{ - Provider: kms.ProviderTypeMemory, - Cache: kmsconfig.DefaultCacheConfig(), - } - - tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) - - hierarchy, _ := tkm.ProvisionTenant("cache-test") - - plaintext := []byte("test data for caching") - - // First encrypt - encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext) - - // Decrypt multiple times - should hit cache - for i := 0; i < 3; i++ { - decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) - if err != nil { - t.Fatalf("Decrypt %d failed: %v", i, err) - } - if !bytes.Equal(decrypted, plaintext) { - t.Errorf("Decrypt %d: data mismatch", i) - } - } - - // Verify cache has entries - stats := cache.Stats() - if stats.Size == 0 { - t.Error("Cache should have entries after operations") - } -} - -// TestKMSProtocol_KeyRotation tests key rotation protocol. -func TestKMSProtocol_KeyRotation(t *testing.T) { - provider := kms.NewMemoryProvider() - defer provider.Close() - - cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) - defer cache.Clear() - - config := kmsconfig.Config{ - Provider: kms.ProviderTypeMemory, - Cache: kmsconfig.DefaultCacheConfig(), - } - - tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) - - // Provision tenant - hierarchy, _ := tkm.ProvisionTenant("rotation-test") - oldKeyID := hierarchy.KMSKeyID - - // Rotate key - newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy) - if err != nil { - t.Fatalf("Key rotation failed: %v", err) - } - - if newHierarchy.KMSKeyID == oldKeyID { - t.Error("New key should have different ID after rotation") - } - - // Cache should be flushed after rotation - stats := cache.Stats() - if stats.Size != 0 { - t.Error("Cache should be flushed after key rotation") - } - - // Encrypt with new key - plaintext2 := []byte("data encrypted with new key") - encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2) - - // Decrypt with new key - decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID) - if err != nil { - t.Fatalf("Decrypt with new key failed: %v", err) - } - - if !bytes.Equal(decrypted2, plaintext2) { - t.Error("Data encrypted with new key should decrypt correctly") - } -} - -// TestKMSProvider_HealthCheck tests health check protocol. -func TestKMSProvider_HealthCheck(t *testing.T) { - provider := kms.NewMemoryProvider() - defer provider.Close() - - ctx := context.Background() - - // Memory provider should always be healthy - if err := provider.HealthCheck(ctx); err != nil { - t.Errorf("Memory provider health check failed: %v", err) - } -} diff --git a/tests/unit/resources/manager_test.go b/tests/unit/resources/manager_test.go deleted file mode 100644 index 236a1e0..0000000 --- a/tests/unit/resources/manager_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package resources_test - -import ( - "context" - "testing" - "time" - - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/resources" - "github.com/stretchr/testify/require" -) - -func TestManager_CPUAcquireBlocksUntilRelease(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 4, GPUCount: 0, SlotsPerGPU: 1}) - require.NoError(t, err) - - task1 := &queue.Task{CPU: 3} - lease1, err := m.Acquire(context.Background(), task1) - require.NoError(t, err) - require.NotNil(t, lease1) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, err = m.Acquire(ctx, &queue.Task{CPU: 2}) - require.Error(t, err) - - lease1.Release() - - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) - defer cancel2() - lease2, err := m.Acquire(ctx2, &queue.Task{CPU: 2}) - require.NoError(t, err) - require.NotNil(t, lease2) - lease2.Release() -} - -func TestManager_GPUSlotsAllowSharing(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 1, SlotsPerGPU: 4}) - require.NoError(t, err) - - leases := make([]*resources.Lease, 0, 4) - for i := 0; i < 4; i++ { - l, err := m.Acquire(context.Background(), &queue.Task{GPU: 1, GPUMemory: "0.25"}) - require.NoError(t, err) - leases = append(leases, l) - } - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, err = m.Acquire(ctx, &queue.Task{GPU: 1, GPUMemory: "0.25"}) - require.Error(t, err) - - for _, l := range leases { - l.Release() - } -} - -func TestManager_MultiGPUExclusiveAllocation(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 2, SlotsPerGPU: 1}) - require.NoError(t, err) - - lease, err := m.Acquire(context.Background(), &queue.Task{GPU: 2}) - require.NoError(t, err) - require.Len(t, lease.GPUs(), 2) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, err = m.Acquire(ctx, &queue.Task{GPU: 1}) - require.Error(t, err) - - lease.Release() -} - -func TestFormatCUDAVisibleDevices_NoLeaseDisablesGPU(t *testing.T) { - require.Equal(t, "-1", resources.FormatCUDAVisibleDevices(nil)) -} - -func TestManager_GPUSlotsAllowSharing_Concurrent(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 0, GPUCount: 1, SlotsPerGPU: 4}) - require.NoError(t, err) - - started := make(chan struct{}) - release := make(chan struct{}) - - errCh := make(chan error, 4) - leases := make(chan *resources.Lease, 4) - for i := 0; i < 4; i++ { - go func() { - <-started - l, err := m.Acquire(context.Background(), &queue.Task{GPU: 1, GPUMemory: "0.25"}) - if err != nil { - errCh <- err - return - } - leases <- l - <-release - l.Release() - errCh <- nil - }() - } - close(started) - - deadline := time.After(500 * time.Millisecond) - acquired := make([]*resources.Lease, 0, 4) - for len(acquired) < 4 { - select { - case l := <-leases: - acquired = append(acquired, l) - case <-deadline: - t.Fatalf("timed out waiting for leases; got %d", len(acquired)) - } - } - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, err = m.Acquire(ctx, &queue.Task{GPU: 1, GPUMemory: "0.25"}) - require.Error(t, err) - - close(release) - for i := 0; i < 4; i++ { - require.NoError(t, <-errCh) - } -} - -func TestManager_CPUOnlyNotBlockedWhenGPUSaturated(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 4, GPUCount: 1, SlotsPerGPU: 1}) - require.NoError(t, err) - - gpuLease, err := m.Acquire(context.Background(), &queue.Task{GPU: 1}) - require.NoError(t, err) - defer gpuLease.Release() - - done := make(chan error, 1) - go func() { - lease, err := m.Acquire(context.Background(), &queue.Task{CPU: 1}) - if err == nil { - lease.Release() - } - done <- err - }() - - select { - case err := <-done: - require.NoError(t, err) - case <-time.After(200 * time.Millisecond): - t.Fatal("cpu-only acquire unexpectedly blocked by gpu saturation") - } -} - -func TestManager_AcquireMetrics_RecordWaitAndTimeout(t *testing.T) { - m, err := resources.NewManager(resources.Options{TotalCPU: 1, GPUCount: 0, SlotsPerGPU: 1}) - require.NoError(t, err) - - lease, err := m.Acquire(context.Background(), &queue.Task{CPU: 1}) - require.NoError(t, err) - defer lease.Release() - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, err = m.Acquire(ctx, &queue.Task{CPU: 1}) - require.Error(t, err) - - s := m.Snapshot() - require.GreaterOrEqual(t, s.AcquireTotal, int64(2)) - require.GreaterOrEqual(t, s.AcquireTimeoutTotal, int64(1)) -}