// Package tenant provides middleware for cross-tenant access prevention. package tenant import ( "context" "fmt" "net/http" "strings" "time" "github.com/jfraeys/fetch_ml/internal/logging" ) // Context key for storing tenant ID type contextKey string const ( // ContextTenantID is the key for tenant ID in context ContextTenantID contextKey = "tenant_id" // ContextUserID is the key for user ID in context ContextUserID contextKey = "user_id" ) // Middleware provides HTTP middleware for tenant isolation type Middleware struct { tenantManager *Manager logger *logging.Logger } // NewMiddleware creates a new tenant middleware func NewMiddleware(tm *Manager, logger *logging.Logger) *Middleware { return &Middleware{ tenantManager: tm, logger: logger, } } // ExtractTenantID extracts tenant ID from request headers or context func ExtractTenantID(r *http.Request) string { // Check header first tenantID := r.Header.Get("X-Tenant-ID") if tenantID != "" { return tenantID } // Check query parameter tenantID = r.URL.Query().Get("tenant_id") if tenantID != "" { return tenantID } // Check context (set by upstream middleware) if ctxTenantID := r.Context().Value(ContextTenantID); ctxTenantID != nil { if id, ok := ctxTenantID.(string); ok { return id } } return "" } // Handler wraps an HTTP handler with tenant validation func (m *Middleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tenantID := ExtractTenantID(r) if tenantID == "" { m.logger.Warn("request without tenant ID", "path", r.URL.Path, "remote_addr", r.RemoteAddr, ) http.Error(w, "Tenant ID required", http.StatusBadRequest) return } // Validate tenant exists and is active tenant, err := m.tenantManager.GetTenant(tenantID) if err != nil { m.logger.Warn("invalid tenant ID", "tenant_id", tenantID, "path", r.URL.Path, "error", err, ) http.Error(w, "Invalid tenant", http.StatusForbidden) return } // Add tenant to context ctx := context.WithValue(r.Context(), ContextTenantID, tenantID) ctx = context.WithValue(ctx, ContextUserID, r.Header.Get("X-User-ID")) // Log access m.logger.Debug("tenant request", "tenant_id", tenantID, "tenant_name", tenant.Name, "path", r.URL.Path, "method", r.Method, ) // Audit log if err := m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditResourceAccess, TenantID: tenantID, Timestamp: time.Now().UTC(), Success: true, Details: map[string]any{ "path": r.URL.Path, "method": r.Method, }, IPAddress: extractIP(r.RemoteAddr), }); err != nil { m.logger.Warn("failed to log audit event", "error", err) } next.ServeHTTP(w, r.WithContext(ctx)) }) } // ResourceAccessChecker validates access to resources across tenants type ResourceAccessChecker struct { tenantManager *Manager logger *logging.Logger } // NewResourceAccessChecker creates a new resource access checker func NewResourceAccessChecker(tm *Manager, logger *logging.Logger) *ResourceAccessChecker { return &ResourceAccessChecker{ tenantManager: tm, logger: logger, } } // CheckAccess validates if a tenant can access a specific resource func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenantID string) error { requestingTenantID := GetTenantIDFromContext(ctx) if requestingTenantID == "" { return fmt.Errorf("no tenant ID in context") } // Same tenant - always allowed if requestingTenantID == resourceTenantID { return nil } // Cross-tenant access - deny by default rac.logger.Warn("cross-tenant access denied", "requesting_tenant", requestingTenantID, "resource_tenant", resourceTenantID, ) // Audit the denial userID := GetUserIDFromContext(ctx) if err := rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{ Type: AuditCrossTenantDeny, TenantID: requestingTenantID, UserID: userID, Timestamp: time.Now().UTC(), Success: false, Details: map[string]any{ "target_tenant": resourceTenantID, "reason": "cross-tenant access not permitted", }, }); err != nil { rac.logger.Warn("failed to log audit event", "error", err) } return fmt.Errorf("cross-tenant access denied: cannot access resources belonging to tenant %s", resourceTenantID) } // CheckResourceOwnership validates that a resource belongs to the requesting tenant func (rac *ResourceAccessChecker) CheckResourceOwnership(ctx context.Context, resourceID, resourceTenantID string) error { return rac.CheckAccess(ctx, resourceTenantID) } // GetTenantIDFromContext extracts tenant ID from context func GetTenantIDFromContext(ctx context.Context) string { if tenantID := ctx.Value(ContextTenantID); tenantID != nil { if id, ok := tenantID.(string); ok { return id } } return "" } // GetUserIDFromContext extracts user ID from context func GetUserIDFromContext(ctx context.Context) string { if userID := ctx.Value(ContextUserID); userID != nil { if id, ok := userID.(string); ok { return id } } return "" } // WithTenantContext creates a context with tenant ID for background operations func WithTenantContext(parent context.Context, tenantID, userID string) context.Context { ctx := context.WithValue(parent, ContextTenantID, tenantID) if userID != "" { ctx = context.WithValue(ctx, ContextUserID, userID) } return ctx } // IsolatedPath returns a tenant-isolated path for storing resources func IsolatedPath(basePath, tenantID, resourceType, resourceID string) string { return fmt.Sprintf("%s/%s/%s/%s", basePath, tenantID, resourceType, resourceID) } // ValidateResourcePath ensures a path is within the tenant's isolated workspace func ValidateResourcePath(basePath, tenantID, requestedPath string) error { expectedPrefix := fmt.Sprintf("%s/%s/", basePath, tenantID) if !strings.HasPrefix(requestedPath, expectedPrefix) { return fmt.Errorf("path %s is outside tenant %s workspace", requestedPath, tenantID) } return nil } // extractIP extracts the IP address from RemoteAddr func extractIP(remoteAddr string) string { // Handle "IP:port" format if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 { return remoteAddr[:idx] } return remoteAddr }