package jupyter import ( "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "fmt" "os" "path/filepath" "regexp" "strings" "time" "github.com/jfraeys/fetch_ml/internal/logging" ) // SecurityManager handles all security-related operations for Jupyter services type SecurityManager struct { logger *logging.Logger config *EnhancedSecurityConfig } // EnhancedSecurityConfig provides comprehensive security settings type EnhancedSecurityConfig struct { // Network Security AllowNetwork bool `json:"allow_network"` AllowedHosts []string `json:"allowed_hosts"` BlockedHosts []string `json:"blocked_hosts"` EnableFirewall bool `json:"enable_firewall"` // Package Security TrustedChannels []string `json:"trusted_channels"` BlockedPackages []string `json:"blocked_packages"` AllowedPackages map[string]bool `json:"allowed_packages"` RequireApproval bool `json:"require_approval"` AutoApproveSafe bool `json:"auto_approve_safe"` MaxPackages int `json:"max_packages"` InstallTimeout time.Duration `json:"install_timeout"` AllowCondaForge bool `json:"allow_conda_forge"` AllowPyPI bool `json:"allow_pypi"` AllowLocal bool `json:"allow_local"` // Container Security ReadOnlyRoot bool `json:"read_only_root"` DropCapabilities []string `json:"drop_capabilities"` RunAsNonRoot bool `json:"run_as_non_root"` EnableSeccomp bool `json:"enable_seccomp"` NoNewPrivileges bool `json:"no_new_privileges"` // Authentication Security EnableTokenAuth bool `json:"enable_token_auth"` TokenLength int `json:"token_length"` TokenExpiry time.Duration `json:"token_expiry"` RequireHTTPS bool `json:"require_https"` SessionTimeout time.Duration `json:"session_timeout"` MaxFailedAttempts int `json:"max_failed_attempts"` LockoutDuration time.Duration `json:"lockout_duration"` // File System Security AllowedPaths []string `json:"allowed_paths"` DeniedPaths []string `json:"denied_paths"` MaxWorkspaceSize string `json:"max_workspace_size"` AllowExecFrom []string `json:"allow_exec_from"` BlockExecFrom []string `json:"block_exec_from"` // Resource Security MaxMemoryLimit string `json:"max_memory_limit"` MaxCPULimit string `json:"max_cpu_limit"` MaxDiskUsage string `json:"max_disk_usage"` MaxProcesses int `json:"max_processes"` // Logging & Monitoring SecurityLogLevel string `json:"security_log_level"` AuditEnabled bool `json:"audit_enabled"` RealTimeAlerts bool `json:"real_time_alerts"` } // SecurityEvent represents a security-related event type SecurityEvent struct { Timestamp time.Time `json:"timestamp"` EventType string `json:"event_type"` Severity string `json:"severity"` // low, medium, high, critical User string `json:"user"` Action string `json:"action"` Resource string `json:"resource"` Description string `json:"description"` IPAddress string `json:"ip_address,omitempty"` UserAgent string `json:"user_agent,omitempty"` } // NewSecurityManager creates a new security manager func NewSecurityManager(logger *logging.Logger, config *EnhancedSecurityConfig) *SecurityManager { return &SecurityManager{ logger: logger, config: config, } } // ValidatePackageRequest validates a package installation request func (sm *SecurityManager) ValidatePackageRequest(req *PackageRequest) error { // Log security event defer sm.logSecurityEvent("package_validation", "medium", req.RequestedBy, fmt.Sprintf("validate_package:%s", req.PackageName), fmt.Sprintf("Package: %s, Version: %s, Channel: %s", req.PackageName, req.Version, req.Channel)) // Check if package is blocked for _, blocked := range sm.config.BlockedPackages { if strings.EqualFold(blocked, req.PackageName) { return fmt.Errorf("package '%s' is blocked by security policy", req.PackageName) } } // Check if package is explicitly allowed (if allowlist exists) if len(sm.config.AllowedPackages) > 0 { if !sm.config.AllowedPackages[req.PackageName] { return fmt.Errorf("package '%s' is not in the allowed packages list", req.PackageName) } } // Validate channel if req.Channel != "" { if !sm.isValidChannel(req.Channel) { return fmt.Errorf("channel '%s' is not trusted", req.Channel) } } // Check package name format if !sm.isValidPackageName(req.PackageName) { return fmt.Errorf("package name '%s' contains invalid characters", req.PackageName) } // Check version format if specified if req.Version != "" && !sm.isValidVersion(req.Version) { return fmt.Errorf("version '%s' is not in valid format", req.Version) } return nil } // ValidateWorkspaceAccess validates workspace path access func (sm *SecurityManager) ValidateWorkspaceAccess(workspacePath, user string) error { defer sm.logSecurityEvent("workspace_access", "medium", user, fmt.Sprintf("access_workspace:%s", workspacePath), fmt.Sprintf("Workspace access attempt: %s", workspacePath)) // Clean path to prevent directory traversal cleanPath := filepath.Clean(workspacePath) // Check for path traversal attempts if strings.Contains(workspacePath, "..") { return fmt.Errorf("path traversal detected in workspace path: %s", workspacePath) } // Check if path is in allowed paths if len(sm.config.AllowedPaths) > 0 { allowed := false for _, allowedPath := range sm.config.AllowedPaths { if strings.HasPrefix(cleanPath, allowedPath) { allowed = true break } } if !allowed { return fmt.Errorf("workspace path '%s' is not in allowed paths", cleanPath) } } // Check if path is in denied paths for _, deniedPath := range sm.config.DeniedPaths { if strings.HasPrefix(cleanPath, deniedPath) { return fmt.Errorf("workspace path '%s' is in denied paths", cleanPath) } } // Check if workspace exists and is accessible if _, err := os.Stat(cleanPath); os.IsNotExist(err) { return fmt.Errorf("workspace path '%s' does not exist", cleanPath) } return nil } // ValidateNetworkAccess validates network access requests func (sm *SecurityManager) ValidateNetworkAccess(host, port, user string) error { defer sm.logSecurityEvent("network_access", "high", user, fmt.Sprintf("network_access:%s:%s", host, port), fmt.Sprintf("Network access attempt: %s:%s", host, port)) if !sm.config.AllowNetwork { return fmt.Errorf("network access is disabled by security policy") } // Check if host is blocked for _, blockedHost := range sm.config.BlockedHosts { if strings.EqualFold(blockedHost, host) || strings.HasSuffix(host, blockedHost) { return fmt.Errorf("host '%s' is blocked by security policy", host) } } // Check if host is allowed (if allowlist exists) if len(sm.config.AllowedHosts) > 0 { allowed := false for _, allowedHost := range sm.config.AllowedHosts { if strings.EqualFold(allowedHost, host) || strings.HasSuffix(host, allowedHost) { allowed = true break } } if !allowed { return fmt.Errorf("host '%s' is not in allowed hosts list", host) } } // Validate port range if port != "" { if !sm.isValidPort(port) { return fmt.Errorf("port '%s' is not in allowed range", port) } } return nil } // GenerateSecureToken generates a cryptographically secure token func (sm *SecurityManager) GenerateSecureToken() (string, error) { bytes := make([]byte, sm.config.TokenLength) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate secure token: %w", err) } token := base64.URLEncoding.EncodeToString(bytes) // Log token generation (without the token itself for security) sm.logSecurityEvent( "token_generation", "low", "system", "generate_token", "Secure token generated", ) return token, nil } // ValidateToken validates a token and checks expiry func (sm *SecurityManager) ValidateToken(token, user string) error { defer sm.logSecurityEvent("token_validation", "medium", user, "validate_token", "Token validation attempt") if !sm.config.EnableTokenAuth { return fmt.Errorf("token authentication is disabled") } if len(token) < sm.config.TokenLength { return fmt.Errorf("invalid token length") } // Additional token validation logic would go here // For now, just check basic format if !sm.isValidTokenFormat(token) { return fmt.Errorf("invalid token format") } return nil } // GetDefaultSecurityConfig returns the default enhanced security configuration func GetDefaultSecurityConfig() *EnhancedSecurityConfig { return &EnhancedSecurityConfig{ // Network Security AllowNetwork: false, AllowedHosts: []string{"localhost", "127.0.0.1"}, BlockedHosts: []string{"0.0.0.0", "0.0.0.0/0"}, EnableFirewall: true, // Package Security TrustedChannels: []string{"conda-forge", "defaults", "pytorch", "nvidia"}, BlockedPackages: append([]string{"aiohttp", "socket", "telnetlib"}, defaultBlockedPackages...), AllowedPackages: make(map[string]bool), // Empty means no explicit allowlist RequireApproval: true, AutoApproveSafe: false, MaxPackages: 50, InstallTimeout: 5 * time.Minute, AllowCondaForge: true, AllowPyPI: false, AllowLocal: false, // Container Security ReadOnlyRoot: true, DropCapabilities: []string{"ALL"}, RunAsNonRoot: true, EnableSeccomp: true, NoNewPrivileges: true, // Authentication Security EnableTokenAuth: true, TokenLength: 32, TokenExpiry: 24 * time.Hour, RequireHTTPS: true, SessionTimeout: 2 * time.Hour, MaxFailedAttempts: 5, LockoutDuration: 15 * time.Minute, // File System Security AllowedPaths: []string{"./workspace", "./data"}, DeniedPaths: []string{"/etc", "/root", "/var", "/sys", "/proc"}, MaxWorkspaceSize: "10G", AllowExecFrom: []string{"./workspace", "./data"}, BlockExecFrom: []string{"/tmp", "/var/tmp"}, // Resource Security MaxMemoryLimit: "4G", MaxCPULimit: "2", MaxDiskUsage: "20G", MaxProcesses: 100, // Logging & Monitoring SecurityLogLevel: "info", AuditEnabled: true, RealTimeAlerts: true, } } // Helper methods func (sm *SecurityManager) isValidChannel(channel string) bool { for _, trusted := range sm.config.TrustedChannels { if strings.EqualFold(trusted, channel) { return true } } return false } func (sm *SecurityManager) isValidPackageName(name string) bool { // Package names should only contain alphanumeric characters, underscores, hyphens, and dots matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.-]+$`, name) return matched } func (sm *SecurityManager) isValidVersion(version string) bool { // Basic semantic version validation matched, _ := regexp.MatchString(`^\d+\.\d+(\.\d+)?([a-zA-Z0-9.-]*)?$`, version) return matched } func (sm *SecurityManager) isValidPort(port string) bool { // Basic port validation (1-65535) matched, _ := regexp.MatchString(`^[1-9][0-9]{0,4}$`, port) if !matched { return false } // Additional range check would go here return true } func (sm *SecurityManager) isValidTokenFormat(token string) bool { // Base64 URL encoded token validation matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, token) return matched } func (sm *SecurityManager) logSecurityEvent(eventType, severity, user, action, description string) { if !sm.config.AuditEnabled { return } event := SecurityEvent{ Timestamp: time.Now(), EventType: eventType, Severity: severity, User: user, Action: action, Resource: "jupyter", Description: description, } // Log the security event sm.logger.Info("Security Event", "event_type", event.EventType, "severity", event.Severity, "user", event.User, "action", event.Action, "resource", event.Resource, "description", event.Description, "timestamp", event.Timestamp, ) // Send real-time alert if enabled and severity is high or critical if sm.config.RealTimeAlerts && (event.Severity == "high" || event.Severity == "critical") { sm.sendSecurityAlert(event) } } func (sm *SecurityManager) sendSecurityAlert(event SecurityEvent) { // Implementation would send alerts to monitoring systems sm.logger.Warn("Security Alert", "alert_type", event.EventType, "severity", event.Severity, "user", event.User, "description", event.Description, "timestamp", event.Timestamp, ) } // HashPassword securely hashes a password using SHA-256 func (sm *SecurityManager) HashPassword(password string) string { hash := sha256.Sum256([]byte(password)) return hex.EncodeToString(hash[:]) } // ValidatePassword validates a password against security requirements func (sm *SecurityManager) ValidatePassword(password string) error { if len(password) < 8 { return fmt.Errorf("password must be at least 8 characters long") } hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(password) hasLower := regexp.MustCompile(`[a-z]`).MatchString(password) hasDigit := regexp.MustCompile(`[0-9]`).MatchString(password) hasSpecial := regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password) if !hasUpper || !hasLower || !hasDigit || !hasSpecial { return fmt.Errorf("password must contain uppercase, lowercase, digit, and special character") } return nil }