- Update E2E tests for consolidated docker-compose.test.yml - Remove references to obsolete logs-debug.yml - Enhance test fixtures and utilities - Improve integration test coverage for KMS, queue, scheduler - Update unit tests for config constants and worker execution - Modernize cleanup-status.sh with new Makefile targets
330 lines
9.1 KiB
Go
330 lines
9.1 KiB
Go
package kms_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/crypto"
|
|
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
|
|
)
|
|
|
|
func TestProtocolSerialization(t *testing.T) {
|
|
// Test success packet
|
|
successPacket := api.NewSuccessPacket("Operation completed successfully")
|
|
data, err := successPacket.Serialize()
|
|
if err != nil {
|
|
t.Fatalf("Failed to serialize success packet: %v", err)
|
|
}
|
|
|
|
// Verify packet type
|
|
if len(data) < 1 || data[0] != api.PacketTypeSuccess {
|
|
t.Errorf("Expected packet type %d, got %d", api.PacketTypeSuccess, data[0])
|
|
}
|
|
|
|
// Verify timestamp is present (9 bytes minimum: 1 type + 8 timestamp)
|
|
if len(data) < 9 {
|
|
t.Errorf("Expected at least 9 bytes, got %d", len(data))
|
|
}
|
|
|
|
// Test error packet
|
|
errorPacket := api.NewErrorPacket(api.ErrorCodeAuthenticationFailed, "Auth failed", "Invalid API key")
|
|
data, err = errorPacket.Serialize()
|
|
if err != nil {
|
|
t.Fatalf("Failed to serialize error packet: %v", err)
|
|
}
|
|
|
|
if len(data) < 1 || data[0] != api.PacketTypeError {
|
|
t.Errorf("Expected packet type %d, got %d", api.PacketTypeError, data[0])
|
|
}
|
|
|
|
// Test progress packet
|
|
progressPacket := api.NewProgressPacket(api.ProgressTypePercentage, 75, 100, "Processing...")
|
|
data, err = progressPacket.Serialize()
|
|
if err != nil {
|
|
t.Fatalf("Failed to serialize progress packet: %v", err)
|
|
}
|
|
|
|
if len(data) < 1 || data[0] != api.PacketTypeProgress {
|
|
t.Errorf("Expected packet type %d, got %d", api.PacketTypeProgress, data[0])
|
|
}
|
|
|
|
// Test status packet
|
|
statusPacket := api.NewStatusPacket(`{"workers":1,"queued":0}`)
|
|
data, err = statusPacket.Serialize()
|
|
if err != nil {
|
|
t.Fatalf("Failed to serialize status packet: %v", err)
|
|
}
|
|
|
|
if len(data) < 1 || data[0] != api.PacketTypeStatus {
|
|
t.Errorf("Expected packet type %d, got %d", api.PacketTypeStatus, data[0])
|
|
}
|
|
}
|
|
|
|
func TestErrorMessageMapping(t *testing.T) {
|
|
tests := map[byte]string{
|
|
api.ErrorCodeUnknownError: "Unknown error occurred",
|
|
api.ErrorCodeAuthenticationFailed: "Authentication failed",
|
|
api.ErrorCodeJobNotFound: "Job not found",
|
|
api.ErrorCodeServerOverloaded: "Server is overloaded",
|
|
}
|
|
|
|
for code, expected := range tests {
|
|
actual := api.GetErrorMessage(code)
|
|
if actual != expected {
|
|
t.Errorf("Expected error message '%s' for code %d, got '%s'", expected, code, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestLogLevelMapping(t *testing.T) {
|
|
tests := map[byte]string{
|
|
api.LogLevelDebug: "DEBUG",
|
|
api.LogLevelInfo: "INFO",
|
|
api.LogLevelWarn: "WARN",
|
|
api.LogLevelError: "ERROR",
|
|
}
|
|
|
|
for level, expected := range tests {
|
|
actual := api.GetLogLevelName(level)
|
|
if actual != expected {
|
|
t.Errorf("Expected log level '%s' for level %d, got '%s'", expected, level, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTimestampConsistency(t *testing.T) {
|
|
before := time.Now().Unix()
|
|
|
|
packet := api.NewSuccessPacket("Test message")
|
|
data, err := packet.Serialize()
|
|
if err != nil {
|
|
t.Fatalf("Failed to serialize: %v", err)
|
|
}
|
|
|
|
after := time.Now().Unix()
|
|
|
|
// Extract timestamp (bytes 1-8, big-endian)
|
|
if len(data) < 9 {
|
|
t.Fatalf("Packet too short: %d bytes", len(data))
|
|
}
|
|
|
|
timestamp := binary.BigEndian.Uint64(data[1:9])
|
|
|
|
if timestamp < uint64(before) || timestamp > uint64(after) {
|
|
t.Errorf("Timestamp %d not in expected range [%d, %d]", timestamp, before, after)
|
|
}
|
|
}
|
|
|
|
// TestKMSProtocol_EncryptDecrypt tests the full KMS encryption/decryption protocol.
|
|
func TestKMSProtocol_EncryptDecrypt(t *testing.T) {
|
|
// Create memory provider for testing
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
|
|
defer cache.Clear()
|
|
|
|
config := kms.Config{
|
|
Provider: kms.ProviderTypeMemory,
|
|
Cache: kms.DefaultCacheConfig(),
|
|
}
|
|
|
|
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
|
|
|
// Provision tenant
|
|
hierarchy, err := tkm.ProvisionTenant("protocol-test-tenant")
|
|
if err != nil {
|
|
t.Fatalf("ProvisionTenant failed: %v", err)
|
|
}
|
|
|
|
// Test data - simulate artifact data
|
|
plaintext := []byte("sensitive model weights and training data")
|
|
|
|
// Encrypt
|
|
encrypted, err := tkm.EncryptArtifact("protocol-test-tenant", "model-v1", hierarchy.KMSKeyID, plaintext)
|
|
if err != nil {
|
|
t.Fatalf("EncryptArtifact failed: %v", err)
|
|
}
|
|
|
|
// Verify encrypted structure
|
|
if encrypted.Ciphertext == "" {
|
|
t.Error("Ciphertext should not be empty")
|
|
}
|
|
if encrypted.DEK == nil {
|
|
t.Error("DEK should not be nil")
|
|
}
|
|
if encrypted.KMSKeyID != hierarchy.KMSKeyID {
|
|
t.Error("KMSKeyID should match")
|
|
}
|
|
if encrypted.Algorithm != "AES-256-GCM" {
|
|
t.Errorf("Algorithm should be AES-256-GCM, got %s", encrypted.Algorithm)
|
|
}
|
|
|
|
// Decrypt
|
|
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
|
if err != nil {
|
|
t.Fatalf("DecryptArtifact failed: %v", err)
|
|
}
|
|
|
|
// Verify round-trip
|
|
if !bytes.Equal(decrypted, plaintext) {
|
|
t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext)
|
|
}
|
|
}
|
|
|
|
// TestKMSProtocol_MultiTenantIsolation verifies tenants cannot decrypt each other's data.
|
|
func TestKMSProtocol_MultiTenantIsolation(t *testing.T) {
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
|
|
defer cache.Clear()
|
|
|
|
config := kms.Config{
|
|
Provider: kms.ProviderTypeMemory,
|
|
Cache: kms.DefaultCacheConfig(),
|
|
}
|
|
|
|
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
|
|
|
// Provision two tenants
|
|
tenant1, err := tkm.ProvisionTenant("tenant-1")
|
|
if err != nil {
|
|
t.Fatalf("Failed to provision tenant-1: %v", err)
|
|
}
|
|
|
|
tenant2, err := tkm.ProvisionTenant("tenant-2")
|
|
if err != nil {
|
|
t.Fatalf("Failed to provision tenant-2: %v", err)
|
|
}
|
|
|
|
// Encrypt data for tenant-1
|
|
plaintext := []byte("tenant-1 secret data")
|
|
encrypted, err := tkm.EncryptArtifact("tenant-1", "artifact-1", tenant1.KMSKeyID, plaintext)
|
|
if err != nil {
|
|
t.Fatalf("Encrypt failed: %v", err)
|
|
}
|
|
|
|
// Attempt to decrypt with tenant-2's key - should fail
|
|
_, err = tkm.DecryptArtifact(encrypted, tenant2.KMSKeyID)
|
|
if err == nil {
|
|
t.Error("Tenant-2 should not be able to decrypt tenant-1's data (expected error)")
|
|
}
|
|
|
|
// Tenant-1 should still be able to decrypt
|
|
decrypted, err := tkm.DecryptArtifact(encrypted, tenant1.KMSKeyID)
|
|
if err != nil {
|
|
t.Fatalf("Tenant-1 decrypt failed: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(decrypted, plaintext) {
|
|
t.Error("Tenant-1 should decrypt their own data correctly")
|
|
}
|
|
}
|
|
|
|
// TestKMSProtocol_CacheHit verifies cached DEKs work correctly.
|
|
func TestKMSProtocol_CacheHit(t *testing.T) {
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
|
|
defer cache.Clear()
|
|
|
|
config := kms.Config{
|
|
Provider: kms.ProviderTypeMemory,
|
|
Cache: kms.DefaultCacheConfig(),
|
|
}
|
|
|
|
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
|
|
|
hierarchy, _ := tkm.ProvisionTenant("cache-test")
|
|
|
|
plaintext := []byte("test data for caching")
|
|
|
|
// First encrypt
|
|
encrypted, _ := tkm.EncryptArtifact("cache-test", "cached-artifact", hierarchy.KMSKeyID, plaintext)
|
|
|
|
// Decrypt multiple times - should hit cache
|
|
for i := 0; i < 3; i++ {
|
|
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
|
if err != nil {
|
|
t.Fatalf("Decrypt %d failed: %v", i, err)
|
|
}
|
|
if !bytes.Equal(decrypted, plaintext) {
|
|
t.Errorf("Decrypt %d: data mismatch", i)
|
|
}
|
|
}
|
|
|
|
// Verify cache has entries
|
|
stats := cache.Stats()
|
|
if stats.Size == 0 {
|
|
t.Error("Cache should have entries after operations")
|
|
}
|
|
}
|
|
|
|
// TestKMSProtocol_KeyRotation tests key rotation protocol.
|
|
func TestKMSProtocol_KeyRotation(t *testing.T) {
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
|
|
defer cache.Clear()
|
|
|
|
config := kms.Config{
|
|
Provider: kms.ProviderTypeMemory,
|
|
Cache: kms.DefaultCacheConfig(),
|
|
}
|
|
|
|
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
|
|
|
// Provision tenant
|
|
hierarchy, _ := tkm.ProvisionTenant("rotation-test")
|
|
oldKeyID := hierarchy.KMSKeyID
|
|
|
|
// Rotate key
|
|
newHierarchy, err := tkm.RotateTenantKey("rotation-test", hierarchy)
|
|
if err != nil {
|
|
t.Fatalf("Key rotation failed: %v", err)
|
|
}
|
|
|
|
if newHierarchy.KMSKeyID == oldKeyID {
|
|
t.Error("New key should have different ID after rotation")
|
|
}
|
|
|
|
// Cache should be flushed after rotation
|
|
stats := cache.Stats()
|
|
if stats.Size != 0 {
|
|
t.Error("Cache should be flushed after key rotation")
|
|
}
|
|
|
|
// Encrypt with new key
|
|
plaintext2 := []byte("data encrypted with new key")
|
|
encrypted2, _ := tkm.EncryptArtifact("rotation-test", "post-rotation", newHierarchy.KMSKeyID, plaintext2)
|
|
|
|
// Decrypt with new key
|
|
decrypted2, err := tkm.DecryptArtifact(encrypted2, newHierarchy.KMSKeyID)
|
|
if err != nil {
|
|
t.Fatalf("Decrypt with new key failed: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(decrypted2, plaintext2) {
|
|
t.Error("Data encrypted with new key should decrypt correctly")
|
|
}
|
|
}
|
|
|
|
// TestKMSProvider_HealthCheck tests health check protocol.
|
|
func TestKMSProvider_HealthCheck(t *testing.T) {
|
|
provider := kms.NewMemoryProvider()
|
|
defer provider.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
// Memory provider should always be healthy
|
|
if err := provider.HealthCheck(ctx); err != nil {
|
|
t.Errorf("Memory provider health check failed: %v", err)
|
|
}
|
|
}
|