Scheduler enhancements: - auth.go: Group membership validation in authentication - hub.go: Task distribution with group affinity - port_allocator.go: Dynamic port allocation with conflict resolution - scheduler_conn.go: Connection pooling and retry logic - service_manager.go: Lifecycle management for scheduler services - service_templates.go: Template-based service configuration - state.go: Persistent state management with recovery Worker improvements: - config.go: Extended configuration for task visibility rules - execution/setup.go: Sandboxed execution environment setup - executor/container.go: Container runtime integration - executor/runner.go: Task runner with visibility enforcement - gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback) - integrity/validate.go: Data integrity validation - lifecycle/runloop.go: Improved runloop with graceful shutdown - lifecycle/service_manager.go: Service lifecycle coordination - process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups - tenant/manager.go: Multi-tenant resource isolation - tenant/middleware.go: Tenant context propagation - worker.go: Core worker with group-scoped task execution
225 lines
6.2 KiB
Go
225 lines
6.2 KiB
Go
// Package tenant provides middleware for cross-tenant access prevention.
|
|
package tenant
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
// Context key for storing tenant ID
|
|
type contextKey string
|
|
|
|
const (
|
|
// ContextTenantID is the key for tenant ID in context
|
|
ContextTenantID contextKey = "tenant_id"
|
|
// ContextUserID is the key for user ID in context
|
|
ContextUserID contextKey = "user_id"
|
|
)
|
|
|
|
// Middleware provides HTTP middleware for tenant isolation
|
|
type Middleware struct {
|
|
tenantManager *Manager
|
|
logger *logging.Logger
|
|
}
|
|
|
|
// NewMiddleware creates a new tenant middleware
|
|
func NewMiddleware(tm *Manager, logger *logging.Logger) *Middleware {
|
|
return &Middleware{
|
|
tenantManager: tm,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// ExtractTenantID extracts tenant ID from request headers or context
|
|
func ExtractTenantID(r *http.Request) string {
|
|
// Check header first
|
|
tenantID := r.Header.Get("X-Tenant-ID")
|
|
if tenantID != "" {
|
|
return tenantID
|
|
}
|
|
|
|
// Check query parameter
|
|
tenantID = r.URL.Query().Get("tenant_id")
|
|
if tenantID != "" {
|
|
return tenantID
|
|
}
|
|
|
|
// Check context (set by upstream middleware)
|
|
if ctxTenantID := r.Context().Value(ContextTenantID); ctxTenantID != nil {
|
|
if id, ok := ctxTenantID.(string); ok {
|
|
return id
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// Handler wraps an HTTP handler with tenant validation
|
|
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tenantID := ExtractTenantID(r)
|
|
|
|
if tenantID == "" {
|
|
m.logger.Warn("request without tenant ID",
|
|
"path", r.URL.Path,
|
|
"remote_addr", r.RemoteAddr,
|
|
)
|
|
http.Error(w, "Tenant ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate tenant exists and is active
|
|
tenant, err := m.tenantManager.GetTenant(tenantID)
|
|
if err != nil {
|
|
m.logger.Warn("invalid tenant ID",
|
|
"tenant_id", tenantID,
|
|
"path", r.URL.Path,
|
|
"error", err,
|
|
)
|
|
http.Error(w, "Invalid tenant", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Add tenant to context
|
|
ctx := context.WithValue(r.Context(), ContextTenantID, tenantID)
|
|
ctx = context.WithValue(ctx, ContextUserID, r.Header.Get("X-User-ID"))
|
|
|
|
// Log access
|
|
m.logger.Debug("tenant request",
|
|
"tenant_id", tenantID,
|
|
"tenant_name", tenant.Name,
|
|
"path", r.URL.Path,
|
|
"method", r.Method,
|
|
)
|
|
|
|
// Audit log
|
|
if err := m.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
|
|
Type: AuditResourceAccess,
|
|
TenantID: tenantID,
|
|
Timestamp: time.Now().UTC(),
|
|
Success: true,
|
|
Details: map[string]any{
|
|
"path": r.URL.Path,
|
|
"method": r.Method,
|
|
},
|
|
IPAddress: extractIP(r.RemoteAddr),
|
|
}); err != nil {
|
|
m.logger.Warn("failed to log audit event", "error", err)
|
|
}
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// ResourceAccessChecker validates access to resources across tenants
|
|
type ResourceAccessChecker struct {
|
|
tenantManager *Manager
|
|
logger *logging.Logger
|
|
}
|
|
|
|
// NewResourceAccessChecker creates a new resource access checker
|
|
func NewResourceAccessChecker(tm *Manager, logger *logging.Logger) *ResourceAccessChecker {
|
|
return &ResourceAccessChecker{
|
|
tenantManager: tm,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CheckAccess validates if a tenant can access a specific resource
|
|
func (rac *ResourceAccessChecker) CheckAccess(ctx context.Context, resourceTenantID string) error {
|
|
requestingTenantID := GetTenantIDFromContext(ctx)
|
|
if requestingTenantID == "" {
|
|
return fmt.Errorf("no tenant ID in context")
|
|
}
|
|
|
|
// Same tenant - always allowed
|
|
if requestingTenantID == resourceTenantID {
|
|
return nil
|
|
}
|
|
|
|
// Cross-tenant access - deny by default
|
|
rac.logger.Warn("cross-tenant access denied",
|
|
"requesting_tenant", requestingTenantID,
|
|
"resource_tenant", resourceTenantID,
|
|
)
|
|
|
|
// Audit the denial
|
|
userID := GetUserIDFromContext(ctx)
|
|
if err := rac.tenantManager.auditLog.LogEvent(ctx, AuditEvent{
|
|
Type: AuditCrossTenantDeny,
|
|
TenantID: requestingTenantID,
|
|
UserID: userID,
|
|
Timestamp: time.Now().UTC(),
|
|
Success: false,
|
|
Details: map[string]any{
|
|
"target_tenant": resourceTenantID,
|
|
"reason": "cross-tenant access not permitted",
|
|
},
|
|
}); err != nil {
|
|
rac.logger.Warn("failed to log audit event", "error", err)
|
|
}
|
|
|
|
return fmt.Errorf("cross-tenant access denied: cannot access resources belonging to tenant %s", resourceTenantID)
|
|
}
|
|
|
|
// CheckResourceOwnership validates that a resource belongs to the requesting tenant
|
|
func (rac *ResourceAccessChecker) CheckResourceOwnership(ctx context.Context, resourceID, resourceTenantID string) error {
|
|
return rac.CheckAccess(ctx, resourceTenantID)
|
|
}
|
|
|
|
// GetTenantIDFromContext extracts tenant ID from context
|
|
func GetTenantIDFromContext(ctx context.Context) string {
|
|
if tenantID := ctx.Value(ContextTenantID); tenantID != nil {
|
|
if id, ok := tenantID.(string); ok {
|
|
return id
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// GetUserIDFromContext extracts user ID from context
|
|
func GetUserIDFromContext(ctx context.Context) string {
|
|
if userID := ctx.Value(ContextUserID); userID != nil {
|
|
if id, ok := userID.(string); ok {
|
|
return id
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// WithTenantContext creates a context with tenant ID for background operations
|
|
func WithTenantContext(parent context.Context, tenantID, userID string) context.Context {
|
|
ctx := context.WithValue(parent, ContextTenantID, tenantID)
|
|
if userID != "" {
|
|
ctx = context.WithValue(ctx, ContextUserID, userID)
|
|
}
|
|
return ctx
|
|
}
|
|
|
|
// IsolatedPath returns a tenant-isolated path for storing resources
|
|
func IsolatedPath(basePath, tenantID, resourceType, resourceID string) string {
|
|
return fmt.Sprintf("%s/%s/%s/%s", basePath, tenantID, resourceType, resourceID)
|
|
}
|
|
|
|
// ValidateResourcePath ensures a path is within the tenant's isolated workspace
|
|
func ValidateResourcePath(basePath, tenantID, requestedPath string) error {
|
|
expectedPrefix := fmt.Sprintf("%s/%s/", basePath, tenantID)
|
|
if !strings.HasPrefix(requestedPath, expectedPrefix) {
|
|
return fmt.Errorf("path %s is outside tenant %s workspace", requestedPath, tenantID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// extractIP extracts the IP address from RemoteAddr
|
|
func extractIP(remoteAddr string) string {
|
|
// Handle "IP:port" format
|
|
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
|
|
return remoteAddr[:idx]
|
|
}
|
|
return remoteAddr
|
|
}
|