fetch_ml/internal/api/responses/errors.go
Jeremie Fraeys 7e5ceec069
feat(api): add groups and tokens handlers, refactor routes
Add new API endpoints and clean up handler interfaces:

- groups/handlers.go: New lab group management API
  * CRUD operations for lab groups
  * Member management with role assignment (admin/member/viewer)
  * Group listing and membership queries

- tokens/handlers.go: Token generation and validation endpoints
  * Create access tokens for public task sharing
  * Validate tokens for secure access
  * Token revocation and cleanup

- routes.go: Refactor handler registration
  * Integrate groups handler into WebSocket routes
  * Remove nil parameters from all handler constructors
  * Cleaner dependency injection pattern

- Handler interface cleanup across all modules:
  * jobs/handlers.go: Remove unused nil privacyEnforcer parameter
  * jupyter/handlers.go: Streamline initialization
  * scheduler/handlers.go: Consistent constructor signature
  * ws/handler.go: Add groups handler to dependencies
2026-03-08 12:51:25 -04:00

180 lines
5.3 KiB
Go

// Package responses provides structured API response types with security-conscious error handling.
package responses
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// ErrorResponse provides a sanitized error response to clients.
// It includes a trace ID for support lookup while preventing information leakage.
type ErrorResponse struct {
Error string `json:"error"` // Sanitized error message for clients
Code string `json:"code"` // Machine-readable error code
TraceID string `json:"trace_id"` // For support lookup (internal correlation)
}
// Error codes for machine-readable error identification
const (
ErrCodeBadRequest = "BAD_REQUEST"
ErrCodeUnauthorized = "UNAUTHORIZED"
ErrCodeForbidden = "FORBIDDEN"
ErrCodeNotFound = "NOT_FOUND"
ErrCodeConflict = "CONFLICT"
ErrCodeRateLimited = "RATE_LIMITED"
ErrCodeInternal = "INTERNAL_ERROR"
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
ErrCodeValidation = "VALIDATION_ERROR"
)
// HTTP status to error code mapping
var statusToCode = map[int]string{
http.StatusBadRequest: ErrCodeBadRequest,
http.StatusUnauthorized: ErrCodeUnauthorized,
http.StatusForbidden: ErrCodeForbidden,
http.StatusNotFound: ErrCodeNotFound,
http.StatusConflict: ErrCodeConflict,
http.StatusTooManyRequests: ErrCodeRateLimited,
http.StatusInternalServerError: ErrCodeInternal,
http.StatusServiceUnavailable: ErrCodeServiceUnavailable,
422: ErrCodeValidation, // Unprocessable Entity
}
// Patterns to sanitize from error messages (security: prevent information leakage)
var (
// Remove file paths
pathPattern = regexp.MustCompile(`/[^\s]*`)
// Remove sensitive keywords
sensitiveKeywords = []string{"password", "secret", "token", "key", "credential", "auth"}
)
// WriteError writes a sanitized error response to the client.
// It extracts the trace ID from the context, logs the full error internally,
// and returns a sanitized message to the client.
func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, logger *logging.Logger) {
traceID := logging.TraceIDFromContext(r.Context())
if traceID == "" {
traceID = generateTraceID()
}
// Log the full error internally with all details
if logger != nil {
logger.Error("request failed",
"trace_id", traceID,
"method", r.Method,
"path", r.URL.Path,
"status", status,
"error", err.Error(),
"client_ip", getClientIP(r),
)
}
// Build sanitized response
resp := ErrorResponse{
Error: sanitizeError(err.Error()),
Code: errorCodeFromStatus(status),
TraceID: traceID,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
}
// WriteErrorMessage writes a sanitized error response with a custom message.
func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, message string, logger *logging.Logger) {
traceID := logging.TraceIDFromContext(r.Context())
if traceID == "" {
traceID = generateTraceID()
}
if logger != nil {
logger.Error("request failed",
"trace_id", traceID,
"method", r.Method,
"path", r.URL.Path,
"status", status,
"error", message,
)
}
resp := ErrorResponse{
Error: sanitizeError(message),
Code: errorCodeFromStatus(status),
TraceID: traceID,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
}
// sanitizeError removes potentially sensitive information from error messages.
// It prevents information leakage to clients while preserving useful context.
func sanitizeError(msg string) string {
if msg == "" {
return "An error occurred"
}
// Remove file paths
msg = pathPattern.ReplaceAllString(msg, "[path]")
// Remove sensitive keywords and their values
lowerMsg := strings.ToLower(msg)
for _, keyword := range sensitiveKeywords {
if strings.Contains(lowerMsg, keyword) {
return "An error occurred"
}
}
// Remove internal error details
msg = strings.ReplaceAll(msg, "internal error", "an error occurred")
msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred")
// TODO: This needs improvement, why is the length static? is there a better way to do this.
// Truncate if too long
if len(msg) > 200 {
msg = msg[:200] + "..."
}
return msg
}
// errorCodeFromStatus returns the appropriate error code for an HTTP status.
func errorCodeFromStatus(status int) string {
if code, ok := statusToCode[status]; ok {
return code
}
return ErrCodeInternal
}
// getClientIP extracts the client IP from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
return r.RemoteAddr
}
// generateTraceID generates a new trace ID when one isn't in context.
func generateTraceID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}