feat: add Plugin GPU Quota implementation and tests
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled

- Add plugin_quota.go with GPU quota management for scheduler

- Update scheduler hub and protocol for plugin support

- Add comprehensive plugin quota unit tests

- Update gang service and WebSocket queue integration tests
This commit is contained in:
Jeremie Fraeys 2026-02-26 14:35:05 -05:00
parent ef05f200ba
commit da104367d6
No known key found for this signature in database
6 changed files with 776 additions and 23 deletions

View file

@ -33,11 +33,13 @@ type SchedulerHub struct {
reservations map[string]*Reservation
multiNodePending map[string]*MultiNodeJob
pendingAcceptance map[string]*JobAssignment
runningTasks map[string]*Task // Track assigned+accepted tasks
state *StateStore
starvation *StarvationTracker
metrics *SchedulerMetrics
auditor *audit.Logger
tokenValidator *TokenValidator
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
config HubConfig
ctx context.Context
cancel context.CancelFunc
@ -59,6 +61,7 @@ type HubConfig struct {
AcceptanceTimeoutSecs int
LocalMode bool
WorkerTokens map[string]string // token -> workerID
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
}
// WorkerConn represents a connected worker
@ -109,6 +112,7 @@ type JobAssignment struct {
AssignedAt time.Time
AcceptanceDeadline time.Time
Accepted bool
Task *Task // Reference to the task (removed from queue)
}
// StarvationTracker monitors long-waiting jobs
@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
reservations: make(map[string]*Reservation),
multiNodePending: make(map[string]*MultiNodeJob),
pendingAcceptance: make(map[string]*JobAssignment),
runningTasks: make(map[string]*Task),
state: state,
starvation: &StarvationTracker{
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
},
auditor: auditor,
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
config: cfg,
ctx: ctx,
cancel: cancel,
@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
}
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
h.mu.RLock()
defer h.mu.RUnlock()
for _, res := range h.reservations {
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
h.batchQueue.Remove(task.ID)
h.serviceQueue.Remove(task.ID)
// Track pending acceptance
// Track pending acceptance with task reference
h.mu.Lock()
h.pendingAcceptance[task.ID] = &JobAssignment{
TaskID: task.ID,
@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
AssignedAt: time.Now(),
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
Accepted: false,
Task: task, // Store reference since removed from queue
}
h.mu.Unlock()
@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
}
}
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
h.mu.Lock()
defer h.mu.Unlock()
if assignment, ok := h.pendingAcceptance[taskID]; ok {
assignment.Accepted = true
// Track as running task
task := assignment.Task
if task != nil {
task.Status = "running"
task.WorkerID = workerID
h.runningTasks[taskID] = task
}
// NEW: Record quota usage for service jobs
if task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
}
}
@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
h.mu.Lock()
defer h.mu.Unlock()
// NEW: Release quota usage for service jobs before deleting pending acceptance
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
delete(h.pendingAcceptance, result.TaskID)
delete(h.runningTasks, result.TaskID)
eventType := EventJobCompleted
switch result.State {
@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() {
h.mu.Lock()
for taskID, a := range h.pendingAcceptance {
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
h.batchQueue.Add(h.getTask(taskID))
if a.Task != nil {
a.Task.Status = "queued"
h.batchQueue.Add(a.Task)
}
delete(h.pendingAcceptance, taskID)
if wc, ok := h.workers[a.WorkerID]; ok {
wc.slots = SlotStatus{}
@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
st.mu.Lock()
defer st.mu.Unlock()
// First check which tasks need reservation under h.mu.RLock
tasksToReserve := make([]*Task, 0)
h.mu.RLock()
for _, task := range h.batchQueue.Items() {
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
h.mu.Lock()
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
tasksToReserve = append(tasksToReserve, task)
}
}
h.mu.RUnlock()
// Now acquire Lock to add reservations
if len(tasksToReserve) > 0 {
h.mu.Lock()
for _, task := range tasksToReserve {
h.reservations[task.ID] = &Reservation{
TaskID: task.ID,
GPUCount: task.Spec.GPUCount,
CreatedAt: time.Now(),
}
h.mu.Unlock()
}
h.mu.Unlock()
}
}
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
_, exists := h.reservations[taskID]
return exists
}
@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
return fmt.Errorf("job ID is required")
}
// NEW: Check plugin quotas for service jobs
if spec.Type == JobTypeService && h.quotaManager != nil {
pluginName := spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
return fmt.Errorf("quota exceeded: %w", err)
}
}
task := &Task{
ID: spec.ID,
Spec: spec,
@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task {
if t != nil {
return t
}
return h.serviceQueue.Get(taskID)
t = h.serviceQueue.Get(taskID)
if t != nil {
return t
}
return h.runningTasks[taskID]
}
func (h *SchedulerHub) restoreJob(ev StateEvent) {
@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() {
if assignment.Accepted {
// Job was accepted but worker is gone (not in h.workers)
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
task := h.getTask(taskID)
task := assignment.Task
if task != nil {
task.Status = "orphaned"
h.batchQueue.Add(task)
@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
}
if msg.Type == MsgMetricsRequest {
metrics := h.getMetricsPayload()
metrics := h.GetMetricsPayload()
conn.WriteJSON(Message{
Type: MsgMetricsResponse,
Payload: mustMarshal(metrics),
@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
}
}
// getMetricsPayload returns current metrics as a map
func (h *SchedulerHub) getMetricsPayload() map[string]any {
// GetMetricsPayload returns current metrics as a map (public API)
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
h.metrics.mu.RLock()
defer h.metrics.mu.RUnlock()
@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
// buildRankedSpec creates a job spec with rank-specific template variables resolved
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
// Clone the spec and add rank info to metadata
// Clone the spec and add rank info to metadata and env
spec := task.Spec
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
for k, v := range task.Spec.Metadata {
@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo
spec.Metadata["HEAD_ADDR"] = headAddr
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
// Also set in Env for job runtime
if spec.Env == nil {
spec.Env = make(map[string]string)
}
spec.Env["HEAD_ADDR"] = headAddr
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
return spec
}

View file

@ -0,0 +1,287 @@
package scheduler
import (
"fmt"
"sync"
)
// PluginQuotaConfig defines GPU limits for plugins.
type PluginQuotaConfig struct {
Enabled bool // Master switch for quota enforcement
TotalGPUs int // Global GPU limit across all plugins
PerUserGPUs int // Default per-user GPU limit
PerUserServices int // Default per-user service count limit
PerPluginLimits map[string]PluginLimit // Plugin-specific overrides
UserOverrides map[string]UserLimit // Per-user overrides
}
// PluginLimit defines limits for a specific plugin.
type PluginLimit struct {
MaxGPUs int
MaxServices int
}
// UserLimit defines per-user override limits.
type UserLimit struct {
MaxGPUs int
MaxServices int
AllowedPlugins []string // Empty = all plugins allowed
}
// PluginUsage tracks GPU and service count for a user-plugin combination.
type PluginUsage struct {
GPUs int
Services int
}
// PluginQuotaManager tracks active usage and enforces quotas.
type PluginQuotaManager struct {
config PluginQuotaConfig
mu sync.RWMutex
usage map[string]map[string]PluginUsage // userID -> pluginName -> usage
pluginTotal map[string]int // pluginName -> total GPUs in use
totalGPUs int // global total GPUs in use
}
// NewPluginQuotaManager creates a new quota manager with the given configuration.
func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager {
return &PluginQuotaManager{
config: config,
usage: make(map[string]map[string]PluginUsage),
pluginTotal: make(map[string]int),
totalGPUs: 0,
}
}
// CheckQuota validates if a job can be submitted without exceeding limits.
// Returns nil if the job is allowed, or an error describing which limit would be exceeded.
func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error {
if !m.config.Enabled {
return nil
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.RLock()
defer m.mu.RUnlock()
// Get user limits (with overrides)
userLimit := m.getUserLimit(userID)
// Check if user is allowed to use this plugin
if len(userLimit.AllowedPlugins) > 0 {
found := false
for _, p := range userLimit.AllowedPlugins {
if p == pluginName {
found = true
break
}
}
if !found {
return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName)
}
}
// Check plugin-specific limits
pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName]
if hasPluginLimit {
if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs {
return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total",
pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs)
}
if pluginLimit.MaxServices > 0 {
// Services limit is across all users for this plugin
totalServices := 0
for _, u := range m.usage {
if p, ok := u[pluginName]; ok {
totalServices += p.Services
}
}
if totalServices+1 > pluginLimit.MaxServices {
return fmt.Errorf("plugin %s service limit exceeded", pluginName)
}
}
}
// Check per-user limits
effectiveUserGPUs := userLimit.MaxGPUs
if effectiveUserGPUs == 0 {
effectiveUserGPUs = m.config.PerUserGPUs
}
effectiveUserServices := userLimit.MaxServices
if effectiveUserServices == 0 {
effectiveUserServices = m.config.PerUserServices
}
// Calculate total user usage across all plugins
totalUserGPUs := 0
totalUserServices := 0
for _, p := range m.usage[userID] {
totalUserGPUs += p.GPUs
totalUserServices += p.Services
}
if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs {
return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total",
userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs)
}
if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices {
return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed",
userID, totalUserServices+1, effectiveUserServices)
}
// Check global total GPU limit
if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs {
return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total",
gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs)
}
return nil
}
// RecordUsage increments usage counters when a job starts.
func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) {
if !m.config.Enabled {
return
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.Lock()
defer m.mu.Unlock()
userPlugins, ok := m.usage[userID]
if !ok {
userPlugins = make(map[string]PluginUsage)
m.usage[userID] = userPlugins
}
usage := userPlugins[pluginName]
usage.GPUs += gpuCount
usage.Services++
userPlugins[pluginName] = usage
m.pluginTotal[pluginName] += gpuCount
m.totalGPUs += gpuCount
}
// ReleaseUsage decrements usage counters when a job stops.
func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) {
if !m.config.Enabled {
return
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.Lock()
defer m.mu.Unlock()
userPlugins, ok := m.usage[userID]
if !ok {
return
}
usage := userPlugins[pluginName]
usage.GPUs -= gpuCount
usage.Services--
if usage.GPUs < 0 {
usage.GPUs = 0
}
if usage.Services < 0 {
usage.Services = 0
}
if usage.GPUs == 0 && usage.Services == 0 {
delete(userPlugins, pluginName)
} else {
userPlugins[pluginName] = usage
}
if len(userPlugins) == 0 {
delete(m.usage, userID)
}
m.pluginTotal[pluginName] -= gpuCount
if m.pluginTotal[pluginName] < 0 {
m.pluginTotal[pluginName] = 0
}
m.totalGPUs -= gpuCount
if m.totalGPUs < 0 {
m.totalGPUs = 0
}
}
// GetUsage returns current usage for a user across all plugins.
func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) {
m.mu.RLock()
defer m.mu.RUnlock()
if userID == "" {
userID = "anonymous"
}
result := make(map[string]PluginUsage)
totalGPUs := 0
if userPlugins, ok := m.usage[userID]; ok {
for plugin, usage := range userPlugins {
result[plugin] = usage
totalGPUs += usage.GPUs
}
}
return result, totalGPUs
}
// GetGlobalUsage returns global GPU usage across all users and plugins.
func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) {
m.mu.RLock()
defer m.mu.RUnlock()
pluginTotals := make(map[string]int, len(m.pluginTotal))
for k, v := range m.pluginTotal {
pluginTotals[k] = v
}
return m.totalGPUs, pluginTotals
}
// getUserLimit returns the effective limits for a user, applying overrides.
func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit {
if override, ok := m.config.UserOverrides[userID]; ok {
return override
}
return UserLimit{
MaxGPUs: m.config.PerUserGPUs,
MaxServices: m.config.PerUserServices,
}
}
// getUsageLocked returns the current usage for a user-plugin combination.
// Must be called with read lock held.
func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage {
if userPlugins, ok := m.usage[userID]; ok {
if usage, ok := userPlugins[pluginName]; ok {
return usage
}
}
return PluginUsage{}
}

View file

@ -100,6 +100,7 @@ type JobSpec struct {
ID string `json:"id"`
Type JobType `json:"type"` // "batch" | "service"
SlotPool string `json:"slot_pool"`
UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`

View file

@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) {
ID: w.id,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 0,
Hostname: "localhost",
},
}),
})
@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) {
// Send job accepted
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgJobAccepted,
Payload: mustMarshal(map[string]string{
"task_id": jobID,
}),
Type: scheduler.MsgJobAccepted,
Payload: mustMarshal(jobID),
})
// Send periodic health updates

View file

@ -307,6 +307,7 @@ func startFakeWorkers(
go func(workerID string) {
defer wg.Done()
for {
// Check for cancellation before getting next task
select {
case <-ctx.Done():
return
@ -324,6 +325,7 @@ func startFakeWorkers(
continue
}
// Process task - finish even if context is cancelled
started := time.Now()
completed := started.Add(10 * time.Millisecond)
@ -338,7 +340,16 @@ func startFakeWorkers(
continue
}
doneCh <- task.JobName
// Only send to doneCh if we successfully completed
select {
case doneCh <- task.JobName:
case <-ctx.Done():
// Context cancelled but we still completed the task - try once more without blocking
select {
case doneCh <- task.JobName:
default:
}
}
}
}(fmt.Sprintf("worker-%d", w))
}

View file

@ -0,0 +1,385 @@
package scheduler_test
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) {
// When quota is disabled, all jobs should pass
config := scheduler.PluginQuotaConfig{
Enabled: false,
TotalGPUs: 1, // Set a low limit that would fail if enabled
}
m := scheduler.NewPluginQuotaManager(config)
err := m.CheckQuota("user1", "plugin1", 100)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 4,
}
m := scheduler.NewPluginQuotaManager(config)
// First job should succeed
err := m.CheckQuota("user1", "plugin1", 2)
require.NoError(t, err)
// Record the usage
m.RecordUsage("user1", "plugin1", 2)
// Second job should succeed (2+2=4, within limit)
err = m.CheckQuota("user2", "plugin2", 2)
require.NoError(t, err)
m.RecordUsage("user2", "plugin2", 2)
// Third job should fail (would exceed global limit)
err = m.CheckQuota("user3", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "global GPU limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 3,
}
m := scheduler.NewPluginQuotaManager(config)
// User1: first job should succeed
err := m.CheckQuota("user1", "plugin1", 2)
require.NoError(t, err)
m.RecordUsage("user1", "plugin1", 2)
// User1: second job should succeed (2+1=3, at limit)
err = m.CheckQuota("user1", "plugin2", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin2", 1)
// User1: third job should fail (would exceed per-user limit)
err = m.CheckQuota("user1", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user user1 GPU limit exceeded")
// User2: job should succeed (different user)
err = m.CheckQuota("user2", "plugin1", 3)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerUserServices: 2,
}
m := scheduler.NewPluginQuotaManager(config)
// User1: first service should succeed
err := m.CheckQuota("user1", "plugin1", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin1", 1)
// User1: second service should succeed
err = m.CheckQuota("user1", "plugin2", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin2", 1)
// User1: third service should fail (would exceed service count limit)
err = m.CheckQuota("user1", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user user1 service limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 2,
PerUserServices: 2,
UserOverrides: map[string]scheduler.UserLimit{
"vip-user": {
MaxGPUs: 5,
MaxServices: 10,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Regular user: limited by default
err := m.CheckQuota("regular", "plugin1", 3)
assert.Error(t, err)
assert.Contains(t, err.Error(), "regular GPU limit exceeded")
// VIP user: has higher limit
err = m.CheckQuota("vip-user", "plugin1", 4)
require.NoError(t, err)
m.RecordUsage("vip-user", "plugin1", 4)
// VIP user: still within limit
err = m.CheckQuota("vip-user", "plugin2", 1)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerPluginLimits: map[string]scheduler.PluginLimit{
"jupyter": {
MaxGPUs: 3,
MaxServices: 2,
},
"vllm": {
MaxGPUs: 8,
MaxServices: 4,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Jupyter: within plugin GPU limit
err := m.CheckQuota("user1", "jupyter", 2)
require.NoError(t, err)
m.RecordUsage("user1", "jupyter", 2)
// Jupyter: exceed plugin GPU limit (but within global and user limits)
err = m.CheckQuota("user2", "jupyter", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded")
// vLLM: within its higher limit
err = m.CheckQuota("user1", "vllm", 4)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerUserServices: 10,
PerPluginLimits: map[string]scheduler.PluginLimit{
"jupyter": {
MaxGPUs: 10,
MaxServices: 2, // Only 2 jupyter services total
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// First jupyter service
err := m.CheckQuota("user1", "jupyter", 1)
require.NoError(t, err)
m.RecordUsage("user1", "jupyter", 1)
// Second jupyter service (different user)
err = m.CheckQuota("user2", "jupyter", 1)
require.NoError(t, err)
m.RecordUsage("user2", "jupyter", 1)
// Third jupyter service should fail (plugin service limit reached)
err = m.CheckQuota("user3", "jupyter", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
UserOverrides: map[string]scheduler.UserLimit{
"restricted-user": {
MaxGPUs: 5,
MaxServices: 5,
AllowedPlugins: []string{"jupyter"},
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Restricted user can use allowed plugin
err := m.CheckQuota("restricted-user", "jupyter", 2)
assert.NoError(t, err)
// Restricted user cannot use other plugins
err = m.CheckQuota("restricted-user", "vllm", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not allowed to use plugin vllm")
// Regular user can use any plugin
err = m.CheckQuota("regular-user", "vllm", 2)
assert.NoError(t, err)
}
func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 5,
}
m := scheduler.NewPluginQuotaManager(config)
// Record usage
m.RecordUsage("user1", "jupyter", 2)
m.RecordUsage("user1", "vllm", 1)
m.RecordUsage("user2", "jupyter", 3)
// Check usage tracking
usage, totalGPUs := m.GetUsage("user1")
assert.Equal(t, 2, usage["jupyter"].GPUs)
assert.Equal(t, 1, usage["jupyter"].Services)
assert.Equal(t, 1, usage["vllm"].GPUs)
assert.Equal(t, 1, usage["vllm"].Services)
assert.Equal(t, 3, totalGPUs)
// Check global usage
globalGPUs, pluginTotals := m.GetGlobalUsage()
assert.Equal(t, 6, globalGPUs)
assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3
assert.Equal(t, 1, pluginTotals["vllm"])
// Release usage
m.ReleaseUsage("user1", "jupyter", 2)
// Verify release
usage, totalGPUs = m.GetUsage("user1")
assert.Equal(t, 0, usage["jupyter"].GPUs)
assert.Equal(t, 0, usage["jupyter"].Services)
assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm
assert.Equal(t, 1, totalGPUs) // only vllm remains for user1
// Check global usage after release
globalGPUs, pluginTotals = m.GetGlobalUsage()
assert.Equal(t, 4, globalGPUs)
assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2
assert.Equal(t, 1, pluginTotals["vllm"])
}
func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: false,
TotalGPUs: 10,
}
m := scheduler.NewPluginQuotaManager(config)
// Recording usage when disabled should not crash
m.RecordUsage("user1", "plugin1", 5)
// Usage should be empty (not tracked)
usage, totalGPUs := m.GetUsage("user1")
assert.Empty(t, usage)
assert.Equal(t, 0, totalGPUs)
}
func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
}
m := scheduler.NewPluginQuotaManager(config)
// Releasing non-existent usage should not crash or go negative
m.ReleaseUsage("nonexistent", "plugin1", 5)
// Global usage should remain 0
globalGPUs, _ := m.GetGlobalUsage()
assert.Equal(t, 0, globalGPUs)
}
func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 2,
PerUserServices: 2,
}
m := scheduler.NewPluginQuotaManager(config)
// Empty userID should be treated as "anonymous"
err := m.CheckQuota("", "plugin1", 2)
require.NoError(t, err)
m.RecordUsage("", "plugin1", 2)
// Second request from anonymous should fail (at limit)
err = m.CheckQuota("", "plugin1", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 5,
PerUserServices: 5,
PerPluginLimits: map[string]scheduler.PluginLimit{
"default": {
MaxGPUs: 2,
MaxServices: 2,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Empty plugin name should be treated as "default"
err := m.CheckQuota("user1", "", 1)
require.NoError(t, err)
m.RecordUsage("user1", "", 1)
// Exceed default plugin limit
err = m.CheckQuota("user2", "", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin default GPU limit exceeded")
}
func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 100,
PerUserGPUs: 50,
PerUserServices: 50,
}
m := scheduler.NewPluginQuotaManager(config)
// Concurrently record usage from multiple goroutines
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
user := "user"
if idx%2 == 0 {
user = "user1"
} else {
user = "user2"
}
m.RecordUsage(user, "plugin1", 1)
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Verify totals
globalGPUs, _ := m.GetGlobalUsage()
assert.Equal(t, 10, globalGPUs)
usage1, _ := m.GetUsage("user1")
assert.Equal(t, 5, usage1["plugin1"].GPUs)
assert.Equal(t, 5, usage1["plugin1"].Services)
usage2, _ := m.GetUsage("user2")
assert.Equal(t, 5, usage2["plugin1"].GPUs)
assert.Equal(t, 5, usage2["plugin1"].Services)
}