test(crypto,security): add tenant key manager and anomaly monitor tests
Add comprehensive tests for: - crypto/tenant_keys: KMS integration, key rotation, encryption/decryption - security/monitor: sliding window, anomaly detection, concurrent access Coverage: crypto 65.1%, security 100%
This commit is contained in:
parent
77542b7068
commit
5057f02167
2 changed files with 601 additions and 0 deletions
303
internal/crypto/tenant_keys_test.go
Normal file
303
internal/crypto/tenant_keys_test.go
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
package crypto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto"
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto/kms"
|
||||
kmsconfig "github.com/jfraeys/fetch_ml/internal/crypto/kms/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTenantKeyManager creates a test TenantKeyManager with memory provider
|
||||
func setupTenantKeyManager(t *testing.T) *crypto.TenantKeyManager {
|
||||
t.Helper()
|
||||
return crypto.NewTestTenantKeyManager(nil)
|
||||
}
|
||||
|
||||
// TestNewTenantKeyManager tests the constructor
|
||||
func TestNewTenantKeyManager(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := kms.NewMemoryProvider()
|
||||
cache := kms.NewDEKCache(kmsconfig.DefaultCacheConfig())
|
||||
config := kmsconfig.Config{Provider: kmsconfig.ProviderTypeMemory}
|
||||
|
||||
km := crypto.NewTenantKeyManager(provider, cache, config, nil)
|
||||
require.NotNil(t, km)
|
||||
}
|
||||
|
||||
// TestNewTestTenantKeyManager tests the test constructor
|
||||
func TestNewTestTenantKeyManager(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := crypto.NewTestTenantKeyManager(nil)
|
||||
require.NotNil(t, km)
|
||||
}
|
||||
|
||||
// TestProvisionTenant tests creating tenant root keys
|
||||
func TestProvisionTenant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
tenantID string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid tenant", "tenant-1", false},
|
||||
{"another tenant", "tenant-2", false},
|
||||
{"empty tenant", "", true},
|
||||
{"whitespace tenant", " ", true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hierarchy, err := km.ProvisionTenant(tc.tenantID)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, hierarchy)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, hierarchy)
|
||||
assert.Equal(t, tc.tenantID, hierarchy.TenantID)
|
||||
assert.NotEmpty(t, hierarchy.RootKeyID)
|
||||
assert.NotEmpty(t, hierarchy.KMSKeyID)
|
||||
assert.Equal(t, "AES-256-GCM", hierarchy.Algorithm)
|
||||
assert.False(t, hierarchy.CreatedAt.IsZero())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotateTenantKey tests key rotation
|
||||
func TestRotateTenantKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision initial key
|
||||
initial, err := km.ProvisionTenant("rotate-tenant")
|
||||
require.NoError(t, err)
|
||||
initialKeyID := initial.KMSKeyID
|
||||
|
||||
// Rotate key
|
||||
rotated, err := km.RotateTenantKey("rotate-tenant", initial)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, rotated)
|
||||
assert.Equal(t, "rotate-tenant", rotated.TenantID)
|
||||
assert.NotEqual(t, initialKeyID, rotated.KMSKeyID)
|
||||
assert.NotEqual(t, initial.RootKeyID, rotated.RootKeyID)
|
||||
}
|
||||
|
||||
// TestRevokeTenant tests tenant key revocation
|
||||
func TestRevokeTenant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision a tenant
|
||||
hierarchy, err := km.ProvisionTenant("revoke-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke tenant
|
||||
err = km.RevokeTenant(hierarchy)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestGenerateDataEncryptionKey tests DEK generation
|
||||
func TestGenerateDataEncryptionKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision tenant first
|
||||
hierarchy, err := km.ProvisionTenant("dek-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate DEK
|
||||
wrappedDEK, err := km.GenerateDataEncryptionKey("dek-tenant", "artifact-1", hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, wrappedDEK)
|
||||
assert.Equal(t, "dek-tenant", wrappedDEK.TenantID)
|
||||
assert.Equal(t, "artifact-1", wrappedDEK.ArtifactID)
|
||||
assert.NotEmpty(t, wrappedDEK.WrappedKey)
|
||||
assert.Equal(t, "AES-256-GCM", wrappedDEK.Algorithm)
|
||||
assert.False(t, wrappedDEK.CreatedAt.IsZero())
|
||||
}
|
||||
|
||||
// TestUnwrapDataEncryptionKey tests DEK unwrapping
|
||||
func TestUnwrapDataEncryptionKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision tenant and generate DEK
|
||||
hierarchy, err := km.ProvisionTenant("unwrap-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
wrappedDEK, err := km.GenerateDataEncryptionKey("unwrap-tenant", "artifact-unwrap", hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unwrap DEK
|
||||
dek, err := km.UnwrapDataEncryptionKey(wrappedDEK, hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, dek)
|
||||
assert.Len(t, dek, 32) // AES-256 key
|
||||
}
|
||||
|
||||
// TestEncryptArtifact tests artifact encryption
|
||||
func TestEncryptArtifact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, err := km.ProvisionTenant("encrypt-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt data
|
||||
plaintext := []byte("sensitive data for encryption test")
|
||||
encrypted, err := km.EncryptArtifact("encrypt-tenant", "artifact-enc", hierarchy.KMSKeyID, plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, encrypted)
|
||||
assert.NotEmpty(t, encrypted.Ciphertext)
|
||||
assert.NotNil(t, encrypted.DEK)
|
||||
assert.Equal(t, hierarchy.KMSKeyID, encrypted.KMSKeyID)
|
||||
assert.Equal(t, "AES-256-GCM", encrypted.Algorithm)
|
||||
}
|
||||
|
||||
// TestDecryptArtifact tests artifact decryption
|
||||
func TestDecryptArtifact(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, err := km.ProvisionTenant("decrypt-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt data
|
||||
plaintext := []byte("secret message for decryption test")
|
||||
encrypted, err := km.EncryptArtifact("decrypt-tenant", "artifact-dec", hierarchy.KMSKeyID, plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt data
|
||||
decrypted, err := km.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
// TestEncryptDecryptRoundtrip tests full encryption/decryption cycle
|
||||
func TestEncryptDecryptRoundtrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision tenant
|
||||
hierarchy, err := km.ProvisionTenant("roundtrip-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test various payloads
|
||||
payloads := [][]byte{
|
||||
[]byte("short"),
|
||||
[]byte("this is a longer message with more content"),
|
||||
[]byte{},
|
||||
[]byte{0x00, 0x01, 0x02, 0x03, 0xFF},
|
||||
}
|
||||
|
||||
for i, payload := range payloads {
|
||||
artifactID := "artifact-rt-" + string(rune('a'+i))
|
||||
encrypted, err := km.EncryptArtifact("roundtrip-tenant", artifactID, hierarchy.KMSKeyID, payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := km.DecryptArtifact(encrypted, hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(payload) == 0 {
|
||||
assert.Empty(t, decrypted, "Payload %d should decrypt to empty", i)
|
||||
} else {
|
||||
assert.Equal(t, payload, decrypted, "Payload %d should decrypt correctly", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyHierarchyStructure tests KeyHierarchy fields
|
||||
func TestKeyHierarchyStructure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
hierarchy, err := km.ProvisionTenant("struct-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify structure
|
||||
assert.NotEmpty(t, hierarchy.TenantID)
|
||||
assert.NotEmpty(t, hierarchy.RootKeyID)
|
||||
assert.NotEmpty(t, hierarchy.KMSKeyID)
|
||||
assert.NotEmpty(t, hierarchy.Algorithm)
|
||||
assert.False(t, hierarchy.CreatedAt.IsZero())
|
||||
}
|
||||
|
||||
// TestWrappedDEKStructure tests WrappedDEK fields
|
||||
func TestWrappedDEKStructure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
hierarchy, err := km.ProvisionTenant("wrapped-dek-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
wrappedDEK, err := km.GenerateDataEncryptionKey("wrapped-dek-tenant", "artifact-wrapped", hierarchy.KMSKeyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, wrappedDEK.TenantID)
|
||||
assert.NotEmpty(t, wrappedDEK.ArtifactID)
|
||||
assert.NotEmpty(t, wrappedDEK.WrappedKey)
|
||||
assert.NotEmpty(t, wrappedDEK.Algorithm)
|
||||
assert.False(t, wrappedDEK.CreatedAt.IsZero())
|
||||
}
|
||||
|
||||
// TestEncryptedArtifactStructure tests EncryptedArtifact fields
|
||||
func TestEncryptedArtifactStructure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
hierarchy, err := km.ProvisionTenant("enc-art-tenant")
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("test data")
|
||||
encrypted, err := km.EncryptArtifact("enc-art-tenant", "artifact-struct", hierarchy.KMSKeyID, plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, encrypted.Ciphertext)
|
||||
assert.NotNil(t, encrypted.DEK)
|
||||
assert.NotEmpty(t, encrypted.KMSKeyID)
|
||||
assert.NotEmpty(t, encrypted.Algorithm)
|
||||
}
|
||||
|
||||
// TestMultipleTenantsIsolation tests that tenants are properly isolated
|
||||
func TestMultipleTenantsIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
km := setupTenantKeyManager(t)
|
||||
|
||||
// Provision two tenants
|
||||
tenant1, err := km.ProvisionTenant("isolation-tenant-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
tenant2, err := km.ProvisionTenant("isolation-tenant-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encrypt with tenant 1
|
||||
plaintext := []byte("cross-tenant data")
|
||||
encrypted, err := km.EncryptArtifact("isolation-tenant-1", "cross-artifact", tenant1.KMSKeyID, plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with tenant 2's key (should fail)
|
||||
_, err = km.DecryptArtifact(encrypted, tenant2.KMSKeyID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
298
internal/security/monitor_test.go
Normal file
298
internal/security/monitor_test.go
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/security"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewSlidingWindow tests sliding window creation
|
||||
func TestNewSlidingWindow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
window := security.NewSlidingWindow(5 * time.Minute)
|
||||
require.NotNil(t, window)
|
||||
assert.Equal(t, 0, window.Count())
|
||||
}
|
||||
|
||||
// TestSlidingWindowAddAndCount tests adding events and counting
|
||||
func TestSlidingWindowAddAndCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
window := security.NewSlidingWindow(1 * time.Second)
|
||||
|
||||
// Add events
|
||||
window.Add(time.Now())
|
||||
window.Add(time.Now())
|
||||
window.Add(time.Now())
|
||||
|
||||
assert.Equal(t, 3, window.Count())
|
||||
}
|
||||
|
||||
// TestSlidingWindowExpiration tests that old events expire
|
||||
func TestSlidingWindowExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
window := security.NewSlidingWindow(100 * time.Millisecond)
|
||||
|
||||
// Add event
|
||||
window.Add(time.Now())
|
||||
assert.Equal(t, 1, window.Count())
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
assert.Equal(t, 0, window.Count())
|
||||
}
|
||||
|
||||
// TestNewAnomalyMonitor tests monitor creation
|
||||
func TestNewAnomalyMonitor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var alerts []security.Alert
|
||||
handler := func(a security.Alert) {
|
||||
alerts = append(alerts, a)
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
require.NotNil(t, monitor)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
assert.Equal(t, 0, stats["privileged_container_attempts"])
|
||||
assert.Equal(t, 0, stats["path_traversal_attempts"])
|
||||
assert.Equal(t, 0, stats["command_injection_attempts"])
|
||||
}
|
||||
|
||||
// TestRecordFailedAuth tests recording failed authentication
|
||||
func TestRecordFailedAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var alertCount int
|
||||
handler := func(a security.Alert) {
|
||||
alertCount++
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
|
||||
// Record fewer than threshold attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
||||
}
|
||||
assert.Equal(t, 0, alertCount, "Should not alert below threshold")
|
||||
|
||||
// Record more attempts to reach threshold
|
||||
for i := 0; i < 10; i++ {
|
||||
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
||||
}
|
||||
assert.Greater(t, alertCount, 0, "Should alert after threshold")
|
||||
}
|
||||
|
||||
// TestRecordPrivilegedContainerAttempt tests recording privileged container attempts
|
||||
func TestRecordPrivilegedContainerAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var alertCount int
|
||||
handler := func(a security.Alert) {
|
||||
alertCount++
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
|
||||
monitor.RecordPrivilegedContainerAttempt("user-1")
|
||||
assert.GreaterOrEqual(t, alertCount, 0)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1)
|
||||
}
|
||||
|
||||
// TestRecordPathTraversal tests recording path traversal attempts
|
||||
func TestRecordPathTraversal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var lastAlert security.Alert
|
||||
handler := func(a security.Alert) {
|
||||
lastAlert = a
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
|
||||
monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd")
|
||||
|
||||
assert.Equal(t, security.AlertPathTraversal, lastAlert.Type)
|
||||
assert.Equal(t, security.SeverityHigh, lastAlert.Severity)
|
||||
assert.Equal(t, "192.168.1.1", lastAlert.SourceIP)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1)
|
||||
}
|
||||
|
||||
// TestRecordCommandInjection tests recording command injection attempts
|
||||
func TestRecordCommandInjection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var lastAlert security.Alert
|
||||
handler := func(a security.Alert) {
|
||||
lastAlert = a
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
|
||||
monitor.RecordCommandInjection("192.168.1.1", "; rm -rf /")
|
||||
|
||||
assert.Equal(t, security.AlertCommandInjection, lastAlert.Type)
|
||||
assert.Equal(t, security.SeverityCritical, lastAlert.Severity)
|
||||
assert.Equal(t, "192.168.1.1", lastAlert.SourceIP)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1)
|
||||
}
|
||||
|
||||
// TestGetStats tests statistics retrieval
|
||||
func TestGetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
monitor := security.NewAnomalyMonitor(nil)
|
||||
|
||||
// Record various events
|
||||
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
||||
monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd")
|
||||
monitor.RecordCommandInjection("192.168.1.2", "; rm -rf /")
|
||||
monitor.RecordPrivilegedContainerAttempt("user-2")
|
||||
|
||||
stats := monitor.GetStats()
|
||||
|
||||
assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1)
|
||||
assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1)
|
||||
assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1)
|
||||
assert.GreaterOrEqual(t, stats["monitored_ips"], 0)
|
||||
}
|
||||
|
||||
// TestDefaultAlertHandler tests the default alert handler doesn't panic
|
||||
func TestDefaultAlertHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
alert := security.Alert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: security.SeverityHigh,
|
||||
Type: security.AlertBruteForce,
|
||||
Message: "Test alert",
|
||||
SourceIP: "192.168.1.1",
|
||||
UserID: "user-1",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
security.DefaultAlertHandler(alert)
|
||||
}
|
||||
|
||||
// TestNewLoggingAlertHandler tests the logging alert handler
|
||||
func TestNewLoggingAlertHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var logged bool
|
||||
logFunc := func(msg string, args ...any) {
|
||||
logged = true
|
||||
}
|
||||
|
||||
handler := security.NewLoggingAlertHandler(logFunc)
|
||||
require.NotNil(t, handler)
|
||||
|
||||
alert := security.Alert{
|
||||
Timestamp: time.Now(),
|
||||
Severity: security.SeverityMedium,
|
||||
Type: security.AlertRateLimitExceeded,
|
||||
Message: "Rate limit exceeded",
|
||||
SourceIP: "192.168.1.1",
|
||||
UserID: "user-1",
|
||||
}
|
||||
|
||||
handler(alert)
|
||||
assert.True(t, logged)
|
||||
}
|
||||
|
||||
// TestAlertTypes tests alert type constants
|
||||
func TestAlertTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, security.AlertSeverity("low"), security.SeverityLow)
|
||||
assert.Equal(t, security.AlertSeverity("medium"), security.SeverityMedium)
|
||||
assert.Equal(t, security.AlertSeverity("high"), security.SeverityHigh)
|
||||
assert.Equal(t, security.AlertSeverity("critical"), security.SeverityCritical)
|
||||
|
||||
assert.Equal(t, security.AlertType("brute_force"), security.AlertBruteForce)
|
||||
assert.Equal(t, security.AlertType("privilege_escalation"), security.AlertPrivilegeEscalation)
|
||||
assert.Equal(t, security.AlertType("path_traversal"), security.AlertPathTraversal)
|
||||
assert.Equal(t, security.AlertType("command_injection"), security.AlertCommandInjection)
|
||||
assert.Equal(t, security.AlertType("suspicious_container"), security.AlertSuspiciousContainer)
|
||||
assert.Equal(t, security.AlertType("rate_limit_exceeded"), security.AlertRateLimitExceeded)
|
||||
}
|
||||
|
||||
// TestAlertStructure tests alert struct fields
|
||||
func TestAlertStructure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
metadata := map[string]any{"key": "value"}
|
||||
|
||||
alert := security.Alert{
|
||||
Timestamp: now,
|
||||
Severity: security.SeverityCritical,
|
||||
Type: security.AlertCommandInjection,
|
||||
Message: "Test message",
|
||||
SourceIP: "192.168.1.1",
|
||||
UserID: "user-123",
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
assert.Equal(t, now, alert.Timestamp)
|
||||
assert.Equal(t, security.SeverityCritical, alert.Severity)
|
||||
assert.Equal(t, security.AlertCommandInjection, alert.Type)
|
||||
assert.Equal(t, "Test message", alert.Message)
|
||||
assert.Equal(t, "192.168.1.1", alert.SourceIP)
|
||||
assert.Equal(t, "user-123", alert.UserID)
|
||||
assert.Equal(t, metadata, alert.Metadata)
|
||||
}
|
||||
|
||||
// TestConcurrentRecordFailedAuth tests concurrent auth recording
|
||||
func TestConcurrentRecordFailedAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var alertCount int
|
||||
handler := func(a security.Alert) {
|
||||
alertCount++
|
||||
}
|
||||
|
||||
monitor := security.NewAnomalyMonitor(handler)
|
||||
|
||||
// Concurrent failed auth attempts
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Should have recorded all attempts
|
||||
stats := monitor.GetStats()
|
||||
assert.GreaterOrEqual(t, stats["monitored_ips"], 0)
|
||||
}
|
||||
|
||||
// TestMultipleIPs tests monitoring multiple IPs independently
|
||||
func TestMultipleIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
monitor := security.NewAnomalyMonitor(nil)
|
||||
|
||||
// Record from different IPs
|
||||
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
||||
monitor.RecordFailedAuth("192.168.1.2", "user-2")
|
||||
monitor.RecordFailedAuth("192.168.1.3", "user-3")
|
||||
|
||||
stats := monitor.GetStats()
|
||||
assert.GreaterOrEqual(t, stats["monitored_ips"], 3)
|
||||
}
|
||||
Loading…
Reference in a new issue