Compare commits

...

4 commits

Author SHA1 Message Date
Jeremie Fraeys
da104367d6
feat: add Plugin GPU Quota implementation and tests
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
- Add plugin_quota.go with GPU quota management for scheduler

- Update scheduler hub and protocol for plugin support

- Add comprehensive plugin quota unit tests

- Update gang service and WebSocket queue integration tests
2026-02-26 14:35:05 -05:00
Jeremie Fraeys
ef05f200ba
build: add Scheduler & Services section to Makefile help
- Add new help section for scheduler and service targets

- Document dev-up, dev-down, prod-up, prod-down targets
2026-02-26 14:34:58 -05:00
Jeremie Fraeys
a653a2d0ed
ci: add plugin, quota, and scheduler tests to workflows
- Add plugin quota, service templates, scheduler tests to ci.yml

- Add vLLM plugin and audit logging test steps

- Add plugin configuration validation to security-modes-test.yml:

  - Verify HIPAA mode disables plugins

  - Verify standard mode enables plugins with security

  - Verify dev mode enables plugins with relaxed security
2026-02-26 14:34:49 -05:00
Jeremie Fraeys
b3a0c78903
config: add Plugin GPU Quota, plugins, and audit logging to configs
- Add Plugin GPU Quota config section to scheduler.yaml.example

- Add audit logging config to homelab-secure.yaml (HIPAA-compliant)

- Add Jupyter and vLLM plugin configs to all worker configs:

  - Security settings (passwords, trusted channels, blocked packages)

  - Resource limits (GPU, memory, CPU)

  - Model cache paths and quantization options for vLLM

- Disable plugins in HIPAA deployment mode for compliance

- Update deployments README with plugin services and GPU quotas
2026-02-26 14:34:42 -05:00
21 changed files with 1200 additions and 46 deletions

View file

@ -34,7 +34,7 @@ env:
jobs:
test:
name: Test
name: Test (ubuntu-latest on self-hosted)
runs-on: self-hosted
timeout-minutes: 30
@ -424,11 +424,31 @@ jobs:
echo "=== Testing ${{ matrix.build_config.name }} build (CGO_ENABLED=${{ matrix.build_config.cgo_enabled }}, tags=${{ matrix.build_config.tags }}) ==="
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/... || true
- name: Run GPU matrix tests - ${{ matrix.build_config.name }}
- name: Run plugin quota tests
run: |
echo "=== GPU Golden Test Matrix - ${{ matrix.build_config.name }} ==="
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/gpu/ -run TestGoldenGPUStatus || true
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/gpu/ -run TestBuildTagMatrix || true
echo "=== Running Plugin GPU Quota tests ==="
go test -v ./tests/unit/scheduler/... -run TestPluginQuota
- name: Run service templates tests
run: |
echo "=== Running Service Templates tests ==="
go test -v ./tests/unit/scheduler/... -run TestServiceTemplate
- name: Run scheduler tests
run: |
echo "=== Running Scheduler tests ==="
go test -v ./tests/unit/scheduler/... -run TestScheduler
- name: Run vLLM plugin tests
run: |
echo "=== Running vLLM Plugin tests ==="
go test -v ./tests/unit/worker/plugins/... -run TestVLLM
- name: Run audit tests
run: |
echo "=== Running Audit Logging tests ==="
go test -v ./tests/unit/security/... -run TestAudit
go test -v ./tests/integration/audit/...
build-trigger:
name: Trigger Build Workflow

View file

@ -175,24 +175,39 @@ EOF
echo "All required HIPAA fields have corresponding tests"
- name: Run security custom vet rules
- name: Validate plugin configuration for ${{ matrix.security_mode }} mode
run: |
echo "=== Running custom vet rules for security ==="
echo "=== Validating plugin configuration for ${{ matrix.security_mode }} mode ==="
# Check if fetchml-vet tool exists
if [ -d "tools/fetchml-vet" ]; then
cd tools/fetchml-vet
go build -o fetchml-vet ./cmd/fetchml-vet/
cd ../..
# Run the custom vet analyzer
./tools/fetchml-vet/fetchml-vet ./... || {
echo "Custom vet found issues - review required"
exit 1
}
else
echo "fetchml-vet tool not found - skipping custom vet"
fi
CONFIG_FILE="${{ matrix.config_file }}"
# Check plugin configuration based on security mode
case "${{ matrix.security_mode }}" in
hipaa)
echo "Checking HIPAA mode: plugins should be disabled"
if grep -A 5 "plugins:" "$CONFIG_FILE" | grep -q "enabled: false"; then
echo "✓ Plugins are disabled for HIPAA compliance"
else
echo "⚠ Warning: Plugins may not be properly disabled in HIPAA mode"
fi
;;
standard)
echo "Checking standard mode: plugins should be enabled with security"
if grep -A 10 "plugins:" "$CONFIG_FILE" | grep -q "enabled: true"; then
echo "✓ Plugins are enabled in standard mode"
# Check for security settings
if grep -A 20 "plugins:" "$CONFIG_FILE" | grep -q "require_password: true"; then
echo "✓ Plugin security (password) is enabled"
fi
fi
;;
dev)
echo "Checking dev mode: plugins should be enabled (relaxed security)"
if grep -A 10 "plugins:" "$CONFIG_FILE" | grep -q "enabled: true"; then
echo "✓ Plugins are enabled in dev mode"
fi
;;
esac
- name: Security mode test summary
if: always()

View file

@ -457,6 +457,15 @@ help:
@echo " make docs-build - Build static documentation (local)"
@echo " make docs-build-prod - Build static documentation (prod flags; set DOCS_PROD_BASEURL)"
@echo ""
@echo "Scheduler & Services:"
@echo " make dev-up - Start development environment (API, Worker, Redis, monitoring)"
@echo " make dev-down - Stop development environment"
@echo " make dev-status - Check development environment status"
@echo " make dev-logs - View development environment logs"
@echo " make prod-up - Start production environment"
@echo " make prod-down - Stop production environment"
@echo " make scheduler-config - Edit scheduler configuration example"
@echo ""
@echo "Utility:"
@echo " make size - Show binary sizes"
@echo " make self-cleanup - Clean up Docker resources"

View file

@ -62,7 +62,26 @@ database:
logging:
level: "info"
file: "/logs/fetch_ml.log"
audit_log: ""
# Audit logging (HIPAA-compliant with tamper-evident chain hashing)
audit:
enabled: true
file: "/var/log/fetch_ml/audit.log" # Separate file for audit events
chain_hashing: true # Enable tamper-evident logging
retention_days: 2555 # 7 years for HIPAA compliance
log_ip_address: true # Include source IP in audit events
log_user_agent: true # Include user agent in audit events
# Sensitive events to always log
events:
- "authentication_success"
- "authentication_failure"
- "file_access"
- "file_write"
- "file_delete"
- "job_queued"
- "job_started"
- "job_completed"
- "experiment_created"
- "experiment_deleted"
resources:
max_workers: 1

View file

@ -30,3 +30,30 @@ scheduler:
token: "wkr_PLACEHOLDER_GENERATE_WITH_OPENSSL_RAND_HEX_32"
- id: "worker-02"
token: "wkr_PLACEHOLDER_GENERATE_WITH_OPENSSL_RAND_HEX_32"
# Plugin GPU Quota Configuration
# Controls GPU allocation for plugin-based services (Jupyter, vLLM, etc.)
plugin_quota:
enabled: false # Enable quota enforcement (default: false)
total_gpus: 16 # Global GPU limit across all plugins (0 = unlimited)
per_user_gpus: 4 # Default per-user GPU limit (0 = unlimited)
per_user_services: 2 # Default per-user service count limit (0 = unlimited)
# Plugin-specific limits (optional)
per_plugin_limits:
vllm:
max_gpus: 8 # Max GPUs for vLLM across all users
max_services: 4 # Max vLLM service instances
jupyter:
max_gpus: 4 # Max GPUs for Jupyter across all users
max_services: 10 # Max Jupyter service instances
# Per-user overrides (optional)
user_overrides:
admin:
max_gpus: 8 # Admin gets more GPUs
max_services: 5 # Admin can run more services
allowed_plugins: ["jupyter", "vllm"] # Restrict which plugins user can use
researcher1:
max_gpus: 2 # Limited GPU access
max_services: 1 # Single service limit

View file

@ -48,6 +48,39 @@ queue:
native:
data_dir: "data/dev/queue"
# Plugin Configuration (for local development)
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab"
# Security settings
security:
trusted_channels:
- "conda-forge"
- "defaults"
blocked_packages: [] # Less restrictive for local dev
require_password: false # No password for local dev
# Resource limits
max_gpu_per_instance: 1
max_memory_per_instance: "4Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
# Model cache location
model_cache: "data/dev/models"
# Supported quantization methods: awq, gptq, fp8, squeezellm
default_quantization: "" # No quantization for dev (better quality)
# Resource limits
max_gpu_per_instance: 1
max_model_len: 2048
tensor_parallel_size: 1
task_lease_duration: "30m"
heartbeat_interval: "1m"
max_retries: 3

View file

@ -50,7 +50,40 @@ resources:
metrics:
enabled: true
listen_addr: ":9100"
metrics_flush_interval: "500ms"
metrics_flush_interval: "500ms"
# Plugin Configuration
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab"
# Security settings
security:
trusted_channels:
- "conda-forge"
- "defaults"
blocked_packages: [] # Dev environment - less restrictive
require_password: false # No password for dev
# Resource limits
max_gpu_per_instance: 1
max_memory_per_instance: "4Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
# Model cache location
model_cache: "/models"
# Supported quantization methods: awq, gptq, fp8, squeezellm
default_quantization: "" # No quantization for dev
# Resource limits
max_gpu_per_instance: 1
max_model_len: 2048
tensor_parallel_size: 1
task_lease_duration: "30m"
heartbeat_interval: "1m"

View file

@ -48,3 +48,42 @@ task_lease_duration: "30m"
heartbeat_interval: "1m"
max_retries: 3
graceful_timeout: "5m"
# Plugin Configuration
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
# Security settings
security:
trusted_channels:
- "conda-forge"
- "defaults"
- "pytorch"
blocked_packages:
- "requests"
- "urllib3"
- "httpx"
require_password: true
# Resource limits (enforced by scheduler quota system)
max_gpu_per_instance: 1
max_memory_per_instance: "8Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
# Model cache location
model_cache: "/models"
# Supported quantization methods: awq, gptq, fp8, squeezellm
default_quantization: "" # empty = no quantization
# Resource limits
max_gpu_per_instance: 4
max_model_len: 4096
# Environment variables passed to container
env:
- "HF_HOME=/models"
- "VLLM_WORKER_MULTIPROC_METHOD=spawn"

View file

@ -48,6 +48,46 @@ queue:
backend: "redis"
redis_url: "redis://localhost:6379/0"
# Plugin Configuration
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab" # "lab" or "notebook"
# Security settings
security:
trusted_channels:
- "conda-forge"
- "defaults"
- "pytorch"
- "nvidia"
blocked_packages:
- "requests"
- "urllib3"
- "httpx"
- "socket"
- "subprocess"
require_password: true
# Resource limits
max_gpu_per_instance: 1
max_memory_per_instance: "16Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
# Model cache location (should be on fast storage)
model_cache: "/var/lib/fetchml/models"
# Supported quantization methods: awq, gptq, fp8, squeezellm
default_quantization: ""
# Resource limits
max_gpu_per_instance: 2
max_model_len: 4096
tensor_parallel_size: 1
# Snapshot store (optional)
snapshot_store:
enabled: false

View file

@ -45,3 +45,42 @@ task_lease_duration: "30m"
heartbeat_interval: "1m"
max_retries: 3
graceful_timeout: "5m"
# Plugin Configuration
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab"
# Security settings (strict for secure config)
security:
trusted_channels:
- "conda-forge"
- "defaults"
blocked_packages:
- "requests"
- "urllib3"
- "httpx"
- "socket"
- "subprocess"
- "os.system"
require_password: true
# Resource limits
max_gpu_per_instance: 1
max_memory_per_instance: "8Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
# Model cache location
model_cache: "/models"
# Supported quantization methods: awq, gptq, fp8, squeezellm
default_quantization: ""
# Resource limits
max_gpu_per_instance: 1
max_model_len: 4096
tensor_parallel_size: 1

View file

@ -45,3 +45,34 @@ podman_memory = "16g"
[metrics]
enabled = true
listen_addr = ":9100"
# Plugin Configuration
[plugins]
[plugins.jupyter]
enabled = true
image = "quay.io/jupyter/base-notebook:latest"
default_port = 8888
mode = "lab"
max_gpu_per_instance = 1
max_memory_per_instance = "8Gi"
[plugins.jupyter.security]
require_password = true
trusted_channels = ["conda-forge", "defaults", "pytorch"]
blocked_packages = ["requests", "urllib3", "httpx"]
[plugins.vllm]
enabled = true
image = "vllm/vllm-openai:latest"
default_port = 8000
model_cache = "/models"
default_quantization = "" # Options: awq, gptq, fp8, squeezellm
max_gpu_per_instance = 2
max_model_len = 4096
tensor_parallel_size = 1
# Environment variables for vLLM
[plugins.vllm.env]
HF_HOME = "/models"
VLLM_WORKER_MULTIPROC_METHOD = "spawn"

View file

@ -110,6 +110,36 @@ TLS_KEY_PATH=/app/ssl/key.pem
| Prometheus | 9090 | - | - |
| Grafana | 3000 | - | - |
| Loki | 3100 | - | - |
| JupyterLab | 8888* | 8888* | - |
| vLLM | 8000* | 8000* | - |
*Plugin service ports are dynamically allocated from the 8000-9000 range by the scheduler.
## Plugin Services
The deployment configurations include support for interactive ML services:
### Jupyter Notebook/Lab
- **Image**: `quay.io/jupyter/base-notebook:latest`
- **Security**: Trusted channels (conda-forge, defaults), blocked packages (http clients)
- **Resources**: Configurable GPU/memory limits
- **Access**: Via scheduler-assigned port (8000-9000 range)
### vLLM Inference
- **Image**: `vllm/vllm-openai:latest`
- **Features**: OpenAI-compatible API, quantization support (AWQ, GPTQ, FP8)
- **Model Cache**: Configurable path for model storage
- **Resources**: Multi-GPU tensor parallelism support
## Scheduler GPU Quotas
The scheduler supports GPU quota management for plugin services:
- **Global Limit**: Total GPUs across all plugins
- **Per-User Limits**: GPU and service count per user
- **Per-Plugin Limits**: vLLM and Jupyter-specific limits
- **User Overrides**: Special permissions for admins/researchers
See `configs/scheduler/scheduler.yaml.example` for quota configuration.
## Monitoring
@ -122,3 +152,4 @@ TLS_KEY_PATH=/app/ssl/key.pem
- If you need HTTPS externally, terminate TLS at a reverse proxy.
- API keys should be managed via environment variables
- Database credentials should use secrets management in production
- **HIPAA deployments**: Plugins are disabled by default for compliance

View file

@ -29,3 +29,30 @@ max_artifact_total_bytes: 1073741824 # 1GB
# Provenance (disabled in dev for speed)
provenance_best_effort: false
# Plugin Configuration (development mode)
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab"
security:
trusted_channels:
- "conda-forge"
- "defaults"
blocked_packages: [] # No restrictions in dev
require_password: false # No password for dev
max_gpu_per_instance: 1
max_memory_per_instance: "4Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
model_cache: "/tmp/models" # Temp location for dev
default_quantization: "" # No quantization for dev
max_gpu_per_instance: 1
max_model_len: 2048

View file

@ -51,3 +51,12 @@ ssh_key: ${SSH_KEY_PATH}
# Config hash computation enabled (required for audit)
# This is automatically computed by Validate()
# Plugin Configuration (DISABLED for HIPAA compliance)
# Jupyter and vLLM services are disabled in HIPAA mode to ensure
# no unauthorized network access or data processing
plugins:
jupyter:
enabled: false # Disabled: HIPAA requires strict network isolation
vllm:
enabled: false # Disabled: External model downloads violate PHI controls

View file

@ -33,3 +33,32 @@ max_artifact_total_bytes: 536870912 # 512MB
# Provenance (enabled)
provenance_best_effort: true
# Plugin Configuration
plugins:
# Jupyter Notebook/Lab Service
jupyter:
enabled: true
image: "quay.io/jupyter/base-notebook:latest"
default_port: 8888
mode: "lab"
security:
trusted_channels:
- "conda-forge"
- "defaults"
blocked_packages:
- "requests"
- "urllib3"
require_password: true
max_gpu_per_instance: 1
max_memory_per_instance: "8Gi"
# vLLM Inference Service
vllm:
enabled: true
image: "vllm/vllm-openai:latest"
default_port: 8000
model_cache: "/models"
default_quantization: ""
max_gpu_per_instance: 1
max_model_len: 4096

View file

@ -33,11 +33,13 @@ type SchedulerHub struct {
reservations map[string]*Reservation
multiNodePending map[string]*MultiNodeJob
pendingAcceptance map[string]*JobAssignment
runningTasks map[string]*Task // Track assigned+accepted tasks
state *StateStore
starvation *StarvationTracker
metrics *SchedulerMetrics
auditor *audit.Logger
tokenValidator *TokenValidator
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
config HubConfig
ctx context.Context
cancel context.CancelFunc
@ -59,6 +61,7 @@ type HubConfig struct {
AcceptanceTimeoutSecs int
LocalMode bool
WorkerTokens map[string]string // token -> workerID
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
}
// WorkerConn represents a connected worker
@ -109,6 +112,7 @@ type JobAssignment struct {
AssignedAt time.Time
AcceptanceDeadline time.Time
Accepted bool
Task *Task // Reference to the task (removed from queue)
}
// StarvationTracker monitors long-waiting jobs
@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
reservations: make(map[string]*Reservation),
multiNodePending: make(map[string]*MultiNodeJob),
pendingAcceptance: make(map[string]*JobAssignment),
runningTasks: make(map[string]*Task),
state: state,
starvation: &StarvationTracker{
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
},
auditor: auditor,
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
config: cfg,
ctx: ctx,
cancel: cancel,
@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
}
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
h.mu.RLock()
defer h.mu.RUnlock()
for _, res := range h.reservations {
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
h.batchQueue.Remove(task.ID)
h.serviceQueue.Remove(task.ID)
// Track pending acceptance
// Track pending acceptance with task reference
h.mu.Lock()
h.pendingAcceptance[task.ID] = &JobAssignment{
TaskID: task.ID,
@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
AssignedAt: time.Now(),
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
Accepted: false,
Task: task, // Store reference since removed from queue
}
h.mu.Unlock()
@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
}
}
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
h.mu.Lock()
defer h.mu.Unlock()
if assignment, ok := h.pendingAcceptance[taskID]; ok {
assignment.Accepted = true
// Track as running task
task := assignment.Task
if task != nil {
task.Status = "running"
task.WorkerID = workerID
h.runningTasks[taskID] = task
}
// NEW: Record quota usage for service jobs
if task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
}
}
@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
h.mu.Lock()
defer h.mu.Unlock()
// NEW: Release quota usage for service jobs before deleting pending acceptance
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
delete(h.pendingAcceptance, result.TaskID)
delete(h.runningTasks, result.TaskID)
eventType := EventJobCompleted
switch result.State {
@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() {
h.mu.Lock()
for taskID, a := range h.pendingAcceptance {
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
h.batchQueue.Add(h.getTask(taskID))
if a.Task != nil {
a.Task.Status = "queued"
h.batchQueue.Add(a.Task)
}
delete(h.pendingAcceptance, taskID)
if wc, ok := h.workers[a.WorkerID]; ok {
wc.slots = SlotStatus{}
@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
st.mu.Lock()
defer st.mu.Unlock()
// First check which tasks need reservation under h.mu.RLock
tasksToReserve := make([]*Task, 0)
h.mu.RLock()
for _, task := range h.batchQueue.Items() {
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
h.mu.Lock()
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
tasksToReserve = append(tasksToReserve, task)
}
}
h.mu.RUnlock()
// Now acquire Lock to add reservations
if len(tasksToReserve) > 0 {
h.mu.Lock()
for _, task := range tasksToReserve {
h.reservations[task.ID] = &Reservation{
TaskID: task.ID,
GPUCount: task.Spec.GPUCount,
CreatedAt: time.Now(),
}
h.mu.Unlock()
}
h.mu.Unlock()
}
}
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
_, exists := h.reservations[taskID]
return exists
}
@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
return fmt.Errorf("job ID is required")
}
// NEW: Check plugin quotas for service jobs
if spec.Type == JobTypeService && h.quotaManager != nil {
pluginName := spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
return fmt.Errorf("quota exceeded: %w", err)
}
}
task := &Task{
ID: spec.ID,
Spec: spec,
@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task {
if t != nil {
return t
}
return h.serviceQueue.Get(taskID)
t = h.serviceQueue.Get(taskID)
if t != nil {
return t
}
return h.runningTasks[taskID]
}
func (h *SchedulerHub) restoreJob(ev StateEvent) {
@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() {
if assignment.Accepted {
// Job was accepted but worker is gone (not in h.workers)
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
task := h.getTask(taskID)
task := assignment.Task
if task != nil {
task.Status = "orphaned"
h.batchQueue.Add(task)
@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
}
if msg.Type == MsgMetricsRequest {
metrics := h.getMetricsPayload()
metrics := h.GetMetricsPayload()
conn.WriteJSON(Message{
Type: MsgMetricsResponse,
Payload: mustMarshal(metrics),
@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
}
}
// getMetricsPayload returns current metrics as a map
func (h *SchedulerHub) getMetricsPayload() map[string]any {
// GetMetricsPayload returns current metrics as a map (public API)
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
h.metrics.mu.RLock()
defer h.metrics.mu.RUnlock()
@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
// buildRankedSpec creates a job spec with rank-specific template variables resolved
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
// Clone the spec and add rank info to metadata
// Clone the spec and add rank info to metadata and env
spec := task.Spec
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
for k, v := range task.Spec.Metadata {
@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo
spec.Metadata["HEAD_ADDR"] = headAddr
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
// Also set in Env for job runtime
if spec.Env == nil {
spec.Env = make(map[string]string)
}
spec.Env["HEAD_ADDR"] = headAddr
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
return spec
}

View file

@ -0,0 +1,287 @@
package scheduler
import (
"fmt"
"sync"
)
// PluginQuotaConfig defines GPU limits for plugins.
type PluginQuotaConfig struct {
Enabled bool // Master switch for quota enforcement
TotalGPUs int // Global GPU limit across all plugins
PerUserGPUs int // Default per-user GPU limit
PerUserServices int // Default per-user service count limit
PerPluginLimits map[string]PluginLimit // Plugin-specific overrides
UserOverrides map[string]UserLimit // Per-user overrides
}
// PluginLimit defines limits for a specific plugin.
type PluginLimit struct {
MaxGPUs int
MaxServices int
}
// UserLimit defines per-user override limits.
type UserLimit struct {
MaxGPUs int
MaxServices int
AllowedPlugins []string // Empty = all plugins allowed
}
// PluginUsage tracks GPU and service count for a user-plugin combination.
type PluginUsage struct {
GPUs int
Services int
}
// PluginQuotaManager tracks active usage and enforces quotas.
type PluginQuotaManager struct {
config PluginQuotaConfig
mu sync.RWMutex
usage map[string]map[string]PluginUsage // userID -> pluginName -> usage
pluginTotal map[string]int // pluginName -> total GPUs in use
totalGPUs int // global total GPUs in use
}
// NewPluginQuotaManager creates a new quota manager with the given configuration.
func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager {
return &PluginQuotaManager{
config: config,
usage: make(map[string]map[string]PluginUsage),
pluginTotal: make(map[string]int),
totalGPUs: 0,
}
}
// CheckQuota validates if a job can be submitted without exceeding limits.
// Returns nil if the job is allowed, or an error describing which limit would be exceeded.
func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error {
if !m.config.Enabled {
return nil
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.RLock()
defer m.mu.RUnlock()
// Get user limits (with overrides)
userLimit := m.getUserLimit(userID)
// Check if user is allowed to use this plugin
if len(userLimit.AllowedPlugins) > 0 {
found := false
for _, p := range userLimit.AllowedPlugins {
if p == pluginName {
found = true
break
}
}
if !found {
return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName)
}
}
// Check plugin-specific limits
pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName]
if hasPluginLimit {
if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs {
return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total",
pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs)
}
if pluginLimit.MaxServices > 0 {
// Services limit is across all users for this plugin
totalServices := 0
for _, u := range m.usage {
if p, ok := u[pluginName]; ok {
totalServices += p.Services
}
}
if totalServices+1 > pluginLimit.MaxServices {
return fmt.Errorf("plugin %s service limit exceeded", pluginName)
}
}
}
// Check per-user limits
effectiveUserGPUs := userLimit.MaxGPUs
if effectiveUserGPUs == 0 {
effectiveUserGPUs = m.config.PerUserGPUs
}
effectiveUserServices := userLimit.MaxServices
if effectiveUserServices == 0 {
effectiveUserServices = m.config.PerUserServices
}
// Calculate total user usage across all plugins
totalUserGPUs := 0
totalUserServices := 0
for _, p := range m.usage[userID] {
totalUserGPUs += p.GPUs
totalUserServices += p.Services
}
if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs {
return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total",
userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs)
}
if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices {
return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed",
userID, totalUserServices+1, effectiveUserServices)
}
// Check global total GPU limit
if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs {
return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total",
gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs)
}
return nil
}
// RecordUsage increments usage counters when a job starts.
func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) {
if !m.config.Enabled {
return
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.Lock()
defer m.mu.Unlock()
userPlugins, ok := m.usage[userID]
if !ok {
userPlugins = make(map[string]PluginUsage)
m.usage[userID] = userPlugins
}
usage := userPlugins[pluginName]
usage.GPUs += gpuCount
usage.Services++
userPlugins[pluginName] = usage
m.pluginTotal[pluginName] += gpuCount
m.totalGPUs += gpuCount
}
// ReleaseUsage decrements usage counters when a job stops.
func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) {
if !m.config.Enabled {
return
}
if userID == "" {
userID = "anonymous"
}
if pluginName == "" {
pluginName = "default"
}
m.mu.Lock()
defer m.mu.Unlock()
userPlugins, ok := m.usage[userID]
if !ok {
return
}
usage := userPlugins[pluginName]
usage.GPUs -= gpuCount
usage.Services--
if usage.GPUs < 0 {
usage.GPUs = 0
}
if usage.Services < 0 {
usage.Services = 0
}
if usage.GPUs == 0 && usage.Services == 0 {
delete(userPlugins, pluginName)
} else {
userPlugins[pluginName] = usage
}
if len(userPlugins) == 0 {
delete(m.usage, userID)
}
m.pluginTotal[pluginName] -= gpuCount
if m.pluginTotal[pluginName] < 0 {
m.pluginTotal[pluginName] = 0
}
m.totalGPUs -= gpuCount
if m.totalGPUs < 0 {
m.totalGPUs = 0
}
}
// GetUsage returns current usage for a user across all plugins.
func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) {
m.mu.RLock()
defer m.mu.RUnlock()
if userID == "" {
userID = "anonymous"
}
result := make(map[string]PluginUsage)
totalGPUs := 0
if userPlugins, ok := m.usage[userID]; ok {
for plugin, usage := range userPlugins {
result[plugin] = usage
totalGPUs += usage.GPUs
}
}
return result, totalGPUs
}
// GetGlobalUsage returns global GPU usage across all users and plugins.
func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) {
m.mu.RLock()
defer m.mu.RUnlock()
pluginTotals := make(map[string]int, len(m.pluginTotal))
for k, v := range m.pluginTotal {
pluginTotals[k] = v
}
return m.totalGPUs, pluginTotals
}
// getUserLimit returns the effective limits for a user, applying overrides.
func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit {
if override, ok := m.config.UserOverrides[userID]; ok {
return override
}
return UserLimit{
MaxGPUs: m.config.PerUserGPUs,
MaxServices: m.config.PerUserServices,
}
}
// getUsageLocked returns the current usage for a user-plugin combination.
// Must be called with read lock held.
func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage {
if userPlugins, ok := m.usage[userID]; ok {
if usage, ok := userPlugins[pluginName]; ok {
return usage
}
}
return PluginUsage{}
}

View file

@ -100,6 +100,7 @@ type JobSpec struct {
ID string `json:"id"`
Type JobType `json:"type"` // "batch" | "service"
SlotPool string `json:"slot_pool"`
UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`

View file

@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) {
ID: w.id,
Capabilities: scheduler.WorkerCapabilities{
GPUCount: 0,
Hostname: "localhost",
},
}),
})
@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) {
// Send job accepted
conn.WriteJSON(scheduler.Message{
Type: scheduler.MsgJobAccepted,
Payload: mustMarshal(map[string]string{
"task_id": jobID,
}),
Type: scheduler.MsgJobAccepted,
Payload: mustMarshal(jobID),
})
// Send periodic health updates

View file

@ -307,6 +307,7 @@ func startFakeWorkers(
go func(workerID string) {
defer wg.Done()
for {
// Check for cancellation before getting next task
select {
case <-ctx.Done():
return
@ -324,6 +325,7 @@ func startFakeWorkers(
continue
}
// Process task - finish even if context is cancelled
started := time.Now()
completed := started.Add(10 * time.Millisecond)
@ -338,7 +340,16 @@ func startFakeWorkers(
continue
}
doneCh <- task.JobName
// Only send to doneCh if we successfully completed
select {
case doneCh <- task.JobName:
case <-ctx.Done():
// Context cancelled but we still completed the task - try once more without blocking
select {
case doneCh <- task.JobName:
default:
}
}
}
}(fmt.Sprintf("worker-%d", w))
}

View file

@ -0,0 +1,385 @@
package scheduler_test
import (
"testing"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) {
// When quota is disabled, all jobs should pass
config := scheduler.PluginQuotaConfig{
Enabled: false,
TotalGPUs: 1, // Set a low limit that would fail if enabled
}
m := scheduler.NewPluginQuotaManager(config)
err := m.CheckQuota("user1", "plugin1", 100)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 4,
}
m := scheduler.NewPluginQuotaManager(config)
// First job should succeed
err := m.CheckQuota("user1", "plugin1", 2)
require.NoError(t, err)
// Record the usage
m.RecordUsage("user1", "plugin1", 2)
// Second job should succeed (2+2=4, within limit)
err = m.CheckQuota("user2", "plugin2", 2)
require.NoError(t, err)
m.RecordUsage("user2", "plugin2", 2)
// Third job should fail (would exceed global limit)
err = m.CheckQuota("user3", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "global GPU limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 3,
}
m := scheduler.NewPluginQuotaManager(config)
// User1: first job should succeed
err := m.CheckQuota("user1", "plugin1", 2)
require.NoError(t, err)
m.RecordUsage("user1", "plugin1", 2)
// User1: second job should succeed (2+1=3, at limit)
err = m.CheckQuota("user1", "plugin2", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin2", 1)
// User1: third job should fail (would exceed per-user limit)
err = m.CheckQuota("user1", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user user1 GPU limit exceeded")
// User2: job should succeed (different user)
err = m.CheckQuota("user2", "plugin1", 3)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerUserServices: 2,
}
m := scheduler.NewPluginQuotaManager(config)
// User1: first service should succeed
err := m.CheckQuota("user1", "plugin1", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin1", 1)
// User1: second service should succeed
err = m.CheckQuota("user1", "plugin2", 1)
require.NoError(t, err)
m.RecordUsage("user1", "plugin2", 1)
// User1: third service should fail (would exceed service count limit)
err = m.CheckQuota("user1", "plugin3", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user user1 service limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 2,
PerUserServices: 2,
UserOverrides: map[string]scheduler.UserLimit{
"vip-user": {
MaxGPUs: 5,
MaxServices: 10,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Regular user: limited by default
err := m.CheckQuota("regular", "plugin1", 3)
assert.Error(t, err)
assert.Contains(t, err.Error(), "regular GPU limit exceeded")
// VIP user: has higher limit
err = m.CheckQuota("vip-user", "plugin1", 4)
require.NoError(t, err)
m.RecordUsage("vip-user", "plugin1", 4)
// VIP user: still within limit
err = m.CheckQuota("vip-user", "plugin2", 1)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerPluginLimits: map[string]scheduler.PluginLimit{
"jupyter": {
MaxGPUs: 3,
MaxServices: 2,
},
"vllm": {
MaxGPUs: 8,
MaxServices: 4,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Jupyter: within plugin GPU limit
err := m.CheckQuota("user1", "jupyter", 2)
require.NoError(t, err)
m.RecordUsage("user1", "jupyter", 2)
// Jupyter: exceed plugin GPU limit (but within global and user limits)
err = m.CheckQuota("user2", "jupyter", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded")
// vLLM: within its higher limit
err = m.CheckQuota("user1", "vllm", 4)
assert.NoError(t, err)
}
func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
PerUserServices: 10,
PerPluginLimits: map[string]scheduler.PluginLimit{
"jupyter": {
MaxGPUs: 10,
MaxServices: 2, // Only 2 jupyter services total
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// First jupyter service
err := m.CheckQuota("user1", "jupyter", 1)
require.NoError(t, err)
m.RecordUsage("user1", "jupyter", 1)
// Second jupyter service (different user)
err = m.CheckQuota("user2", "jupyter", 1)
require.NoError(t, err)
m.RecordUsage("user2", "jupyter", 1)
// Third jupyter service should fail (plugin service limit reached)
err = m.CheckQuota("user3", "jupyter", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 10,
UserOverrides: map[string]scheduler.UserLimit{
"restricted-user": {
MaxGPUs: 5,
MaxServices: 5,
AllowedPlugins: []string{"jupyter"},
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Restricted user can use allowed plugin
err := m.CheckQuota("restricted-user", "jupyter", 2)
assert.NoError(t, err)
// Restricted user cannot use other plugins
err = m.CheckQuota("restricted-user", "vllm", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not allowed to use plugin vllm")
// Regular user can use any plugin
err = m.CheckQuota("regular-user", "vllm", 2)
assert.NoError(t, err)
}
func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 5,
}
m := scheduler.NewPluginQuotaManager(config)
// Record usage
m.RecordUsage("user1", "jupyter", 2)
m.RecordUsage("user1", "vllm", 1)
m.RecordUsage("user2", "jupyter", 3)
// Check usage tracking
usage, totalGPUs := m.GetUsage("user1")
assert.Equal(t, 2, usage["jupyter"].GPUs)
assert.Equal(t, 1, usage["jupyter"].Services)
assert.Equal(t, 1, usage["vllm"].GPUs)
assert.Equal(t, 1, usage["vllm"].Services)
assert.Equal(t, 3, totalGPUs)
// Check global usage
globalGPUs, pluginTotals := m.GetGlobalUsage()
assert.Equal(t, 6, globalGPUs)
assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3
assert.Equal(t, 1, pluginTotals["vllm"])
// Release usage
m.ReleaseUsage("user1", "jupyter", 2)
// Verify release
usage, totalGPUs = m.GetUsage("user1")
assert.Equal(t, 0, usage["jupyter"].GPUs)
assert.Equal(t, 0, usage["jupyter"].Services)
assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm
assert.Equal(t, 1, totalGPUs) // only vllm remains for user1
// Check global usage after release
globalGPUs, pluginTotals = m.GetGlobalUsage()
assert.Equal(t, 4, globalGPUs)
assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2
assert.Equal(t, 1, pluginTotals["vllm"])
}
func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: false,
TotalGPUs: 10,
}
m := scheduler.NewPluginQuotaManager(config)
// Recording usage when disabled should not crash
m.RecordUsage("user1", "plugin1", 5)
// Usage should be empty (not tracked)
usage, totalGPUs := m.GetUsage("user1")
assert.Empty(t, usage)
assert.Equal(t, 0, totalGPUs)
}
func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
}
m := scheduler.NewPluginQuotaManager(config)
// Releasing non-existent usage should not crash or go negative
m.ReleaseUsage("nonexistent", "plugin1", 5)
// Global usage should remain 0
globalGPUs, _ := m.GetGlobalUsage()
assert.Equal(t, 0, globalGPUs)
}
func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 2,
PerUserServices: 2,
}
m := scheduler.NewPluginQuotaManager(config)
// Empty userID should be treated as "anonymous"
err := m.CheckQuota("", "plugin1", 2)
require.NoError(t, err)
m.RecordUsage("", "plugin1", 2)
// Second request from anonymous should fail (at limit)
err = m.CheckQuota("", "plugin1", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded")
}
func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 10,
PerUserGPUs: 5,
PerUserServices: 5,
PerPluginLimits: map[string]scheduler.PluginLimit{
"default": {
MaxGPUs: 2,
MaxServices: 2,
},
},
}
m := scheduler.NewPluginQuotaManager(config)
// Empty plugin name should be treated as "default"
err := m.CheckQuota("user1", "", 1)
require.NoError(t, err)
m.RecordUsage("user1", "", 1)
// Exceed default plugin limit
err = m.CheckQuota("user2", "", 2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "plugin default GPU limit exceeded")
}
func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) {
config := scheduler.PluginQuotaConfig{
Enabled: true,
TotalGPUs: 100,
PerUserGPUs: 50,
PerUserServices: 50,
}
m := scheduler.NewPluginQuotaManager(config)
// Concurrently record usage from multiple goroutines
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
user := "user"
if idx%2 == 0 {
user = "user1"
} else {
user = "user2"
}
m.RecordUsage(user, "plugin1", 1)
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Verify totals
globalGPUs, _ := m.GetGlobalUsage()
assert.Equal(t, 10, globalGPUs)
usage1, _ := m.GetUsage("user1")
assert.Equal(t, 5, usage1["plugin1"].GPUs)
assert.Equal(t, 5, usage1["plugin1"].Services)
usage2, _ := m.GetUsage("user2")
assert.Equal(t, 5, usage2["plugin1"].GPUs)
assert.Equal(t, 5, usage2["plugin1"].Services)
}