feat: add Plugin GPU Quota implementation and tests
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
- Add plugin_quota.go with GPU quota management for scheduler - Update scheduler hub and protocol for plugin support - Add comprehensive plugin quota unit tests - Update gang service and WebSocket queue integration tests
This commit is contained in:
parent
ef05f200ba
commit
da104367d6
6 changed files with 776 additions and 23 deletions
|
|
@ -33,11 +33,13 @@ type SchedulerHub struct {
|
|||
reservations map[string]*Reservation
|
||||
multiNodePending map[string]*MultiNodeJob
|
||||
pendingAcceptance map[string]*JobAssignment
|
||||
runningTasks map[string]*Task // Track assigned+accepted tasks
|
||||
state *StateStore
|
||||
starvation *StarvationTracker
|
||||
metrics *SchedulerMetrics
|
||||
auditor *audit.Logger
|
||||
tokenValidator *TokenValidator
|
||||
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
|
||||
config HubConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
|
@ -59,6 +61,7 @@ type HubConfig struct {
|
|||
AcceptanceTimeoutSecs int
|
||||
LocalMode bool
|
||||
WorkerTokens map[string]string // token -> workerID
|
||||
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
|
||||
}
|
||||
|
||||
// WorkerConn represents a connected worker
|
||||
|
|
@ -109,6 +112,7 @@ type JobAssignment struct {
|
|||
AssignedAt time.Time
|
||||
AcceptanceDeadline time.Time
|
||||
Accepted bool
|
||||
Task *Task // Reference to the task (removed from queue)
|
||||
}
|
||||
|
||||
// StarvationTracker monitors long-waiting jobs
|
||||
|
|
@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
|||
reservations: make(map[string]*Reservation),
|
||||
multiNodePending: make(map[string]*MultiNodeJob),
|
||||
pendingAcceptance: make(map[string]*JobAssignment),
|
||||
runningTasks: make(map[string]*Task),
|
||||
state: state,
|
||||
starvation: &StarvationTracker{
|
||||
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
||||
|
|
@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
|||
},
|
||||
auditor: auditor,
|
||||
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
||||
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
|
||||
config: cfg,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
|
@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
|
|||
}
|
||||
|
||||
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
for _, res := range h.reservations {
|
||||
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
||||
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
||||
|
|
@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
h.batchQueue.Remove(task.ID)
|
||||
h.serviceQueue.Remove(task.ID)
|
||||
|
||||
// Track pending acceptance
|
||||
// Track pending acceptance with task reference
|
||||
h.mu.Lock()
|
||||
h.pendingAcceptance[task.ID] = &JobAssignment{
|
||||
TaskID: task.ID,
|
||||
|
|
@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
AssignedAt: time.Now(),
|
||||
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||
Accepted: false,
|
||||
Task: task, // Store reference since removed from queue
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
|
|
@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
|
||||
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
||||
assignment.Accepted = true
|
||||
|
||||
// Track as running task
|
||||
task := assignment.Task
|
||||
if task != nil {
|
||||
task.Status = "running"
|
||||
task.WorkerID = workerID
|
||||
h.runningTasks[taskID] = task
|
||||
}
|
||||
|
||||
// NEW: Record quota usage for service jobs
|
||||
if task != nil && task.Spec.Type == JobTypeService {
|
||||
if h.quotaManager != nil {
|
||||
pluginName := task.Spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
|
|||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// NEW: Release quota usage for service jobs before deleting pending acceptance
|
||||
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
|
||||
if h.quotaManager != nil {
|
||||
pluginName := task.Spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||
}
|
||||
}
|
||||
|
||||
delete(h.pendingAcceptance, result.TaskID)
|
||||
delete(h.runningTasks, result.TaskID)
|
||||
|
||||
eventType := EventJobCompleted
|
||||
switch result.State {
|
||||
|
|
@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() {
|
|||
h.mu.Lock()
|
||||
for taskID, a := range h.pendingAcceptance {
|
||||
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
||||
h.batchQueue.Add(h.getTask(taskID))
|
||||
if a.Task != nil {
|
||||
a.Task.Status = "queued"
|
||||
h.batchQueue.Add(a.Task)
|
||||
}
|
||||
delete(h.pendingAcceptance, taskID)
|
||||
if wc, ok := h.workers[a.WorkerID]; ok {
|
||||
wc.slots = SlotStatus{}
|
||||
|
|
@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
|
|||
st.mu.Lock()
|
||||
defer st.mu.Unlock()
|
||||
|
||||
// First check which tasks need reservation under h.mu.RLock
|
||||
tasksToReserve := make([]*Task, 0)
|
||||
h.mu.RLock()
|
||||
for _, task := range h.batchQueue.Items() {
|
||||
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
|
||||
h.mu.Lock()
|
||||
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
|
||||
tasksToReserve = append(tasksToReserve, task)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Now acquire Lock to add reservations
|
||||
if len(tasksToReserve) > 0 {
|
||||
h.mu.Lock()
|
||||
for _, task := range tasksToReserve {
|
||||
h.reservations[task.ID] = &Reservation{
|
||||
TaskID: task.ID,
|
||||
GPUCount: task.Spec.GPUCount,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
|
||||
_, exists := h.reservations[taskID]
|
||||
return exists
|
||||
}
|
||||
|
|
@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
|
|||
return fmt.Errorf("job ID is required")
|
||||
}
|
||||
|
||||
// NEW: Check plugin quotas for service jobs
|
||||
if spec.Type == JobTypeService && h.quotaManager != nil {
|
||||
pluginName := spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
|
||||
return fmt.Errorf("quota exceeded: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
task := &Task{
|
||||
ID: spec.ID,
|
||||
Spec: spec,
|
||||
|
|
@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task {
|
|||
if t != nil {
|
||||
return t
|
||||
}
|
||||
return h.serviceQueue.Get(taskID)
|
||||
t = h.serviceQueue.Get(taskID)
|
||||
if t != nil {
|
||||
return t
|
||||
}
|
||||
return h.runningTasks[taskID]
|
||||
}
|
||||
|
||||
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
||||
|
|
@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() {
|
|||
if assignment.Accepted {
|
||||
// Job was accepted but worker is gone (not in h.workers)
|
||||
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
||||
task := h.getTask(taskID)
|
||||
task := assignment.Task
|
||||
if task != nil {
|
||||
task.Status = "orphaned"
|
||||
h.batchQueue.Add(task)
|
||||
|
|
@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
|||
}
|
||||
|
||||
if msg.Type == MsgMetricsRequest {
|
||||
metrics := h.getMetricsPayload()
|
||||
metrics := h.GetMetricsPayload()
|
||||
conn.WriteJSON(Message{
|
||||
Type: MsgMetricsResponse,
|
||||
Payload: mustMarshal(metrics),
|
||||
|
|
@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
// getMetricsPayload returns current metrics as a map
|
||||
func (h *SchedulerHub) getMetricsPayload() map[string]any {
|
||||
// GetMetricsPayload returns current metrics as a map (public API)
|
||||
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
|
||||
h.metrics.mu.RLock()
|
||||
defer h.metrics.mu.RUnlock()
|
||||
|
||||
|
|
@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
|
|||
|
||||
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
||||
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
||||
// Clone the spec and add rank info to metadata
|
||||
// Clone the spec and add rank info to metadata and env
|
||||
spec := task.Spec
|
||||
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
||||
for k, v := range task.Spec.Metadata {
|
||||
|
|
@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo
|
|||
spec.Metadata["HEAD_ADDR"] = headAddr
|
||||
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||
|
||||
// Also set in Env for job runtime
|
||||
if spec.Env == nil {
|
||||
spec.Env = make(map[string]string)
|
||||
}
|
||||
spec.Env["HEAD_ADDR"] = headAddr
|
||||
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||
return spec
|
||||
}
|
||||
|
||||
|
|
|
|||
287
internal/scheduler/plugin_quota.go
Normal file
287
internal/scheduler/plugin_quota.go
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PluginQuotaConfig defines GPU limits for plugins.
|
||||
type PluginQuotaConfig struct {
|
||||
Enabled bool // Master switch for quota enforcement
|
||||
TotalGPUs int // Global GPU limit across all plugins
|
||||
PerUserGPUs int // Default per-user GPU limit
|
||||
PerUserServices int // Default per-user service count limit
|
||||
PerPluginLimits map[string]PluginLimit // Plugin-specific overrides
|
||||
UserOverrides map[string]UserLimit // Per-user overrides
|
||||
}
|
||||
|
||||
// PluginLimit defines limits for a specific plugin.
|
||||
type PluginLimit struct {
|
||||
MaxGPUs int
|
||||
MaxServices int
|
||||
}
|
||||
|
||||
// UserLimit defines per-user override limits.
|
||||
type UserLimit struct {
|
||||
MaxGPUs int
|
||||
MaxServices int
|
||||
AllowedPlugins []string // Empty = all plugins allowed
|
||||
}
|
||||
|
||||
// PluginUsage tracks GPU and service count for a user-plugin combination.
|
||||
type PluginUsage struct {
|
||||
GPUs int
|
||||
Services int
|
||||
}
|
||||
|
||||
// PluginQuotaManager tracks active usage and enforces quotas.
|
||||
type PluginQuotaManager struct {
|
||||
config PluginQuotaConfig
|
||||
mu sync.RWMutex
|
||||
usage map[string]map[string]PluginUsage // userID -> pluginName -> usage
|
||||
pluginTotal map[string]int // pluginName -> total GPUs in use
|
||||
totalGPUs int // global total GPUs in use
|
||||
}
|
||||
|
||||
// NewPluginQuotaManager creates a new quota manager with the given configuration.
|
||||
func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager {
|
||||
return &PluginQuotaManager{
|
||||
config: config,
|
||||
usage: make(map[string]map[string]PluginUsage),
|
||||
pluginTotal: make(map[string]int),
|
||||
totalGPUs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckQuota validates if a job can be submitted without exceeding limits.
|
||||
// Returns nil if the job is allowed, or an error describing which limit would be exceeded.
|
||||
func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error {
|
||||
if !m.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Get user limits (with overrides)
|
||||
userLimit := m.getUserLimit(userID)
|
||||
|
||||
// Check if user is allowed to use this plugin
|
||||
if len(userLimit.AllowedPlugins) > 0 {
|
||||
found := false
|
||||
for _, p := range userLimit.AllowedPlugins {
|
||||
if p == pluginName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check plugin-specific limits
|
||||
pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName]
|
||||
if hasPluginLimit {
|
||||
if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs {
|
||||
return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||
pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs)
|
||||
}
|
||||
if pluginLimit.MaxServices > 0 {
|
||||
// Services limit is across all users for this plugin
|
||||
totalServices := 0
|
||||
for _, u := range m.usage {
|
||||
if p, ok := u[pluginName]; ok {
|
||||
totalServices += p.Services
|
||||
}
|
||||
}
|
||||
if totalServices+1 > pluginLimit.MaxServices {
|
||||
return fmt.Errorf("plugin %s service limit exceeded", pluginName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-user limits
|
||||
effectiveUserGPUs := userLimit.MaxGPUs
|
||||
if effectiveUserGPUs == 0 {
|
||||
effectiveUserGPUs = m.config.PerUserGPUs
|
||||
}
|
||||
effectiveUserServices := userLimit.MaxServices
|
||||
if effectiveUserServices == 0 {
|
||||
effectiveUserServices = m.config.PerUserServices
|
||||
}
|
||||
|
||||
// Calculate total user usage across all plugins
|
||||
totalUserGPUs := 0
|
||||
totalUserServices := 0
|
||||
for _, p := range m.usage[userID] {
|
||||
totalUserGPUs += p.GPUs
|
||||
totalUserServices += p.Services
|
||||
}
|
||||
|
||||
if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs {
|
||||
return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||
userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs)
|
||||
}
|
||||
|
||||
if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices {
|
||||
return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed",
|
||||
userID, totalUserServices+1, effectiveUserServices)
|
||||
}
|
||||
|
||||
// Check global total GPU limit
|
||||
if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs {
|
||||
return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total",
|
||||
gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordUsage increments usage counters when a job starts.
|
||||
func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) {
|
||||
if !m.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userPlugins, ok := m.usage[userID]
|
||||
if !ok {
|
||||
userPlugins = make(map[string]PluginUsage)
|
||||
m.usage[userID] = userPlugins
|
||||
}
|
||||
|
||||
usage := userPlugins[pluginName]
|
||||
usage.GPUs += gpuCount
|
||||
usage.Services++
|
||||
userPlugins[pluginName] = usage
|
||||
|
||||
m.pluginTotal[pluginName] += gpuCount
|
||||
m.totalGPUs += gpuCount
|
||||
}
|
||||
|
||||
// ReleaseUsage decrements usage counters when a job stops.
|
||||
func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) {
|
||||
if !m.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userPlugins, ok := m.usage[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
usage := userPlugins[pluginName]
|
||||
usage.GPUs -= gpuCount
|
||||
usage.Services--
|
||||
|
||||
if usage.GPUs < 0 {
|
||||
usage.GPUs = 0
|
||||
}
|
||||
if usage.Services < 0 {
|
||||
usage.Services = 0
|
||||
}
|
||||
|
||||
if usage.GPUs == 0 && usage.Services == 0 {
|
||||
delete(userPlugins, pluginName)
|
||||
} else {
|
||||
userPlugins[pluginName] = usage
|
||||
}
|
||||
|
||||
if len(userPlugins) == 0 {
|
||||
delete(m.usage, userID)
|
||||
}
|
||||
|
||||
m.pluginTotal[pluginName] -= gpuCount
|
||||
if m.pluginTotal[pluginName] < 0 {
|
||||
m.pluginTotal[pluginName] = 0
|
||||
}
|
||||
|
||||
m.totalGPUs -= gpuCount
|
||||
if m.totalGPUs < 0 {
|
||||
m.totalGPUs = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsage returns current usage for a user across all plugins.
|
||||
func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
result := make(map[string]PluginUsage)
|
||||
totalGPUs := 0
|
||||
|
||||
if userPlugins, ok := m.usage[userID]; ok {
|
||||
for plugin, usage := range userPlugins {
|
||||
result[plugin] = usage
|
||||
totalGPUs += usage.GPUs
|
||||
}
|
||||
}
|
||||
|
||||
return result, totalGPUs
|
||||
}
|
||||
|
||||
// GetGlobalUsage returns global GPU usage across all users and plugins.
|
||||
func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
pluginTotals := make(map[string]int, len(m.pluginTotal))
|
||||
for k, v := range m.pluginTotal {
|
||||
pluginTotals[k] = v
|
||||
}
|
||||
|
||||
return m.totalGPUs, pluginTotals
|
||||
}
|
||||
|
||||
// getUserLimit returns the effective limits for a user, applying overrides.
|
||||
func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit {
|
||||
if override, ok := m.config.UserOverrides[userID]; ok {
|
||||
return override
|
||||
}
|
||||
return UserLimit{
|
||||
MaxGPUs: m.config.PerUserGPUs,
|
||||
MaxServices: m.config.PerUserServices,
|
||||
}
|
||||
}
|
||||
|
||||
// getUsageLocked returns the current usage for a user-plugin combination.
|
||||
// Must be called with read lock held.
|
||||
func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage {
|
||||
if userPlugins, ok := m.usage[userID]; ok {
|
||||
if usage, ok := userPlugins[pluginName]; ok {
|
||||
return usage
|
||||
}
|
||||
}
|
||||
return PluginUsage{}
|
||||
}
|
||||
|
|
@ -100,6 +100,7 @@ type JobSpec struct {
|
|||
ID string `json:"id"`
|
||||
Type JobType `json:"type"` // "batch" | "service"
|
||||
SlotPool string `json:"slot_pool"`
|
||||
UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking
|
||||
|
||||
GPUCount int `json:"gpu_count"`
|
||||
GPUType string `json:"gpu_type,omitempty"`
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) {
|
|||
ID: w.id,
|
||||
Capabilities: scheduler.WorkerCapabilities{
|
||||
GPUCount: 0,
|
||||
Hostname: "localhost",
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
|
@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) {
|
|||
|
||||
// Send job accepted
|
||||
conn.WriteJSON(scheduler.Message{
|
||||
Type: scheduler.MsgJobAccepted,
|
||||
Payload: mustMarshal(map[string]string{
|
||||
"task_id": jobID,
|
||||
}),
|
||||
Type: scheduler.MsgJobAccepted,
|
||||
Payload: mustMarshal(jobID),
|
||||
})
|
||||
|
||||
// Send periodic health updates
|
||||
|
|
|
|||
|
|
@ -307,6 +307,7 @@ func startFakeWorkers(
|
|||
go func(workerID string) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
// Check for cancellation before getting next task
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
|
@ -324,6 +325,7 @@ func startFakeWorkers(
|
|||
continue
|
||||
}
|
||||
|
||||
// Process task - finish even if context is cancelled
|
||||
started := time.Now()
|
||||
completed := started.Add(10 * time.Millisecond)
|
||||
|
||||
|
|
@ -338,7 +340,16 @@ func startFakeWorkers(
|
|||
continue
|
||||
}
|
||||
|
||||
doneCh <- task.JobName
|
||||
// Only send to doneCh if we successfully completed
|
||||
select {
|
||||
case doneCh <- task.JobName:
|
||||
case <-ctx.Done():
|
||||
// Context cancelled but we still completed the task - try once more without blocking
|
||||
select {
|
||||
case doneCh <- task.JobName:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}(fmt.Sprintf("worker-%d", w))
|
||||
}
|
||||
|
|
|
|||
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
package scheduler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) {
|
||||
// When quota is disabled, all jobs should pass
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: false,
|
||||
TotalGPUs: 1, // Set a low limit that would fail if enabled
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
err := m.CheckQuota("user1", "plugin1", 100)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 4,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// First job should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Record the usage
|
||||
m.RecordUsage("user1", "plugin1", 2)
|
||||
|
||||
// Second job should succeed (2+2=4, within limit)
|
||||
err = m.CheckQuota("user2", "plugin2", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user2", "plugin2", 2)
|
||||
|
||||
// Third job should fail (would exceed global limit)
|
||||
err = m.CheckQuota("user3", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "global GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 3,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// User1: first job should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin1", 2)
|
||||
|
||||
// User1: second job should succeed (2+1=3, at limit)
|
||||
err = m.CheckQuota("user1", "plugin2", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin2", 1)
|
||||
|
||||
// User1: third job should fail (would exceed per-user limit)
|
||||
err = m.CheckQuota("user1", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user user1 GPU limit exceeded")
|
||||
|
||||
// User2: job should succeed (different user)
|
||||
err = m.CheckQuota("user2", "plugin1", 3)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerUserServices: 2,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// User1: first service should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin1", 1)
|
||||
|
||||
// User1: second service should succeed
|
||||
err = m.CheckQuota("user1", "plugin2", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin2", 1)
|
||||
|
||||
// User1: third service should fail (would exceed service count limit)
|
||||
err = m.CheckQuota("user1", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user user1 service limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 2,
|
||||
PerUserServices: 2,
|
||||
UserOverrides: map[string]scheduler.UserLimit{
|
||||
"vip-user": {
|
||||
MaxGPUs: 5,
|
||||
MaxServices: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Regular user: limited by default
|
||||
err := m.CheckQuota("regular", "plugin1", 3)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "regular GPU limit exceeded")
|
||||
|
||||
// VIP user: has higher limit
|
||||
err = m.CheckQuota("vip-user", "plugin1", 4)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("vip-user", "plugin1", 4)
|
||||
|
||||
// VIP user: still within limit
|
||||
err = m.CheckQuota("vip-user", "plugin2", 1)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"jupyter": {
|
||||
MaxGPUs: 3,
|
||||
MaxServices: 2,
|
||||
},
|
||||
"vllm": {
|
||||
MaxGPUs: 8,
|
||||
MaxServices: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Jupyter: within plugin GPU limit
|
||||
err := m.CheckQuota("user1", "jupyter", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "jupyter", 2)
|
||||
|
||||
// Jupyter: exceed plugin GPU limit (but within global and user limits)
|
||||
err = m.CheckQuota("user2", "jupyter", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded")
|
||||
|
||||
// vLLM: within its higher limit
|
||||
err = m.CheckQuota("user1", "vllm", 4)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerUserServices: 10,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"jupyter": {
|
||||
MaxGPUs: 10,
|
||||
MaxServices: 2, // Only 2 jupyter services total
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// First jupyter service
|
||||
err := m.CheckQuota("user1", "jupyter", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "jupyter", 1)
|
||||
|
||||
// Second jupyter service (different user)
|
||||
err = m.CheckQuota("user2", "jupyter", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user2", "jupyter", 1)
|
||||
|
||||
// Third jupyter service should fail (plugin service limit reached)
|
||||
err = m.CheckQuota("user3", "jupyter", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
UserOverrides: map[string]scheduler.UserLimit{
|
||||
"restricted-user": {
|
||||
MaxGPUs: 5,
|
||||
MaxServices: 5,
|
||||
AllowedPlugins: []string{"jupyter"},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Restricted user can use allowed plugin
|
||||
err := m.CheckQuota("restricted-user", "jupyter", 2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Restricted user cannot use other plugins
|
||||
err = m.CheckQuota("restricted-user", "vllm", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not allowed to use plugin vllm")
|
||||
|
||||
// Regular user can use any plugin
|
||||
err = m.CheckQuota("regular-user", "vllm", 2)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 5,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Record usage
|
||||
m.RecordUsage("user1", "jupyter", 2)
|
||||
m.RecordUsage("user1", "vllm", 1)
|
||||
m.RecordUsage("user2", "jupyter", 3)
|
||||
|
||||
// Check usage tracking
|
||||
usage, totalGPUs := m.GetUsage("user1")
|
||||
assert.Equal(t, 2, usage["jupyter"].GPUs)
|
||||
assert.Equal(t, 1, usage["jupyter"].Services)
|
||||
assert.Equal(t, 1, usage["vllm"].GPUs)
|
||||
assert.Equal(t, 1, usage["vllm"].Services)
|
||||
assert.Equal(t, 3, totalGPUs)
|
||||
|
||||
// Check global usage
|
||||
globalGPUs, pluginTotals := m.GetGlobalUsage()
|
||||
assert.Equal(t, 6, globalGPUs)
|
||||
assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3
|
||||
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||
|
||||
// Release usage
|
||||
m.ReleaseUsage("user1", "jupyter", 2)
|
||||
|
||||
// Verify release
|
||||
usage, totalGPUs = m.GetUsage("user1")
|
||||
assert.Equal(t, 0, usage["jupyter"].GPUs)
|
||||
assert.Equal(t, 0, usage["jupyter"].Services)
|
||||
assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm
|
||||
assert.Equal(t, 1, totalGPUs) // only vllm remains for user1
|
||||
|
||||
// Check global usage after release
|
||||
globalGPUs, pluginTotals = m.GetGlobalUsage()
|
||||
assert.Equal(t, 4, globalGPUs)
|
||||
assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2
|
||||
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: false,
|
||||
TotalGPUs: 10,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Recording usage when disabled should not crash
|
||||
m.RecordUsage("user1", "plugin1", 5)
|
||||
|
||||
// Usage should be empty (not tracked)
|
||||
usage, totalGPUs := m.GetUsage("user1")
|
||||
assert.Empty(t, usage)
|
||||
assert.Equal(t, 0, totalGPUs)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Releasing non-existent usage should not crash or go negative
|
||||
m.ReleaseUsage("nonexistent", "plugin1", 5)
|
||||
|
||||
// Global usage should remain 0
|
||||
globalGPUs, _ := m.GetGlobalUsage()
|
||||
assert.Equal(t, 0, globalGPUs)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 2,
|
||||
PerUserServices: 2,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Empty userID should be treated as "anonymous"
|
||||
err := m.CheckQuota("", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("", "plugin1", 2)
|
||||
|
||||
// Second request from anonymous should fail (at limit)
|
||||
err = m.CheckQuota("", "plugin1", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 5,
|
||||
PerUserServices: 5,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"default": {
|
||||
MaxGPUs: 2,
|
||||
MaxServices: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Empty plugin name should be treated as "default"
|
||||
err := m.CheckQuota("user1", "", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "", 1)
|
||||
|
||||
// Exceed default plugin limit
|
||||
err = m.CheckQuota("user2", "", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin default GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 100,
|
||||
PerUserGPUs: 50,
|
||||
PerUserServices: 50,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Concurrently record usage from multiple goroutines
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(idx int) {
|
||||
user := "user"
|
||||
if idx%2 == 0 {
|
||||
user = "user1"
|
||||
} else {
|
||||
user = "user2"
|
||||
}
|
||||
m.RecordUsage(user, "plugin1", 1)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify totals
|
||||
globalGPUs, _ := m.GetGlobalUsage()
|
||||
assert.Equal(t, 10, globalGPUs)
|
||||
|
||||
usage1, _ := m.GetUsage("user1")
|
||||
assert.Equal(t, 5, usage1["plugin1"].GPUs)
|
||||
assert.Equal(t, 5, usage1["plugin1"].Services)
|
||||
|
||||
usage2, _ := m.GetUsage("user2")
|
||||
assert.Equal(t, 5, usage2["plugin1"].GPUs)
|
||||
assert.Equal(t, 5, usage2["plugin1"].Services)
|
||||
}
|
||||
Loading…
Reference in a new issue