package scheduler import ( "fmt" "sync" ) // PluginQuotaConfig defines GPU limits for plugins. type PluginQuotaConfig struct { Enabled bool // Master switch for quota enforcement TotalGPUs int // Global GPU limit across all plugins PerUserGPUs int // Default per-user GPU limit PerUserServices int // Default per-user service count limit PerPluginLimits map[string]PluginLimit // Plugin-specific overrides UserOverrides map[string]UserLimit // Per-user overrides } // PluginLimit defines limits for a specific plugin. type PluginLimit struct { MaxGPUs int MaxServices int } // UserLimit defines per-user override limits. type UserLimit struct { MaxGPUs int MaxServices int AllowedPlugins []string // Empty = all plugins allowed } // PluginUsage tracks GPU and service count for a user-plugin combination. type PluginUsage struct { GPUs int Services int } // PluginQuotaManager tracks active usage and enforces quotas. type PluginQuotaManager struct { config PluginQuotaConfig mu sync.RWMutex usage map[string]map[string]PluginUsage // userID -> pluginName -> usage pluginTotal map[string]int // pluginName -> total GPUs in use totalGPUs int // global total GPUs in use } // NewPluginQuotaManager creates a new quota manager with the given configuration. func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager { return &PluginQuotaManager{ config: config, usage: make(map[string]map[string]PluginUsage), pluginTotal: make(map[string]int), totalGPUs: 0, } } // CheckQuota validates if a job can be submitted without exceeding limits. // Returns nil if the job is allowed, or an error describing which limit would be exceeded. func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error { if !m.config.Enabled { return nil } if userID == "" { userID = "anonymous" } if pluginName == "" { pluginName = "default" } m.mu.RLock() defer m.mu.RUnlock() // Get user limits (with overrides) userLimit := m.getUserLimit(userID) // Check if user is allowed to use this plugin if len(userLimit.AllowedPlugins) > 0 { found := false for _, p := range userLimit.AllowedPlugins { if p == pluginName { found = true break } } if !found { return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName) } } // Check plugin-specific limits pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName] if hasPluginLimit { if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs { return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total", pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs) } if pluginLimit.MaxServices > 0 { // Services limit is across all users for this plugin totalServices := 0 for _, u := range m.usage { if p, ok := u[pluginName]; ok { totalServices += p.Services } } if totalServices+1 > pluginLimit.MaxServices { return fmt.Errorf("plugin %s service limit exceeded", pluginName) } } } // Check per-user limits effectiveUserGPUs := userLimit.MaxGPUs if effectiveUserGPUs == 0 { effectiveUserGPUs = m.config.PerUserGPUs } effectiveUserServices := userLimit.MaxServices if effectiveUserServices == 0 { effectiveUserServices = m.config.PerUserServices } // Calculate total user usage across all plugins totalUserGPUs := 0 totalUserServices := 0 for _, p := range m.usage[userID] { totalUserGPUs += p.GPUs totalUserServices += p.Services } if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs { return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total", userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs) } if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices { return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed", userID, totalUserServices+1, effectiveUserServices) } // Check global total GPU limit if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs { return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total", gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs) } return nil } // RecordUsage increments usage counters when a job starts. func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) { if !m.config.Enabled { return } if userID == "" { userID = "anonymous" } if pluginName == "" { pluginName = "default" } m.mu.Lock() defer m.mu.Unlock() userPlugins, ok := m.usage[userID] if !ok { userPlugins = make(map[string]PluginUsage) m.usage[userID] = userPlugins } usage := userPlugins[pluginName] usage.GPUs += gpuCount usage.Services++ userPlugins[pluginName] = usage m.pluginTotal[pluginName] += gpuCount m.totalGPUs += gpuCount } // ReleaseUsage decrements usage counters when a job stops. func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) { if !m.config.Enabled { return } if userID == "" { userID = "anonymous" } if pluginName == "" { pluginName = "default" } m.mu.Lock() defer m.mu.Unlock() userPlugins, ok := m.usage[userID] if !ok { return } usage := userPlugins[pluginName] usage.GPUs -= gpuCount usage.Services-- if usage.GPUs < 0 { usage.GPUs = 0 } if usage.Services < 0 { usage.Services = 0 } if usage.GPUs == 0 && usage.Services == 0 { delete(userPlugins, pluginName) } else { userPlugins[pluginName] = usage } if len(userPlugins) == 0 { delete(m.usage, userID) } m.pluginTotal[pluginName] -= gpuCount if m.pluginTotal[pluginName] < 0 { m.pluginTotal[pluginName] = 0 } m.totalGPUs -= gpuCount if m.totalGPUs < 0 { m.totalGPUs = 0 } } // GetUsage returns current usage for a user across all plugins. func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) { m.mu.RLock() defer m.mu.RUnlock() if userID == "" { userID = "anonymous" } result := make(map[string]PluginUsage) totalGPUs := 0 if userPlugins, ok := m.usage[userID]; ok { for plugin, usage := range userPlugins { result[plugin] = usage totalGPUs += usage.GPUs } } return result, totalGPUs } // GetGlobalUsage returns global GPU usage across all users and plugins. func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) { m.mu.RLock() defer m.mu.RUnlock() pluginTotals := make(map[string]int, len(m.pluginTotal)) for k, v := range m.pluginTotal { pluginTotals[k] = v } return m.totalGPUs, pluginTotals } // getUserLimit returns the effective limits for a user, applying overrides. func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit { if override, ok := m.config.UserOverrides[userID]; ok { return override } return UserLimit{ MaxGPUs: m.config.PerUserGPUs, MaxServices: m.config.PerUserServices, } } // getUsageLocked returns the current usage for a user-plugin combination. // Must be called with read lock held. func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage { if userPlugins, ok := m.usage[userID]; ok { if usage, ok := userPlugins[pluginName]; ok { return usage } } return PluginUsage{} }