// Package middleware provides request/response validation using OpenAPI spec package middleware import ( "bytes" "encoding/json" "io" "net/http" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" "github.com/getkin/kin-openapi/routers" "github.com/getkin/kin-openapi/routers/gorillamux" ) // ValidationMiddleware validates HTTP requests against OpenAPI spec type ValidationMiddleware struct { router routers.Router } // NewValidationMiddleware creates a new validation middleware from OpenAPI spec func NewValidationMiddleware(specPath string) (*ValidationMiddleware, error) { // Load OpenAPI spec loader := openapi3.NewLoader() doc, err := loader.LoadFromFile(specPath) if err != nil { return nil, err } // Validate the spec itself if err := doc.Validate(loader.Context); err != nil { return nil, err } // Create router for path matching router, err := gorillamux.NewRouter(doc) if err != nil { return nil, err } return &ValidationMiddleware{router: router}, nil } // ValidateRequest validates an incoming HTTP request func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip validation for health endpoints if r.URL.Path == "/health/ok" || r.URL.Path == "/health" || r.URL.Path == "/metrics" { next.ServeHTTP(w, r) return } // Find the route route, pathParams, err := v.router.FindRoute(r) if err != nil { // Route not in spec - allow through (might be unregistered endpoint) next.ServeHTTP(w, r) return } // Read and restore body for validation var bodyBytes []byte if r.Body != nil && r.Body != http.NoBody { bodyBytes, _ = io.ReadAll(r.Body) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } // Validate request - body is read from r.Body automatically requestValidationInput := &openapi3filter.RequestValidationInput{ Request: r, PathParams: pathParams, Route: route, } if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]any{ "error": "validation failed", "message": err.Error(), }) return } next.ServeHTTP(w, r) }) }