fetch_ml/internal/jupyter/security_enhanced.go

430 lines
13 KiB
Go

package jupyter
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// SecurityManager handles all security-related operations for Jupyter services
type SecurityManager struct {
logger *logging.Logger
config *EnhancedSecurityConfig
}
// EnhancedSecurityConfig provides comprehensive security settings
type EnhancedSecurityConfig struct {
// Network Security
AllowNetwork bool `json:"allow_network"`
AllowedHosts []string `json:"allowed_hosts"`
BlockedHosts []string `json:"blocked_hosts"`
EnableFirewall bool `json:"enable_firewall"`
// Package Security
TrustedChannels []string `json:"trusted_channels"`
BlockedPackages []string `json:"blocked_packages"`
AllowedPackages map[string]bool `json:"allowed_packages"`
RequireApproval bool `json:"require_approval"`
AutoApproveSafe bool `json:"auto_approve_safe"`
MaxPackages int `json:"max_packages"`
InstallTimeout time.Duration `json:"install_timeout"`
AllowCondaForge bool `json:"allow_conda_forge"`
AllowPyPI bool `json:"allow_pypi"`
AllowLocal bool `json:"allow_local"`
// Container Security
ReadOnlyRoot bool `json:"read_only_root"`
DropCapabilities []string `json:"drop_capabilities"`
RunAsNonRoot bool `json:"run_as_non_root"`
EnableSeccomp bool `json:"enable_seccomp"`
NoNewPrivileges bool `json:"no_new_privileges"`
// Authentication Security
EnableTokenAuth bool `json:"enable_token_auth"`
TokenLength int `json:"token_length"`
TokenExpiry time.Duration `json:"token_expiry"`
RequireHTTPS bool `json:"require_https"`
SessionTimeout time.Duration `json:"session_timeout"`
MaxFailedAttempts int `json:"max_failed_attempts"`
LockoutDuration time.Duration `json:"lockout_duration"`
// File System Security
AllowedPaths []string `json:"allowed_paths"`
DeniedPaths []string `json:"denied_paths"`
MaxWorkspaceSize string `json:"max_workspace_size"`
AllowExecFrom []string `json:"allow_exec_from"`
BlockExecFrom []string `json:"block_exec_from"`
// Resource Security
MaxMemoryLimit string `json:"max_memory_limit"`
MaxCPULimit string `json:"max_cpu_limit"`
MaxDiskUsage string `json:"max_disk_usage"`
MaxProcesses int `json:"max_processes"`
// Logging & Monitoring
SecurityLogLevel string `json:"security_log_level"`
AuditEnabled bool `json:"audit_enabled"`
RealTimeAlerts bool `json:"real_time_alerts"`
}
// SecurityEvent represents a security-related event
type SecurityEvent struct {
Timestamp time.Time `json:"timestamp"`
EventType string `json:"event_type"`
Severity string `json:"severity"` // low, medium, high, critical
User string `json:"user"`
Action string `json:"action"`
Resource string `json:"resource"`
Description string `json:"description"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
}
// NewSecurityManager creates a new security manager
func NewSecurityManager(logger *logging.Logger, config *EnhancedSecurityConfig) *SecurityManager {
return &SecurityManager{
logger: logger,
config: config,
}
}
// ValidatePackageRequest validates a package installation request
func (sm *SecurityManager) ValidatePackageRequest(req *PackageRequest) error {
// Log security event
defer sm.logSecurityEvent("package_validation", "medium", req.RequestedBy,
fmt.Sprintf("validate_package:%s", req.PackageName),
fmt.Sprintf("Package: %s, Version: %s, Channel: %s", req.PackageName, req.Version, req.Channel))
// Check if package is blocked
for _, blocked := range sm.config.BlockedPackages {
if strings.EqualFold(blocked, req.PackageName) {
return fmt.Errorf("package '%s' is blocked by security policy", req.PackageName)
}
}
// Check if package is explicitly allowed (if allowlist exists)
if len(sm.config.AllowedPackages) > 0 {
if !sm.config.AllowedPackages[req.PackageName] {
return fmt.Errorf("package '%s' is not in the allowed packages list", req.PackageName)
}
}
// Validate channel
if req.Channel != "" {
if !sm.isValidChannel(req.Channel) {
return fmt.Errorf("channel '%s' is not trusted", req.Channel)
}
}
// Check package name format
if !sm.isValidPackageName(req.PackageName) {
return fmt.Errorf("package name '%s' contains invalid characters", req.PackageName)
}
// Check version format if specified
if req.Version != "" && !sm.isValidVersion(req.Version) {
return fmt.Errorf("version '%s' is not in valid format", req.Version)
}
return nil
}
// ValidateWorkspaceAccess validates workspace path access
func (sm *SecurityManager) ValidateWorkspaceAccess(workspacePath, user string) error {
defer sm.logSecurityEvent("workspace_access", "medium", user,
fmt.Sprintf("access_workspace:%s", workspacePath),
fmt.Sprintf("Workspace access attempt: %s", workspacePath))
// Clean path to prevent directory traversal
cleanPath := filepath.Clean(workspacePath)
// Check for path traversal attempts
if strings.Contains(workspacePath, "..") {
return fmt.Errorf("path traversal detected in workspace path: %s", workspacePath)
}
// Check if path is in allowed paths
if len(sm.config.AllowedPaths) > 0 {
allowed := false
for _, allowedPath := range sm.config.AllowedPaths {
if strings.HasPrefix(cleanPath, allowedPath) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("workspace path '%s' is not in allowed paths", cleanPath)
}
}
// Check if path is in denied paths
for _, deniedPath := range sm.config.DeniedPaths {
if strings.HasPrefix(cleanPath, deniedPath) {
return fmt.Errorf("workspace path '%s' is in denied paths", cleanPath)
}
}
// Check if workspace exists and is accessible
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
return fmt.Errorf("workspace path '%s' does not exist", cleanPath)
}
return nil
}
// ValidateNetworkAccess validates network access requests
func (sm *SecurityManager) ValidateNetworkAccess(host, port, user string) error {
defer sm.logSecurityEvent("network_access", "high", user,
fmt.Sprintf("network_access:%s:%s", host, port),
fmt.Sprintf("Network access attempt: %s:%s", host, port))
if !sm.config.AllowNetwork {
return fmt.Errorf("network access is disabled by security policy")
}
// Check if host is blocked
for _, blockedHost := range sm.config.BlockedHosts {
if strings.EqualFold(blockedHost, host) || strings.HasSuffix(host, blockedHost) {
return fmt.Errorf("host '%s' is blocked by security policy", host)
}
}
// Check if host is allowed (if allowlist exists)
if len(sm.config.AllowedHosts) > 0 {
allowed := false
for _, allowedHost := range sm.config.AllowedHosts {
if strings.EqualFold(allowedHost, host) || strings.HasSuffix(host, allowedHost) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("host '%s' is not in allowed hosts list", host)
}
}
// Validate port range
if port != "" {
if !sm.isValidPort(port) {
return fmt.Errorf("port '%s' is not in allowed range", port)
}
}
return nil
}
// GenerateSecureToken generates a cryptographically secure token
func (sm *SecurityManager) GenerateSecureToken() (string, error) {
bytes := make([]byte, sm.config.TokenLength)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
token := base64.URLEncoding.EncodeToString(bytes)
// Log token generation (without the token itself for security)
sm.logSecurityEvent(
"token_generation",
"low",
"system",
"generate_token",
"Secure token generated",
)
return token, nil
}
// ValidateToken validates a token and checks expiry
func (sm *SecurityManager) ValidateToken(token, user string) error {
defer sm.logSecurityEvent("token_validation", "medium", user,
"validate_token", "Token validation attempt")
if !sm.config.EnableTokenAuth {
return fmt.Errorf("token authentication is disabled")
}
if len(token) < sm.config.TokenLength {
return fmt.Errorf("invalid token length")
}
// Additional token validation logic would go here
// For now, just check basic format
if !sm.isValidTokenFormat(token) {
return fmt.Errorf("invalid token format")
}
return nil
}
// GetDefaultSecurityConfig returns the default enhanced security configuration
func GetDefaultSecurityConfig() *EnhancedSecurityConfig {
return &EnhancedSecurityConfig{
// Network Security
AllowNetwork: false,
AllowedHosts: []string{"localhost", "127.0.0.1"},
BlockedHosts: []string{"0.0.0.0", "0.0.0.0/0"},
EnableFirewall: true,
// Package Security
TrustedChannels: []string{"conda-forge", "defaults", "pytorch", "nvidia"},
BlockedPackages: append([]string{"aiohttp", "socket", "telnetlib"}, defaultBlockedPackages...),
AllowedPackages: make(map[string]bool), // Empty means no explicit allowlist
RequireApproval: true,
AutoApproveSafe: false,
MaxPackages: 50,
InstallTimeout: 5 * time.Minute,
AllowCondaForge: true,
AllowPyPI: false,
AllowLocal: false,
// Container Security
ReadOnlyRoot: true,
DropCapabilities: []string{"ALL"},
RunAsNonRoot: true,
EnableSeccomp: true,
NoNewPrivileges: true,
// Authentication Security
EnableTokenAuth: true,
TokenLength: 32,
TokenExpiry: 24 * time.Hour,
RequireHTTPS: true,
SessionTimeout: 2 * time.Hour,
MaxFailedAttempts: 5,
LockoutDuration: 15 * time.Minute,
// File System Security
AllowedPaths: []string{"./workspace", "./data"},
DeniedPaths: []string{"/etc", "/root", "/var", "/sys", "/proc"},
MaxWorkspaceSize: "10G",
AllowExecFrom: []string{"./workspace", "./data"},
BlockExecFrom: []string{"/tmp", "/var/tmp"},
// Resource Security
MaxMemoryLimit: "4G",
MaxCPULimit: "2",
MaxDiskUsage: "20G",
MaxProcesses: 100,
// Logging & Monitoring
SecurityLogLevel: "info",
AuditEnabled: true,
RealTimeAlerts: true,
}
}
// Helper methods
func (sm *SecurityManager) isValidChannel(channel string) bool {
for _, trusted := range sm.config.TrustedChannels {
if strings.EqualFold(trusted, channel) {
return true
}
}
return false
}
func (sm *SecurityManager) isValidPackageName(name string) bool {
// Package names should only contain alphanumeric characters, underscores, hyphens, and dots
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.-]+$`, name)
return matched
}
func (sm *SecurityManager) isValidVersion(version string) bool {
// Basic semantic version validation
matched, _ := regexp.MatchString(`^\d+\.\d+(\.\d+)?([a-zA-Z0-9.-]*)?$`, version)
return matched
}
func (sm *SecurityManager) isValidPort(port string) bool {
// Basic port validation (1-65535)
matched, _ := regexp.MatchString(`^[1-9][0-9]{0,4}$`, port)
if !matched {
return false
}
// Additional range check would go here
return true
}
func (sm *SecurityManager) isValidTokenFormat(token string) bool {
// Base64 URL encoded token validation
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, token)
return matched
}
func (sm *SecurityManager) logSecurityEvent(eventType, severity, user, action, description string) {
if !sm.config.AuditEnabled {
return
}
event := SecurityEvent{
Timestamp: time.Now(),
EventType: eventType,
Severity: severity,
User: user,
Action: action,
Resource: "jupyter",
Description: description,
}
// Log the security event
sm.logger.Info("Security Event",
"event_type", event.EventType,
"severity", event.Severity,
"user", event.User,
"action", event.Action,
"resource", event.Resource,
"description", event.Description,
"timestamp", event.Timestamp,
)
// Send real-time alert if enabled and severity is high or critical
if sm.config.RealTimeAlerts && (event.Severity == "high" || event.Severity == "critical") {
sm.sendSecurityAlert(event)
}
}
func (sm *SecurityManager) sendSecurityAlert(event SecurityEvent) {
// Implementation would send alerts to monitoring systems
sm.logger.Warn("Security Alert",
"alert_type", event.EventType,
"severity", event.Severity,
"user", event.User,
"description", event.Description,
"timestamp", event.Timestamp,
)
}
// HashPassword securely hashes a password using SHA-256
func (sm *SecurityManager) HashPassword(password string) string {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:])
}
// ValidatePassword validates a password against security requirements
func (sm *SecurityManager) ValidatePassword(password string) error {
if len(password) < 8 {
return fmt.Errorf("password must be at least 8 characters long")
}
hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(password)
hasLower := regexp.MustCompile(`[a-z]`).MatchString(password)
hasDigit := regexp.MustCompile(`[0-9]`).MatchString(password)
hasSpecial := regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
return fmt.Errorf("password must contain uppercase, lowercase, digit, and special character")
}
return nil
}