fetch_ml/internal/jupyter/package_manager.go

465 lines
13 KiB
Go

package jupyter
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
const (
statusPending = "pending"
)
// PackageManager manages package installations in Jupyter workspaces
type PackageManager struct {
logger *logging.Logger
trustedChannels []string
allowedPackages map[string]bool
blockedPackages []string
workspacePath string
packageCachePath string
}
// PackageConfig defines package management configuration
type PackageConfig struct {
TrustedChannels []string `json:"trusted_channels"`
AllowedPackages map[string]bool `json:"allowed_packages"`
BlockedPackages []string `json:"blocked_packages"`
RequireApproval bool `json:"require_approval"`
AutoApproveSafe bool `json:"auto_approve_safe"`
MaxPackages int `json:"max_packages"`
InstallTimeout time.Duration `json:"install_timeout"`
AllowCondaForge bool `json:"allow_conda_forge"`
AllowPyPI bool `json:"allow_pypi"`
AllowLocal bool `json:"allow_local"`
}
// PackageRequest represents a package installation request
type PackageRequest struct {
PackageName string `json:"package_name"`
Version string `json:"version,omitempty"`
Channel string `json:"channel,omitempty"`
RequestedBy string `json:"requested_by"`
WorkspacePath string `json:"workspace_path"`
Timestamp time.Time `json:"timestamp"`
Status string `json:"status"` // pending, approved, rejected, installed, failed
RejectionReason string `json:"rejection_reason,omitempty"`
ApprovalUser string `json:"approval_user,omitempty"`
ApprovalTime time.Time `json:"approval_time,omitempty"`
}
// PackageInfo contains information about an installed package
type PackageInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Channel string `json:"channel"`
InstalledAt time.Time `json:"installed_at"`
InstalledBy string `json:"installed_by"`
Size string `json:"size"`
Dependencies []string `json:"dependencies"`
Metadata map[string]string `json:"metadata"`
}
// NewPackageManager creates a new package manager
func NewPackageManager(
logger *logging.Logger,
config *PackageConfig,
workspacePath string,
) (*PackageManager, error) {
pm := &PackageManager{
logger: logger,
trustedChannels: config.TrustedChannels,
allowedPackages: config.AllowedPackages,
blockedPackages: config.BlockedPackages,
workspacePath: workspacePath,
packageCachePath: filepath.Join(workspacePath, ".package_cache"),
}
// Create package cache directory
if err := os.MkdirAll(pm.packageCachePath, 0750); err != nil {
return nil, fmt.Errorf("failed to create package cache: %w", err)
}
// Initialize default trusted channels if none provided
if len(pm.trustedChannels) == 0 {
pm.trustedChannels = []string{
"conda-forge",
"defaults",
"pytorch",
"nvidia",
}
}
return pm, nil
}
// ValidatePackageRequest validates a package installation request
func (pm *PackageManager) ValidatePackageRequest(req *PackageRequest) error {
// Check if package is blocked
for _, blocked := range pm.blockedPackages {
if strings.EqualFold(req.PackageName, blocked) {
return fmt.Errorf("package '%s' is blocked for security reasons", req.PackageName)
}
}
// Check if channel is trusted
if req.Channel != "" {
if !pm.isChannelTrusted(req.Channel) {
return fmt.Errorf(
"channel '%s' is not trusted. Allowed channels: %v",
req.Channel,
pm.trustedChannels,
)
}
} else {
// Default to conda-forge if no channel specified
req.Channel = "conda-forge"
}
// Check package against allowlist if configured
if len(pm.allowedPackages) > 0 {
if !pm.allowedPackages[req.PackageName] {
return fmt.Errorf("package '%s' is not in the allowlist", req.PackageName)
}
}
// Validate package name format
if !pm.isValidPackageName(req.PackageName) {
return fmt.Errorf("invalid package name format: '%s'", req.PackageName)
}
return nil
}
// isChannelTrusted checks if a channel is in the trusted list
func (pm *PackageManager) isChannelTrusted(channel string) bool {
for _, trusted := range pm.trustedChannels {
if strings.EqualFold(channel, trusted) {
return true
}
}
return false
}
func (pm *PackageManager) isValidPackageName(name string) bool {
if name == "" {
return false
}
for _, c := range name {
if ('a' > c || c > 'z') &&
('A' > c || c > 'Z') &&
('0' > c || c > '9') &&
c != '-' &&
c != '_' &&
c != '.' {
return false
}
}
return true
}
// RequestPackage creates a package installation request
func (pm *PackageManager) RequestPackage(
packageName,
version,
channel,
requestedBy string,
) (*PackageRequest, error) {
req := &PackageRequest{
PackageName: strings.ToLower(strings.TrimSpace(packageName)),
Version: version,
Channel: channel,
RequestedBy: requestedBy,
WorkspacePath: pm.workspacePath,
Timestamp: time.Now(),
Status: statusPending,
}
// Validate the request
if err := pm.ValidatePackageRequest(req); err != nil {
req.Status = "rejected"
req.RejectionReason = err.Error()
return req, err
}
// Save request to cache
if err := pm.savePackageRequest(req); err != nil {
return nil, fmt.Errorf("failed to save package request: %w", err)
}
pm.logger.Info("package installation request created",
"package", req.PackageName,
"version", req.Version,
"channel", req.Channel,
"requested_by", req.RequestedBy)
return req, nil
}
// ApprovePackageRequest approves a pending package request
func (pm *PackageManager) ApprovePackageRequest(requestID, approvalUser string) error {
req, err := pm.loadPackageRequest(requestID)
if err != nil {
return fmt.Errorf("failed to load package request: %w", err)
}
if req.Status != statusPending {
return fmt.Errorf("package request is not pending (current status: %s)", req.Status)
}
req.Status = "approved"
req.ApprovalUser = approvalUser
req.ApprovalTime = time.Now()
// Save updated request
if err := pm.savePackageRequest(req); err != nil {
return fmt.Errorf("failed to save approved request: %w", err)
}
pm.logger.Info("package request approved",
"package", req.PackageName,
"request_id", requestID,
"approved_by", approvalUser)
return nil
}
// RejectPackageRequest rejects a pending package request
func (pm *PackageManager) RejectPackageRequest(requestID, reason string) error {
req, err := pm.loadPackageRequest(requestID)
if err != nil {
return fmt.Errorf("failed to load package request: %w", err)
}
if req.Status != statusPending {
return fmt.Errorf("package request is not pending (current status: %s)", req.Status)
}
req.Status = "rejected"
req.RejectionReason = reason
// Save updated request
if err := pm.savePackageRequest(req); err != nil {
return fmt.Errorf("failed to save rejected request: %w", err)
}
pm.logger.Info("package request rejected",
"package", req.PackageName,
"request_id", requestID,
"reason", reason)
return nil
}
// InstallPackage installs an approved package
func (pm *PackageManager) InstallPackage(requestID string) error {
req, err := pm.loadPackageRequest(requestID)
if err != nil {
return fmt.Errorf("failed to load package request: %w", err)
}
if req.Status != "approved" {
return fmt.Errorf("package request is not approved (current status: %s)", req.Status)
}
// Install package using conda
installCmd := pm.buildInstallCommand(req)
pm.logger.Info("installing package",
"package", req.PackageName,
"version", req.Version,
"channel", req.Channel,
"command", installCmd)
// Execute installation (this would be implemented with proper process execution)
// For now, simulate successful installation
req.Status = "installed"
// Save package info
packageInfo := &PackageInfo{
Name: req.PackageName,
Version: req.Version,
Channel: req.Channel,
InstalledAt: time.Now(),
InstalledBy: req.RequestedBy,
}
if err := pm.savePackageInfo(packageInfo); err != nil {
pm.logger.Warn("failed to save package info", "error", err)
}
// Save updated request
if err := pm.savePackageRequest(req); err != nil {
return fmt.Errorf("failed to save installed request: %w", err)
}
pm.logger.Info("package installed successfully",
"package", req.PackageName,
"version", req.Version)
return nil
}
// buildInstallCommand builds the conda install command
func (pm *PackageManager) buildInstallCommand(req *PackageRequest) string {
cmd := []string{"conda", "install", "-y"}
// Add channel
if req.Channel != "" {
cmd = append(cmd, "-c", req.Channel)
}
// Add package with version
if req.Version != "" {
cmd = append(cmd, fmt.Sprintf("%s=%s", req.PackageName, req.Version))
} else {
cmd = append(cmd, req.PackageName)
}
return strings.Join(cmd, " ")
}
// ListPendingRequests returns all pending package requests
func (pm *PackageManager) ListPendingRequests() ([]*PackageRequest, error) {
requests, err := pm.loadAllPackageRequests()
if err != nil {
return nil, err
}
var pending []*PackageRequest
for _, req := range requests {
if req.Status == statusPending {
pending = append(pending, req)
}
}
return pending, nil
}
// ListInstalledPackages returns all installed packages in the workspace
func (pm *PackageManager) ListInstalledPackages() ([]*PackageInfo, error) {
return pm.loadAllPackageInfo()
}
// GetPackageRequest retrieves a specific package request
func (pm *PackageManager) GetPackageRequest(requestID string) (*PackageRequest, error) {
return pm.loadPackageRequest(requestID)
}
// savePackageRequest saves a package request to cache
func (pm *PackageManager) savePackageRequest(req *PackageRequest) error {
requestFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("request_%s.json", req.PackageName))
data, err := json.MarshalIndent(req, "", " ")
if err != nil {
return err
}
return os.WriteFile(requestFile, data, 0600)
}
// loadPackageRequest loads a package request from cache
func (pm *PackageManager) loadPackageRequest(requestID string) (*PackageRequest, error) {
requestFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("request_%s.json", requestID))
data, err := os.ReadFile(requestFile)
if err != nil {
return nil, err
}
var req PackageRequest
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
return &req, nil
}
// loadAllPackageRequests loads all package requests from cache
func (pm *PackageManager) loadAllPackageRequests() ([]*PackageRequest, error) {
files, err := filepath.Glob(filepath.Join(pm.packageCachePath, "request_*.json"))
if err != nil {
return nil, err
}
var requests []*PackageRequest
for _, file := range files {
data, err := os.ReadFile(file)
if err != nil {
pm.logger.Warn("failed to read request file", "file", file, "error", err)
continue
}
var req PackageRequest
if err := json.Unmarshal(data, &req); err != nil {
pm.logger.Warn("failed to parse request file", "file", file, "error", err)
continue
}
requests = append(requests, &req)
}
return requests, nil
}
// savePackageInfo saves package information to cache
func (pm *PackageManager) savePackageInfo(info *PackageInfo) error {
infoFile := filepath.Join(pm.packageCachePath, fmt.Sprintf("installed_%s.json", info.Name))
data, err := json.MarshalIndent(info, "", " ")
if err != nil {
return err
}
return os.WriteFile(infoFile, data, 0600)
}
// loadAllPackageInfo loads all installed package information
func (pm *PackageManager) loadAllPackageInfo() ([]*PackageInfo, error) {
files, err := filepath.Glob(filepath.Join(pm.packageCachePath, "installed_*.json"))
if err != nil {
return nil, err
}
var packages []*PackageInfo
for _, file := range files {
data, err := os.ReadFile(file)
if err != nil {
pm.logger.Warn("failed to read package info file", "file", file, "error", err)
continue
}
var info PackageInfo
if err := json.Unmarshal(data, &info); err != nil {
pm.logger.Warn("failed to parse package info file", "file", file, "error", err)
continue
}
packages = append(packages, &info)
}
return packages, nil
}
// GetDefaultPackageConfig returns default package management configuration
func GetDefaultPackageConfig() *PackageConfig {
return &PackageConfig{
TrustedChannels: []string{
"conda-forge",
"defaults",
"pytorch",
"nvidia",
},
AllowedPackages: make(map[string]bool), // Empty means all packages allowed
BlockedPackages: append([]string{}, defaultBlockedPackages...),
RequireApproval: false,
AutoApproveSafe: true,
MaxPackages: 100,
InstallTimeout: 5 * time.Minute,
AllowCondaForge: true,
AllowPyPI: false,
AllowLocal: false,
}
}