package kms_test import ( "bytes" "context" "encoding/binary" "testing" "time" "github.com/jfraeys/fetch_ml/internal/api" "github.com/jfraeys/fetch_ml/internal/crypto" "github.com/jfraeys/fetch_ml/internal/crypto/kms" kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" ) func TestProtocolSerialization(t *testing.T) { // Test success packet successPacket := api.NewSuccessPacket("Operation completed successfully") data, err := successPacket.Serialize() if err != nil { t.Fatalf("Failed to serialize success packet: %v", err) } // Verify packet type if len(data) < 1 || data[0] != api.PacketTypeSuccess { t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0]) } // Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp) if len(data) < 9 { t.Errorf("Expected at least 9 bytes, got %d", len(data)) } // Test error packet - uses string error code from errors package errorPacket := api.NewErrorPacket("AUTHENTICATION_FAILED", "Auth failed", "Invalid API key") data, err = errorPacket.Serialize() if err != nil { t.Fatalf("Failed to serialize error packet: %v", err) } if len(data) < 1 || data[0] != api.PacketTypeError { t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0]) } // Test progress packet progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...") data, err = progressPacket.Serialize() if err != nil { t.Fatalf("Failed to serialize progress packet: %v", err) } if len(data) < 1 || data[0] != api.PacketTypeProgress { t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0]) } // Test status packet statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`) data, err = statusPacket.Serialize() if err != nil { t.Fatalf("Failed to serialize status packet: %v", err) } if len(data) < 1 || data[0] != api.PacketTypeStatus { t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0]) } } func TestByteCodeFromErrorCode(t *testing.T) { tests := map[string]byte{ "UNKNOWN_ERROR": api.ErrorCodeUnknownError, "AUTHENTICATION_FAILED": api.ErrorCodeAuthenticationFailed, "JOB_NOT_FOUND": api.ErrorCodeJobNotFound, "SERVER_OVERLOADED": api.ErrorCodeServerOverloaded, "INVALID_REQUEST": api.ErrorCodeInvalidRequest, "BAD_REQUEST": api.ErrorCodeInvalidRequest, "PERMISSION_DENIED": api.ErrorCodePermissionDenied, "FORBIDDEN": api.ErrorCodePermissionDenied, } for code, expectedByte := range tests { actual := api.ByteCodeFromErrorCode(code) if actual != expectedByte { t.Errorf("Expected byte %d for code '%s', got %d", expectedByte, code, actual) } } } func TestLogLevelMapping(t *testing.T) { tests := map[byte]string{ api.LogLevelDebug: "DEBUG", api.LogLevelInfo: "INFO", api.LogLevelWarn: "WARN", api.LogLevelError: "ERROR", } for level, expected := range tests { actual := api.GetLogLevelName(level) if actual != expected { t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual) } } } func TestTimestampConsistency(t *testing.T) { before := time.Now().Unix() packet := api.NewSuccessPacket("Test message") data, err := packet.Serialize() if err != nil { t.Fatalf("Failed to serialize: %v", err) } after := time.Now().Unix() // Extract timestamp (bytes 1-8, big-endian) if len(data) < 9 { t.Fatalf("Packet too short: %d bytes", len(data)) } timestamp := binary.BigEndian.Uint64(data[1:9]) if timestamp < uint64(before) || timestamp > uint64(after) { t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after) } } // TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol. func TestKMSProtocol_EncryptDecrypt(t *testing.T) { // Create memory provider for testing provider := kms.NewMemoryProvider() defer provider.Close() cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) defer cache.Clear() config := kmsconfig.Config{ Provider: kms.ProviderTypeMemory, Cache: kmsconfig.DefaultCacheConfig(), } tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) // Provision tenant hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant") if err != nil { t.Fatalf("ProvisionTenant failed: %v", err) } // Test data - simulate artifact data plaintext := []byte("sensitive model weights and training data") // Encrypt encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext) if err != nil { t.Fatalf("EncryptArtifact failed: %v", err) } // Verify encrypted structure if encrypted.Ciphertext == "" { t.Error("Ciphertext should not be empty") } if encrypted.DEK == nil { t.Error("DEK should not be nil") } if encrypted.KMSKeyID != hierarchy.KMSKeyID { t.Error("KMSKeyID should match") } if encrypted.Algorithm != "AES-256-GCM" { t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm) } // Decrypt decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) if err != nil { t.Fatalf("DecryptArtifact failed: %v", err) } // Verify round-trip if !bytes.Equal(decrypted, plaintext) { t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext) } } // TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data. func TestKMSProtocol_MultiTenantIsolation(t *testing.T) { provider := kms.NewMemoryProvider() defer provider.Close() cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) defer cache.Clear() config := kmsconfig.Config{ Provider: kms.ProviderTypeMemory, Cache: kmsconfig.DefaultCacheConfig(), } tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) // Provision two tenants tenant1, err := tkm.ProvisionTenant("tenant-1") if err != nil { t.Fatalf("Failed to provision tenant-1: %v", err) } tenant2, err := tkm.ProvisionTenant("tenant-2") if err != nil { t.Fatalf("Failed to provision tenant-2: %v", err) } // Encrypt data for tenant-1 plaintext := []byte("tenant-1 secret data") encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext) if err != nil { t.Fatalf("Encrypt failed: %v", err) } // Attempt to decrypt with tenant-2's key - should fail _, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID) if err == nil { t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)") } // Tenant-1 should still be able to decrypt decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID) if err != nil { t.Fatalf("Tenant-1 decrypt failed: %v", err) } if !bytes.Equal(decrypted, plaintext) { t.Error("Tenant-1 should decrypt their own data correctly") } } // TestKMSProtocol_CacheHit verifies cached DEKs work correctly. func TestKMSProtocol_CacheHit(t *testing.T) { provider := kms.NewMemoryProvider() defer provider.Close() cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) defer cache.Clear() config := kmsconfig.Config{ Provider: kms.ProviderTypeMemory, Cache: kmsconfig.DefaultCacheConfig(), } tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) hierarchy, _ := tkm.ProvisionTenant("cache-test") plaintext := []byte("test data for caching") // First encrypt encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext) // Decrypt multiple times - should hit cache for i := 0; i < 3; i++ { decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID) if err != nil { t.Fatalf("Decrypt %d failed: %v", i, err) } if !bytes.Equal(decrypted, plaintext) { t.Errorf("Decrypt %d: data mismatch", i) } } // Verify cache has entries stats := cache.Stats() if stats.Size == 0 { t.Error("Cache should have entries after operations") } } // TestKMSProtocol_KeyRotation tests key rotation protocol. func TestKMSProtocol_KeyRotation(t *testing.T) { provider := kms.NewMemoryProvider() defer provider.Close() cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) defer cache.Clear() config := kmsconfig.Config{ Provider: kms.ProviderTypeMemory, Cache: kmsconfig.DefaultCacheConfig(), } tkm := crypto.NewTenantKeyManager(provider, cache, config, nil) // Provision tenant hierarchy, _ := tkm.ProvisionTenant("rotation-test") oldKeyID := hierarchy.KMSKeyID // Rotate key newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy) if err != nil { t.Fatalf("Key rotation failed: %v", err) } if newHierarchy.KMSKeyID == oldKeyID { t.Error("New key should have different ID after rotation") } // Cache should be flushed after rotation stats := cache.Stats() if stats.Size != 0 { t.Error("Cache should be flushed after key rotation") } // Encrypt with new key plaintext2 := []byte("data encrypted with new key") encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2) // Decrypt with new key decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID) if err != nil { t.Fatalf("Decrypt with new key failed: %v", err) } if !bytes.Equal(decrypted2, plaintext2) { t.Error("Data encrypted with new key should decrypt correctly") } } // TestKMSProvider_HealthCheck tests health check protocol. func TestKMSProvider_HealthCheck(t *testing.T) { provider := kms.NewMemoryProvider() defer provider.Close() ctx := context.Background() // Memory provider should always be healthy if err := provider.HealthCheck(ctx); err != nil { t.Errorf("Memory provider health check failed: %v", err) } }