fetch_ml/internal/worker/tenant/middleware.go
Jeremie Fraeys 0b5e99f720
refactor(scheduler,worker): improve service management and GPU detection
Scheduler enhancements:
- auth.go: Group membership validation in authentication
- hub.go: Task distribution with group affinity
- port_allocator.go: Dynamic port allocation with conflict resolution
- scheduler_conn.go: Connection pooling and retry logic
- service_manager.go: Lifecycle management for scheduler services
- service_templates.go: Template-based service configuration
- state.go: Persistent state management with recovery

Worker improvements:
- config.go: Extended configuration for task visibility rules
- execution/setup.go: Sandboxed execution environment setup
- executor/container.go: Container runtime integration
- executor/runner.go: Task runner with visibility enforcement
- gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback)
- integrity/validate.go: Data integrity validation
- lifecycle/runloop.go: Improved runloop with graceful shutdown
- lifecycle/service_manager.go: Service lifecycle coordination
- process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups
- tenant/manager.go: Multi-tenant resource isolation
- tenant/middleware.go: Tenant context propagation
- worker.go: Core worker with group-scoped task execution
2026-03-08 13:03:15 -04:00

225 lines
6.2 KiB
Go

// Package tenant provides middleware for cross-tenant access prevention.
package tenant
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// Context key for storing tenant ID
type contextKey string
const (
// ContextTenantID is the key for tenant ID in context
ContextTenantID contextKey = "tenant_id"
// ContextUserID is the key for user ID in context
ContextUserID contextKey = "user_id"
)
// Middleware provides HTTP middleware for tenant isolation
type Middleware struct {
tenantManager *Manager
logger *logging.Logger
}
// NewMiddleware creates a new tenant middleware
func NewMiddleware(tm *Manager, logger *logging.Logger) *Middleware {
return &Middleware{
tenantManager: tm,
logger: logger,
}
}
// ExtractTenantID extracts tenant ID from request headers or context
func ExtractTenantID(r *http.Request) string {
// Check header first
tenantID := r.Header.Get("X-Tenant-ID")
if tenantID != "" {
return tenantID
}
// Check query parameter
tenantID = r.URL.Query().Get("tenant_id")
if tenantID != "" {
return tenantID
}
// Check context (set by upstream middleware)
if ctxTenantID := r.Context().Value(ContextTenantID); ctxTenantID != nil {
if id, ok := ctxTenantID.(string); ok {
return id
}
}
return ""
}
// Handler wraps an HTTP handler with tenant validation
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenantID := ExtractTenantID(r)
if tenantID == "" {
m.logger.Warn("request without tenant ID",
"path", r.URL.Path,
"remote_addr", r.RemoteAddr,
)
http.Error(w, "Tenant ID required", http.StatusBadRequest)
return
}
// Validate tenant exists and is active
tenant, err := m.tenantManager.GetTenant(tenantID)
if err != nil {
m.logger.Warn("invalid tenant ID",
"tenant_id", tenantID,
"path", r.URL.Path,
"error", err,
)
http.Error(w, "Invalid tenant", http.StatusForbidden)
return
}
// Add tenant to context
ctx := context.WithValue(r.Context(), ContextTenantID, tenantID)
ctx = context.WithValue(ctx, ContextUserID, r.Header.Get("X-User-ID"))
// Log access
m.logger.Debug("tenant request",
"tenant_id", tenantID,
"tenant_name", tenant.Name,
"path", r.URL.Path,
"method", r.Method,
)
// Audit log
if err := m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditResourceAccess,
TenantID: tenantID,
Timestamp: time.Now().UTC(),
Success: true,
Details: map[string]any{
"path": r.URL.Path,
"method": r.Method,
},
IPAddress: extractIP(r.RemoteAddr),
}); err != nil {
m.logger.Warn("failed to log audit event", "error", err)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ResourceAccessChecker validates access to resources across tenants
type ResourceAccessChecker struct {
tenantManager *Manager
logger *logging.Logger
}
// NewResourceAccessChecker creates a new resource access checker
func NewResourceAccessChecker(tm *Manager, logger *logging.Logger) *ResourceAccessChecker {
return &ResourceAccessChecker{
tenantManager: tm,
logger: logger,
}
}
// CheckAccess validates if a tenant can access a specific resource
func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenantID string) error {
requestingTenantID := GetTenantIDFromContext(ctx)
if requestingTenantID == "" {
return fmt.Errorf("no tenant ID in context")
}
// Same tenant - always allowed
if requestingTenantID == resourceTenantID {
return nil
}
// Cross-tenant access - deny by default
rac.logger.Warn("cross-tenant access denied",
"requesting_tenant", requestingTenantID,
"resource_tenant", resourceTenantID,
)
// Audit the denial
userID := GetUserIDFromContext(ctx)
if err := rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
Type: AuditCrossTenantDeny,
TenantID: requestingTenantID,
UserID: userID,
Timestamp: time.Now().UTC(),
Success: false,
Details: map[string]any{
"target_tenant": resourceTenantID,
"reason": "cross-tenant access not permitted",
},
}); err != nil {
rac.logger.Warn("failed to log audit event", "error", err)
}
return fmt.Errorf("cross-tenant access denied: cannot access resources belonging to tenant %s", resourceTenantID)
}
// CheckResourceOwnership validates that a resource belongs to the requesting tenant
func (rac *ResourceAccessChecker) CheckResourceOwnership(ctx context.Context, resourceID, resourceTenantID string) error {
return rac.CheckAccess(ctx, resourceTenantID)
}
// GetTenantIDFromContext extracts tenant ID from context
func GetTenantIDFromContext(ctx context.Context) string {
if tenantID := ctx.Value(ContextTenantID); tenantID != nil {
if id, ok := tenantID.(string); ok {
return id
}
}
return ""
}
// GetUserIDFromContext extracts user ID from context
func GetUserIDFromContext(ctx context.Context) string {
if userID := ctx.Value(ContextUserID); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// WithTenantContext creates a context with tenant ID for background operations
func WithTenantContext(parent context.Context, tenantID, userID string) context.Context {
ctx := context.WithValue(parent, ContextTenantID, tenantID)
if userID != "" {
ctx = context.WithValue(ctx, ContextUserID, userID)
}
return ctx
}
// IsolatedPath returns a tenant-isolated path for storing resources
func IsolatedPath(basePath, tenantID, resourceType, resourceID string) string {
return fmt.Sprintf("%s/%s/%s/%s", basePath, tenantID, resourceType, resourceID)
}
// ValidateResourcePath ensures a path is within the tenant's isolated workspace
func ValidateResourcePath(basePath, tenantID, requestedPath string) error {
expectedPrefix := fmt.Sprintf("%s/%s/", basePath, tenantID)
if !strings.HasPrefix(requestedPath, expectedPrefix) {
return fmt.Errorf("path %s is outside tenant %s workspace", requestedPath, tenantID)
}
return nil
}
// extractIP extracts the IP address from RemoteAddr
func extractIP(remoteAddr string) string {
// Handle "IP:port" format
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
return remoteAddr[:idx]
}
return remoteAddr
}