diff --git a/internal/crypto/tenant_keys_test.go b/internal/crypto/tenant_keys_test.go new file mode 100644 index 0000000..3804f7c --- /dev/null +++ b/internal/crypto/tenant_keys_test.go @@ -0,0 +1,303 @@ +package crypto_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/crypto" + "github.com/jfraeys/fetch_ml/internal/crypto/kms" + kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupTenantKeyManager creates a test TenantKeyManager with memory provider +func setupTenantKeyManager(t *testing.T) *crypto.TenantKeyManager { + t.Helper() + return crypto.NewTestTenantKeyManager(nil) +} + +// TestNewTenantKeyManager tests the constructor +func TestNewTenantKeyManager(t *testing.T) { + t.Parallel() + + provider := kms.NewMemoryProvider() + cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig()) + config := kmsconfig.Config{Provider: kmsconfig.ProviderTypeMemory} + + km := crypto.NewTenantKeyManager(provider, cache, config, nil) + require.NotNil(t, km) +} + +// TestNewTestTenantKeyManager tests the test constructor +func TestNewTestTenantKeyManager(t *testing.T) { + t.Parallel() + + km := crypto.NewTestTenantKeyManager(nil) + require.NotNil(t, km) +} + +// TestProvisionTenant tests creating tenant root keys +func TestProvisionTenant(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + cases := []struct { + name string + tenantID string + wantErr bool + }{ + {"valid tenant", "tenant-1", false}, + {"another tenant", "tenant-2", false}, + {"empty tenant", "", true}, + {"whitespace tenant", " ", true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + hierarchy, err := km.ProvisionTenant(tc.tenantID) + if tc.wantErr { + require.Error(t, err) + assert.Nil(t, hierarchy) + return + } + require.NoError(t, err) + assert.NotNil(t, hierarchy) + assert.Equal(t, tc.tenantID, hierarchy.TenantID) + assert.NotEmpty(t, hierarchy.RootKeyID) + assert.NotEmpty(t, hierarchy.KMSKeyID) + assert.Equal(t, "AES-256-GCM", hierarchy.Algorithm) + assert.False(t, hierarchy.CreatedAt.IsZero()) + }) + } +} + +// TestRotateTenantKey tests key rotation +func TestRotateTenantKey(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision initial key + initial, err := km.ProvisionTenant("rotate-tenant") + require.NoError(t, err) + initialKeyID := initial.KMSKeyID + + // Rotate key + rotated, err := km.RotateTenantKey("rotate-tenant", initial) + require.NoError(t, err) + assert.NotNil(t, rotated) + assert.Equal(t, "rotate-tenant", rotated.TenantID) + assert.NotEqual(t, initialKeyID, rotated.KMSKeyID) + assert.NotEqual(t, initial.RootKeyID, rotated.RootKeyID) +} + +// TestRevokeTenant tests tenant key revocation +func TestRevokeTenant(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision a tenant + hierarchy, err := km.ProvisionTenant("revoke-tenant") + require.NoError(t, err) + + // Revoke tenant + err = km.RevokeTenant(hierarchy) + require.NoError(t, err) +} + +// TestGenerateDataEncryptionKey tests DEK generation +func TestGenerateDataEncryptionKey(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision tenant first + hierarchy, err := km.ProvisionTenant("dek-tenant") + require.NoError(t, err) + + // Generate DEK + wrappedDEK, err := km.GenerateDataEncryptionKey("dek-tenant", "artifact-1", hierarchy.KMSKeyID) + require.NoError(t, err) + assert.NotNil(t, wrappedDEK) + assert.Equal(t, "dek-tenant", wrappedDEK.TenantID) + assert.Equal(t, "artifact-1", wrappedDEK.ArtifactID) + assert.NotEmpty(t, wrappedDEK.WrappedKey) + assert.Equal(t, "AES-256-GCM", wrappedDEK.Algorithm) + assert.False(t, wrappedDEK.CreatedAt.IsZero()) +} + +// TestUnwrapDataEncryptionKey tests DEK unwrapping +func TestUnwrapDataEncryptionKey(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision tenant and generate DEK + hierarchy, err := km.ProvisionTenant("unwrap-tenant") + require.NoError(t, err) + + wrappedDEK, err := km.GenerateDataEncryptionKey("unwrap-tenant", "artifact-unwrap", hierarchy.KMSKeyID) + require.NoError(t, err) + + // Unwrap DEK + dek, err := km.UnwrapDataEncryptionKey(wrappedDEK, hierarchy.KMSKeyID) + require.NoError(t, err) + assert.NotNil(t, dek) + assert.Len(t, dek, 32) // AES-256 key +} + +// TestEncryptArtifact tests artifact encryption +func TestEncryptArtifact(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision tenant + hierarchy, err := km.ProvisionTenant("encrypt-tenant") + require.NoError(t, err) + + // Encrypt data + plaintext := []byte("sensitive data for encryption test") + encrypted, err := km.EncryptArtifact("encrypt-tenant", "artifact-enc", hierarchy.KMSKeyID, plaintext) + require.NoError(t, err) + assert.NotNil(t, encrypted) + assert.NotEmpty(t, encrypted.Ciphertext) + assert.NotNil(t, encrypted.DEK) + assert.Equal(t, hierarchy.KMSKeyID, encrypted.KMSKeyID) + assert.Equal(t, "AES-256-GCM", encrypted.Algorithm) +} + +// TestDecryptArtifact tests artifact decryption +func TestDecryptArtifact(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision tenant + hierarchy, err := km.ProvisionTenant("decrypt-tenant") + require.NoError(t, err) + + // Encrypt data + plaintext := []byte("secret message for decryption test") + encrypted, err := km.EncryptArtifact("decrypt-tenant", "artifact-dec", hierarchy.KMSKeyID, plaintext) + require.NoError(t, err) + + // Decrypt data + decrypted, err := km.DecryptArtifact(encrypted, hierarchy.KMSKeyID) + require.NoError(t, err) + assert.Equal(t, plaintext, decrypted) +} + +// TestEncryptDecryptRoundtrip tests full encryption/decryption cycle +func TestEncryptDecryptRoundtrip(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision tenant + hierarchy, err := km.ProvisionTenant("roundtrip-tenant") + require.NoError(t, err) + + // Test various payloads + payloads := [][]byte{ + []byte("short"), + []byte("this is a longer message with more content"), + []byte{}, + []byte{0x00, 0x01, 0x02, 0x03, 0xFF}, + } + + for i, payload := range payloads { + artifactID := "artifact-rt-" + string(rune('a'+i)) + encrypted, err := km.EncryptArtifact("roundtrip-tenant", artifactID, hierarchy.KMSKeyID, payload) + require.NoError(t, err) + + decrypted, err := km.DecryptArtifact(encrypted, hierarchy.KMSKeyID) + require.NoError(t, err) + // Handle nil vs empty slice comparison + if len(payload) == 0 { + assert.Empty(t, decrypted, "Payload %d should decrypt to empty", i) + } else { + assert.Equal(t, payload, decrypted, "Payload %d should decrypt correctly", i) + } + } +} + +// TestKeyHierarchyStructure tests KeyHierarchy fields +func TestKeyHierarchyStructure(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + hierarchy, err := km.ProvisionTenant("struct-tenant") + require.NoError(t, err) + + // Verify structure + assert.NotEmpty(t, hierarchy.TenantID) + assert.NotEmpty(t, hierarchy.RootKeyID) + assert.NotEmpty(t, hierarchy.KMSKeyID) + assert.NotEmpty(t, hierarchy.Algorithm) + assert.False(t, hierarchy.CreatedAt.IsZero()) +} + +// TestWrappedDEKStructure tests WrappedDEK fields +func TestWrappedDEKStructure(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + hierarchy, err := km.ProvisionTenant("wrapped-dek-tenant") + require.NoError(t, err) + + wrappedDEK, err := km.GenerateDataEncryptionKey("wrapped-dek-tenant", "artifact-wrapped", hierarchy.KMSKeyID) + require.NoError(t, err) + + assert.NotEmpty(t, wrappedDEK.TenantID) + assert.NotEmpty(t, wrappedDEK.ArtifactID) + assert.NotEmpty(t, wrappedDEK.WrappedKey) + assert.NotEmpty(t, wrappedDEK.Algorithm) + assert.False(t, wrappedDEK.CreatedAt.IsZero()) +} + +// TestEncryptedArtifactStructure tests EncryptedArtifact fields +func TestEncryptedArtifactStructure(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + hierarchy, err := km.ProvisionTenant("enc-art-tenant") + require.NoError(t, err) + + plaintext := []byte("test data") + encrypted, err := km.EncryptArtifact("enc-art-tenant", "artifact-struct", hierarchy.KMSKeyID, plaintext) + require.NoError(t, err) + + assert.NotEmpty(t, encrypted.Ciphertext) + assert.NotNil(t, encrypted.DEK) + assert.NotEmpty(t, encrypted.KMSKeyID) + assert.NotEmpty(t, encrypted.Algorithm) +} + +// TestMultipleTenantsIsolation tests that tenants are properly isolated +func TestMultipleTenantsIsolation(t *testing.T) { + t.Parallel() + + km := setupTenantKeyManager(t) + + // Provision two tenants + tenant1, err := km.ProvisionTenant("isolation-tenant-1") + require.NoError(t, err) + + tenant2, err := km.ProvisionTenant("isolation-tenant-2") + require.NoError(t, err) + + // Encrypt with tenant 1 + plaintext := []byte("cross-tenant data") + encrypted, err := km.EncryptArtifact("isolation-tenant-1", "cross-artifact", tenant1.KMSKeyID, plaintext) + require.NoError(t, err) + + // Try to decrypt with tenant 2's key (should fail) + _, err = km.DecryptArtifact(encrypted, tenant2.KMSKeyID) + assert.Error(t, err) +} diff --git a/internal/security/monitor_test.go b/internal/security/monitor_test.go new file mode 100644 index 0000000..18e6165 --- /dev/null +++ b/internal/security/monitor_test.go @@ -0,0 +1,298 @@ +package security_test + +import ( + "sync" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/security" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewSlidingWindow tests sliding window creation +func TestNewSlidingWindow(t *testing.T) { + t.Parallel() + + window := security.NewSlidingWindow(5 * time.Minute) + require.NotNil(t, window) + assert.Equal(t, 0, window.Count()) +} + +// TestSlidingWindowAddAndCount tests adding events and counting +func TestSlidingWindowAddAndCount(t *testing.T) { + t.Parallel() + + window := security.NewSlidingWindow(1 * time.Second) + + // Add events + window.Add(time.Now()) + window.Add(time.Now()) + window.Add(time.Now()) + + assert.Equal(t, 3, window.Count()) +} + +// TestSlidingWindowExpiration tests that old events expire +func TestSlidingWindowExpiration(t *testing.T) { + t.Parallel() + + window := security.NewSlidingWindow(100 * time.Millisecond) + + // Add event + window.Add(time.Now()) + assert.Equal(t, 1, window.Count()) + + // Wait for expiration + time.Sleep(150 * time.Millisecond) + assert.Equal(t, 0, window.Count()) +} + +// TestNewAnomalyMonitor tests monitor creation +func TestNewAnomalyMonitor(t *testing.T) { + t.Parallel() + + var alerts []security.Alert + handler := func(a security.Alert) { + alerts = append(alerts, a) + } + + monitor := security.NewAnomalyMonitor(handler) + require.NotNil(t, monitor) + + stats := monitor.GetStats() + assert.Equal(t, 0, stats["privileged_container_attempts"]) + assert.Equal(t, 0, stats["path_traversal_attempts"]) + assert.Equal(t, 0, stats["command_injection_attempts"]) +} + +// TestRecordFailedAuth tests recording failed authentication +func TestRecordFailedAuth(t *testing.T) { + t.Parallel() + + var alertCount int + handler := func(a security.Alert) { + alertCount++ + } + + monitor := security.NewAnomalyMonitor(handler) + + // Record fewer than threshold attempts + for i := 0; i < 5; i++ { + monitor.RecordFailedAuth("192.168.1.1", "user-1") + } + assert.Equal(t, 0, alertCount, "Should not alert below threshold") + + // Record more attempts to reach threshold + for i := 0; i < 10; i++ { + monitor.RecordFailedAuth("192.168.1.1", "user-1") + } + assert.Greater(t, alertCount, 0, "Should alert after threshold") +} + +// TestRecordPrivilegedContainerAttempt tests recording privileged container attempts +func TestRecordPrivilegedContainerAttempt(t *testing.T) { + t.Parallel() + + var alertCount int + handler := func(a security.Alert) { + alertCount++ + } + + monitor := security.NewAnomalyMonitor(handler) + + monitor.RecordPrivilegedContainerAttempt("user-1") + assert.GreaterOrEqual(t, alertCount, 0) + + stats := monitor.GetStats() + assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1) +} + +// TestRecordPathTraversal tests recording path traversal attempts +func TestRecordPathTraversal(t *testing.T) { + t.Parallel() + + var lastAlert security.Alert + handler := func(a security.Alert) { + lastAlert = a + } + + monitor := security.NewAnomalyMonitor(handler) + + monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd") + + assert.Equal(t, security.AlertPathTraversal, lastAlert.Type) + assert.Equal(t, security.SeverityHigh, lastAlert.Severity) + assert.Equal(t, "192.168.1.1", lastAlert.SourceIP) + + stats := monitor.GetStats() + assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1) +} + +// TestRecordCommandInjection tests recording command injection attempts +func TestRecordCommandInjection(t *testing.T) { + t.Parallel() + + var lastAlert security.Alert + handler := func(a security.Alert) { + lastAlert = a + } + + monitor := security.NewAnomalyMonitor(handler) + + monitor.RecordCommandInjection("192.168.1.1", "; rm -rf /") + + assert.Equal(t, security.AlertCommandInjection, lastAlert.Type) + assert.Equal(t, security.SeverityCritical, lastAlert.Severity) + assert.Equal(t, "192.168.1.1", lastAlert.SourceIP) + + stats := monitor.GetStats() + assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1) +} + +// TestGetStats tests statistics retrieval +func TestGetStats(t *testing.T) { + t.Parallel() + + monitor := security.NewAnomalyMonitor(nil) + + // Record various events + monitor.RecordFailedAuth("192.168.1.1", "user-1") + monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd") + monitor.RecordCommandInjection("192.168.1.2", "; rm -rf /") + monitor.RecordPrivilegedContainerAttempt("user-2") + + stats := monitor.GetStats() + + assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1) + assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1) + assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1) + assert.GreaterOrEqual(t, stats["monitored_ips"], 0) +} + +// TestDefaultAlertHandler tests the default alert handler doesn't panic +func TestDefaultAlertHandler(t *testing.T) { + t.Parallel() + + alert := security.Alert{ + Timestamp: time.Now(), + Severity: security.SeverityHigh, + Type: security.AlertBruteForce, + Message: "Test alert", + SourceIP: "192.168.1.1", + UserID: "user-1", + } + + // Should not panic + security.DefaultAlertHandler(alert) +} + +// TestNewLoggingAlertHandler tests the logging alert handler +func TestNewLoggingAlertHandler(t *testing.T) { + t.Parallel() + + var logged bool + logFunc := func(msg string, args ...any) { + logged = true + } + + handler := security.NewLoggingAlertHandler(logFunc) + require.NotNil(t, handler) + + alert := security.Alert{ + Timestamp: time.Now(), + Severity: security.SeverityMedium, + Type: security.AlertRateLimitExceeded, + Message: "Rate limit exceeded", + SourceIP: "192.168.1.1", + UserID: "user-1", + } + + handler(alert) + assert.True(t, logged) +} + +// TestAlertTypes tests alert type constants +func TestAlertTypes(t *testing.T) { + t.Parallel() + + assert.Equal(t, security.AlertSeverity("low"), security.SeverityLow) + assert.Equal(t, security.AlertSeverity("medium"), security.SeverityMedium) + assert.Equal(t, security.AlertSeverity("high"), security.SeverityHigh) + assert.Equal(t, security.AlertSeverity("critical"), security.SeverityCritical) + + assert.Equal(t, security.AlertType("brute_force"), security.AlertBruteForce) + assert.Equal(t, security.AlertType("privilege_escalation"), security.AlertPrivilegeEscalation) + assert.Equal(t, security.AlertType("path_traversal"), security.AlertPathTraversal) + assert.Equal(t, security.AlertType("command_injection"), security.AlertCommandInjection) + assert.Equal(t, security.AlertType("suspicious_container"), security.AlertSuspiciousContainer) + assert.Equal(t, security.AlertType("rate_limit_exceeded"), security.AlertRateLimitExceeded) +} + +// TestAlertStructure tests alert struct fields +func TestAlertStructure(t *testing.T) { + t.Parallel() + + now := time.Now() + metadata := map[string]any{"key": "value"} + + alert := security.Alert{ + Timestamp: now, + Severity: security.SeverityCritical, + Type: security.AlertCommandInjection, + Message: "Test message", + SourceIP: "192.168.1.1", + UserID: "user-123", + Metadata: metadata, + } + + assert.Equal(t, now, alert.Timestamp) + assert.Equal(t, security.SeverityCritical, alert.Severity) + assert.Equal(t, security.AlertCommandInjection, alert.Type) + assert.Equal(t, "Test message", alert.Message) + assert.Equal(t, "192.168.1.1", alert.SourceIP) + assert.Equal(t, "user-123", alert.UserID) + assert.Equal(t, metadata, alert.Metadata) +} + +// TestConcurrentRecordFailedAuth tests concurrent auth recording +func TestConcurrentRecordFailedAuth(t *testing.T) { + t.Parallel() + + var alertCount int + handler := func(a security.Alert) { + alertCount++ + } + + monitor := security.NewAnomalyMonitor(handler) + + // Concurrent failed auth attempts + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + monitor.RecordFailedAuth("192.168.1.1", "user-1") + }() + } + wg.Wait() + + // Should have recorded all attempts + stats := monitor.GetStats() + assert.GreaterOrEqual(t, stats["monitored_ips"], 0) +} + +// TestMultipleIPs tests monitoring multiple IPs independently +func TestMultipleIPs(t *testing.T) { + t.Parallel() + + monitor := security.NewAnomalyMonitor(nil) + + // Record from different IPs + monitor.RecordFailedAuth("192.168.1.1", "user-1") + monitor.RecordFailedAuth("192.168.1.2", "user-2") + monitor.RecordFailedAuth("192.168.1.3", "user-3") + + stats := monitor.GetStats() + assert.GreaterOrEqual(t, stats["monitored_ips"], 3) +}