diff --git a/internal/security/monitor_test.go b/internal/security/monitor_test.go index 18e6165..00a4b0f 100644 --- a/internal/security/monitor_test.go +++ b/internal/security/monitor_test.go @@ -2,6 +2,7 @@ package security_test import ( "sync" + "sync/atomic" "testing" "time" @@ -70,9 +71,9 @@ func TestNewAnomalyMonitor(t *testing.T) { func TestRecordFailedAuth(t *testing.T) { t.Parallel() - var alertCount int + var alertCount atomic.Int32 handler := func(a security.Alert) { - alertCount++ + alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler) @@ -81,28 +82,28 @@ func TestRecordFailedAuth(t *testing.T) { for i := 0; i < 5; i++ { monitor.RecordFailedAuth("192.168.1.1", "user-1") } - assert.Equal(t, 0, alertCount, "Should not alert below threshold") + assert.Equal(t, int32(0), alertCount.Load(), "Should not alert below threshold") // Record more attempts to reach threshold for i := 0; i < 10; i++ { monitor.RecordFailedAuth("192.168.1.1", "user-1") } - assert.Greater(t, alertCount, 0, "Should alert after threshold") + assert.Greater(t, alertCount.Load(), int32(0), "Should alert after threshold") } // TestRecordPrivilegedContainerAttempt tests recording privileged container attempts func TestRecordPrivilegedContainerAttempt(t *testing.T) { t.Parallel() - var alertCount int + var alertCount atomic.Int32 handler := func(a security.Alert) { - alertCount++ + alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordPrivilegedContainerAttempt("user-1") - assert.GreaterOrEqual(t, alertCount, 0) + assert.GreaterOrEqual(t, alertCount.Load(), int32(0)) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["privileged_container_attempts"], 1) @@ -112,18 +113,19 @@ func TestRecordPrivilegedContainerAttempt(t *testing.T) { func TestRecordPathTraversal(t *testing.T) { t.Parallel() - var lastAlert security.Alert + var lastAlert atomic.Value handler := func(a security.Alert) { - lastAlert = a + lastAlert.Store(a) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordPathTraversal("192.168.1.1", "../../../etc/passwd") - assert.Equal(t, security.AlertPathTraversal, lastAlert.Type) - assert.Equal(t, security.SeverityHigh, lastAlert.Severity) - assert.Equal(t, "192.168.1.1", lastAlert.SourceIP) + alert := lastAlert.Load().(security.Alert) + assert.Equal(t, security.AlertPathTraversal, alert.Type) + assert.Equal(t, security.SeverityHigh, alert.Severity) + assert.Equal(t, "192.168.1.1", alert.SourceIP) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["path_traversal_attempts"], 1) @@ -133,18 +135,19 @@ func TestRecordPathTraversal(t *testing.T) { func TestRecordCommandInjection(t *testing.T) { t.Parallel() - var lastAlert security.Alert + var lastAlert atomic.Value handler := func(a security.Alert) { - lastAlert = a + lastAlert.Store(a) } monitor := security.NewAnomalyMonitor(handler) monitor.RecordCommandInjection("192.168.1.1", "; rm -rf /") - assert.Equal(t, security.AlertCommandInjection, lastAlert.Type) - assert.Equal(t, security.SeverityCritical, lastAlert.Severity) - assert.Equal(t, "192.168.1.1", lastAlert.SourceIP) + alert := lastAlert.Load().(security.Alert) + assert.Equal(t, security.AlertCommandInjection, alert.Type) + assert.Equal(t, security.SeverityCritical, alert.Severity) + assert.Equal(t, "192.168.1.1", alert.SourceIP) stats := monitor.GetStats() assert.GreaterOrEqual(t, stats["command_injection_attempts"], 1) @@ -259,9 +262,9 @@ func TestAlertStructure(t *testing.T) { func TestConcurrentRecordFailedAuth(t *testing.T) { t.Parallel() - var alertCount int + var alertCount atomic.Int32 handler := func(a security.Alert) { - alertCount++ + alertCount.Add(1) } monitor := security.NewAnomalyMonitor(handler)