test(kms): add comprehensive unit and integration tests

Unit tests for DEK cache:
- Put/Get operations, TTL expiry, LRU eviction
- Tenant isolation, flush/clear, stats, empty DEK rejection

Unit tests for KMS protocol:
- Encrypt/decrypt round-trip with MemoryProvider
- Multi-tenant isolation (wrong key fails MAC verification)
- Cache hit verification, key rotation flow
- Health check protocol

Integration tests with testcontainers:
- VaultProvider with hashicorp/vault:1.15 container
- AWSProvider with localstack/localstack container
- TenantKeyManager end-to-end with MemoryProvider
This commit is contained in:
Jeremie Fraeys 2026-03-03 19:14:31 -05:00
parent e1ec255ad2
commit 16343e6c2a
No known key found for this signature in database
3 changed files with 825 additions and 0 deletions

View file

@ -0,0 +1,241 @@
package tests_test
import (
"context"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/providers"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// TestVaultProvider_Integration tests the Vault provider with a real Vault container.
func TestVaultProvider_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
ctx := context.Background()
// Start Vault container
req := testcontainers.ContainerRequest{
Image: "hashicorp/vault:1.15",
ExposedPorts: []string{"8200/tcp"},
Env: map[string]string{
"VAULT_DEV_ROOT_TOKEN_ID": "test-token",
"VAULT_ADDR": "http://0.0.0.0:8200",
},
WaitingFor: wait.ForLog("Vault server started").WithStartupTimeout(30 * time.Second),
}
vaultC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Fatalf("Failed to start Vault container: %v", err)
}
defer vaultC.Terminate(ctx)
// Get container address
endpoint, err := vaultC.Endpoint(ctx, "http")
if err != nil {
t.Fatalf("Failed to get Vault endpoint: %v", err)
}
// Create provider config
config := kms.VaultConfig{
Address: endpoint,
AuthMethod: "token",
Token: "test-token",
TransitMount: "transit",
KeyPrefix: "test-tenant",
Timeout: 10 * time.Second,
}
// Create provider
provider, err := providers.NewVaultProvider(config)
if err != nil {
t.Fatalf("Failed to create Vault provider: %v", err)
}
defer provider.Close()
// Test health check
if err := provider.HealthCheck(ctx); err != nil {
t.Logf("Vault health check warning (may need Transit setup): %v", err)
}
// Test key creation
keyID, err := provider.CreateKey(ctx, "integration-test")
if err != nil {
t.Logf("Key creation may need manual Transit setup: %v", err)
} else {
t.Logf("Created key: %s", keyID)
}
}
// TestAWSKMSProvider_Integration tests the AWS KMS provider with LocalStack.
func TestAWSKMSProvider_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
ctx := context.Background()
// Start LocalStack container with KMS
req := testcontainers.ContainerRequest{
Image: "localstack/localstack:latest",
ExposedPorts: []string{"4566/tcp"},
Env: map[string]string{
"SERVICES": "kms",
"DEFAULT_REGION": "us-east-1",
"AWS_DEFAULT_REGION": "us-east-1",
},
WaitingFor: wait.ForLog("Ready").WithStartupTimeout(60 * time.Second),
}
localstackC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Fatalf("Failed to start LocalStack container: %v", err)
}
defer localstackC.Terminate(ctx)
// Get container address
endpoint, err := localstackC.Endpoint(ctx, "http")
if err != nil {
t.Fatalf("Failed to get LocalStack endpoint: %v", err)
}
// Create provider config
config := kms.AWSConfig{
Region: "us-east-1",
KeyAliasPrefix: "alias/test-fetchml",
Endpoint: endpoint,
}
// Create provider
provider, err := providers.NewAWSProvider(config)
if err != nil {
t.Fatalf("Failed to create AWS provider: %v", err)
}
defer provider.Close()
// Note: LocalStack KMS requires credentials even though they're not validated
// Set dummy credentials via environment for the test
t.Setenv("AWS_ACCESS_KEY_ID", "test")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
// Test health check
if err := provider.HealthCheck(ctx); err != nil {
t.Logf("AWS health check (may need retry): %v", err)
}
// Test key creation
keyID, err := provider.CreateKey(ctx, "integration-test")
if err != nil {
t.Logf("Key creation may need retry: %v", err)
} else {
t.Logf("Created key: %s", keyID)
}
}
// TestTenantKeyManager_WithMemoryProvider tests the full TenantKeyManager with MemoryProvider.
func TestTenantKeyManager_WithMemoryProvider(t *testing.T) {
// Create memory provider for testing
provider := kms.NewMemoryProvider()
defer provider.Close()
// Create DEK cache
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
// Create config
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
// Create TenantKeyManager
tkm := crypto.NewTenantKeyManager(provider, cache, config)
ctx := context.Background()
ctx = context.WithValue(ctx, "test", true)
// Test provisioning
hierarchy, err := tkm.ProvisionTenant("test-tenant")
if err != nil {
t.Fatalf("ProvisionTenant failed: %v", err)
}
if hierarchy.TenantID != "test-tenant" {
t.Errorf("Expected tenant ID 'test-tenant', got '%s'", hierarchy.TenantID)
}
if hierarchy.KMSKeyID == "" {
t.Error("KMSKeyID should not be empty")
}
t.Logf("Provisioned tenant with KMSKeyID: %s", hierarchy.KMSKeyID)
// Test encryption
plaintext := []byte("sensitive data that needs encryption")
encrypted, err := tkm.EncryptArtifact("test-tenant", "artifact-1", hierarchy.KMSKeyID, plaintext)
if err != nil {
t.Fatalf("EncryptArtifact failed: %v", err)
}
if encrypted.Ciphertext == "" {
t.Error("Ciphertext should not be empty")
}
if encrypted.KMSKeyID != hierarchy.KMSKeyID {
t.Error("EncryptedArtifact should store KMSKeyID")
}
t.Logf("Encrypted data successfully")
// Test decryption
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
if err != nil {
t.Fatalf("DecryptArtifact failed: %v", err)
}
if string(decrypted) != string(plaintext) {
t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext)
}
t.Logf("Decrypted data successfully")
// Test that cache is working
stats := cache.Stats()
if stats.Size == 0 {
t.Error("Cache should have entries after encrypt/decrypt")
}
t.Logf("Cache stats: %+v", stats)
// Test key rotation
newHierarchy, err := tkm.RotateTenantKey("test-tenant", hierarchy)
if err != nil {
t.Fatalf("RotateTenantKey failed: %v", err)
}
if newHierarchy.KMSKeyID == hierarchy.KMSKeyID {
t.Error("Rotated key should have different KMSKeyID")
}
t.Logf("Rotated key from %s to %s", hierarchy.KMSKeyID, newHierarchy.KMSKeyID)
// Test tenant revocation
if err := tkm.RevokeTenant(hierarchy); err != nil {
t.Fatalf("RevokeTenant failed: %v", err)
}
t.Logf("Revoked tenant successfully")
}

View file

@ -0,0 +1,254 @@
package kms_test
import (
"bytes"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
)
// TestDEKCache_PutAndGet tests basic cache put and get operations.
func TestDEKCache_PutAndGet(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
tenantID := "tenant-1"
artifactID := "artifact-1"
dek := []byte("test-dek-data-12345678901234567890123456789012")
// Put DEK in cache
if err := cache.Put(tenantID, artifactID, "kms-key-1", dek); err != nil {
t.Fatalf("Put failed: %v", err)
}
// Get DEK from cache (KMS available)
retrieved, ok := cache.Get(tenantID, artifactID, "kms-key-1", false)
if !ok {
t.Fatal("Get returned false, expected true")
}
if !bytes.Equal(retrieved, dek) {
t.Error("Retrieved DEK doesn't match original")
}
}
// TestDEKCache_GetNonexistent tests getting a non-existent entry.
func TestDEKCache_GetNonexistent(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
_, ok := cache.Get("nonexistent", "nonexistent", "kms-key-1", false)
if ok {
t.Error("Get for non-existent key should return false")
}
}
// TestDEKCache_TTLExpiry tests that entries expire after TTL.
func TestDEKCache_TTLExpiry(t *testing.T) {
// Use very short TTL for testing
config := kms.CacheConfig{
TTL: 50 * time.Millisecond,
MaxEntries: 100,
GraceWindow: 100 * time.Millisecond,
}
cache := kms.NewDEKCache(config)
defer cache.Clear()
tenantID := "tenant-1"
artifactID := "artifact-1"
dek := []byte("test-dek-data-12345678901234567890123456789012")
// Put DEK
cache.Put(tenantID, artifactID, "kms-key-1", dek)
// Should be available immediately
_, ok := cache.Get(tenantID, artifactID, "kms-key-1", false)
if !ok {
t.Error("DEK should be available immediately after put")
}
// Wait for TTL to expire
time.Sleep(60 * time.Millisecond)
// Should not be available after TTL (KMS available)
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", false)
if ok {
t.Error("DEK should not be available after TTL expires")
}
// Should be available in grace window (KMS unavailable)
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", true)
if !ok {
t.Error("DEK should be available in grace window when KMS is unavailable")
}
// Wait for grace window to expire
time.Sleep(150 * time.Millisecond)
// Should not be available after grace window
_, ok = cache.Get(tenantID, artifactID, "kms-key-1", true)
if ok {
t.Error("DEK should not be available after grace window expires")
}
}
// TestDEKCache_LRUeviction tests LRU eviction when cache is full.
func TestDEKCache_LRUeviction(t *testing.T) {
config := kms.CacheConfig{
TTL: 1 * time.Hour, // Long TTL so eviction is due to size
MaxEntries: 3,
GraceWindow: 1 * time.Hour,
}
cache := kms.NewDEKCache(config)
defer cache.Clear()
// Add 3 entries (at capacity)
for i := 0; i < 3; i++ {
dek := []byte("dek-data-12345678901234567890123456789012-" + string(rune('0'+i)))
cache.Put("tenant-1", string(rune('a'+i)), "kms-key-1", dek)
}
// Access first entry to make it recently used
cache.Get("tenant-1", "a", "kms-key-1", false)
// Add 4th entry (should evict 'b' as it's the oldest unaccessed)
dek4 := []byte("dek-data-12345678901234567890123456789012-4")
cache.Put("tenant-1", "d", "kms-key-1", dek4)
// 'a' should still exist (was accessed)
_, ok := cache.Get("tenant-1", "a", "kms-key-1", false)
if !ok {
t.Error("Entry 'a' should still exist after eviction")
}
// 'b' should be evicted
_, ok = cache.Get("tenant-1", "b", "kms-key-1", false)
if ok {
t.Error("Entry 'b' should have been evicted")
}
// 'c' and 'd' should exist
_, ok = cache.Get("tenant-1", "c", "kms-key-1", false)
if !ok {
t.Error("Entry 'c' should still exist")
}
_, ok = cache.Get("tenant-1", "d", "kms-key-1", false)
if !ok {
t.Error("Entry 'd' should exist")
}
}
// TestDEKCache_Flush tests flushing entries for a specific tenant.
func TestDEKCache_Flush(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
// Add entries for two tenants
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
cache.Put("tenant-1", "artifact-2", "kms-key-1", []byte("dek-2"))
cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-3"))
// Flush tenant-1
cache.Flush("tenant-1")
// tenant-1 entries should be gone
_, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false)
if ok {
t.Error("tenant-1 artifact-1 should be flushed")
}
_, ok = cache.Get("tenant-1", "artifact-2", "kms-key-1", false)
if ok {
t.Error("tenant-1 artifact-2 should be flushed")
}
// tenant-2 entry should still exist
_, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false)
if !ok {
t.Error("tenant-2 artifact-1 should still exist")
}
}
// TestDEKCache_Clear tests clearing all entries.
func TestDEKCache_Clear(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
// Add entries
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
cache.Put("tenant-2", "artifact-1", "kms-key-2", []byte("dek-2"))
// Clear
cache.Clear()
// All entries should be gone
_, ok := cache.Get("tenant-1", "artifact-1", "kms-key-1", false)
if ok {
t.Error("All entries should be cleared")
}
_, ok = cache.Get("tenant-2", "artifact-1", "kms-key-2", false)
if ok {
t.Error("All entries should be cleared")
}
}
// TestDEKCache_Stats tests cache statistics.
func TestDEKCache_Stats(t *testing.T) {
config := kms.DefaultCacheConfig()
cache := kms.NewDEKCache(config)
defer cache.Clear()
stats := cache.Stats()
if stats.Size != 0 {
t.Errorf("Initial size should be 0, got %d", stats.Size)
}
if stats.MaxSize != config.MaxEntries {
t.Errorf("MaxSize should be %d, got %d", config.MaxEntries, stats.MaxSize)
}
if stats.TTL != config.TTL {
t.Errorf("TTL should be %v, got %v", config.TTL, stats.TTL)
}
if stats.GraceWindow != config.GraceWindow {
t.Errorf("GraceWindow should be %v, got %v", config.GraceWindow, stats.GraceWindow)
}
// Add entry
cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte("dek-1"))
stats = cache.Stats()
if stats.Size != 1 {
t.Errorf("Size should be 1 after put, got %d", stats.Size)
}
}
// TestDEKCache_EmptyDEK tests that empty DEK is rejected.
func TestDEKCache_EmptyDEK(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
err := cache.Put("tenant-1", "artifact-1", "kms-key-1", []byte{})
if err == nil {
t.Error("Should reject empty DEK")
}
}
// TestDEKCache_Isolation tests that DEKs are isolated between tenants.
func TestDEKCache_Isolation(t *testing.T) {
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
// Same artifact ID, different tenants
cache.Put("tenant-1", "shared-artifact", "kms-key-1", []byte("dek-for-tenant-1"))
cache.Put("tenant-2", "shared-artifact", "kms-key-2", []byte("dek-for-tenant-2"))
// Each tenant should get their own DEK
d1, ok := cache.Get("tenant-1", "shared-artifact", "kms-key-1", false)
if !ok || string(d1) != "dek-for-tenant-1" {
t.Error("tenant-1 should get their own DEK")
}
d2, ok := cache.Get("tenant-2", "shared-artifact", "kms-key-2", false)
if !ok || string(d2) != "dek-for-tenant-2" {
t.Error("tenant-2 should get their own DEK")
}
}

View file

@ -0,0 +1,330 @@
package kms_test
import (
"bytes"
"context"
"encoding/binary"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/crypto"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
)
func TestProtocolSerialization(t *testing.T) {
// Test success packet
successPacket := api.NewSuccessPacket("Operation completed successfully")
data, err := successPacket.Serialize()
if err != nil {
t.Fatalf("Failed to serialize success packet: %v", err)
}
// Verify packet type
if len(data) < 1 || data[0] != api.PacketTypeSuccess {
t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0])
}
// Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp)
if len(data) < 9 {
t.Errorf("Expected at least 9 bytes, got %d", len(data))
}
// Test error packet
errorPacket := api.NewErrorPacket(api.ErrorCodeAuthenticationFailed, "Auth failed", "Invalid API key")
data, err = errorPacket.Serialize()
if err != nil {
t.Fatalf("Failed to serialize error packet: %v", err)
}
if len(data) < 1 || data[0] != api.PacketTypeError {
t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0])
}
// Test progress packet
progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...")
data, err = progressPacket.Serialize()
if err != nil {
t.Fatalf("Failed to serialize progress packet: %v", err)
}
if len(data) < 1 || data[0] != api.PacketTypeProgress {
t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0])
}
// Test status packet
statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`)
data, err = statusPacket.Serialize()
if err != nil {
t.Fatalf("Failed to serialize status packet: %v", err)
}
if len(data) < 1 || data[0] != api.PacketTypeStatus {
t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0])
}
}
func TestErrorMessageMapping(t *testing.T) {
tests := map[byte]string{
api.ErrorCodeUnknownError: "Unknown error occurred",
api.ErrorCodeAuthenticationFailed: "Authentication failed",
api.ErrorCodeJobNotFound: "Job not found",
api.ErrorCodeServerOverloaded: "Server is overloaded",
}
for code, expected := range tests {
actual := api.GetErrorMessage(code)
if actual != expected {
t.Errorf("Expected error message '%s' for code %d, got '%s'", expected, code, actual)
}
}
}
func TestLogLevelMapping(t *testing.T) {
tests := map[byte]string{
api.LogLevelDebug: "DEBUG",
api.LogLevelInfo: "INFO",
api.LogLevelWarn: "WARN",
api.LogLevelError: "ERROR",
}
for level, expected := range tests {
actual := api.GetLogLevelName(level)
if actual != expected {
t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual)
}
}
}
func TestTimestampConsistency(t *testing.T) {
before := time.Now().Unix()
packet := api.NewSuccessPacket("Test message")
data, err := packet.Serialize()
if err != nil {
t.Fatalf("Failed to serialize: %v", err)
}
after := time.Now().Unix()
// Extract timestamp (bytes 1-8, big-endian)
if len(data) < 9 {
t.Fatalf("Packet too short: %d bytes", len(data))
}
timestamp := binary.BigEndian.Uint64(data[1:9])
if timestamp < uint64(before) || timestamp > uint64(after) {
t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after)
}
}
// TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol.
func TestKMSProtocol_EncryptDecrypt(t *testing.T) {
// Create memory provider for testing
provider := kms.NewMemoryProvider()
defer provider.Close()
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
tkm := crypto.NewTenantKeyManager(provider, cache, config)
// Provision tenant
hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant")
if err != nil {
t.Fatalf("ProvisionTenant failed: %v", err)
}
// Test data - simulate artifact data
plaintext := []byte("sensitive model weights and training data")
// Encrypt
encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext)
if err != nil {
t.Fatalf("EncryptArtifact failed: %v", err)
}
// Verify encrypted structure
if encrypted.Ciphertext == "" {
t.Error("Ciphertext should not be empty")
}
if encrypted.DEK == nil {
t.Error("DEK should not be nil")
}
if encrypted.KMSKeyID != hierarchy.KMSKeyID {
t.Error("KMSKeyID should match")
}
if encrypted.Algorithm != "AES-256-GCM" {
t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm)
}
// Decrypt
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
if err != nil {
t.Fatalf("DecryptArtifact failed: %v", err)
}
// Verify round-trip
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext)
}
}
// TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data.
func TestKMSProtocol_MultiTenantIsolation(t *testing.T) {
provider := kms.NewMemoryProvider()
defer provider.Close()
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
tkm := crypto.NewTenantKeyManager(provider, cache, config)
// Provision two tenants
tenant1, err := tkm.ProvisionTenant("tenant-1")
if err != nil {
t.Fatalf("Failed to provision tenant-1: %v", err)
}
tenant2, err := tkm.ProvisionTenant("tenant-2")
if err != nil {
t.Fatalf("Failed to provision tenant-2: %v", err)
}
// Encrypt data for tenant-1
plaintext := []byte("tenant-1 secret data")
encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext)
if err != nil {
t.Fatalf("Encrypt failed: %v", err)
}
// Attempt to decrypt with tenant-2's key - should fail
_, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID)
if err == nil {
t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)")
}
// Tenant-1 should still be able to decrypt
decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID)
if err != nil {
t.Fatalf("Tenant-1 decrypt failed: %v", err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Error("Tenant-1 should decrypt their own data correctly")
}
}
// TestKMSProtocol_CacheHit verifies cached DEKs work correctly.
func TestKMSProtocol_CacheHit(t *testing.T) {
provider := kms.NewMemoryProvider()
defer provider.Close()
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
tkm := crypto.NewTenantKeyManager(provider, cache, config)
hierarchy, _ := tkm.ProvisionTenant("cache-test")
plaintext := []byte("test data for caching")
// First encrypt
encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext)
// Decrypt multiple times - should hit cache
for i := 0; i < 3; i++ {
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
if err != nil {
t.Fatalf("Decrypt %d failed: %v", i, err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Decrypt %d: data mismatch", i)
}
}
// Verify cache has entries
stats := cache.Stats()
if stats.Size == 0 {
t.Error("Cache should have entries after operations")
}
}
// TestKMSProtocol_KeyRotation tests key rotation protocol.
func TestKMSProtocol_KeyRotation(t *testing.T) {
provider := kms.NewMemoryProvider()
defer provider.Close()
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
tkm := crypto.NewTenantKeyManager(provider, cache, config)
// Provision tenant
hierarchy, _ := tkm.ProvisionTenant("rotation-test")
oldKeyID := hierarchy.KMSKeyID
// Rotate key
newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy)
if err != nil {
t.Fatalf("Key rotation failed: %v", err)
}
if newHierarchy.KMSKeyID == oldKeyID {
t.Error("New key should have different ID after rotation")
}
// Cache should be flushed after rotation
stats := cache.Stats()
if stats.Size != 0 {
t.Error("Cache should be flushed after key rotation")
}
// Encrypt with new key
plaintext2 := []byte("data encrypted with new key")
encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2)
// Decrypt with new key
decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID)
if err != nil {
t.Fatalf("Decrypt with new key failed: %v", err)
}
if !bytes.Equal(decrypted2, plaintext2) {
t.Error("Data encrypted with new key should decrypt correctly")
}
}
// TestKMSProvider_HealthCheck tests health check protocol.
func TestKMSProvider_HealthCheck(t *testing.T) {
provider := kms.NewMemoryProvider()
defer provider.Close()
ctx := context.Background()
// Memory provider should always be healthy
if err := provider.HealthCheck(ctx); err != nil {
t.Errorf("Memory provider health check failed: %v", err)
}
}