feat: add Plugin GPU Quota implementation and tests
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
- Add plugin_quota.go with GPU quota management for scheduler - Update scheduler hub and protocol for plugin support - Add comprehensive plugin quota unit tests - Update gang service and WebSocket queue integration tests
This commit is contained in:
parent
ef05f200ba
commit
da104367d6
6 changed files with 776 additions and 23 deletions
|
|
@ -33,11 +33,13 @@ type SchedulerHub struct {
|
||||||
reservations map[string]*Reservation
|
reservations map[string]*Reservation
|
||||||
multiNodePending map[string]*MultiNodeJob
|
multiNodePending map[string]*MultiNodeJob
|
||||||
pendingAcceptance map[string]*JobAssignment
|
pendingAcceptance map[string]*JobAssignment
|
||||||
|
runningTasks map[string]*Task // Track assigned+accepted tasks
|
||||||
state *StateStore
|
state *StateStore
|
||||||
starvation *StarvationTracker
|
starvation *StarvationTracker
|
||||||
metrics *SchedulerMetrics
|
metrics *SchedulerMetrics
|
||||||
auditor *audit.Logger
|
auditor *audit.Logger
|
||||||
tokenValidator *TokenValidator
|
tokenValidator *TokenValidator
|
||||||
|
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
|
||||||
config HubConfig
|
config HubConfig
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
@ -59,6 +61,7 @@ type HubConfig struct {
|
||||||
AcceptanceTimeoutSecs int
|
AcceptanceTimeoutSecs int
|
||||||
LocalMode bool
|
LocalMode bool
|
||||||
WorkerTokens map[string]string // token -> workerID
|
WorkerTokens map[string]string // token -> workerID
|
||||||
|
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
|
||||||
}
|
}
|
||||||
|
|
||||||
// WorkerConn represents a connected worker
|
// WorkerConn represents a connected worker
|
||||||
|
|
@ -109,6 +112,7 @@ type JobAssignment struct {
|
||||||
AssignedAt time.Time
|
AssignedAt time.Time
|
||||||
AcceptanceDeadline time.Time
|
AcceptanceDeadline time.Time
|
||||||
Accepted bool
|
Accepted bool
|
||||||
|
Task *Task // Reference to the task (removed from queue)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StarvationTracker monitors long-waiting jobs
|
// StarvationTracker monitors long-waiting jobs
|
||||||
|
|
@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
||||||
reservations: make(map[string]*Reservation),
|
reservations: make(map[string]*Reservation),
|
||||||
multiNodePending: make(map[string]*MultiNodeJob),
|
multiNodePending: make(map[string]*MultiNodeJob),
|
||||||
pendingAcceptance: make(map[string]*JobAssignment),
|
pendingAcceptance: make(map[string]*JobAssignment),
|
||||||
|
runningTasks: make(map[string]*Task),
|
||||||
state: state,
|
state: state,
|
||||||
starvation: &StarvationTracker{
|
starvation: &StarvationTracker{
|
||||||
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
||||||
|
|
@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
||||||
},
|
},
|
||||||
auditor: auditor,
|
auditor: auditor,
|
||||||
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
||||||
|
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
|
||||||
config: cfg,
|
config: cfg,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
|
@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, res := range h.reservations {
|
for _, res := range h.reservations {
|
||||||
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
||||||
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
||||||
|
|
@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
||||||
h.batchQueue.Remove(task.ID)
|
h.batchQueue.Remove(task.ID)
|
||||||
h.serviceQueue.Remove(task.ID)
|
h.serviceQueue.Remove(task.ID)
|
||||||
|
|
||||||
// Track pending acceptance
|
// Track pending acceptance with task reference
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
h.pendingAcceptance[task.ID] = &JobAssignment{
|
h.pendingAcceptance[task.ID] = &JobAssignment{
|
||||||
TaskID: task.ID,
|
TaskID: task.ID,
|
||||||
|
|
@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
||||||
AssignedAt: time.Now(),
|
AssignedAt: time.Now(),
|
||||||
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||||
Accepted: false,
|
Accepted: false,
|
||||||
|
Task: task, // Store reference since removed from queue
|
||||||
}
|
}
|
||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
|
@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
|
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
||||||
assignment.Accepted = true
|
assignment.Accepted = true
|
||||||
|
|
||||||
|
// Track as running task
|
||||||
|
task := assignment.Task
|
||||||
|
if task != nil {
|
||||||
|
task.Status = "running"
|
||||||
|
task.WorkerID = workerID
|
||||||
|
h.runningTasks[taskID] = task
|
||||||
|
}
|
||||||
|
|
||||||
|
// NEW: Record quota usage for service jobs
|
||||||
|
if task != nil && task.Spec.Type == JobTypeService {
|
||||||
|
if h.quotaManager != nil {
|
||||||
|
pluginName := task.Spec.Metadata["plugin_name"]
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// NEW: Release quota usage for service jobs before deleting pending acceptance
|
||||||
|
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
|
||||||
|
if h.quotaManager != nil {
|
||||||
|
pluginName := task.Spec.Metadata["plugin_name"]
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(h.pendingAcceptance, result.TaskID)
|
delete(h.pendingAcceptance, result.TaskID)
|
||||||
|
delete(h.runningTasks, result.TaskID)
|
||||||
|
|
||||||
eventType := EventJobCompleted
|
eventType := EventJobCompleted
|
||||||
switch result.State {
|
switch result.State {
|
||||||
|
|
@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
for taskID, a := range h.pendingAcceptance {
|
for taskID, a := range h.pendingAcceptance {
|
||||||
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
||||||
h.batchQueue.Add(h.getTask(taskID))
|
if a.Task != nil {
|
||||||
|
a.Task.Status = "queued"
|
||||||
|
h.batchQueue.Add(a.Task)
|
||||||
|
}
|
||||||
delete(h.pendingAcceptance, taskID)
|
delete(h.pendingAcceptance, taskID)
|
||||||
if wc, ok := h.workers[a.WorkerID]; ok {
|
if wc, ok := h.workers[a.WorkerID]; ok {
|
||||||
wc.slots = SlotStatus{}
|
wc.slots = SlotStatus{}
|
||||||
|
|
@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
|
||||||
st.mu.Lock()
|
st.mu.Lock()
|
||||||
defer st.mu.Unlock()
|
defer st.mu.Unlock()
|
||||||
|
|
||||||
|
// First check which tasks need reservation under h.mu.RLock
|
||||||
|
tasksToReserve := make([]*Task, 0)
|
||||||
|
h.mu.RLock()
|
||||||
for _, task := range h.batchQueue.Items() {
|
for _, task := range h.batchQueue.Items() {
|
||||||
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
|
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
|
||||||
h.mu.Lock()
|
tasksToReserve = append(tasksToReserve, task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
// Now acquire Lock to add reservations
|
||||||
|
if len(tasksToReserve) > 0 {
|
||||||
|
h.mu.Lock()
|
||||||
|
for _, task := range tasksToReserve {
|
||||||
h.reservations[task.ID] = &Reservation{
|
h.reservations[task.ID] = &Reservation{
|
||||||
TaskID: task.ID,
|
TaskID: task.ID,
|
||||||
GPUCount: task.Spec.GPUCount,
|
GPUCount: task.Spec.GPUCount,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
h.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
|
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
_, exists := h.reservations[taskID]
|
_, exists := h.reservations[taskID]
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
|
||||||
return fmt.Errorf("job ID is required")
|
return fmt.Errorf("job ID is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NEW: Check plugin quotas for service jobs
|
||||||
|
if spec.Type == JobTypeService && h.quotaManager != nil {
|
||||||
|
pluginName := spec.Metadata["plugin_name"]
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
|
||||||
|
return fmt.Errorf("quota exceeded: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
task := &Task{
|
task := &Task{
|
||||||
ID: spec.ID,
|
ID: spec.ID,
|
||||||
Spec: spec,
|
Spec: spec,
|
||||||
|
|
@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
return h.serviceQueue.Get(taskID)
|
t = h.serviceQueue.Get(taskID)
|
||||||
|
if t != nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return h.runningTasks[taskID]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
||||||
|
|
@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() {
|
||||||
if assignment.Accepted {
|
if assignment.Accepted {
|
||||||
// Job was accepted but worker is gone (not in h.workers)
|
// Job was accepted but worker is gone (not in h.workers)
|
||||||
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
||||||
task := h.getTask(taskID)
|
task := assignment.Task
|
||||||
if task != nil {
|
if task != nil {
|
||||||
task.Status = "orphaned"
|
task.Status = "orphaned"
|
||||||
h.batchQueue.Add(task)
|
h.batchQueue.Add(task)
|
||||||
|
|
@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.Type == MsgMetricsRequest {
|
if msg.Type == MsgMetricsRequest {
|
||||||
metrics := h.getMetricsPayload()
|
metrics := h.GetMetricsPayload()
|
||||||
conn.WriteJSON(Message{
|
conn.WriteJSON(Message{
|
||||||
Type: MsgMetricsResponse,
|
Type: MsgMetricsResponse,
|
||||||
Payload: mustMarshal(metrics),
|
Payload: mustMarshal(metrics),
|
||||||
|
|
@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMetricsPayload returns current metrics as a map
|
// GetMetricsPayload returns current metrics as a map (public API)
|
||||||
func (h *SchedulerHub) getMetricsPayload() map[string]any {
|
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
|
||||||
h.metrics.mu.RLock()
|
h.metrics.mu.RLock()
|
||||||
defer h.metrics.mu.RUnlock()
|
defer h.metrics.mu.RUnlock()
|
||||||
|
|
||||||
|
|
@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
|
||||||
|
|
||||||
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
||||||
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
||||||
// Clone the spec and add rank info to metadata
|
// Clone the spec and add rank info to metadata and env
|
||||||
spec := task.Spec
|
spec := task.Spec
|
||||||
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
||||||
for k, v := range task.Spec.Metadata {
|
for k, v := range task.Spec.Metadata {
|
||||||
|
|
@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo
|
||||||
spec.Metadata["HEAD_ADDR"] = headAddr
|
spec.Metadata["HEAD_ADDR"] = headAddr
|
||||||
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||||
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||||
|
|
||||||
|
// Also set in Env for job runtime
|
||||||
|
if spec.Env == nil {
|
||||||
|
spec.Env = make(map[string]string)
|
||||||
|
}
|
||||||
|
spec.Env["HEAD_ADDR"] = headAddr
|
||||||
|
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||||
|
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||||
return spec
|
return spec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
287
internal/scheduler/plugin_quota.go
Normal file
287
internal/scheduler/plugin_quota.go
Normal file
|
|
@ -0,0 +1,287 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PluginQuotaConfig defines GPU limits for plugins.
|
||||||
|
type PluginQuotaConfig struct {
|
||||||
|
Enabled bool // Master switch for quota enforcement
|
||||||
|
TotalGPUs int // Global GPU limit across all plugins
|
||||||
|
PerUserGPUs int // Default per-user GPU limit
|
||||||
|
PerUserServices int // Default per-user service count limit
|
||||||
|
PerPluginLimits map[string]PluginLimit // Plugin-specific overrides
|
||||||
|
UserOverrides map[string]UserLimit // Per-user overrides
|
||||||
|
}
|
||||||
|
|
||||||
|
// PluginLimit defines limits for a specific plugin.
|
||||||
|
type PluginLimit struct {
|
||||||
|
MaxGPUs int
|
||||||
|
MaxServices int
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserLimit defines per-user override limits.
|
||||||
|
type UserLimit struct {
|
||||||
|
MaxGPUs int
|
||||||
|
MaxServices int
|
||||||
|
AllowedPlugins []string // Empty = all plugins allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// PluginUsage tracks GPU and service count for a user-plugin combination.
|
||||||
|
type PluginUsage struct {
|
||||||
|
GPUs int
|
||||||
|
Services int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PluginQuotaManager tracks active usage and enforces quotas.
|
||||||
|
type PluginQuotaManager struct {
|
||||||
|
config PluginQuotaConfig
|
||||||
|
mu sync.RWMutex
|
||||||
|
usage map[string]map[string]PluginUsage // userID -> pluginName -> usage
|
||||||
|
pluginTotal map[string]int // pluginName -> total GPUs in use
|
||||||
|
totalGPUs int // global total GPUs in use
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPluginQuotaManager creates a new quota manager with the given configuration.
|
||||||
|
func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager {
|
||||||
|
return &PluginQuotaManager{
|
||||||
|
config: config,
|
||||||
|
usage: make(map[string]map[string]PluginUsage),
|
||||||
|
pluginTotal: make(map[string]int),
|
||||||
|
totalGPUs: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckQuota validates if a job can be submitted without exceeding limits.
|
||||||
|
// Returns nil if the job is allowed, or an error describing which limit would be exceeded.
|
||||||
|
func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error {
|
||||||
|
if !m.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Get user limits (with overrides)
|
||||||
|
userLimit := m.getUserLimit(userID)
|
||||||
|
|
||||||
|
// Check if user is allowed to use this plugin
|
||||||
|
if len(userLimit.AllowedPlugins) > 0 {
|
||||||
|
found := false
|
||||||
|
for _, p := range userLimit.AllowedPlugins {
|
||||||
|
if p == pluginName {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check plugin-specific limits
|
||||||
|
pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName]
|
||||||
|
if hasPluginLimit {
|
||||||
|
if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs {
|
||||||
|
return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||||
|
pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs)
|
||||||
|
}
|
||||||
|
if pluginLimit.MaxServices > 0 {
|
||||||
|
// Services limit is across all users for this plugin
|
||||||
|
totalServices := 0
|
||||||
|
for _, u := range m.usage {
|
||||||
|
if p, ok := u[pluginName]; ok {
|
||||||
|
totalServices += p.Services
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if totalServices+1 > pluginLimit.MaxServices {
|
||||||
|
return fmt.Errorf("plugin %s service limit exceeded", pluginName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check per-user limits
|
||||||
|
effectiveUserGPUs := userLimit.MaxGPUs
|
||||||
|
if effectiveUserGPUs == 0 {
|
||||||
|
effectiveUserGPUs = m.config.PerUserGPUs
|
||||||
|
}
|
||||||
|
effectiveUserServices := userLimit.MaxServices
|
||||||
|
if effectiveUserServices == 0 {
|
||||||
|
effectiveUserServices = m.config.PerUserServices
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate total user usage across all plugins
|
||||||
|
totalUserGPUs := 0
|
||||||
|
totalUserServices := 0
|
||||||
|
for _, p := range m.usage[userID] {
|
||||||
|
totalUserGPUs += p.GPUs
|
||||||
|
totalUserServices += p.Services
|
||||||
|
}
|
||||||
|
|
||||||
|
if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs {
|
||||||
|
return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||||
|
userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices {
|
||||||
|
return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed",
|
||||||
|
userID, totalUserServices+1, effectiveUserServices)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check global total GPU limit
|
||||||
|
if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs {
|
||||||
|
return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total",
|
||||||
|
gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordUsage increments usage counters when a job starts.
|
||||||
|
func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) {
|
||||||
|
if !m.config.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
userPlugins, ok := m.usage[userID]
|
||||||
|
if !ok {
|
||||||
|
userPlugins = make(map[string]PluginUsage)
|
||||||
|
m.usage[userID] = userPlugins
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := userPlugins[pluginName]
|
||||||
|
usage.GPUs += gpuCount
|
||||||
|
usage.Services++
|
||||||
|
userPlugins[pluginName] = usage
|
||||||
|
|
||||||
|
m.pluginTotal[pluginName] += gpuCount
|
||||||
|
m.totalGPUs += gpuCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseUsage decrements usage counters when a job stops.
|
||||||
|
func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) {
|
||||||
|
if !m.config.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
if pluginName == "" {
|
||||||
|
pluginName = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
userPlugins, ok := m.usage[userID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := userPlugins[pluginName]
|
||||||
|
usage.GPUs -= gpuCount
|
||||||
|
usage.Services--
|
||||||
|
|
||||||
|
if usage.GPUs < 0 {
|
||||||
|
usage.GPUs = 0
|
||||||
|
}
|
||||||
|
if usage.Services < 0 {
|
||||||
|
usage.Services = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.GPUs == 0 && usage.Services == 0 {
|
||||||
|
delete(userPlugins, pluginName)
|
||||||
|
} else {
|
||||||
|
userPlugins[pluginName] = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userPlugins) == 0 {
|
||||||
|
delete(m.usage, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.pluginTotal[pluginName] -= gpuCount
|
||||||
|
if m.pluginTotal[pluginName] < 0 {
|
||||||
|
m.pluginTotal[pluginName] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
m.totalGPUs -= gpuCount
|
||||||
|
if m.totalGPUs < 0 {
|
||||||
|
m.totalGPUs = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage returns current usage for a user across all plugins.
|
||||||
|
func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]PluginUsage)
|
||||||
|
totalGPUs := 0
|
||||||
|
|
||||||
|
if userPlugins, ok := m.usage[userID]; ok {
|
||||||
|
for plugin, usage := range userPlugins {
|
||||||
|
result[plugin] = usage
|
||||||
|
totalGPUs += usage.GPUs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, totalGPUs
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalUsage returns global GPU usage across all users and plugins.
|
||||||
|
func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
pluginTotals := make(map[string]int, len(m.pluginTotal))
|
||||||
|
for k, v := range m.pluginTotal {
|
||||||
|
pluginTotals[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.totalGPUs, pluginTotals
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserLimit returns the effective limits for a user, applying overrides.
|
||||||
|
func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit {
|
||||||
|
if override, ok := m.config.UserOverrides[userID]; ok {
|
||||||
|
return override
|
||||||
|
}
|
||||||
|
return UserLimit{
|
||||||
|
MaxGPUs: m.config.PerUserGPUs,
|
||||||
|
MaxServices: m.config.PerUserServices,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUsageLocked returns the current usage for a user-plugin combination.
|
||||||
|
// Must be called with read lock held.
|
||||||
|
func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage {
|
||||||
|
if userPlugins, ok := m.usage[userID]; ok {
|
||||||
|
if usage, ok := userPlugins[pluginName]; ok {
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return PluginUsage{}
|
||||||
|
}
|
||||||
|
|
@ -100,6 +100,7 @@ type JobSpec struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Type JobType `json:"type"` // "batch" | "service"
|
Type JobType `json:"type"` // "batch" | "service"
|
||||||
SlotPool string `json:"slot_pool"`
|
SlotPool string `json:"slot_pool"`
|
||||||
|
UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking
|
||||||
|
|
||||||
GPUCount int `json:"gpu_count"`
|
GPUCount int `json:"gpu_count"`
|
||||||
GPUType string `json:"gpu_type,omitempty"`
|
GPUType string `json:"gpu_type,omitempty"`
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) {
|
||||||
ID: w.id,
|
ID: w.id,
|
||||||
Capabilities: scheduler.WorkerCapabilities{
|
Capabilities: scheduler.WorkerCapabilities{
|
||||||
GPUCount: 0,
|
GPUCount: 0,
|
||||||
|
Hostname: "localhost",
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) {
|
||||||
|
|
||||||
// Send job accepted
|
// Send job accepted
|
||||||
conn.WriteJSON(scheduler.Message{
|
conn.WriteJSON(scheduler.Message{
|
||||||
Type: scheduler.MsgJobAccepted,
|
Type: scheduler.MsgJobAccepted,
|
||||||
Payload: mustMarshal(map[string]string{
|
Payload: mustMarshal(jobID),
|
||||||
"task_id": jobID,
|
|
||||||
}),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Send periodic health updates
|
// Send periodic health updates
|
||||||
|
|
|
||||||
|
|
@ -307,6 +307,7 @@ func startFakeWorkers(
|
||||||
go func(workerID string) {
|
go func(workerID string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for {
|
for {
|
||||||
|
// Check for cancellation before getting next task
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
|
@ -324,6 +325,7 @@ func startFakeWorkers(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process task - finish even if context is cancelled
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
completed := started.Add(10 * time.Millisecond)
|
completed := started.Add(10 * time.Millisecond)
|
||||||
|
|
||||||
|
|
@ -338,7 +340,16 @@ func startFakeWorkers(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
doneCh <- task.JobName
|
// Only send to doneCh if we successfully completed
|
||||||
|
select {
|
||||||
|
case doneCh <- task.JobName:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context cancelled but we still completed the task - try once more without blocking
|
||||||
|
select {
|
||||||
|
case doneCh <- task.JobName:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}(fmt.Sprintf("worker-%d", w))
|
}(fmt.Sprintf("worker-%d", w))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
|
|
@ -0,0 +1,385 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) {
|
||||||
|
// When quota is disabled, all jobs should pass
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TotalGPUs: 1, // Set a low limit that would fail if enabled
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
err := m.CheckQuota("user1", "plugin1", 100)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 4,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// First job should succeed
|
||||||
|
err := m.CheckQuota("user1", "plugin1", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Record the usage
|
||||||
|
m.RecordUsage("user1", "plugin1", 2)
|
||||||
|
|
||||||
|
// Second job should succeed (2+2=4, within limit)
|
||||||
|
err = m.CheckQuota("user2", "plugin2", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user2", "plugin2", 2)
|
||||||
|
|
||||||
|
// Third job should fail (would exceed global limit)
|
||||||
|
err = m.CheckQuota("user3", "plugin3", 1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "global GPU limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 3,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// User1: first job should succeed
|
||||||
|
err := m.CheckQuota("user1", "plugin1", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "plugin1", 2)
|
||||||
|
|
||||||
|
// User1: second job should succeed (2+1=3, at limit)
|
||||||
|
err = m.CheckQuota("user1", "plugin2", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "plugin2", 1)
|
||||||
|
|
||||||
|
// User1: third job should fail (would exceed per-user limit)
|
||||||
|
err = m.CheckQuota("user1", "plugin3", 1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "user user1 GPU limit exceeded")
|
||||||
|
|
||||||
|
// User2: job should succeed (different user)
|
||||||
|
err = m.CheckQuota("user2", "plugin1", 3)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 10,
|
||||||
|
PerUserServices: 2,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// User1: first service should succeed
|
||||||
|
err := m.CheckQuota("user1", "plugin1", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "plugin1", 1)
|
||||||
|
|
||||||
|
// User1: second service should succeed
|
||||||
|
err = m.CheckQuota("user1", "plugin2", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "plugin2", 1)
|
||||||
|
|
||||||
|
// User1: third service should fail (would exceed service count limit)
|
||||||
|
err = m.CheckQuota("user1", "plugin3", 1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "user user1 service limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 2,
|
||||||
|
PerUserServices: 2,
|
||||||
|
UserOverrides: map[string]scheduler.UserLimit{
|
||||||
|
"vip-user": {
|
||||||
|
MaxGPUs: 5,
|
||||||
|
MaxServices: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Regular user: limited by default
|
||||||
|
err := m.CheckQuota("regular", "plugin1", 3)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "regular GPU limit exceeded")
|
||||||
|
|
||||||
|
// VIP user: has higher limit
|
||||||
|
err = m.CheckQuota("vip-user", "plugin1", 4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("vip-user", "plugin1", 4)
|
||||||
|
|
||||||
|
// VIP user: still within limit
|
||||||
|
err = m.CheckQuota("vip-user", "plugin2", 1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 10,
|
||||||
|
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||||
|
"jupyter": {
|
||||||
|
MaxGPUs: 3,
|
||||||
|
MaxServices: 2,
|
||||||
|
},
|
||||||
|
"vllm": {
|
||||||
|
MaxGPUs: 8,
|
||||||
|
MaxServices: 4,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Jupyter: within plugin GPU limit
|
||||||
|
err := m.CheckQuota("user1", "jupyter", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "jupyter", 2)
|
||||||
|
|
||||||
|
// Jupyter: exceed plugin GPU limit (but within global and user limits)
|
||||||
|
err = m.CheckQuota("user2", "jupyter", 2)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded")
|
||||||
|
|
||||||
|
// vLLM: within its higher limit
|
||||||
|
err = m.CheckQuota("user1", "vllm", 4)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 10,
|
||||||
|
PerUserServices: 10,
|
||||||
|
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||||
|
"jupyter": {
|
||||||
|
MaxGPUs: 10,
|
||||||
|
MaxServices: 2, // Only 2 jupyter services total
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// First jupyter service
|
||||||
|
err := m.CheckQuota("user1", "jupyter", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "jupyter", 1)
|
||||||
|
|
||||||
|
// Second jupyter service (different user)
|
||||||
|
err = m.CheckQuota("user2", "jupyter", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user2", "jupyter", 1)
|
||||||
|
|
||||||
|
// Third jupyter service should fail (plugin service limit reached)
|
||||||
|
err = m.CheckQuota("user3", "jupyter", 1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 10,
|
||||||
|
UserOverrides: map[string]scheduler.UserLimit{
|
||||||
|
"restricted-user": {
|
||||||
|
MaxGPUs: 5,
|
||||||
|
MaxServices: 5,
|
||||||
|
AllowedPlugins: []string{"jupyter"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Restricted user can use allowed plugin
|
||||||
|
err := m.CheckQuota("restricted-user", "jupyter", 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Restricted user cannot use other plugins
|
||||||
|
err = m.CheckQuota("restricted-user", "vllm", 2)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "not allowed to use plugin vllm")
|
||||||
|
|
||||||
|
// Regular user can use any plugin
|
||||||
|
err = m.CheckQuota("regular-user", "vllm", 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 5,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Record usage
|
||||||
|
m.RecordUsage("user1", "jupyter", 2)
|
||||||
|
m.RecordUsage("user1", "vllm", 1)
|
||||||
|
m.RecordUsage("user2", "jupyter", 3)
|
||||||
|
|
||||||
|
// Check usage tracking
|
||||||
|
usage, totalGPUs := m.GetUsage("user1")
|
||||||
|
assert.Equal(t, 2, usage["jupyter"].GPUs)
|
||||||
|
assert.Equal(t, 1, usage["jupyter"].Services)
|
||||||
|
assert.Equal(t, 1, usage["vllm"].GPUs)
|
||||||
|
assert.Equal(t, 1, usage["vllm"].Services)
|
||||||
|
assert.Equal(t, 3, totalGPUs)
|
||||||
|
|
||||||
|
// Check global usage
|
||||||
|
globalGPUs, pluginTotals := m.GetGlobalUsage()
|
||||||
|
assert.Equal(t, 6, globalGPUs)
|
||||||
|
assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3
|
||||||
|
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||||
|
|
||||||
|
// Release usage
|
||||||
|
m.ReleaseUsage("user1", "jupyter", 2)
|
||||||
|
|
||||||
|
// Verify release
|
||||||
|
usage, totalGPUs = m.GetUsage("user1")
|
||||||
|
assert.Equal(t, 0, usage["jupyter"].GPUs)
|
||||||
|
assert.Equal(t, 0, usage["jupyter"].Services)
|
||||||
|
assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm
|
||||||
|
assert.Equal(t, 1, totalGPUs) // only vllm remains for user1
|
||||||
|
|
||||||
|
// Check global usage after release
|
||||||
|
globalGPUs, pluginTotals = m.GetGlobalUsage()
|
||||||
|
assert.Equal(t, 4, globalGPUs)
|
||||||
|
assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2
|
||||||
|
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Recording usage when disabled should not crash
|
||||||
|
m.RecordUsage("user1", "plugin1", 5)
|
||||||
|
|
||||||
|
// Usage should be empty (not tracked)
|
||||||
|
usage, totalGPUs := m.GetUsage("user1")
|
||||||
|
assert.Empty(t, usage)
|
||||||
|
assert.Equal(t, 0, totalGPUs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Releasing non-existent usage should not crash or go negative
|
||||||
|
m.ReleaseUsage("nonexistent", "plugin1", 5)
|
||||||
|
|
||||||
|
// Global usage should remain 0
|
||||||
|
globalGPUs, _ := m.GetGlobalUsage()
|
||||||
|
assert.Equal(t, 0, globalGPUs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 2,
|
||||||
|
PerUserServices: 2,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Empty userID should be treated as "anonymous"
|
||||||
|
err := m.CheckQuota("", "plugin1", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("", "plugin1", 2)
|
||||||
|
|
||||||
|
// Second request from anonymous should fail (at limit)
|
||||||
|
err = m.CheckQuota("", "plugin1", 1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 10,
|
||||||
|
PerUserGPUs: 5,
|
||||||
|
PerUserServices: 5,
|
||||||
|
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||||
|
"default": {
|
||||||
|
MaxGPUs: 2,
|
||||||
|
MaxServices: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Empty plugin name should be treated as "default"
|
||||||
|
err := m.CheckQuota("user1", "", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m.RecordUsage("user1", "", 1)
|
||||||
|
|
||||||
|
// Exceed default plugin limit
|
||||||
|
err = m.CheckQuota("user2", "", 2)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "plugin default GPU limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) {
|
||||||
|
config := scheduler.PluginQuotaConfig{
|
||||||
|
Enabled: true,
|
||||||
|
TotalGPUs: 100,
|
||||||
|
PerUserGPUs: 50,
|
||||||
|
PerUserServices: 50,
|
||||||
|
}
|
||||||
|
m := scheduler.NewPluginQuotaManager(config)
|
||||||
|
|
||||||
|
// Concurrently record usage from multiple goroutines
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
user := "user"
|
||||||
|
if idx%2 == 0 {
|
||||||
|
user = "user1"
|
||||||
|
} else {
|
||||||
|
user = "user2"
|
||||||
|
}
|
||||||
|
m.RecordUsage(user, "plugin1", 1)
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify totals
|
||||||
|
globalGPUs, _ := m.GetGlobalUsage()
|
||||||
|
assert.Equal(t, 10, globalGPUs)
|
||||||
|
|
||||||
|
usage1, _ := m.GetUsage("user1")
|
||||||
|
assert.Equal(t, 5, usage1["plugin1"].GPUs)
|
||||||
|
assert.Equal(t, 5, usage1["plugin1"].Services)
|
||||||
|
|
||||||
|
usage2, _ := m.GetUsage("user2")
|
||||||
|
assert.Equal(t, 5, usage2["plugin1"].GPUs)
|
||||||
|
assert.Equal(t, 5, usage2["plugin1"].Services)
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue