Add comprehensive tests for: - crypto/tenant_keys: KMS integration, key rotation, encryption/decryption - security/monitor: sliding window, anomaly detection, concurrent access Coverage: crypto 65.1%, security 100%
298 lines
8 KiB
Go
298 lines
8 KiB
Go
package security_test
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/security"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestNewSlidingWindow tests sliding window creation
|
|
func TestNewSlidingWindow(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
window := security.NewSlidingWindow(5 * time.Minute)
|
|
require.NotNil(t, window)
|
|
assert.Equal(t, 0, window.Count())
|
|
}
|
|
|
|
// TestSlidingWindowAddAndCount tests adding events and counting
|
|
func TestSlidingWindowAddAndCount(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
window := security.NewSlidingWindow(1 * time.Second)
|
|
|
|
// Add events
|
|
window.Add(time.Now())
|
|
window.Add(time.Now())
|
|
window.Add(time.Now())
|
|
|
|
assert.Equal(t, 3, window.Count())
|
|
}
|
|
|
|
// TestSlidingWindowExpiration tests that old events expire
|
|
func TestSlidingWindowExpiration(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
window := security.NewSlidingWindow(100 * time.Millisecond)
|
|
|
|
// Add event
|
|
window.Add(time.Now())
|
|
assert.Equal(t, 1, window.Count())
|
|
|
|
// Wait for expiration
|
|
time.Sleep(150 * time.Millisecond)
|
|
assert.Equal(t, 0, window.Count())
|
|
}
|
|
|
|
// TestNewAnomalyMonitor tests monitor creation
|
|
func TestNewAnomalyMonitor(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var alerts []security.Alert
|
|
handler := func(a security.Alert) {
|
|
alerts = append(alerts, a)
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
require.NotNil(t, monitor)
|
|
|
|
stats := monitor.GetStats()
|
|
assert.Equal(t, 0, stats["privileged_container_attempts"])
|
|
assert.Equal(t, 0, stats["path_traversal_attempts"])
|
|
assert.Equal(t, 0, stats["command_injection_attempts"])
|
|
}
|
|
|
|
// TestRecordFailedAuth tests recording failed authentication
|
|
func TestRecordFailedAuth(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var alertCount int
|
|
handler := func(a security.Alert) {
|
|
alertCount++
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
|
|
// Record fewer than threshold attempts
|
|
for i := 0; i < 5; i++ {
|
|
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
|
}
|
|
assert.Equal(t, 0, alertCount, "Should not alert below threshold")
|
|
|
|
// Record more attempts to reach threshold
|
|
for i := 0; i < 10; i++ {
|
|
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
|
}
|
|
assert.Greater(t, alertCount, 0, "Should alert after threshold")
|
|
}
|
|
|
|
// TestRecordPrivilegedContainerAttempt tests recording privileged container attempts
|
|
func TestRecordPrivilegedContainerAttempt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var alertCount int
|
|
handler := func(a security.Alert) {
|
|
alertCount++
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
|
|
monitor.RecordPrivilegedContainerAttempt("user-1")
|
|
assert.GreaterOrEqual(t, alertCount, 0)
|
|
|
|
stats := monitor.GetStats()
|
|
assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1)
|
|
}
|
|
|
|
// TestRecordPathTraversal tests recording path traversal attempts
|
|
func TestRecordPathTraversal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var lastAlert security.Alert
|
|
handler := func(a security.Alert) {
|
|
lastAlert = a
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
|
|
monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd")
|
|
|
|
assert.Equal(t, security.AlertPathTraversal, lastAlert.Type)
|
|
assert.Equal(t, security.SeverityHigh, lastAlert.Severity)
|
|
assert.Equal(t, "192.168.1.1", lastAlert.SourceIP)
|
|
|
|
stats := monitor.GetStats()
|
|
assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1)
|
|
}
|
|
|
|
// TestRecordCommandInjection tests recording command injection attempts
|
|
func TestRecordCommandInjection(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var lastAlert security.Alert
|
|
handler := func(a security.Alert) {
|
|
lastAlert = a
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
|
|
monitor.RecordCommandInjection("192.168.1.1", "; rm -rf /")
|
|
|
|
assert.Equal(t, security.AlertCommandInjection, lastAlert.Type)
|
|
assert.Equal(t, security.SeverityCritical, lastAlert.Severity)
|
|
assert.Equal(t, "192.168.1.1", lastAlert.SourceIP)
|
|
|
|
stats := monitor.GetStats()
|
|
assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1)
|
|
}
|
|
|
|
// TestGetStats tests statistics retrieval
|
|
func TestGetStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
monitor := security.NewAnomalyMonitor(nil)
|
|
|
|
// Record various events
|
|
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
|
monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd")
|
|
monitor.RecordCommandInjection("192.168.1.2", "; rm -rf /")
|
|
monitor.RecordPrivilegedContainerAttempt("user-2")
|
|
|
|
stats := monitor.GetStats()
|
|
|
|
assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1)
|
|
assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1)
|
|
assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1)
|
|
assert.GreaterOrEqual(t, stats["monitored_ips"], 0)
|
|
}
|
|
|
|
// TestDefaultAlertHandler tests the default alert handler doesn't panic
|
|
func TestDefaultAlertHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
alert := security.Alert{
|
|
Timestamp: time.Now(),
|
|
Severity: security.SeverityHigh,
|
|
Type: security.AlertBruteForce,
|
|
Message: "Test alert",
|
|
SourceIP: "192.168.1.1",
|
|
UserID: "user-1",
|
|
}
|
|
|
|
// Should not panic
|
|
security.DefaultAlertHandler(alert)
|
|
}
|
|
|
|
// TestNewLoggingAlertHandler tests the logging alert handler
|
|
func TestNewLoggingAlertHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var logged bool
|
|
logFunc := func(msg string, args ...any) {
|
|
logged = true
|
|
}
|
|
|
|
handler := security.NewLoggingAlertHandler(logFunc)
|
|
require.NotNil(t, handler)
|
|
|
|
alert := security.Alert{
|
|
Timestamp: time.Now(),
|
|
Severity: security.SeverityMedium,
|
|
Type: security.AlertRateLimitExceeded,
|
|
Message: "Rate limit exceeded",
|
|
SourceIP: "192.168.1.1",
|
|
UserID: "user-1",
|
|
}
|
|
|
|
handler(alert)
|
|
assert.True(t, logged)
|
|
}
|
|
|
|
// TestAlertTypes tests alert type constants
|
|
func TestAlertTypes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, security.AlertSeverity("low"), security.SeverityLow)
|
|
assert.Equal(t, security.AlertSeverity("medium"), security.SeverityMedium)
|
|
assert.Equal(t, security.AlertSeverity("high"), security.SeverityHigh)
|
|
assert.Equal(t, security.AlertSeverity("critical"), security.SeverityCritical)
|
|
|
|
assert.Equal(t, security.AlertType("brute_force"), security.AlertBruteForce)
|
|
assert.Equal(t, security.AlertType("privilege_escalation"), security.AlertPrivilegeEscalation)
|
|
assert.Equal(t, security.AlertType("path_traversal"), security.AlertPathTraversal)
|
|
assert.Equal(t, security.AlertType("command_injection"), security.AlertCommandInjection)
|
|
assert.Equal(t, security.AlertType("suspicious_container"), security.AlertSuspiciousContainer)
|
|
assert.Equal(t, security.AlertType("rate_limit_exceeded"), security.AlertRateLimitExceeded)
|
|
}
|
|
|
|
// TestAlertStructure tests alert struct fields
|
|
func TestAlertStructure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := time.Now()
|
|
metadata := map[string]any{"key": "value"}
|
|
|
|
alert := security.Alert{
|
|
Timestamp: now,
|
|
Severity: security.SeverityCritical,
|
|
Type: security.AlertCommandInjection,
|
|
Message: "Test message",
|
|
SourceIP: "192.168.1.1",
|
|
UserID: "user-123",
|
|
Metadata: metadata,
|
|
}
|
|
|
|
assert.Equal(t, now, alert.Timestamp)
|
|
assert.Equal(t, security.SeverityCritical, alert.Severity)
|
|
assert.Equal(t, security.AlertCommandInjection, alert.Type)
|
|
assert.Equal(t, "Test message", alert.Message)
|
|
assert.Equal(t, "192.168.1.1", alert.SourceIP)
|
|
assert.Equal(t, "user-123", alert.UserID)
|
|
assert.Equal(t, metadata, alert.Metadata)
|
|
}
|
|
|
|
// TestConcurrentRecordFailedAuth tests concurrent auth recording
|
|
func TestConcurrentRecordFailedAuth(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var alertCount int
|
|
handler := func(a security.Alert) {
|
|
alertCount++
|
|
}
|
|
|
|
monitor := security.NewAnomalyMonitor(handler)
|
|
|
|
// Concurrent failed auth attempts
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 20; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
// Should have recorded all attempts
|
|
stats := monitor.GetStats()
|
|
assert.GreaterOrEqual(t, stats["monitored_ips"], 0)
|
|
}
|
|
|
|
// TestMultipleIPs tests monitoring multiple IPs independently
|
|
func TestMultipleIPs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
monitor := security.NewAnomalyMonitor(nil)
|
|
|
|
// Record from different IPs
|
|
monitor.RecordFailedAuth("192.168.1.1", "user-1")
|
|
monitor.RecordFailedAuth("192.168.1.2", "user-2")
|
|
monitor.RecordFailedAuth("192.168.1.3", "user-3")
|
|
|
|
stats := monitor.GetStats()
|
|
assert.GreaterOrEqual(t, stats["monitored_ips"], 3)
|
|
}
|