fetch_ml/internal/security/monitor.go
Jeremie Fraeys 4cdb68907e
refactor(utilities): update supporting modules for scheduler integration
Update utility modules:
- File utilities with secure file operations
- Environment pool with resource tracking
- Error types with scheduler error categories
- Logging with audit context support
- Network/SSH with connection pooling
- Privacy/PII handling with tenant boundaries
- Resource manager with scheduler allocation
- Security monitor with audit integration
- Tracking plugins (MLflow, TensorBoard) with auth
- Crypto signing with tenant keys
- Database init with multi-user support
2026-02-26 12:07:15 -05:00

301 lines
8.2 KiB
Go

// Package security provides security monitoring and anomaly detection
package security
import (
"fmt"
"sync"
"time"
)
// AlertSeverity represents the severity of a security alert
type AlertSeverity string
const (
SeverityLow AlertSeverity = "low"
SeverityMedium AlertSeverity = "medium"
SeverityHigh AlertSeverity = "high"
SeverityCritical AlertSeverity = "critical"
)
// AlertType represents the type of security alert
type AlertType string
const (
AlertBruteForce AlertType = "brute_force"
AlertPrivilegeEscalation AlertType = "privilege_escalation"
AlertPathTraversal AlertType = "path_traversal"
AlertCommandInjection AlertType = "command_injection"
AlertSuspiciousContainer AlertType = "suspicious_container"
AlertRateLimitExceeded AlertType = "rate_limit_exceeded"
)
// Alert represents a security alert
type Alert struct {
Timestamp time.Time `json:"timestamp"`
Metadata map[string]any `json:"metadata,omitempty"`
Severity AlertSeverity `json:"severity"`
Type AlertType `json:"type"`
Message string `json:"message"`
SourceIP string `json:"source_ip,omitempty"`
UserID string `json:"user_id,omitempty"`
}
// AlertHandler is called when a security alert is generated
type AlertHandler func(alert Alert)
// SlidingWindow tracks events in a time window
type SlidingWindow struct {
events []time.Time
window time.Duration
mu sync.RWMutex
}
// NewSlidingWindow creates a new sliding window
func NewSlidingWindow(window time.Duration) *SlidingWindow {
return &SlidingWindow{
events: make([]time.Time, 0),
window: window,
}
}
// Add adds an event to the window
func (w *SlidingWindow) Add(t time.Time) {
w.mu.Lock()
defer w.mu.Unlock()
// Remove old events outside the window
cutoff := t.Add(-w.window)
newEvents := make([]time.Time, 0, len(w.events)+1)
for _, e := range w.events {
if e.After(cutoff) {
newEvents = append(newEvents, e)
}
}
newEvents = append(newEvents, t)
w.events = newEvents
}
// Count returns the number of events in the window
func (w *SlidingWindow) Count() int {
w.mu.RLock()
defer w.mu.RUnlock()
// Clean up old events
cutoff := time.Now().Add(-w.window)
count := 0
for _, e := range w.events {
if e.After(cutoff) {
count++
}
}
return count
}
// AnomalyMonitor tracks security-relevant events and generates alerts
type AnomalyMonitor struct {
lastPrivilegedAlert time.Time
failedAuthByIP map[string]*SlidingWindow
alertHandler AlertHandler
privilegedContainerAttempts int
pathTraversalAttempts int
commandInjectionAttempts int
bruteForceThreshold int
bruteForceWindow time.Duration
privilegedAlertInterval time.Duration
mu sync.RWMutex
}
// NewAnomalyMonitor creates a new security anomaly monitor
func NewAnomalyMonitor(alertHandler AlertHandler) *AnomalyMonitor {
return &AnomalyMonitor{
failedAuthByIP: make(map[string]*SlidingWindow),
alertHandler: alertHandler,
bruteForceThreshold: 10, // 10 failed attempts
bruteForceWindow: 5 * time.Minute, // in 5 minutes
privilegedAlertInterval: 1 * time.Minute, // max 1 alert per minute
}
}
// RecordFailedAuth records a failed authentication attempt
func (m *AnomalyMonitor) RecordFailedAuth(ip, userID string) {
m.mu.Lock()
window, exists := m.failedAuthByIP[ip]
if !exists {
window = NewSlidingWindow(m.bruteForceWindow)
m.failedAuthByIP[ip] = window
}
m.mu.Unlock()
window.Add(time.Now())
count := window.Count()
if count >= m.bruteForceThreshold {
m.alert(Alert{
Severity: SeverityHigh,
Type: AlertBruteForce,
Message: fmt.Sprintf("%d+ failed auth attempts from %s", m.bruteForceThreshold, ip),
Timestamp: time.Now(),
SourceIP: ip,
UserID: userID,
Metadata: map[string]any{
"count": count,
"threshold": m.bruteForceThreshold,
"window_seconds": m.bruteForceWindow.Seconds(),
},
})
}
}
// RecordPrivilegedContainerAttempt records a blocked privileged container request
func (m *AnomalyMonitor) RecordPrivilegedContainerAttempt(userID string) {
m.mu.Lock()
m.privilegedContainerAttempts++
now := time.Now()
shouldAlert := now.Sub(m.lastPrivilegedAlert) > m.privilegedAlertInterval
if shouldAlert {
m.lastPrivilegedAlert = now
}
m.mu.Unlock()
if shouldAlert {
m.alert(Alert{
Severity: SeverityCritical,
Type: AlertPrivilegeEscalation,
Message: "Attempted to create privileged container",
Timestamp: time.Now(),
UserID: userID,
Metadata: map[string]any{
"total_attempts": m.privilegedContainerAttempts,
},
})
}
}
// RecordPathTraversal records a path traversal attempt
func (m *AnomalyMonitor) RecordPathTraversal(ip, path string) {
m.mu.Lock()
m.pathTraversalAttempts++
m.mu.Unlock()
m.alert(Alert{
Severity: SeverityHigh,
Type: AlertPathTraversal,
Message: "Path traversal attempt detected",
Timestamp: time.Now(),
SourceIP: ip,
Metadata: map[string]any{
"path": path,
"total_attempts": m.pathTraversalAttempts,
},
})
}
// RecordCommandInjection records a command injection attempt
func (m *AnomalyMonitor) RecordCommandInjection(ip, input string) {
m.mu.Lock()
m.commandInjectionAttempts++
m.mu.Unlock()
m.alert(Alert{
Severity: SeverityCritical,
Type: AlertCommandInjection,
Message: "Command injection attempt detected",
Timestamp: time.Now(),
SourceIP: ip,
Metadata: map[string]any{
"input": input,
"total_attempts": m.commandInjectionAttempts,
},
})
}
// GetStats returns current monitoring statistics
func (m *AnomalyMonitor) GetStats() map[string]int {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]int{
"privileged_container_attempts": m.privilegedContainerAttempts,
"path_traversal_attempts": m.pathTraversalAttempts,
"command_injection_attempts": m.commandInjectionAttempts,
"monitored_ips": len(m.failedAuthByIP),
}
}
// alert sends an alert through the handler
func (m *AnomalyMonitor) alert(alert Alert) {
if m.alertHandler != nil {
m.alertHandler(alert)
}
}
// DefaultAlertHandler logs alerts to stderr
func DefaultAlertHandler(alert Alert) {
fmt.Printf("[SECURITY ALERT] %s | %s | %s | %s\n",
alert.Timestamp.Format(time.RFC3339),
alert.Severity,
alert.Type,
alert.Message,
)
}
// LoggingAlertHandler creates an alert handler that logs via a structured logger
type LoggingAlertHandler struct {
logFunc func(string, ...any)
}
// NewLoggingAlertHandler creates a new logging alert handler
func NewLoggingAlertHandler(logFunc func(string, ...any)) AlertHandler {
return func(alert Alert) {
logFunc("security_alert",
"severity", alert.Severity,
"type", alert.Type,
"message", alert.Message,
"source_ip", alert.SourceIP,
"user_id", alert.UserID,
)
}
}
// Integration Example:
//
// To integrate the anomaly monitor with your application:
//
// 1. Create a monitor with a logging handler:
// monitor := security.NewAnomalyMonitor(security.DefaultAlertHandler)
//
// 2. Wire into authentication middleware:
// func authMiddleware(next http.Handler) http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// key := r.Header.Get("X-API-Key")
// user, err := validateAPIKey(key)
// if err != nil {
// monitor.RecordFailedAuth(r.RemoteAddr, "")
// http.Error(w, "Unauthorized", 401)
// return
// }
// next.ServeHTTP(w, r)
// })
// }
//
// 3. Wire into container creation:
// func createContainer(config ContainerConfig) error {
// if config.Privileged {
// monitor.RecordPrivilegedContainerAttempt(userID)
// return fmt.Errorf("privileged containers not allowed")
// }
// // ... create container
// }
//
// 4. Wire into input validation:
// func validateJobName(name string) error {
// if strings.Contains(name, "..") {
// monitor.RecordPathTraversal(ip, name)
// return fmt.Errorf("invalid job name")
// }
// // ... continue validation
// }
//
// 5. Periodically check stats:
// stats := monitor.GetStats()
// log.Printf("Security stats: %+v", stats)