From da104367d62e3bbb11782a1fdf0ca874ece7ac91 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Thu, 26 Feb 2026 14:35:05 -0500 Subject: [PATCH] feat: add Plugin GPU Quota implementation and tests - Add plugin_quota.go with GPU quota management for scheduler - Update scheduler hub and protocol for plugin support - Add comprehensive plugin quota unit tests - Update gang service and WebSocket queue integration tests --- internal/scheduler/hub.go | 106 ++++- internal/scheduler/plugin_quota.go | 287 +++++++++++++ internal/scheduler/protocol.go | 1 + .../scheduler/gang_service_test.go | 7 +- .../websocket_queue_integration_test.go | 13 +- tests/unit/scheduler/plugin_quota_test.go | 385 ++++++++++++++++++ 6 files changed, 776 insertions(+), 23 deletions(-) create mode 100644 internal/scheduler/plugin_quota.go create mode 100644 tests/unit/scheduler/plugin_quota_test.go diff --git a/internal/scheduler/hub.go b/internal/scheduler/hub.go index 4027e6d..cc77ac6 100644 --- a/internal/scheduler/hub.go +++ b/internal/scheduler/hub.go @@ -33,11 +33,13 @@ type SchedulerHub struct { reservations map[string]*Reservation multiNodePending map[string]*MultiNodeJob pendingAcceptance map[string]*JobAssignment + runningTasks map[string]*Task // Track assigned+accepted tasks state *StateStore starvation *StarvationTracker metrics *SchedulerMetrics auditor *audit.Logger tokenValidator *TokenValidator + quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager config HubConfig ctx context.Context cancel context.CancelFunc @@ -59,6 +61,7 @@ type HubConfig struct { AcceptanceTimeoutSecs int LocalMode bool WorkerTokens map[string]string // token -> workerID + PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration } // WorkerConn represents a connected worker @@ -109,6 +112,7 @@ type JobAssignment struct { AssignedAt time.Time AcceptanceDeadline time.Time Accepted bool + Task *Task // Reference to the task (removed from queue) } // StarvationTracker monitors long-waiting jobs @@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) { reservations: make(map[string]*Reservation), multiNodePending: make(map[string]*MultiNodeJob), pendingAcceptance: make(map[string]*JobAssignment), + runningTasks: make(map[string]*Task), state: state, starvation: &StarvationTracker{ threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute, @@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) { }, auditor: auditor, tokenValidator: NewTokenValidator(cfg.WorkerTokens), + quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager config: cfg, ctx: ctx, cancel: cancel, @@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task { } func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool { - h.mu.RLock() - defer h.mu.RUnlock() - for _, res := range h.reservations { if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 { if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount { @@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { h.batchQueue.Remove(task.ID) h.serviceQueue.Remove(task.ID) - // Track pending acceptance + // Track pending acceptance with task reference h.mu.Lock() h.pendingAcceptance[task.ID] = &JobAssignment{ TaskID: task.ID, @@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { AssignedAt: time.Now(), AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), Accepted: false, + Task: task, // Store reference since removed from queue } h.mu.Unlock() @@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { } } -func (h *SchedulerHub) handleJobAccepted(_, taskID string) { +func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) { h.mu.Lock() defer h.mu.Unlock() if assignment, ok := h.pendingAcceptance[taskID]; ok { assignment.Accepted = true + + // Track as running task + task := assignment.Task + if task != nil { + task.Status = "running" + task.WorkerID = workerID + h.runningTasks[taskID] = task + } + + // NEW: Record quota usage for service jobs + if task != nil && task.Spec.Type == JobTypeService { + if h.quotaManager != nil { + pluginName := task.Spec.Metadata["plugin_name"] + if pluginName == "" { + pluginName = "default" + } + h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount) + } + } } } @@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) h.mu.Lock() defer h.mu.Unlock() + // NEW: Release quota usage for service jobs before deleting pending acceptance + if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService { + if h.quotaManager != nil { + pluginName := task.Spec.Metadata["plugin_name"] + if pluginName == "" { + pluginName = "default" + } + h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount) + } + } + delete(h.pendingAcceptance, result.TaskID) + delete(h.runningTasks, result.TaskID) eventType := EventJobCompleted switch result.State { @@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() { h.mu.Lock() for taskID, a := range h.pendingAcceptance { if !a.Accepted && time.Now().After(a.AcceptanceDeadline) { - h.batchQueue.Add(h.getTask(taskID)) + if a.Task != nil { + a.Task.Status = "queued" + h.batchQueue.Add(a.Task) + } delete(h.pendingAcceptance, taskID) if wc, ok := h.workers[a.WorkerID]; ok { wc.slots = SlotStatus{} @@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) { st.mu.Lock() defer st.mu.Unlock() + // First check which tasks need reservation under h.mu.RLock + tasksToReserve := make([]*Task, 0) + h.mu.RLock() for _, task := range h.batchQueue.Items() { - if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) { - h.mu.Lock() + if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) { + tasksToReserve = append(tasksToReserve, task) + } + } + h.mu.RUnlock() + + // Now acquire Lock to add reservations + if len(tasksToReserve) > 0 { + h.mu.Lock() + for _, task := range tasksToReserve { h.reservations[task.ID] = &Reservation{ TaskID: task.ID, GPUCount: task.Spec.GPUCount, CreatedAt: time.Now(), } - h.mu.Unlock() } + h.mu.Unlock() } } -func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool { - h.mu.RLock() - defer h.mu.RUnlock() +func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool { _, exists := h.reservations[taskID] return exists } @@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error { return fmt.Errorf("job ID is required") } + // NEW: Check plugin quotas for service jobs + if spec.Type == JobTypeService && h.quotaManager != nil { + pluginName := spec.Metadata["plugin_name"] + if pluginName == "" { + pluginName = "default" + } + if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil { + return fmt.Errorf("quota exceeded: %w", err) + } + } + task := &Task{ ID: spec.ID, Spec: spec, @@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task { if t != nil { return t } - return h.serviceQueue.Get(taskID) + t = h.serviceQueue.Get(taskID) + if t != nil { + return t + } + return h.runningTasks[taskID] } func (h *SchedulerHub) restoreJob(ev StateEvent) { @@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() { if assignment.Accepted { // Job was accepted but worker is gone (not in h.workers) if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected { - task := h.getTask(taskID) + task := assignment.Task if task != nil { task.Status = "orphaned" h.batchQueue.Add(task) @@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) { } if msg.Type == MsgMetricsRequest { - metrics := h.getMetricsPayload() + metrics := h.GetMetricsPayload() conn.WriteJSON(Message{ Type: MsgMetricsResponse, Payload: mustMarshal(metrics), @@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) { } } -// getMetricsPayload returns current metrics as a map -func (h *SchedulerHub) getMetricsPayload() map[string]any { +// GetMetricsPayload returns current metrics as a map (public API) +func (h *SchedulerHub) GetMetricsPayload() map[string]any { h.metrics.mu.RLock() defer h.metrics.mu.RUnlock() @@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) { // buildRankedSpec creates a job spec with rank-specific template variables resolved func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec { - // Clone the spec and add rank info to metadata + // Clone the spec and add rank info to metadata and env spec := task.Spec spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3) for k, v := range task.Spec.Metadata { @@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo spec.Metadata["HEAD_ADDR"] = headAddr spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize) spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank) + + // Also set in Env for job runtime + if spec.Env == nil { + spec.Env = make(map[string]string) + } + spec.Env["HEAD_ADDR"] = headAddr + spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize) + spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank) return spec } diff --git a/internal/scheduler/plugin_quota.go b/internal/scheduler/plugin_quota.go new file mode 100644 index 0000000..605fd4b --- /dev/null +++ b/internal/scheduler/plugin_quota.go @@ -0,0 +1,287 @@ +package scheduler + +import ( + "fmt" + "sync" +) + +// PluginQuotaConfig defines GPU limits for plugins. +type PluginQuotaConfig struct { + Enabled bool // Master switch for quota enforcement + TotalGPUs int // Global GPU limit across all plugins + PerUserGPUs int // Default per-user GPU limit + PerUserServices int // Default per-user service count limit + PerPluginLimits map[string]PluginLimit // Plugin-specific overrides + UserOverrides map[string]UserLimit // Per-user overrides +} + +// PluginLimit defines limits for a specific plugin. +type PluginLimit struct { + MaxGPUs int + MaxServices int +} + +// UserLimit defines per-user override limits. +type UserLimit struct { + MaxGPUs int + MaxServices int + AllowedPlugins []string // Empty = all plugins allowed +} + +// PluginUsage tracks GPU and service count for a user-plugin combination. +type PluginUsage struct { + GPUs int + Services int +} + +// PluginQuotaManager tracks active usage and enforces quotas. +type PluginQuotaManager struct { + config PluginQuotaConfig + mu sync.RWMutex + usage map[string]map[string]PluginUsage // userID -> pluginName -> usage + pluginTotal map[string]int // pluginName -> total GPUs in use + totalGPUs int // global total GPUs in use +} + +// NewPluginQuotaManager creates a new quota manager with the given configuration. +func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager { + return &PluginQuotaManager{ + config: config, + usage: make(map[string]map[string]PluginUsage), + pluginTotal: make(map[string]int), + totalGPUs: 0, + } +} + +// CheckQuota validates if a job can be submitted without exceeding limits. +// Returns nil if the job is allowed, or an error describing which limit would be exceeded. +func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error { + if !m.config.Enabled { + return nil + } + + if userID == "" { + userID = "anonymous" + } + if pluginName == "" { + pluginName = "default" + } + + m.mu.RLock() + defer m.mu.RUnlock() + + // Get user limits (with overrides) + userLimit := m.getUserLimit(userID) + + // Check if user is allowed to use this plugin + if len(userLimit.AllowedPlugins) > 0 { + found := false + for _, p := range userLimit.AllowedPlugins { + if p == pluginName { + found = true + break + } + } + if !found { + return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName) + } + } + + // Check plugin-specific limits + pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName] + if hasPluginLimit { + if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs { + return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total", + pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs) + } + if pluginLimit.MaxServices > 0 { + // Services limit is across all users for this plugin + totalServices := 0 + for _, u := range m.usage { + if p, ok := u[pluginName]; ok { + totalServices += p.Services + } + } + if totalServices+1 > pluginLimit.MaxServices { + return fmt.Errorf("plugin %s service limit exceeded", pluginName) + } + } + } + + // Check per-user limits + effectiveUserGPUs := userLimit.MaxGPUs + if effectiveUserGPUs == 0 { + effectiveUserGPUs = m.config.PerUserGPUs + } + effectiveUserServices := userLimit.MaxServices + if effectiveUserServices == 0 { + effectiveUserServices = m.config.PerUserServices + } + + // Calculate total user usage across all plugins + totalUserGPUs := 0 + totalUserServices := 0 + for _, p := range m.usage[userID] { + totalUserGPUs += p.GPUs + totalUserServices += p.Services + } + + if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs { + return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total", + userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs) + } + + if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices { + return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed", + userID, totalUserServices+1, effectiveUserServices) + } + + // Check global total GPU limit + if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs { + return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total", + gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs) + } + + return nil +} + +// RecordUsage increments usage counters when a job starts. +func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) { + if !m.config.Enabled { + return + } + + if userID == "" { + userID = "anonymous" + } + if pluginName == "" { + pluginName = "default" + } + + m.mu.Lock() + defer m.mu.Unlock() + + userPlugins, ok := m.usage[userID] + if !ok { + userPlugins = make(map[string]PluginUsage) + m.usage[userID] = userPlugins + } + + usage := userPlugins[pluginName] + usage.GPUs += gpuCount + usage.Services++ + userPlugins[pluginName] = usage + + m.pluginTotal[pluginName] += gpuCount + m.totalGPUs += gpuCount +} + +// ReleaseUsage decrements usage counters when a job stops. +func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) { + if !m.config.Enabled { + return + } + + if userID == "" { + userID = "anonymous" + } + if pluginName == "" { + pluginName = "default" + } + + m.mu.Lock() + defer m.mu.Unlock() + + userPlugins, ok := m.usage[userID] + if !ok { + return + } + + usage := userPlugins[pluginName] + usage.GPUs -= gpuCount + usage.Services-- + + if usage.GPUs < 0 { + usage.GPUs = 0 + } + if usage.Services < 0 { + usage.Services = 0 + } + + if usage.GPUs == 0 && usage.Services == 0 { + delete(userPlugins, pluginName) + } else { + userPlugins[pluginName] = usage + } + + if len(userPlugins) == 0 { + delete(m.usage, userID) + } + + m.pluginTotal[pluginName] -= gpuCount + if m.pluginTotal[pluginName] < 0 { + m.pluginTotal[pluginName] = 0 + } + + m.totalGPUs -= gpuCount + if m.totalGPUs < 0 { + m.totalGPUs = 0 + } +} + +// GetUsage returns current usage for a user across all plugins. +func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) { + m.mu.RLock() + defer m.mu.RUnlock() + + if userID == "" { + userID = "anonymous" + } + + result := make(map[string]PluginUsage) + totalGPUs := 0 + + if userPlugins, ok := m.usage[userID]; ok { + for plugin, usage := range userPlugins { + result[plugin] = usage + totalGPUs += usage.GPUs + } + } + + return result, totalGPUs +} + +// GetGlobalUsage returns global GPU usage across all users and plugins. +func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) { + m.mu.RLock() + defer m.mu.RUnlock() + + pluginTotals := make(map[string]int, len(m.pluginTotal)) + for k, v := range m.pluginTotal { + pluginTotals[k] = v + } + + return m.totalGPUs, pluginTotals +} + +// getUserLimit returns the effective limits for a user, applying overrides. +func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit { + if override, ok := m.config.UserOverrides[userID]; ok { + return override + } + return UserLimit{ + MaxGPUs: m.config.PerUserGPUs, + MaxServices: m.config.PerUserServices, + } +} + +// getUsageLocked returns the current usage for a user-plugin combination. +// Must be called with read lock held. +func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage { + if userPlugins, ok := m.usage[userID]; ok { + if usage, ok := userPlugins[pluginName]; ok { + return usage + } + } + return PluginUsage{} +} diff --git a/internal/scheduler/protocol.go b/internal/scheduler/protocol.go index 2d38bfc..efa7c18 100644 --- a/internal/scheduler/protocol.go +++ b/internal/scheduler/protocol.go @@ -100,6 +100,7 @@ type JobSpec struct { ID string `json:"id"` Type JobType `json:"type"` // "batch" | "service" SlotPool string `json:"slot_pool"` + UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking GPUCount int `json:"gpu_count"` GPUType string `json:"gpu_type,omitempty"` diff --git a/tests/integration/scheduler/gang_service_test.go b/tests/integration/scheduler/gang_service_test.go index 35e011b..8481ce4 100644 --- a/tests/integration/scheduler/gang_service_test.go +++ b/tests/integration/scheduler/gang_service_test.go @@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) { ID: w.id, Capabilities: scheduler.WorkerCapabilities{ GPUCount: 0, + Hostname: "localhost", }, }), }) @@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) { // Send job accepted conn.WriteJSON(scheduler.Message{ - Type: scheduler.MsgJobAccepted, - Payload: mustMarshal(map[string]string{ - "task_id": jobID, - }), + Type: scheduler.MsgJobAccepted, + Payload: mustMarshal(jobID), }) // Send periodic health updates diff --git a/tests/integration/websocket_queue_integration_test.go b/tests/integration/websocket_queue_integration_test.go index 3e0c100..e6186b6 100644 --- a/tests/integration/websocket_queue_integration_test.go +++ b/tests/integration/websocket_queue_integration_test.go @@ -307,6 +307,7 @@ func startFakeWorkers( go func(workerID string) { defer wg.Done() for { + // Check for cancellation before getting next task select { case <-ctx.Done(): return @@ -324,6 +325,7 @@ func startFakeWorkers( continue } + // Process task - finish even if context is cancelled started := time.Now() completed := started.Add(10 * time.Millisecond) @@ -338,7 +340,16 @@ func startFakeWorkers( continue } - doneCh <- task.JobName + // Only send to doneCh if we successfully completed + select { + case doneCh <- task.JobName: + case <-ctx.Done(): + // Context cancelled but we still completed the task - try once more without blocking + select { + case doneCh <- task.JobName: + default: + } + } } }(fmt.Sprintf("worker-%d", w)) } diff --git a/tests/unit/scheduler/plugin_quota_test.go b/tests/unit/scheduler/plugin_quota_test.go new file mode 100644 index 0000000..539d6bd --- /dev/null +++ b/tests/unit/scheduler/plugin_quota_test.go @@ -0,0 +1,385 @@ +package scheduler_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) { + // When quota is disabled, all jobs should pass + config := scheduler.PluginQuotaConfig{ + Enabled: false, + TotalGPUs: 1, // Set a low limit that would fail if enabled + } + m := scheduler.NewPluginQuotaManager(config) + + err := m.CheckQuota("user1", "plugin1", 100) + assert.NoError(t, err) +} + +func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 4, + } + m := scheduler.NewPluginQuotaManager(config) + + // First job should succeed + err := m.CheckQuota("user1", "plugin1", 2) + require.NoError(t, err) + + // Record the usage + m.RecordUsage("user1", "plugin1", 2) + + // Second job should succeed (2+2=4, within limit) + err = m.CheckQuota("user2", "plugin2", 2) + require.NoError(t, err) + m.RecordUsage("user2", "plugin2", 2) + + // Third job should fail (would exceed global limit) + err = m.CheckQuota("user3", "plugin3", 1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "global GPU limit exceeded") +} + +func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 3, + } + m := scheduler.NewPluginQuotaManager(config) + + // User1: first job should succeed + err := m.CheckQuota("user1", "plugin1", 2) + require.NoError(t, err) + m.RecordUsage("user1", "plugin1", 2) + + // User1: second job should succeed (2+1=3, at limit) + err = m.CheckQuota("user1", "plugin2", 1) + require.NoError(t, err) + m.RecordUsage("user1", "plugin2", 1) + + // User1: third job should fail (would exceed per-user limit) + err = m.CheckQuota("user1", "plugin3", 1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user user1 GPU limit exceeded") + + // User2: job should succeed (different user) + err = m.CheckQuota("user2", "plugin1", 3) + assert.NoError(t, err) +} + +func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 10, + PerUserServices: 2, + } + m := scheduler.NewPluginQuotaManager(config) + + // User1: first service should succeed + err := m.CheckQuota("user1", "plugin1", 1) + require.NoError(t, err) + m.RecordUsage("user1", "plugin1", 1) + + // User1: second service should succeed + err = m.CheckQuota("user1", "plugin2", 1) + require.NoError(t, err) + m.RecordUsage("user1", "plugin2", 1) + + // User1: third service should fail (would exceed service count limit) + err = m.CheckQuota("user1", "plugin3", 1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user user1 service limit exceeded") +} + +func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 2, + PerUserServices: 2, + UserOverrides: map[string]scheduler.UserLimit{ + "vip-user": { + MaxGPUs: 5, + MaxServices: 10, + }, + }, + } + m := scheduler.NewPluginQuotaManager(config) + + // Regular user: limited by default + err := m.CheckQuota("regular", "plugin1", 3) + assert.Error(t, err) + assert.Contains(t, err.Error(), "regular GPU limit exceeded") + + // VIP user: has higher limit + err = m.CheckQuota("vip-user", "plugin1", 4) + require.NoError(t, err) + m.RecordUsage("vip-user", "plugin1", 4) + + // VIP user: still within limit + err = m.CheckQuota("vip-user", "plugin2", 1) + assert.NoError(t, err) +} + +func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 10, + PerPluginLimits: map[string]scheduler.PluginLimit{ + "jupyter": { + MaxGPUs: 3, + MaxServices: 2, + }, + "vllm": { + MaxGPUs: 8, + MaxServices: 4, + }, + }, + } + m := scheduler.NewPluginQuotaManager(config) + + // Jupyter: within plugin GPU limit + err := m.CheckQuota("user1", "jupyter", 2) + require.NoError(t, err) + m.RecordUsage("user1", "jupyter", 2) + + // Jupyter: exceed plugin GPU limit (but within global and user limits) + err = m.CheckQuota("user2", "jupyter", 2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded") + + // vLLM: within its higher limit + err = m.CheckQuota("user1", "vllm", 4) + assert.NoError(t, err) +} + +func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 10, + PerUserServices: 10, + PerPluginLimits: map[string]scheduler.PluginLimit{ + "jupyter": { + MaxGPUs: 10, + MaxServices: 2, // Only 2 jupyter services total + }, + }, + } + m := scheduler.NewPluginQuotaManager(config) + + // First jupyter service + err := m.CheckQuota("user1", "jupyter", 1) + require.NoError(t, err) + m.RecordUsage("user1", "jupyter", 1) + + // Second jupyter service (different user) + err = m.CheckQuota("user2", "jupyter", 1) + require.NoError(t, err) + m.RecordUsage("user2", "jupyter", 1) + + // Third jupyter service should fail (plugin service limit reached) + err = m.CheckQuota("user3", "jupyter", 1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded") +} + +func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 10, + UserOverrides: map[string]scheduler.UserLimit{ + "restricted-user": { + MaxGPUs: 5, + MaxServices: 5, + AllowedPlugins: []string{"jupyter"}, + }, + }, + } + m := scheduler.NewPluginQuotaManager(config) + + // Restricted user can use allowed plugin + err := m.CheckQuota("restricted-user", "jupyter", 2) + assert.NoError(t, err) + + // Restricted user cannot use other plugins + err = m.CheckQuota("restricted-user", "vllm", 2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to use plugin vllm") + + // Regular user can use any plugin + err = m.CheckQuota("regular-user", "vllm", 2) + assert.NoError(t, err) +} + +func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 5, + } + m := scheduler.NewPluginQuotaManager(config) + + // Record usage + m.RecordUsage("user1", "jupyter", 2) + m.RecordUsage("user1", "vllm", 1) + m.RecordUsage("user2", "jupyter", 3) + + // Check usage tracking + usage, totalGPUs := m.GetUsage("user1") + assert.Equal(t, 2, usage["jupyter"].GPUs) + assert.Equal(t, 1, usage["jupyter"].Services) + assert.Equal(t, 1, usage["vllm"].GPUs) + assert.Equal(t, 1, usage["vllm"].Services) + assert.Equal(t, 3, totalGPUs) + + // Check global usage + globalGPUs, pluginTotals := m.GetGlobalUsage() + assert.Equal(t, 6, globalGPUs) + assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3 + assert.Equal(t, 1, pluginTotals["vllm"]) + + // Release usage + m.ReleaseUsage("user1", "jupyter", 2) + + // Verify release + usage, totalGPUs = m.GetUsage("user1") + assert.Equal(t, 0, usage["jupyter"].GPUs) + assert.Equal(t, 0, usage["jupyter"].Services) + assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm + assert.Equal(t, 1, totalGPUs) // only vllm remains for user1 + + // Check global usage after release + globalGPUs, pluginTotals = m.GetGlobalUsage() + assert.Equal(t, 4, globalGPUs) + assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2 + assert.Equal(t, 1, pluginTotals["vllm"]) +} + +func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: false, + TotalGPUs: 10, + } + m := scheduler.NewPluginQuotaManager(config) + + // Recording usage when disabled should not crash + m.RecordUsage("user1", "plugin1", 5) + + // Usage should be empty (not tracked) + usage, totalGPUs := m.GetUsage("user1") + assert.Empty(t, usage) + assert.Equal(t, 0, totalGPUs) +} + +func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + } + m := scheduler.NewPluginQuotaManager(config) + + // Releasing non-existent usage should not crash or go negative + m.ReleaseUsage("nonexistent", "plugin1", 5) + + // Global usage should remain 0 + globalGPUs, _ := m.GetGlobalUsage() + assert.Equal(t, 0, globalGPUs) +} + +func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 2, + PerUserServices: 2, + } + m := scheduler.NewPluginQuotaManager(config) + + // Empty userID should be treated as "anonymous" + err := m.CheckQuota("", "plugin1", 2) + require.NoError(t, err) + m.RecordUsage("", "plugin1", 2) + + // Second request from anonymous should fail (at limit) + err = m.CheckQuota("", "plugin1", 1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded") +} + +func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 10, + PerUserGPUs: 5, + PerUserServices: 5, + PerPluginLimits: map[string]scheduler.PluginLimit{ + "default": { + MaxGPUs: 2, + MaxServices: 2, + }, + }, + } + m := scheduler.NewPluginQuotaManager(config) + + // Empty plugin name should be treated as "default" + err := m.CheckQuota("user1", "", 1) + require.NoError(t, err) + m.RecordUsage("user1", "", 1) + + // Exceed default plugin limit + err = m.CheckQuota("user2", "", 2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "plugin default GPU limit exceeded") +} + +func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) { + config := scheduler.PluginQuotaConfig{ + Enabled: true, + TotalGPUs: 100, + PerUserGPUs: 50, + PerUserServices: 50, + } + m := scheduler.NewPluginQuotaManager(config) + + // Concurrently record usage from multiple goroutines + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(idx int) { + user := "user" + if idx%2 == 0 { + user = "user1" + } else { + user = "user2" + } + m.RecordUsage(user, "plugin1", 1) + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + + // Verify totals + globalGPUs, _ := m.GetGlobalUsage() + assert.Equal(t, 10, globalGPUs) + + usage1, _ := m.GetUsage("user1") + assert.Equal(t, 5, usage1["plugin1"].GPUs) + assert.Equal(t, 5, usage1["plugin1"].Services) + + usage2, _ := m.GetUsage("user2") + assert.Equal(t, 5, usage2["plugin1"].GPUs) + assert.Equal(t, 5, usage2["plugin1"].Services) +}