package security_test import ( "sync" "sync/atomic" "testing" "time" "github.com/jfraeys/fetch_ml/internal/security" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestNewSlidingWindow tests sliding window creation func TestNewSlidingWindow(t *testing.T) { t.Parallel() window := security.NewSlidingWindow(5 * time.Minute) require.NotNil(t, window) assert.Equal(t, 0, window.Count()) } // TestSlidingWindowAddAndCount tests adding events and counting func TestSlidingWindowAddAndCount(t *testing.T) { t.Parallel() window := security.NewSlidingWindow(1 * time.Second) // Add events window.Add(time.Now()) window.Add(time.Now()) window.Add(time.Now()) assert.Equal(t, 3, window.Count()) } // TestSlidingWindowExpiration tests that old events expire func TestSlidingWindowExpiration(t *testing.T) { t.Parallel() window := security.NewSlidingWindow(100 * time.Millisecond) // Add event window.Add(time.Now()) assert.Equal(t, 1, window.Count()) // Wait for expiration time.Sleep(150 * time.Millisecond) assert.Equal(t, 0, window.Count()) } // TestNewAnomalyMonitor tests monitor creation func TestNewAnomalyMonitor(t *testing.T) { t.Parallel() var alerts []security.Alert handler := func(a security.Alert) { alerts = append(alerts, a) } monitor := security.NewAnomalyMonitor(handler) require.NotNil(t, monitor) stats := monitor.GetStats() assert.Equal(t, 0, stats["privileged_container_attempts"]) assert.Equal(t, 0, stats["path_traversal_attempts"]) assert.Equal(t, 0, stats["command_injection_attempts"]) } // TestRecordFailedAuth tests recording failed authentication func TestRecordFailedAuth(t *testing.T) { t.Parallel() var alertCount atomic.Int32 handler := func(a security.Alert) { alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler) // Record fewer than threshold attempts for i := 0; i < 5; i++ { monitor.RecordFailedAuth("192.168.1.1", "user-1") } assert.Equal(t, int32(0), alertCount.Load(), "Should not alert below threshold") // Record more attempts to reach threshold for i := 0; i < 10; i++ { monitor.RecordFailedAuth("192.168.1.1", "user-1") } assert.Greater(t, alertCount.Load(), int32(0), "Should alert after threshold") } // TestRecordPrivilegedContainerAttempt tests recording privileged container attempts func TestRecordPrivilegedContainerAttempt(t *testing.T) { t.Parallel() var alertCount atomic.Int32 handler := func(a security.Alert) { alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordPrivilegedContainerAttempt("user-1") assert.GreaterOrEqual(t, alertCount.Load(), int32(0)) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1) } // TestRecordPathTraversal tests recording path traversal attempts func TestRecordPathTraversal(t *testing.T) { t.Parallel() var lastAlert atomic.Value handler := func(a security.Alert) { lastAlert.Store(a) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd") alert := lastAlert.Load().(security.Alert) assert.Equal(t, security.AlertPathTraversal, alert.Type) assert.Equal(t, security.SeverityHigh, alert.Severity) assert.Equal(t, "192.168.1.1", alert.SourceIP) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1) } // TestRecordCommandInjection tests recording command injection attempts func TestRecordCommandInjection(t *testing.T) { t.Parallel() var lastAlert atomic.Value handler := func(a security.Alert) { lastAlert.Store(a) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordCommandInjection("192.168.1.1", "; rm -rf /") alert := lastAlert.Load().(security.Alert) assert.Equal(t, security.AlertCommandInjection, alert.Type) assert.Equal(t, security.SeverityCritical, alert.Severity) assert.Equal(t, "192.168.1.1", alert.SourceIP) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1) } // TestGetStats tests statistics retrieval func TestGetStats(t *testing.T) { t.Parallel() monitor := security.NewAnomalyMonitor(nil) // Record various events monitor.RecordFailedAuth("192.168.1.1", "user-1") monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd") monitor.RecordCommandInjection("192.168.1.2", "; rm -rf /") monitor.RecordPrivilegedContainerAttempt("user-2") stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1) assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1) assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1) assert.GreaterOrEqual(t, stats["monitored_ips"], 0) } // TestDefaultAlertHandler tests the default alert handler doesn't panic func TestDefaultAlertHandler(t *testing.T) { t.Parallel() alert := security.Alert{ Timestamp: time.Now(), Severity: security.SeverityHigh, Type: security.AlertBruteForce, Message: "Test alert", SourceIP: "192.168.1.1", UserID: "user-1", } // Should not panic security.DefaultAlertHandler(alert) } // TestNewLoggingAlertHandler tests the logging alert handler func TestNewLoggingAlertHandler(t *testing.T) { t.Parallel() var logged bool logFunc := func(msg string, args ...any) { logged = true } handler := security.NewLoggingAlertHandler(logFunc) require.NotNil(t, handler) alert := security.Alert{ Timestamp: time.Now(), Severity: security.SeverityMedium, Type: security.AlertRateLimitExceeded, Message: "Rate limit exceeded", SourceIP: "192.168.1.1", UserID: "user-1", } handler(alert) assert.True(t, logged) } // TestAlertTypes tests alert type constants func TestAlertTypes(t *testing.T) { t.Parallel() assert.Equal(t, security.AlertSeverity("low"), security.SeverityLow) assert.Equal(t, security.AlertSeverity("medium"), security.SeverityMedium) assert.Equal(t, security.AlertSeverity("high"), security.SeverityHigh) assert.Equal(t, security.AlertSeverity("critical"), security.SeverityCritical) assert.Equal(t, security.AlertType("brute_force"), security.AlertBruteForce) assert.Equal(t, security.AlertType("privilege_escalation"), security.AlertPrivilegeEscalation) assert.Equal(t, security.AlertType("path_traversal"), security.AlertPathTraversal) assert.Equal(t, security.AlertType("command_injection"), security.AlertCommandInjection) assert.Equal(t, security.AlertType("suspicious_container"), security.AlertSuspiciousContainer) assert.Equal(t, security.AlertType("rate_limit_exceeded"), security.AlertRateLimitExceeded) } // TestAlertStructure tests alert struct fields func TestAlertStructure(t *testing.T) { t.Parallel() now := time.Now() metadata := map[string]any{"key": "value"} alert := security.Alert{ Timestamp: now, Severity: security.SeverityCritical, Type: security.AlertCommandInjection, Message: "Test message", SourceIP: "192.168.1.1", UserID: "user-123", Metadata: metadata, } assert.Equal(t, now, alert.Timestamp) assert.Equal(t, security.SeverityCritical, alert.Severity) assert.Equal(t, security.AlertCommandInjection, alert.Type) assert.Equal(t, "Test message", alert.Message) assert.Equal(t, "192.168.1.1", alert.SourceIP) assert.Equal(t, "user-123", alert.UserID) assert.Equal(t, metadata, alert.Metadata) } // TestConcurrentRecordFailedAuth tests concurrent auth recording func TestConcurrentRecordFailedAuth(t *testing.T) { t.Parallel() var alertCount atomic.Int32 handler := func(a security.Alert) { alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler) // Concurrent failed auth attempts var wg sync.WaitGroup for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() monitor.RecordFailedAuth("192.168.1.1", "user-1") }() } wg.Wait() // Should have recorded all attempts stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["monitored_ips"], 0) } // TestMultipleIPs tests monitoring multiple IPs independently func TestMultipleIPs(t *testing.T) { t.Parallel() monitor := security.NewAnomalyMonitor(nil) // Record from different IPs monitor.RecordFailedAuth("192.168.1.1", "user-1") monitor.RecordFailedAuth("192.168.1.2", "user-2") monitor.RecordFailedAuth("192.168.1.3", "user-3") stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["monitored_ips"], 3) }