fetch_ml/tests/integration/kms_integration_test.go
Jeremie Fraeys 5f53104fcd
test: modernize test suite for streamlined infrastructure
- Update E2E tests for consolidated docker-compose.test.yml
- Remove references to obsolete logs-debug.yml
- Enhance test fixtures and utilities
- Improve integration test coverage for KMS, queue, scheduler
- Update unit tests for config constants and worker execution
- Modernize cleanup-status.sh with new Makefile targets
2026-03-04 13:24:24 -05:00

297 lines
8.1 KiB
Go

package tests_test
import (
"context"
"os"
"os/exec"
"testing"
"time"
"github.com/jfraeys/fetch_ml/internal/crypto"
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
"github.com/jfraeys/fetch_ml/internal/crypto/kms/providers"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// TestVaultProvider_Integration tests the Vault provider with a real Vault container.
func TestVaultProvider_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Skip if Docker is not available
if _, err := exec.LookPath("docker"); err != nil {
t.Skip("Docker not available, skipping container-based test")
}
if err := exec.Command("docker", "ps").Run(); err != nil {
t.Skip("Docker daemon not running, skipping container-based test")
}
// Testcontainers requires Docker socket access - skip if not available
if os.Getenv("DOCKER_HOST") == "" {
// Check for default Docker socket locations
dockerSocketPaths := []string{
"/var/run/docker.sock",
"/run/docker.sock",
os.Getenv("HOME") + "/.docker/run/docker.sock", // rootless Docker
}
socketFound := false
for _, path := range dockerSocketPaths {
if _, err := os.Stat(path); err == nil {
socketFound = true
break
}
}
if !socketFound {
t.Skip("Docker socket not found, skipping container-based test")
}
}
ctx := context.Background()
// Start Vault container
req := testcontainers.ContainerRequest{
Image: "hashicorp/vault:1.15",
ExposedPorts: []string{"8200/tcp"},
Env: map[string]string{
"VAULT_DEV_ROOT_TOKEN_ID": "test-token",
"VAULT_ADDR": "http://0.0.0.0:8200",
},
WaitingFor: wait.ForLog("Vault server started").WithStartupTimeout(30 * time.Second),
}
vaultC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Fatalf("Failed to start Vault container: %v", err)
}
defer vaultC.Terminate(ctx)
// Get container address
endpoint, err := vaultC.Endpoint(ctx, "http")
if err != nil {
t.Fatalf("Failed to get Vault endpoint: %v", err)
}
// Create provider config
config := kms.VaultConfig{
Address: endpoint,
AuthMethod: "token",
Token: "test-token",
TransitMount: "transit",
KeyPrefix: "test-tenant",
Timeout: 10 * time.Second,
}
// Create provider
provider, err := providers.NewVaultProvider(config)
if err != nil {
t.Fatalf("Failed to create Vault provider: %v", err)
}
defer provider.Close()
// Test health check
if err := provider.HealthCheck(ctx); err != nil {
t.Logf("Vault health check warning (may need Transit setup): %v", err)
}
// Test key creation
keyID, err := provider.CreateKey(ctx, "integration-test")
if err != nil {
t.Logf("Key creation may need manual Transit setup: %v", err)
} else {
t.Logf("Created key: %s", keyID)
}
}
// TestAWSKMSProvider_Integration tests the AWS KMS provider with LocalStack.
func TestAWSKMSProvider_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Skip if Docker is not available
if _, err := exec.LookPath("docker"); err != nil {
t.Skip("Docker not available, skipping container-based test")
}
if err := exec.Command("docker", "ps").Run(); err != nil {
t.Skip("Docker daemon not running, skipping container-based test")
}
// Testcontainers requires Docker socket access - skip if not available
if os.Getenv("DOCKER_HOST") == "" {
// Check for default Docker socket locations
dockerSocketPaths := []string{
"/var/run/docker.sock",
"/run/docker.sock",
os.Getenv("HOME") + "/.docker/run/docker.sock", // rootless Docker
}
socketFound := false
for _, path := range dockerSocketPaths {
if _, err := os.Stat(path); err == nil {
socketFound = true
break
}
}
if !socketFound {
t.Skip("Docker socket not found, skipping container-based test")
}
}
ctx := context.Background()
// Start LocalStack container with KMS
req := testcontainers.ContainerRequest{
Image: "localstack/localstack:latest",
ExposedPorts: []string{"4566/tcp"},
Env: map[string]string{
"SERVICES": "kms",
"DEFAULT_REGION": "us-east-1",
"AWS_DEFAULT_REGION": "us-east-1",
},
WaitingFor: wait.ForLog("Ready").WithStartupTimeout(60 * time.Second),
}
localstackC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
t.Fatalf("Failed to start LocalStack container: %v", err)
}
defer localstackC.Terminate(ctx)
// Get container address
endpoint, err := localstackC.Endpoint(ctx, "http")
if err != nil {
t.Fatalf("Failed to get LocalStack endpoint: %v", err)
}
// Create provider config
config := kms.AWSConfig{
Region: "us-east-1",
KeyAliasPrefix: "alias/test-fetchml",
Endpoint: endpoint,
}
// Create provider
provider, err := providers.NewAWSProvider(config)
if err != nil {
t.Fatalf("Failed to create AWS provider: %v", err)
}
defer provider.Close()
// Note: LocalStack KMS requires credentials even though they're not validated
// Set dummy credentials via environment for the test
t.Setenv("AWS_ACCESS_KEY_ID", "test")
t.Setenv("AWS_SECRET_ACCESS_KEY", "test")
// Test health check
if err := provider.HealthCheck(ctx); err != nil {
t.Logf("AWS health check (may need retry): %v", err)
}
// Test key creation
keyID, err := provider.CreateKey(ctx, "integration-test")
if err != nil {
t.Logf("Key creation may need retry: %v", err)
} else {
t.Logf("Created key: %s", keyID)
}
}
// TestTenantKeyManager_WithMemoryProvider tests the full TenantKeyManager with MemoryProvider.
func TestTenantKeyManager_WithMemoryProvider(t *testing.T) {
// Create memory provider for testing
provider := kms.NewMemoryProvider()
defer provider.Close()
// Create DEK cache
cache := kms.NewDEKCache(kms.DefaultCacheConfig())
defer cache.Clear()
// Create config
config := kms.Config{
Provider: kms.ProviderTypeMemory,
Cache: kms.DefaultCacheConfig(),
}
// Create TenantKeyManager
tkm := crypto.NewTenantKeyManager(provider, cache, config, nil)
ctx := context.Background()
ctx = context.WithValue(ctx, "test", true)
// Test provisioning
hierarchy, err := tkm.ProvisionTenant("test-tenant")
if err != nil {
t.Fatalf("ProvisionTenant failed: %v", err)
}
if hierarchy.TenantID != "test-tenant" {
t.Errorf("Expected tenant ID 'test-tenant', got '%s'", hierarchy.TenantID)
}
if hierarchy.KMSKeyID == "" {
t.Error("KMSKeyID should not be empty")
}
t.Logf("Provisioned tenant with KMSKeyID: %s", hierarchy.KMSKeyID)
// Test encryption
plaintext := []byte("sensitive data that needs encryption")
encrypted, err := tkm.EncryptArtifact("test-tenant", "artifact-1", hierarchy.KMSKeyID, plaintext)
if err != nil {
t.Fatalf("EncryptArtifact failed: %v", err)
}
if encrypted.Ciphertext == "" {
t.Error("Ciphertext should not be empty")
}
if encrypted.KMSKeyID != hierarchy.KMSKeyID {
t.Error("EncryptedArtifact should store KMSKeyID")
}
t.Logf("Encrypted data successfully")
// Test decryption
decrypted, err := tkm.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
if err != nil {
t.Fatalf("DecryptArtifact failed: %v", err)
}
if string(decrypted) != string(plaintext) {
t.Errorf("Decrypted data doesn't match: got %s, want %s", decrypted, plaintext)
}
t.Logf("Decrypted data successfully")
// Test that cache is working
stats := cache.Stats()
if stats.Size == 0 {
t.Error("Cache should have entries after encrypt/decrypt")
}
t.Logf("Cache stats: %+v", stats)
// Test key rotation
newHierarchy, err := tkm.RotateTenantKey("test-tenant", hierarchy)
if err != nil {
t.Fatalf("RotateTenantKey failed: %v", err)
}
if newHierarchy.KMSKeyID == hierarchy.KMSKeyID {
t.Error("Rotated key should have different KMSKeyID")
}
t.Logf("Rotated key from %s to %s", hierarchy.KMSKeyID, newHierarchy.KMSKeyID)
// Test tenant revocation
if err := tkm.RevokeTenant(hierarchy); err != nil {
t.Fatalf("RevokeTenant failed: %v", err)
}
t.Logf("Revoked tenant successfully")
}