430 lines
13 KiB
Go
430 lines
13 KiB
Go
package jupyter
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
// SecurityManager handles all security-related operations for Jupyter services
|
|
type SecurityManager struct {
|
|
logger *logging.Logger
|
|
config *EnhancedSecurityConfig
|
|
}
|
|
|
|
// EnhancedSecurityConfig provides comprehensive security settings
|
|
type EnhancedSecurityConfig struct {
|
|
// Network Security
|
|
AllowNetwork bool `json:"allow_network"`
|
|
AllowedHosts []string `json:"allowed_hosts"`
|
|
BlockedHosts []string `json:"blocked_hosts"`
|
|
EnableFirewall bool `json:"enable_firewall"`
|
|
|
|
// Package Security
|
|
TrustedChannels []string `json:"trusted_channels"`
|
|
BlockedPackages []string `json:"blocked_packages"`
|
|
AllowedPackages map[string]bool `json:"allowed_packages"`
|
|
RequireApproval bool `json:"require_approval"`
|
|
AutoApproveSafe bool `json:"auto_approve_safe"`
|
|
MaxPackages int `json:"max_packages"`
|
|
InstallTimeout time.Duration `json:"install_timeout"`
|
|
AllowCondaForge bool `json:"allow_conda_forge"`
|
|
AllowPyPI bool `json:"allow_pypi"`
|
|
AllowLocal bool `json:"allow_local"`
|
|
|
|
// Container Security
|
|
ReadOnlyRoot bool `json:"read_only_root"`
|
|
DropCapabilities []string `json:"drop_capabilities"`
|
|
RunAsNonRoot bool `json:"run_as_non_root"`
|
|
EnableSeccomp bool `json:"enable_seccomp"`
|
|
NoNewPrivileges bool `json:"no_new_privileges"`
|
|
|
|
// Authentication Security
|
|
EnableTokenAuth bool `json:"enable_token_auth"`
|
|
TokenLength int `json:"token_length"`
|
|
TokenExpiry time.Duration `json:"token_expiry"`
|
|
RequireHTTPS bool `json:"require_https"`
|
|
SessionTimeout time.Duration `json:"session_timeout"`
|
|
MaxFailedAttempts int `json:"max_failed_attempts"`
|
|
LockoutDuration time.Duration `json:"lockout_duration"`
|
|
|
|
// File System Security
|
|
AllowedPaths []string `json:"allowed_paths"`
|
|
DeniedPaths []string `json:"denied_paths"`
|
|
MaxWorkspaceSize string `json:"max_workspace_size"`
|
|
AllowExecFrom []string `json:"allow_exec_from"`
|
|
BlockExecFrom []string `json:"block_exec_from"`
|
|
|
|
// Resource Security
|
|
MaxMemoryLimit string `json:"max_memory_limit"`
|
|
MaxCPULimit string `json:"max_cpu_limit"`
|
|
MaxDiskUsage string `json:"max_disk_usage"`
|
|
MaxProcesses int `json:"max_processes"`
|
|
|
|
// Logging & Monitoring
|
|
SecurityLogLevel string `json:"security_log_level"`
|
|
AuditEnabled bool `json:"audit_enabled"`
|
|
RealTimeAlerts bool `json:"real_time_alerts"`
|
|
}
|
|
|
|
// SecurityEvent represents a security-related event
|
|
type SecurityEvent struct {
|
|
Timestamp time.Time `json:"timestamp"`
|
|
EventType string `json:"event_type"`
|
|
Severity string `json:"severity"` // low, medium, high, critical
|
|
User string `json:"user"`
|
|
Action string `json:"action"`
|
|
Resource string `json:"resource"`
|
|
Description string `json:"description"`
|
|
IPAddress string `json:"ip_address,omitempty"`
|
|
UserAgent string `json:"user_agent,omitempty"`
|
|
}
|
|
|
|
// NewSecurityManager creates a new security manager
|
|
func NewSecurityManager(logger *logging.Logger, config *EnhancedSecurityConfig) *SecurityManager {
|
|
return &SecurityManager{
|
|
logger: logger,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// ValidatePackageRequest validates a package installation request
|
|
func (sm *SecurityManager) ValidatePackageRequest(req *PackageRequest) error {
|
|
// Log security event
|
|
defer sm.logSecurityEvent("package_validation", "medium", req.RequestedBy,
|
|
fmt.Sprintf("validate_package:%s", req.PackageName),
|
|
fmt.Sprintf("Package: %s, Version: %s, Channel: %s", req.PackageName, req.Version, req.Channel))
|
|
|
|
// Check if package is blocked
|
|
for _, blocked := range sm.config.BlockedPackages {
|
|
if strings.EqualFold(blocked, req.PackageName) {
|
|
return fmt.Errorf("package '%s' is blocked by security policy", req.PackageName)
|
|
}
|
|
}
|
|
|
|
// Check if package is explicitly allowed (if allowlist exists)
|
|
if len(sm.config.AllowedPackages) > 0 {
|
|
if !sm.config.AllowedPackages[req.PackageName] {
|
|
return fmt.Errorf("package '%s' is not in the allowed packages list", req.PackageName)
|
|
}
|
|
}
|
|
|
|
// Validate channel
|
|
if req.Channel != "" {
|
|
if !sm.isValidChannel(req.Channel) {
|
|
return fmt.Errorf("channel '%s' is not trusted", req.Channel)
|
|
}
|
|
}
|
|
|
|
// Check package name format
|
|
if !sm.isValidPackageName(req.PackageName) {
|
|
return fmt.Errorf("package name '%s' contains invalid characters", req.PackageName)
|
|
}
|
|
|
|
// Check version format if specified
|
|
if req.Version != "" && !sm.isValidVersion(req.Version) {
|
|
return fmt.Errorf("version '%s' is not in valid format", req.Version)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateWorkspaceAccess validates workspace path access
|
|
func (sm *SecurityManager) ValidateWorkspaceAccess(workspacePath, user string) error {
|
|
defer sm.logSecurityEvent("workspace_access", "medium", user,
|
|
fmt.Sprintf("access_workspace:%s", workspacePath),
|
|
fmt.Sprintf("Workspace access attempt: %s", workspacePath))
|
|
|
|
// Clean path to prevent directory traversal
|
|
cleanPath := filepath.Clean(workspacePath)
|
|
|
|
// Check for path traversal attempts
|
|
if strings.Contains(workspacePath, "..") {
|
|
return fmt.Errorf("path traversal detected in workspace path: %s", workspacePath)
|
|
}
|
|
|
|
// Check if path is in allowed paths
|
|
if len(sm.config.AllowedPaths) > 0 {
|
|
allowed := false
|
|
for _, allowedPath := range sm.config.AllowedPaths {
|
|
if strings.HasPrefix(cleanPath, allowedPath) {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
return fmt.Errorf("workspace path '%s' is not in allowed paths", cleanPath)
|
|
}
|
|
}
|
|
|
|
// Check if path is in denied paths
|
|
for _, deniedPath := range sm.config.DeniedPaths {
|
|
if strings.HasPrefix(cleanPath, deniedPath) {
|
|
return fmt.Errorf("workspace path '%s' is in denied paths", cleanPath)
|
|
}
|
|
}
|
|
|
|
// Check if workspace exists and is accessible
|
|
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
|
|
return fmt.Errorf("workspace path '%s' does not exist", cleanPath)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateNetworkAccess validates network access requests
|
|
func (sm *SecurityManager) ValidateNetworkAccess(host, port, user string) error {
|
|
defer sm.logSecurityEvent("network_access", "high", user,
|
|
fmt.Sprintf("network_access:%s:%s", host, port),
|
|
fmt.Sprintf("Network access attempt: %s:%s", host, port))
|
|
|
|
if !sm.config.AllowNetwork {
|
|
return fmt.Errorf("network access is disabled by security policy")
|
|
}
|
|
|
|
// Check if host is blocked
|
|
for _, blockedHost := range sm.config.BlockedHosts {
|
|
if strings.EqualFold(blockedHost, host) || strings.HasSuffix(host, blockedHost) {
|
|
return fmt.Errorf("host '%s' is blocked by security policy", host)
|
|
}
|
|
}
|
|
|
|
// Check if host is allowed (if allowlist exists)
|
|
if len(sm.config.AllowedHosts) > 0 {
|
|
allowed := false
|
|
for _, allowedHost := range sm.config.AllowedHosts {
|
|
if strings.EqualFold(allowedHost, host) || strings.HasSuffix(host, allowedHost) {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
return fmt.Errorf("host '%s' is not in allowed hosts list", host)
|
|
}
|
|
}
|
|
|
|
// Validate port range
|
|
if port != "" {
|
|
if !sm.isValidPort(port) {
|
|
return fmt.Errorf("port '%s' is not in allowed range", port)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateSecureToken generates a cryptographically secure token
|
|
func (sm *SecurityManager) GenerateSecureToken() (string, error) {
|
|
bytes := make([]byte, sm.config.TokenLength)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate secure token: %w", err)
|
|
}
|
|
|
|
token := base64.URLEncoding.EncodeToString(bytes)
|
|
|
|
// Log token generation (without the token itself for security)
|
|
sm.logSecurityEvent(
|
|
"token_generation",
|
|
"low",
|
|
"system",
|
|
"generate_token",
|
|
"Secure token generated",
|
|
)
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// ValidateToken validates a token and checks expiry
|
|
func (sm *SecurityManager) ValidateToken(token, user string) error {
|
|
defer sm.logSecurityEvent("token_validation", "medium", user,
|
|
"validate_token", "Token validation attempt")
|
|
|
|
if !sm.config.EnableTokenAuth {
|
|
return fmt.Errorf("token authentication is disabled")
|
|
}
|
|
|
|
if len(token) < sm.config.TokenLength {
|
|
return fmt.Errorf("invalid token length")
|
|
}
|
|
|
|
// Additional token validation logic would go here
|
|
// For now, just check basic format
|
|
if !sm.isValidTokenFormat(token) {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDefaultSecurityConfig returns the default enhanced security configuration
|
|
func GetDefaultSecurityConfig() *EnhancedSecurityConfig {
|
|
return &EnhancedSecurityConfig{
|
|
// Network Security
|
|
AllowNetwork: false,
|
|
AllowedHosts: []string{"localhost", "127.0.0.1"},
|
|
BlockedHosts: []string{"0.0.0.0", "0.0.0.0/0"},
|
|
EnableFirewall: true,
|
|
|
|
// Package Security
|
|
TrustedChannels: []string{"conda-forge", "defaults", "pytorch", "nvidia"},
|
|
BlockedPackages: append([]string{"aiohttp", "socket", "telnetlib"}, defaultBlockedPackages...),
|
|
AllowedPackages: make(map[string]bool), // Empty means no explicit allowlist
|
|
RequireApproval: true,
|
|
AutoApproveSafe: false,
|
|
MaxPackages: 50,
|
|
InstallTimeout: 5 * time.Minute,
|
|
AllowCondaForge: true,
|
|
AllowPyPI: false,
|
|
AllowLocal: false,
|
|
|
|
// Container Security
|
|
ReadOnlyRoot: true,
|
|
DropCapabilities: []string{"ALL"},
|
|
RunAsNonRoot: true,
|
|
EnableSeccomp: true,
|
|
NoNewPrivileges: true,
|
|
|
|
// Authentication Security
|
|
EnableTokenAuth: true,
|
|
TokenLength: 32,
|
|
TokenExpiry: 24 * time.Hour,
|
|
RequireHTTPS: true,
|
|
SessionTimeout: 2 * time.Hour,
|
|
MaxFailedAttempts: 5,
|
|
LockoutDuration: 15 * time.Minute,
|
|
|
|
// File System Security
|
|
AllowedPaths: []string{"./workspace", "./data"},
|
|
DeniedPaths: []string{"/etc", "/root", "/var", "/sys", "/proc"},
|
|
MaxWorkspaceSize: "10G",
|
|
AllowExecFrom: []string{"./workspace", "./data"},
|
|
BlockExecFrom: []string{"/tmp", "/var/tmp"},
|
|
|
|
// Resource Security
|
|
MaxMemoryLimit: "4G",
|
|
MaxCPULimit: "2",
|
|
MaxDiskUsage: "20G",
|
|
MaxProcesses: 100,
|
|
|
|
// Logging & Monitoring
|
|
SecurityLogLevel: "info",
|
|
AuditEnabled: true,
|
|
RealTimeAlerts: true,
|
|
}
|
|
}
|
|
|
|
// Helper methods
|
|
|
|
func (sm *SecurityManager) isValidChannel(channel string) bool {
|
|
for _, trusted := range sm.config.TrustedChannels {
|
|
if strings.EqualFold(trusted, channel) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (sm *SecurityManager) isValidPackageName(name string) bool {
|
|
// Package names should only contain alphanumeric characters, underscores, hyphens, and dots
|
|
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_.-]+$`, name)
|
|
return matched
|
|
}
|
|
|
|
func (sm *SecurityManager) isValidVersion(version string) bool {
|
|
// Basic semantic version validation
|
|
matched, _ := regexp.MatchString(`^\d+\.\d+(\.\d+)?([a-zA-Z0-9.-]*)?$`, version)
|
|
return matched
|
|
}
|
|
|
|
func (sm *SecurityManager) isValidPort(port string) bool {
|
|
// Basic port validation (1-65535)
|
|
matched, _ := regexp.MatchString(`^[1-9][0-9]{0,4}$`, port)
|
|
if !matched {
|
|
return false
|
|
}
|
|
|
|
// Additional range check would go here
|
|
return true
|
|
}
|
|
|
|
func (sm *SecurityManager) isValidTokenFormat(token string) bool {
|
|
// Base64 URL encoded token validation
|
|
matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, token)
|
|
return matched
|
|
}
|
|
|
|
func (sm *SecurityManager) logSecurityEvent(eventType, severity, user, action, description string) {
|
|
if !sm.config.AuditEnabled {
|
|
return
|
|
}
|
|
|
|
event := SecurityEvent{
|
|
Timestamp: time.Now(),
|
|
EventType: eventType,
|
|
Severity: severity,
|
|
User: user,
|
|
Action: action,
|
|
Resource: "jupyter",
|
|
Description: description,
|
|
}
|
|
|
|
// Log the security event
|
|
sm.logger.Info("Security Event",
|
|
"event_type", event.EventType,
|
|
"severity", event.Severity,
|
|
"user", event.User,
|
|
"action", event.Action,
|
|
"resource", event.Resource,
|
|
"description", event.Description,
|
|
"timestamp", event.Timestamp,
|
|
)
|
|
|
|
// Send real-time alert if enabled and severity is high or critical
|
|
if sm.config.RealTimeAlerts && (event.Severity == "high" || event.Severity == "critical") {
|
|
sm.sendSecurityAlert(event)
|
|
}
|
|
}
|
|
|
|
func (sm *SecurityManager) sendSecurityAlert(event SecurityEvent) {
|
|
// Implementation would send alerts to monitoring systems
|
|
sm.logger.Warn("Security Alert",
|
|
"alert_type", event.EventType,
|
|
"severity", event.Severity,
|
|
"user", event.User,
|
|
"description", event.Description,
|
|
"timestamp", event.Timestamp,
|
|
)
|
|
}
|
|
|
|
// HashPassword securely hashes a password using SHA-256
|
|
func (sm *SecurityManager) HashPassword(password string) string {
|
|
hash := sha256.Sum256([]byte(password))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// ValidatePassword validates a password against security requirements
|
|
func (sm *SecurityManager) ValidatePassword(password string) error {
|
|
if len(password) < 8 {
|
|
return fmt.Errorf("password must be at least 8 characters long")
|
|
}
|
|
|
|
hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(password)
|
|
hasLower := regexp.MustCompile(`[a-z]`).MatchString(password)
|
|
hasDigit := regexp.MustCompile(`[0-9]`).MatchString(password)
|
|
hasSpecial := regexp.MustCompile(`[!@#$%^&*(),.?":{}|<>]`).MatchString(password)
|
|
|
|
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
|
return fmt.Errorf("password must contain uppercase, lowercase, digit, and special character")
|
|
}
|
|
|
|
return nil
|
|
}
|