diff --git a/tests/integration/kms_integration_test.go b/tests/integration/kms_integration_test.go new file mode 100644 index 0000000..9c11b6b --- /dev/null +++ b/tests/integration/kms_integration_test.go @@ -0,0 +1,241 @@ +package tests_test + +import ( + "context" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/crypto" + "github.com/jfraeys/fetch_ml/internal/crypto/kms" + "github.com/jfraeys/fetch_ml/internal/crypto/kms/providers" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +// TestVaultProvider_Integration tests the Vault provider with a real Vault container. +func TestVaultProvider_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + + // Start Vault container + req := testcontainers.ContainerRequest{ + Image: "hashicorp/vault:1.15", + ExposedPorts: []string{"8200/tcp"}, + Env: map[string]string{ + "VAULT_DEV_ROOT_TOKEN_ID": "test-token", + "VAULT_ADDR": "http://0.0.0.0:8200", + }, + WaitingFor: wait.ForLog("Vault server started").WithStartupTimeout(30 * time.Second), + } + + vaultC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + t.Fatalf("Failed to start Vault container: %v", err) + } + defer vaultC.Terminate(ctx) + + // Get container address + endpoint, err := vaultC.Endpoint(ctx, "http") + if err != nil { + t.Fatalf("Failed to get Vault endpoint: %v", err) + } + + // Create provider config + config := kms.VaultConfig{ + Address: endpoint, + AuthMethod: "token", + Token: "test-token", + TransitMount: "transit", + KeyPrefix: "test-tenant", + Timeout: 10 * time.Second, + } + + // Create provider + provider, err := providers.NewVaultProvider(config) + if err != nil { + t.Fatalf("Failed to create Vault provider: %v", err) + } + defer provider.Close() + + // Test health check + if err := provider.HealthCheck(ctx); err != nil { + t.Logf("Vault health check warning (may need Transit setup): %v", err) + } + + // Test key creation + keyID, err := provider.CreateKey(ctx, "integration-test") + if err != nil { + t.Logf("Key creation may need manual Transit setup: %v", err) + } else { + t.Logf("Created key: %s", keyID) + } +} + +// TestAWSKMSProvider_Integration tests the AWS KMS provider with LocalStack. +func TestAWSKMSProvider_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + + // Start LocalStack container with KMS + req := testcontainers.ContainerRequest{ + Image: "localstack/localstack:latest", + ExposedPorts: []string{"4566/tcp"}, + Env: map[string]string{ + "SERVICES": "kms", + "DEFAULT_REGION": "us-east-1", + "AWS_DEFAULT_REGION": "us-east-1", + }, + WaitingFor: wait.ForLog("Ready").WithStartupTimeout(60 * time.Second), + } + + localstackC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + t.Fatalf("Failed to start LocalStack container: %v", err) + } + defer localstackC.Terminate(ctx) + + // Get container address + endpoint, err := localstackC.Endpoint(ctx, "http") + if err != nil { + t.Fatalf("Failed to get LocalStack endpoint: %v", err) + } + + // Create provider config + config := kms.AWSConfig{ + Region: "us-east-1", + KeyAliasPrefix: "alias/test-fetchml", + Endpoint: endpoint, + } + + // Create provider + provider, err := providers.NewAWSProvider(config) + if err != nil { + t.Fatalf("Failed to create AWS provider: %v", err) + } + defer provider.Close() + + // Note: LocalStack KMS requires credentials even though they're not validated + // Set dummy credentials via environment for the test + t.Setenv("AWS_ACCESS_KEY_ID", "test") + t.Setenv("AWS_SECRET_ACCESS_KEY", "test") + + // Test health check + if err := provider.HealthCheck(ctx); err != nil { + t.Logf("AWS health check (may need retry): %v", err) + } + + // Test key creation + keyID, err := provider.CreateKey(ctx, "integration-test") + if err != nil { + t.Logf("Key creation may need retry: %v", err) + } else { + t.Logf("Created key: %s", keyID) + } +} + +// TestTenantKeyManager_WithMemoryProvider tests the full TenantKeyManager with MemoryProvider. +func TestTenantKeyManager_WithMemoryProvider(t *testing.T) { + // Create memory provider for testing + provider := kms.NewMemoryProvider() + defer provider.Close() + + // Create DEK cache + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + // Create config + config := kms.Config{ + Provider: kms.ProviderTypeMemory, + Cache: kms.DefaultCacheConfig(), + } + + // Create TenantKeyManager + tkm := crypto.NewTenantKeyManager(provider, cache, config) + + ctx := context.Background() + ctx = context.WithValue(ctx, "test", true) + + // Test provisioning + hierarchy, err := tkm.ProvisionTenant("test-tenant") + if err != nil { + t.Fatalf("ProvisionTenant failed: %v", err) + } + + if hierarchy.TenantID != "test-tenant" { + t.Errorf("Expected tenant ID 'test-tenant', got '%s'", hierarchy.TenantID) + } + + if hierarchy.KMSKeyID == "" { + t.Error("KMSKeyID should not be empty") + } + + t.Logf("Provisioned tenant with KMSKeyID: %s", hierarchy.KMSKeyID) + + // Test encryption + plaintext := []byte("sensitive data that needs encryption") + encrypted, err := tkm.EncryptArtifact("test-tenant", "artifact-1", hierarchy.KMSKeyID, plaintext) + if err != nil { + t.Fatalf("EncryptArtifact failed: %v", err) + } + + if encrypted.Ciphertext == "" { + t.Error("Ciphertext should not be empty") + } + + if encrypted.KMSKeyID != hierarchy.KMSKeyID { + t.Error("EncryptedArtifact should store KMSKeyID") + } + + t.Logf("Encrypted data successfully") + + // Test decryption + decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) + if err != nil { + t.Fatalf("DecryptArtifact failed: %v", err) + } + + if string(decrypted) != string(plaintext) { + t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext) + } + + t.Logf("Decrypted data successfully") + + // Test that cache is working + stats := cache.Stats() + if stats.Size == 0 { + t.Error("Cache should have entries after encrypt/decrypt") + } + + t.Logf("Cache stats: %+v", stats) + + // Test key rotation + newHierarchy, err := tkm.RotateTenantKey("test-tenant", hierarchy) + if err != nil { + t.Fatalf("RotateTenantKey failed: %v", err) + } + + if newHierarchy.KMSKeyID == hierarchy.KMSKeyID { + t.Error("Rotated key should have different KMSKeyID") + } + + t.Logf("Rotated key from %s to %s", hierarchy.KMSKeyID, newHierarchy.KMSKeyID) + + // Test tenant revocation + if err := tkm.RevokeTenant(hierarchy); err != nil { + t.Fatalf("RevokeTenant failed: %v", err) + } + + t.Logf("Revoked tenant successfully") +} diff --git a/tests/unit/crypto/kms/cache_test.go b/tests/unit/crypto/kms/cache_test.go new file mode 100644 index 0000000..6f9d922 --- /dev/null +++ b/tests/unit/crypto/kms/cache_test.go @@ -0,0 +1,254 @@ +package kms_test + +import ( + "bytes" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/crypto/kms" +) + +// TestDEKCache_PutAndGet tests basic cache put and get operations. +func TestDEKCache_PutAndGet(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + tenantID := "tenant-1" + artifactID := "artifact-1" + dek := []byte("test-dek-data-12345678901234567890123456789012") + + // Put DEK in cache + if err := cache.Put(tenantID, artifactID, "kms-key-1", dek); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Get DEK from cache (KMS available) + retrieved, ok := cache.Get(tenantID, artifactID, "kms-key-1", false) + if !ok { + t.Fatal("Get returned false, expected true") + } + + if !bytes.Equal(retrieved, dek) { + t.Error("Retrieved DEK doesn't match original") + } +} + +// TestDEKCache_GetNonexistent tests getting a non-existent entry. +func TestDEKCache_GetNonexistent(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + _, ok := cache.Get("nonexistent", "nonexistent", "kms-key-1", false) + if ok { + t.Error("Get for non-existent key should return false") + } +} + +// TestDEKCache_TTLExpiry tests that entries expire after TTL. +func TestDEKCache_TTLExpiry(t *testing.T) { + // Use very short TTL for testing + config := kms.CacheConfig{ + TTL: 50 * time.Millisecond, + MaxEntries: 100, + GraceWindow: 100 * time.Millisecond, + } + cache := kms.NewDEKCache(config) + defer cache.Clear() + + tenantID := "tenant-1" + artifactID := "artifact-1" + dek := []byte("test-dek-data-12345678901234567890123456789012") + + // Put DEK + cache.Put(tenantID, artifactID, "kms-key-1", dek) + + // Should be available immediately + _, ok := cache.Get(tenantID, artifactID, "kms-key-1", false) + if !ok { + t.Error("DEK should be available immediately after put") + } + + // Wait for TTL to expire + time.Sleep(60 * time.Millisecond) + + // Should not be available after TTL (KMS available) + _, ok = cache.Get(tenantID, artifactID, "kms-key-1", false) + if ok { + t.Error("DEK should not be available after TTL expires") + } + + // Should be available in grace window (KMS unavailable) + _, ok = cache.Get(tenantID, artifactID, "kms-key-1", true) + if !ok { + t.Error("DEK should be available in grace window when KMS is unavailable") + } + + // Wait for grace window to expire + time.Sleep(150 * time.Millisecond) + + // Should not be available after grace window + _, ok = cache.Get(tenantID, artifactID, "kms-key-1", true) + if ok { + t.Error("DEK should not be available after grace window expires") + } +} + +// TestDEKCache_LRUeviction tests LRU eviction when cache is full. +func TestDEKCache_LRUeviction(t *testing.T) { + config := kms.CacheConfig{ + TTL: 1 * time.Hour, // Long TTL so eviction is due to size + MaxEntries: 3, + GraceWindow: 1 * time.Hour, + } + cache := kms.NewDEKCache(config) + defer cache.Clear() + + // Add 3 entries (at capacity) + for i := 0; i < 3; i++ { + dek := []byte("dek-data-12345678901234567890123456789012-" + string(rune('0'+i))) + cache.Put("tenant-1", string(rune('a'+i)), "kms-key-1", dek) + } + + // Access first entry to make it recently used + cache.Get("tenant-1", "a", "kms-key-1", false) + + // Add 4th entry (should evict 'b' as it's the oldest unaccessed) + dek4 := []byte("dek-data-12345678901234567890123456789012-4") + cache.Put("tenant-1", "d", "kms-key-1", dek4) + + // 'a' should still exist (was accessed) + _, ok := cache.Get("tenant-1", "a", "kms-key-1", false) + if !ok { + t.Error("Entry 'a' should still exist after eviction") + } + + // 'b' should be evicted + _, ok = cache.Get("tenant-1", "b", "kms-key-1", false) + if ok { + t.Error("Entry 'b' should have been evicted") + } + + // 'c' and 'd' should exist + _, ok = cache.Get("tenant-1", "c", "kms-key-1", false) + if !ok { + t.Error("Entry 'c' should still exist") + } + _, ok = cache.Get("tenant-1", "d", "kms-key-1", false) + if !ok { + t.Error("Entry 'd' should exist") + } +} + +// TestDEKCache_Flush tests flushing entries for a specific tenant. +func TestDEKCache_Flush(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + // Add entries for two tenants + cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) + cache.Put("tenant-1", "artifact-2", "kms-key-1", []byte("dek-2")) + cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-3")) + + // Flush tenant-1 + cache.Flush("tenant-1") + + // tenant-1 entries should be gone + _, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false) + if ok { + t.Error("tenant-1 artifact-1 should be flushed") + } + _, ok = cache.Get("tenant-1", "artifact-2", "kms-key-1", false) + if ok { + t.Error("tenant-1 artifact-2 should be flushed") + } + + // tenant-2 entry should still exist + _, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false) + if !ok { + t.Error("tenant-2 artifact-1 should still exist") + } +} + +// TestDEKCache_Clear tests clearing all entries. +func TestDEKCache_Clear(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + + // Add entries + cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) + cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-2")) + + // Clear + cache.Clear() + + // All entries should be gone + _, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false) + if ok { + t.Error("All entries should be cleared") + } + _, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false) + if ok { + t.Error("All entries should be cleared") + } +} + +// TestDEKCache_Stats tests cache statistics. +func TestDEKCache_Stats(t *testing.T) { + config := kms.DefaultCacheConfig() + cache := kms.NewDEKCache(config) + defer cache.Clear() + + stats := cache.Stats() + + if stats.Size != 0 { + t.Errorf("Initial size should be 0, got %d", stats.Size) + } + if stats.MaxSize != config.MaxEntries { + t.Errorf("MaxSize should be %d, got %d", config.MaxEntries, stats.MaxSize) + } + if stats.TTL != config.TTL { + t.Errorf("TTL should be %v, got %v", config.TTL, stats.TTL) + } + if stats.GraceWindow != config.GraceWindow { + t.Errorf("GraceWindow should be %v, got %v", config.GraceWindow, stats.GraceWindow) + } + + // Add entry + cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1")) + + stats = cache.Stats() + if stats.Size != 1 { + t.Errorf("Size should be 1 after put, got %d", stats.Size) + } +} + +// TestDEKCache_EmptyDEK tests that empty DEK is rejected. +func TestDEKCache_EmptyDEK(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + err := cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte{}) + if err == nil { + t.Error("Should reject empty DEK") + } +} + +// TestDEKCache_Isolation tests that DEKs are isolated between tenants. +func TestDEKCache_Isolation(t *testing.T) { + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + // Same artifact ID, different tenants + cache.Put("tenant-1", "shared-artifact", "kms-key-1", []byte("dek-for-tenant-1")) + cache.Put("tenant-2", "shared-artifact", "kms-key-2", []byte("dek-for-tenant-2")) + + // Each tenant should get their own DEK + d1, ok := cache.Get("tenant-1", "shared-artifact", "kms-key-1", false) + if !ok || string(d1) != "dek-for-tenant-1" { + t.Error("tenant-1 should get their own DEK") + } + + d2, ok := cache.Get("tenant-2", "shared-artifact", "kms-key-2", false) + if !ok || string(d2) != "dek-for-tenant-2" { + t.Error("tenant-2 should get their own DEK") + } +} diff --git a/tests/unit/crypto/kms/protocol_test.go b/tests/unit/crypto/kms/protocol_test.go new file mode 100644 index 0000000..8f2dc4d --- /dev/null +++ b/tests/unit/crypto/kms/protocol_test.go @@ -0,0 +1,330 @@ +package kms_test + +import ( + "bytes" + "context" + "encoding/binary" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/api" + "github.com/jfraeys/fetch_ml/internal/crypto" + "github.com/jfraeys/fetch_ml/internal/crypto/kms" +) + +func TestProtocolSerialization(t *testing.T) { + // Test success packet + successPacket := api.NewSuccessPacket("Operation completed successfully") + data, err := successPacket.Serialize() + if err != nil { + t.Fatalf("Failed to serialize success packet: %v", err) + } + + // Verify packet type + if len(data) < 1 || data[0] != api.PacketTypeSuccess { + t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0]) + } + + // Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp) + if len(data) < 9 { + t.Errorf("Expected at least 9 bytes, got %d", len(data)) + } + + // Test error packet + errorPacket := api.NewErrorPacket(api.ErrorCodeAuthenticationFailed, "Auth failed", "Invalid API key") + data, err = errorPacket.Serialize() + if err != nil { + t.Fatalf("Failed to serialize error packet: %v", err) + } + + if len(data) < 1 || data[0] != api.PacketTypeError { + t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0]) + } + + // Test progress packet + progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...") + data, err = progressPacket.Serialize() + if err != nil { + t.Fatalf("Failed to serialize progress packet: %v", err) + } + + if len(data) < 1 || data[0] != api.PacketTypeProgress { + t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0]) + } + + // Test status packet + statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`) + data, err = statusPacket.Serialize() + if err != nil { + t.Fatalf("Failed to serialize status packet: %v", err) + } + + if len(data) < 1 || data[0] != api.PacketTypeStatus { + t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0]) + } +} + +func TestErrorMessageMapping(t *testing.T) { + tests := map[byte]string{ + api.ErrorCodeUnknownError: "Unknown error occurred", + api.ErrorCodeAuthenticationFailed: "Authentication failed", + api.ErrorCodeJobNotFound: "Job not found", + api.ErrorCodeServerOverloaded: "Server is overloaded", + } + + for code, expected := range tests { + actual := api.GetErrorMessage(code) + if actual != expected { + t.Errorf("Expected error message '%s' for code %d, got '%s'", expected, code, actual) + } + } +} + +func TestLogLevelMapping(t *testing.T) { + tests := map[byte]string{ + api.LogLevelDebug: "DEBUG", + api.LogLevelInfo: "INFO", + api.LogLevelWarn: "WARN", + api.LogLevelError: "ERROR", + } + + for level, expected := range tests { + actual := api.GetLogLevelName(level) + if actual != expected { + t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual) + } + } +} + +func TestTimestampConsistency(t *testing.T) { + before := time.Now().Unix() + + packet := api.NewSuccessPacket("Test message") + data, err := packet.Serialize() + if err != nil { + t.Fatalf("Failed to serialize: %v", err) + } + + after := time.Now().Unix() + + // Extract timestamp (bytes 1-8, big-endian) + if len(data) < 9 { + t.Fatalf("Packet too short: %d bytes", len(data)) + } + + timestamp := binary.BigEndian.Uint64(data[1:9]) + + if timestamp < uint64(before) || timestamp > uint64(after) { + t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after) + } +} + +// TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol. +func TestKMSProtocol_EncryptDecrypt(t *testing.T) { + // Create memory provider for testing + provider := kms.NewMemoryProvider() + defer provider.Close() + + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + config := kms.Config{ + Provider: kms.ProviderTypeMemory, + Cache: kms.DefaultCacheConfig(), + } + + tkm := crypto.NewTenantKeyManager(provider, cache, config) + + // Provision tenant + hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant") + if err != nil { + t.Fatalf("ProvisionTenant failed: %v", err) + } + + // Test data - simulate artifact data + plaintext := []byte("sensitive model weights and training data") + + // Encrypt + encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext) + if err != nil { + t.Fatalf("EncryptArtifact failed: %v", err) + } + + // Verify encrypted structure + if encrypted.Ciphertext == "" { + t.Error("Ciphertext should not be empty") + } + if encrypted.DEK == nil { + t.Error("DEK should not be nil") + } + if encrypted.KMSKeyID != hierarchy.KMSKeyID { + t.Error("KMSKeyID should match") + } + if encrypted.Algorithm != "AES-256-GCM" { + t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm) + } + + // Decrypt + decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) + if err != nil { + t.Fatalf("DecryptArtifact failed: %v", err) + } + + // Verify round-trip + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext) + } +} + +// TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data. +func TestKMSProtocol_MultiTenantIsolation(t *testing.T) { + provider := kms.NewMemoryProvider() + defer provider.Close() + + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + config := kms.Config{ + Provider: kms.ProviderTypeMemory, + Cache: kms.DefaultCacheConfig(), + } + + tkm := crypto.NewTenantKeyManager(provider, cache, config) + + // Provision two tenants + tenant1, err := tkm.ProvisionTenant("tenant-1") + if err != nil { + t.Fatalf("Failed to provision tenant-1: %v", err) + } + + tenant2, err := tkm.ProvisionTenant("tenant-2") + if err != nil { + t.Fatalf("Failed to provision tenant-2: %v", err) + } + + // Encrypt data for tenant-1 + plaintext := []byte("tenant-1 secret data") + encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext) + if err != nil { + t.Fatalf("Encrypt failed: %v", err) + } + + // Attempt to decrypt with tenant-2's key - should fail + _, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID) + if err == nil { + t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)") + } + + // Tenant-1 should still be able to decrypt + decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID) + if err != nil { + t.Fatalf("Tenant-1 decrypt failed: %v", err) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Error("Tenant-1 should decrypt their own data correctly") + } +} + +// TestKMSProtocol_CacheHit verifies cached DEKs work correctly. +func TestKMSProtocol_CacheHit(t *testing.T) { + provider := kms.NewMemoryProvider() + defer provider.Close() + + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + config := kms.Config{ + Provider: kms.ProviderTypeMemory, + Cache: kms.DefaultCacheConfig(), + } + + tkm := crypto.NewTenantKeyManager(provider, cache, config) + + hierarchy, _ := tkm.ProvisionTenant("cache-test") + + plaintext := []byte("test data for caching") + + // First encrypt + encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext) + + // Decrypt multiple times - should hit cache + for i := 0; i < 3; i++ { + decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) + if err != nil { + t.Fatalf("Decrypt %d failed: %v", i, err) + } + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("Decrypt %d: data mismatch", i) + } + } + + // Verify cache has entries + stats := cache.Stats() + if stats.Size == 0 { + t.Error("Cache should have entries after operations") + } +} + +// TestKMSProtocol_KeyRotation tests key rotation protocol. +func TestKMSProtocol_KeyRotation(t *testing.T) { + provider := kms.NewMemoryProvider() + defer provider.Close() + + cache := kms.NewDEKCache(kms.DefaultCacheConfig()) + defer cache.Clear() + + config := kms.Config{ + Provider: kms.ProviderTypeMemory, + Cache: kms.DefaultCacheConfig(), + } + + tkm := crypto.NewTenantKeyManager(provider, cache, config) + + // Provision tenant + hierarchy, _ := tkm.ProvisionTenant("rotation-test") + oldKeyID := hierarchy.KMSKeyID + + // Rotate key + newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy) + if err != nil { + t.Fatalf("Key rotation failed: %v", err) + } + + if newHierarchy.KMSKeyID == oldKeyID { + t.Error("New key should have different ID after rotation") + } + + // Cache should be flushed after rotation + stats := cache.Stats() + if stats.Size != 0 { + t.Error("Cache should be flushed after key rotation") + } + + // Encrypt with new key + plaintext2 := []byte("data encrypted with new key") + encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2) + + // Decrypt with new key + decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID) + if err != nil { + t.Fatalf("Decrypt with new key failed: %v", err) + } + + if !bytes.Equal(decrypted2, plaintext2) { + t.Error("Data encrypted with new key should decrypt correctly") + } +} + +// TestKMSProvider_HealthCheck tests health check protocol. +func TestKMSProvider_HealthCheck(t *testing.T) { + provider := kms.NewMemoryProvider() + defer provider.Close() + + ctx := context.Background() + + // Memory provider should always be healthy + if err := provider.HealthCheck(ctx); err != nil { + t.Errorf("Memory provider health check failed: %v", err) + } +}