package jupyter import ( "encoding/json" "fmt" "os" "path/filepath" "strings" "time" "github.com/jfraeys/fetch_ml/internal/logging" ) const ( statusPending = "pending" ) // PackageManager manages package installations in Jupyter workspaces type PackageManager struct { logger *logging.Logger trustedChannels []string allowedPackages map[string]bool blockedPackages []string workspacePath string packageCachePath string } // PackageConfig defines package management configuration type PackageConfig struct { TrustedChannels []string `json:"trusted_channels"` AllowedPackages map[string]bool `json:"allowed_packages"` BlockedPackages []string `json:"blocked_packages"` RequireApproval bool `json:"require_approval"` AutoApproveSafe bool `json:"auto_approve_safe"` MaxPackages int `json:"max_packages"` InstallTimeout time.Duration `json:"install_timeout"` AllowCondaForge bool `json:"allow_conda_forge"` AllowPyPI bool `json:"allow_pypi"` AllowLocal bool `json:"allow_local"` } // PackageRequest represents a package installation request type PackageRequest struct { PackageName string `json:"package_name"` Version string `json:"version,omitempty"` Channel string `json:"channel,omitempty"` RequestedBy string `json:"requested_by"` WorkspacePath string `json:"workspace_path"` Timestamp time.Time `json:"timestamp"` Status string `json:"status"` // pending, approved, rejected, installed, failed RejectionReason string `json:"rejection_reason,omitempty"` ApprovalUser string `json:"approval_user,omitempty"` ApprovalTime time.Time `json:"approval_time,omitempty"` } // PackageInfo contains information about an installed package type PackageInfo struct { Name string `json:"name"` Version string `json:"version"` Channel string `json:"channel"` InstalledAt time.Time `json:"installed_at"` InstalledBy string `json:"installed_by"` Size string `json:"size"` Dependencies []string `json:"dependencies"` Metadata map[string]string `json:"metadata"` } // NewPackageManager creates a new package manager func NewPackageManager( logger *logging.Logger, config *PackageConfig, workspacePath string, ) (*PackageManager, error) { pm := &PackageManager{ logger: logger, trustedChannels: config.TrustedChannels, allowedPackages: config.AllowedPackages, blockedPackages: config.BlockedPackages, workspacePath: workspacePath, packageCachePath: filepath.Join(workspacePath, ".package_cache"), } // Create package cache directory if err := os.MkdirAll(pm.packageCachePath, 0750); err != nil { return nil, fmt.Errorf("failed to create package cache: %w", err) } // Initialize default trusted channels if none provided if len(pm.trustedChannels) == 0 { pm.trustedChannels = []string{ "conda-forge", "defaults", "pytorch", "nvidia", } } return pm, nil } // ValidatePackageRequest validates a package installation request func (pm *PackageManager) ValidatePackageRequest(req *PackageRequest) error { // Check if package is blocked for _, blocked := range pm.blockedPackages { if strings.EqualFold(req.PackageName, blocked) { return fmt.Errorf("package '%s' is blocked for security reasons", req.PackageName) } } // Check if channel is trusted if req.Channel != "" { if !pm.isChannelTrusted(req.Channel) { return fmt.Errorf( "channel '%s' is not trusted. Allowed channels: %v", req.Channel, pm.trustedChannels, ) } } else { // Default to conda-forge if no channel specified req.Channel = "conda-forge" } // Check package against allowlist if configured if len(pm.allowedPackages) > 0 { if !pm.allowedPackages[req.PackageName] { return fmt.Errorf("package '%s' is not in the allowlist", req.PackageName) } } // Validate package name format if !pm.isValidPackageName(req.PackageName) { return fmt.Errorf("invalid package name format: '%s'", req.PackageName) } return nil } // isChannelTrusted checks if a channel is in the trusted list func (pm *PackageManager) isChannelTrusted(channel string) bool { for _, trusted := range pm.trustedChannels { if strings.EqualFold(channel, trusted) { return true } } return false } func (pm *PackageManager) isValidPackageName(name string) bool { if name == "" { return false } for _, c := range name { if ('a' > c || c > 'z') && ('A' > c || c > 'Z') && ('0' > c || c > '9') && c != '-' && c != '_' && c != '.' { return false } } return true } // RequestPackage creates a package installation request func (pm *PackageManager) RequestPackage( packageName, version, channel, requestedBy string, ) (*PackageRequest, error) { req := &PackageRequest{ PackageName: strings.ToLower(strings.TrimSpace(packageName)), Version: version, Channel: channel, RequestedBy: requestedBy, WorkspacePath: pm.workspacePath, Timestamp: time.Now(), Status: statusPending, } // Validate the request if err := pm.ValidatePackageRequest(req); err != nil { req.Status = "rejected" req.RejectionReason = err.Error() return req, err } // Save request to cache if err := pm.savePackageRequest(req); err != nil { return nil, fmt.Errorf("failed to save package request: %w", err) } pm.logger.Info("package installation request created", "package", req.PackageName, "version", req.Version, "channel", req.Channel, "requested_by", req.RequestedBy) return req, nil } // ApprovePackageRequest approves a pending package request func (pm *PackageManager) ApprovePackageRequest(requestID, approvalUser string) error { req, err := pm.loadPackageRequest(requestID) if err != nil { return fmt.Errorf("failed to load package request: %w", err) } if req.Status != statusPending { return fmt.Errorf("package request is not pending (current status: %s)", req.Status) } req.Status = "approved" req.ApprovalUser = approvalUser req.ApprovalTime = time.Now() // Save updated request if err := pm.savePackageRequest(req); err != nil { return fmt.Errorf("failed to save approved request: %w", err) } pm.logger.Info("package request approved", "package", req.PackageName, "request_id", requestID, "approved_by", approvalUser) return nil } // RejectPackageRequest rejects a pending package request func (pm *PackageManager) RejectPackageRequest(requestID, reason string) error { req, err := pm.loadPackageRequest(requestID) if err != nil { return fmt.Errorf("failed to load package request: %w", err) } if req.Status != statusPending { return fmt.Errorf("package request is not pending (current status: %s)", req.Status) } req.Status = "rejected" req.RejectionReason = reason // Save updated request if err := pm.savePackageRequest(req); err != nil { return fmt.Errorf("failed to save rejected request: %w", err) } pm.logger.Info("package request rejected", "package", req.PackageName, "request_id", requestID, "reason", reason) return nil } // InstallPackage installs an approved package func (pm *PackageManager) InstallPackage(requestID string) error { req, err := pm.loadPackageRequest(requestID) if err != nil { return fmt.Errorf("failed to load package request: %w", err) } if req.Status != "approved" { return fmt.Errorf("package request is not approved (current status: %s)", req.Status) } // Install package using conda installCmd := pm.buildInstallCommand(req) pm.logger.Info("installing package", "package", req.PackageName, "version", req.Version, "channel", req.Channel, "command", installCmd) // Execute installation (this would be implemented with proper process execution) // For now, simulate successful installation req.Status = "installed" // Save package info packageInfo := &PackageInfo{ Name: req.PackageName, Version: req.Version, Channel: req.Channel, InstalledAt: time.Now(), InstalledBy: req.RequestedBy, } if err := pm.savePackageInfo(packageInfo); err != nil { pm.logger.Warn("failed to save package info", "error", err) } // Save updated request if err := pm.savePackageRequest(req); err != nil { return fmt.Errorf("failed to save installed request: %w", err) } pm.logger.Info("package installed successfully", "package", req.PackageName, "version", req.Version) return nil } // buildInstallCommand builds the conda install command func (pm *PackageManager) buildInstallCommand(req *PackageRequest) string { cmd := []string{"conda", "install", "-y"} // Add channel if req.Channel != "" { cmd = append(cmd, "-c", req.Channel) } // Add package with version if req.Version != "" { cmd = append(cmd, fmt.Sprintf("%s=%s", req.PackageName, req.Version)) } else { cmd = append(cmd, req.PackageName) } return strings.Join(cmd, " ") } // ListPendingRequests returns all pending package requests func (pm *PackageManager) ListPendingRequests() ([]*PackageRequest, error) { requests, err := pm.loadAllPackageRequests() if err != nil { return nil, err } var pending []*PackageRequest for _, req := range requests { if req.Status == statusPending { pending = append(pending, req) } } return pending, nil } // ListInstalledPackages returns all installed packages in the workspace func (pm *PackageManager) ListInstalledPackages() ([]*PackageInfo, error) { return pm.loadAllPackageInfo() } // GetPackageRequest retrieves a specific package request func (pm *PackageManager) GetPackageRequest(requestID string) (*PackageRequest, error) { return pm.loadPackageRequest(requestID) } // savePackageRequest saves a package request to cache func (pm *PackageManager) savePackageRequest(req *PackageRequest) error { requestFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("request_%s.json", req.PackageName)) data, err := json.MarshalIndent(req, "", " ") if err != nil { return err } return os.WriteFile(requestFile, data, 0600) } // loadPackageRequest loads a package request from cache func (pm *PackageManager) loadPackageRequest(requestID string) (*PackageRequest, error) { requestFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("request_%s.json", requestID)) data, err := os.ReadFile(requestFile) if err != nil { return nil, err } var req PackageRequest if err := json.Unmarshal(data, &req); err != nil { return nil, err } return &req, nil } // loadAllPackageRequests loads all package requests from cache func (pm *PackageManager) loadAllPackageRequests() ([]*PackageRequest, error) { files, err := filepath.Glob(filepath.Join(pm.packageCachePath, "request_*.json")) if err != nil { return nil, err } var requests []*PackageRequest for _, file := range files { data, err := os.ReadFile(file) if err != nil { pm.logger.Warn("failed to read request file", "file", file, "error", err) continue } var req PackageRequest if err := json.Unmarshal(data, &req); err != nil { pm.logger.Warn("failed to parse request file", "file", file, "error", err) continue } requests = append(requests, &req) } return requests, nil } // savePackageInfo saves package information to cache func (pm *PackageManager) savePackageInfo(info *PackageInfo) error { infoFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("installed_%s.json", info.Name)) data, err := json.MarshalIndent(info, "", " ") if err != nil { return err } return os.WriteFile(infoFile, data, 0600) } // loadAllPackageInfo loads all installed package information func (pm *PackageManager) loadAllPackageInfo() ([]*PackageInfo, error) { files, err := filepath.Glob(filepath.Join(pm.packageCachePath, "installed_*.json")) if err != nil { return nil, err } var packages []*PackageInfo for _, file := range files { data, err := os.ReadFile(file) if err != nil { pm.logger.Warn("failed to read package info file", "file", file, "error", err) continue } var info PackageInfo if err := json.Unmarshal(data, &info); err != nil { pm.logger.Warn("failed to parse package info file", "file", file, "error", err) continue } packages = append(packages, &info) } return packages, nil } // GetDefaultPackageConfig returns default package management configuration func GetDefaultPackageConfig() *PackageConfig { return &PackageConfig{ TrustedChannels: []string{ "conda-forge", "defaults", "pytorch", "nvidia", }, AllowedPackages: make(map[string]bool), // Empty means all packages allowed BlockedPackages: append([]string{}, defaultBlockedPackages...), RequireApproval: false, AutoApproveSafe: true, MaxPackages: 100, InstallTimeout: 5 * time.Minute, AllowCondaForge: true, AllowPyPI: false, AllowLocal: false, } }