Compare commits
4 commits
90ea18555c
...
da104367d6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da104367d6 | ||
|
|
ef05f200ba | ||
|
|
a653a2d0ed | ||
|
|
b3a0c78903 |
21 changed files with 1200 additions and 46 deletions
|
|
@ -34,7 +34,7 @@ env:
|
|||
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
name: Test (ubuntu-latest on self-hosted)
|
||||
runs-on: self-hosted
|
||||
timeout-minutes: 30
|
||||
|
||||
|
|
@ -424,11 +424,31 @@ jobs:
|
|||
echo "=== Testing ${{ matrix.build_config.name }} build (CGO_ENABLED=${{ matrix.build_config.cgo_enabled }}, tags=${{ matrix.build_config.tags }}) ==="
|
||||
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/... || true
|
||||
|
||||
- name: Run GPU matrix tests - ${{ matrix.build_config.name }}
|
||||
- name: Run plugin quota tests
|
||||
run: |
|
||||
echo "=== GPU Golden Test Matrix - ${{ matrix.build_config.name }} ==="
|
||||
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/gpu/ -run TestGoldenGPUStatus || true
|
||||
CGO_ENABLED=${{ matrix.build_config.cgo_enabled }} go test -tags "${{ matrix.build_config.tags }}" -v ./tests/unit/gpu/ -run TestBuildTagMatrix || true
|
||||
echo "=== Running Plugin GPU Quota tests ==="
|
||||
go test -v ./tests/unit/scheduler/... -run TestPluginQuota
|
||||
|
||||
- name: Run service templates tests
|
||||
run: |
|
||||
echo "=== Running Service Templates tests ==="
|
||||
go test -v ./tests/unit/scheduler/... -run TestServiceTemplate
|
||||
|
||||
- name: Run scheduler tests
|
||||
run: |
|
||||
echo "=== Running Scheduler tests ==="
|
||||
go test -v ./tests/unit/scheduler/... -run TestScheduler
|
||||
|
||||
- name: Run vLLM plugin tests
|
||||
run: |
|
||||
echo "=== Running vLLM Plugin tests ==="
|
||||
go test -v ./tests/unit/worker/plugins/... -run TestVLLM
|
||||
|
||||
- name: Run audit tests
|
||||
run: |
|
||||
echo "=== Running Audit Logging tests ==="
|
||||
go test -v ./tests/unit/security/... -run TestAudit
|
||||
go test -v ./tests/integration/audit/...
|
||||
|
||||
build-trigger:
|
||||
name: Trigger Build Workflow
|
||||
|
|
|
|||
|
|
@ -175,24 +175,39 @@ EOF
|
|||
|
||||
echo "All required HIPAA fields have corresponding tests"
|
||||
|
||||
- name: Run security custom vet rules
|
||||
- name: Validate plugin configuration for ${{ matrix.security_mode }} mode
|
||||
run: |
|
||||
echo "=== Running custom vet rules for security ==="
|
||||
echo "=== Validating plugin configuration for ${{ matrix.security_mode }} mode ==="
|
||||
|
||||
# Check if fetchml-vet tool exists
|
||||
if [ -d "tools/fetchml-vet" ]; then
|
||||
cd tools/fetchml-vet
|
||||
go build -o fetchml-vet ./cmd/fetchml-vet/
|
||||
cd ../..
|
||||
|
||||
# Run the custom vet analyzer
|
||||
./tools/fetchml-vet/fetchml-vet ./... || {
|
||||
echo "Custom vet found issues - review required"
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "fetchml-vet tool not found - skipping custom vet"
|
||||
fi
|
||||
CONFIG_FILE="${{ matrix.config_file }}"
|
||||
|
||||
# Check plugin configuration based on security mode
|
||||
case "${{ matrix.security_mode }}" in
|
||||
hipaa)
|
||||
echo "Checking HIPAA mode: plugins should be disabled"
|
||||
if grep -A 5 "plugins:" "$CONFIG_FILE" | grep -q "enabled: false"; then
|
||||
echo "✓ Plugins are disabled for HIPAA compliance"
|
||||
else
|
||||
echo "⚠ Warning: Plugins may not be properly disabled in HIPAA mode"
|
||||
fi
|
||||
;;
|
||||
standard)
|
||||
echo "Checking standard mode: plugins should be enabled with security"
|
||||
if grep -A 10 "plugins:" "$CONFIG_FILE" | grep -q "enabled: true"; then
|
||||
echo "✓ Plugins are enabled in standard mode"
|
||||
# Check for security settings
|
||||
if grep -A 20 "plugins:" "$CONFIG_FILE" | grep -q "require_password: true"; then
|
||||
echo "✓ Plugin security (password) is enabled"
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
dev)
|
||||
echo "Checking dev mode: plugins should be enabled (relaxed security)"
|
||||
if grep -A 10 "plugins:" "$CONFIG_FILE" | grep -q "enabled: true"; then
|
||||
echo "✓ Plugins are enabled in dev mode"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
||||
- name: Security mode test summary
|
||||
if: always()
|
||||
|
|
|
|||
9
Makefile
9
Makefile
|
|
@ -457,6 +457,15 @@ help:
|
|||
@echo " make docs-build - Build static documentation (local)"
|
||||
@echo " make docs-build-prod - Build static documentation (prod flags; set DOCS_PROD_BASEURL)"
|
||||
@echo ""
|
||||
@echo "Scheduler & Services:"
|
||||
@echo " make dev-up - Start development environment (API, Worker, Redis, monitoring)"
|
||||
@echo " make dev-down - Stop development environment"
|
||||
@echo " make dev-status - Check development environment status"
|
||||
@echo " make dev-logs - View development environment logs"
|
||||
@echo " make prod-up - Start production environment"
|
||||
@echo " make prod-down - Stop production environment"
|
||||
@echo " make scheduler-config - Edit scheduler configuration example"
|
||||
@echo ""
|
||||
@echo "Utility:"
|
||||
@echo " make size - Show binary sizes"
|
||||
@echo " make self-cleanup - Clean up Docker resources"
|
||||
|
|
|
|||
|
|
@ -62,7 +62,26 @@ database:
|
|||
logging:
|
||||
level: "info"
|
||||
file: "/logs/fetch_ml.log"
|
||||
audit_log: ""
|
||||
# Audit logging (HIPAA-compliant with tamper-evident chain hashing)
|
||||
audit:
|
||||
enabled: true
|
||||
file: "/var/log/fetch_ml/audit.log" # Separate file for audit events
|
||||
chain_hashing: true # Enable tamper-evident logging
|
||||
retention_days: 2555 # 7 years for HIPAA compliance
|
||||
log_ip_address: true # Include source IP in audit events
|
||||
log_user_agent: true # Include user agent in audit events
|
||||
# Sensitive events to always log
|
||||
events:
|
||||
- "authentication_success"
|
||||
- "authentication_failure"
|
||||
- "file_access"
|
||||
- "file_write"
|
||||
- "file_delete"
|
||||
- "job_queued"
|
||||
- "job_started"
|
||||
- "job_completed"
|
||||
- "experiment_created"
|
||||
- "experiment_deleted"
|
||||
|
||||
resources:
|
||||
max_workers: 1
|
||||
|
|
|
|||
|
|
@ -30,3 +30,30 @@ scheduler:
|
|||
token: "wkr_PLACEHOLDER_GENERATE_WITH_OPENSSL_RAND_HEX_32"
|
||||
- id: "worker-02"
|
||||
token: "wkr_PLACEHOLDER_GENERATE_WITH_OPENSSL_RAND_HEX_32"
|
||||
|
||||
# Plugin GPU Quota Configuration
|
||||
# Controls GPU allocation for plugin-based services (Jupyter, vLLM, etc.)
|
||||
plugin_quota:
|
||||
enabled: false # Enable quota enforcement (default: false)
|
||||
total_gpus: 16 # Global GPU limit across all plugins (0 = unlimited)
|
||||
per_user_gpus: 4 # Default per-user GPU limit (0 = unlimited)
|
||||
per_user_services: 2 # Default per-user service count limit (0 = unlimited)
|
||||
|
||||
# Plugin-specific limits (optional)
|
||||
per_plugin_limits:
|
||||
vllm:
|
||||
max_gpus: 8 # Max GPUs for vLLM across all users
|
||||
max_services: 4 # Max vLLM service instances
|
||||
jupyter:
|
||||
max_gpus: 4 # Max GPUs for Jupyter across all users
|
||||
max_services: 10 # Max Jupyter service instances
|
||||
|
||||
# Per-user overrides (optional)
|
||||
user_overrides:
|
||||
admin:
|
||||
max_gpus: 8 # Admin gets more GPUs
|
||||
max_services: 5 # Admin can run more services
|
||||
allowed_plugins: ["jupyter", "vllm"] # Restrict which plugins user can use
|
||||
researcher1:
|
||||
max_gpus: 2 # Limited GPU access
|
||||
max_services: 1 # Single service limit
|
||||
|
|
|
|||
|
|
@ -48,6 +48,39 @@ queue:
|
|||
native:
|
||||
data_dir: "data/dev/queue"
|
||||
|
||||
# Plugin Configuration (for local development)
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab"
|
||||
# Security settings
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
blocked_packages: [] # Less restrictive for local dev
|
||||
require_password: false # No password for local dev
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "4Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
# Model cache location
|
||||
model_cache: "data/dev/models"
|
||||
# Supported quantization methods: awq, gptq, fp8, squeezellm
|
||||
default_quantization: "" # No quantization for dev (better quality)
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_model_len: 2048
|
||||
tensor_parallel_size: 1
|
||||
|
||||
task_lease_duration: "30m"
|
||||
heartbeat_interval: "1m"
|
||||
max_retries: 3
|
||||
|
|
|
|||
|
|
@ -50,7 +50,40 @@ resources:
|
|||
metrics:
|
||||
enabled: true
|
||||
listen_addr: ":9100"
|
||||
metrics_flush_interval: "500ms"
|
||||
metrics_flush_interval: "500ms"
|
||||
|
||||
# Plugin Configuration
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab"
|
||||
# Security settings
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
blocked_packages: [] # Dev environment - less restrictive
|
||||
require_password: false # No password for dev
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "4Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
# Model cache location
|
||||
model_cache: "/models"
|
||||
# Supported quantization methods: awq, gptq, fp8, squeezellm
|
||||
default_quantization: "" # No quantization for dev
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_model_len: 2048
|
||||
tensor_parallel_size: 1
|
||||
|
||||
task_lease_duration: "30m"
|
||||
heartbeat_interval: "1m"
|
||||
|
|
|
|||
|
|
@ -48,3 +48,42 @@ task_lease_duration: "30m"
|
|||
heartbeat_interval: "1m"
|
||||
max_retries: 3
|
||||
graceful_timeout: "5m"
|
||||
|
||||
# Plugin Configuration
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
# Security settings
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
- "pytorch"
|
||||
blocked_packages:
|
||||
- "requests"
|
||||
- "urllib3"
|
||||
- "httpx"
|
||||
require_password: true
|
||||
# Resource limits (enforced by scheduler quota system)
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "8Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
# Model cache location
|
||||
model_cache: "/models"
|
||||
# Supported quantization methods: awq, gptq, fp8, squeezellm
|
||||
default_quantization: "" # empty = no quantization
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 4
|
||||
max_model_len: 4096
|
||||
# Environment variables passed to container
|
||||
env:
|
||||
- "HF_HOME=/models"
|
||||
- "VLLM_WORKER_MULTIPROC_METHOD=spawn"
|
||||
|
|
|
|||
|
|
@ -48,6 +48,46 @@ queue:
|
|||
backend: "redis"
|
||||
redis_url: "redis://localhost:6379/0"
|
||||
|
||||
# Plugin Configuration
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab" # "lab" or "notebook"
|
||||
# Security settings
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
- "pytorch"
|
||||
- "nvidia"
|
||||
blocked_packages:
|
||||
- "requests"
|
||||
- "urllib3"
|
||||
- "httpx"
|
||||
- "socket"
|
||||
- "subprocess"
|
||||
require_password: true
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "16Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
# Model cache location (should be on fast storage)
|
||||
model_cache: "/var/lib/fetchml/models"
|
||||
# Supported quantization methods: awq, gptq, fp8, squeezellm
|
||||
default_quantization: ""
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 2
|
||||
max_model_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
|
||||
# Snapshot store (optional)
|
||||
snapshot_store:
|
||||
enabled: false
|
||||
|
|
|
|||
|
|
@ -45,3 +45,42 @@ task_lease_duration: "30m"
|
|||
heartbeat_interval: "1m"
|
||||
max_retries: 3
|
||||
graceful_timeout: "5m"
|
||||
|
||||
# Plugin Configuration
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab"
|
||||
# Security settings (strict for secure config)
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
blocked_packages:
|
||||
- "requests"
|
||||
- "urllib3"
|
||||
- "httpx"
|
||||
- "socket"
|
||||
- "subprocess"
|
||||
- "os.system"
|
||||
require_password: true
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "8Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
# Model cache location
|
||||
model_cache: "/models"
|
||||
# Supported quantization methods: awq, gptq, fp8, squeezellm
|
||||
default_quantization: ""
|
||||
# Resource limits
|
||||
max_gpu_per_instance: 1
|
||||
max_model_len: 4096
|
||||
tensor_parallel_size: 1
|
||||
|
|
|
|||
|
|
@ -45,3 +45,34 @@ podman_memory = "16g"
|
|||
[metrics]
|
||||
enabled = true
|
||||
listen_addr = ":9100"
|
||||
|
||||
# Plugin Configuration
|
||||
[plugins]
|
||||
|
||||
[plugins.jupyter]
|
||||
enabled = true
|
||||
image = "quay.io/jupyter/base-notebook:latest"
|
||||
default_port = 8888
|
||||
mode = "lab"
|
||||
max_gpu_per_instance = 1
|
||||
max_memory_per_instance = "8Gi"
|
||||
|
||||
[plugins.jupyter.security]
|
||||
require_password = true
|
||||
trusted_channels = ["conda-forge", "defaults", "pytorch"]
|
||||
blocked_packages = ["requests", "urllib3", "httpx"]
|
||||
|
||||
[plugins.vllm]
|
||||
enabled = true
|
||||
image = "vllm/vllm-openai:latest"
|
||||
default_port = 8000
|
||||
model_cache = "/models"
|
||||
default_quantization = "" # Options: awq, gptq, fp8, squeezellm
|
||||
max_gpu_per_instance = 2
|
||||
max_model_len = 4096
|
||||
tensor_parallel_size = 1
|
||||
|
||||
# Environment variables for vLLM
|
||||
[plugins.vllm.env]
|
||||
HF_HOME = "/models"
|
||||
VLLM_WORKER_MULTIPROC_METHOD = "spawn"
|
||||
|
|
|
|||
|
|
@ -110,6 +110,36 @@ TLS_KEY_PATH=/app/ssl/key.pem
|
|||
| Prometheus | 9090 | - | - |
|
||||
| Grafana | 3000 | - | - |
|
||||
| Loki | 3100 | - | - |
|
||||
| JupyterLab | 8888* | 8888* | - |
|
||||
| vLLM | 8000* | 8000* | - |
|
||||
|
||||
*Plugin service ports are dynamically allocated from the 8000-9000 range by the scheduler.
|
||||
|
||||
## Plugin Services
|
||||
|
||||
The deployment configurations include support for interactive ML services:
|
||||
|
||||
### Jupyter Notebook/Lab
|
||||
- **Image**: `quay.io/jupyter/base-notebook:latest`
|
||||
- **Security**: Trusted channels (conda-forge, defaults), blocked packages (http clients)
|
||||
- **Resources**: Configurable GPU/memory limits
|
||||
- **Access**: Via scheduler-assigned port (8000-9000 range)
|
||||
|
||||
### vLLM Inference
|
||||
- **Image**: `vllm/vllm-openai:latest`
|
||||
- **Features**: OpenAI-compatible API, quantization support (AWQ, GPTQ, FP8)
|
||||
- **Model Cache**: Configurable path for model storage
|
||||
- **Resources**: Multi-GPU tensor parallelism support
|
||||
|
||||
## Scheduler GPU Quotas
|
||||
|
||||
The scheduler supports GPU quota management for plugin services:
|
||||
- **Global Limit**: Total GPUs across all plugins
|
||||
- **Per-User Limits**: GPU and service count per user
|
||||
- **Per-Plugin Limits**: vLLM and Jupyter-specific limits
|
||||
- **User Overrides**: Special permissions for admins/researchers
|
||||
|
||||
See `configs/scheduler/scheduler.yaml.example` for quota configuration.
|
||||
|
||||
## Monitoring
|
||||
|
||||
|
|
@ -122,3 +152,4 @@ TLS_KEY_PATH=/app/ssl/key.pem
|
|||
- If you need HTTPS externally, terminate TLS at a reverse proxy.
|
||||
- API keys should be managed via environment variables
|
||||
- Database credentials should use secrets management in production
|
||||
- **HIPAA deployments**: Plugins are disabled by default for compliance
|
||||
|
|
|
|||
|
|
@ -29,3 +29,30 @@ max_artifact_total_bytes: 1073741824 # 1GB
|
|||
|
||||
# Provenance (disabled in dev for speed)
|
||||
provenance_best_effort: false
|
||||
|
||||
# Plugin Configuration (development mode)
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab"
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
blocked_packages: [] # No restrictions in dev
|
||||
require_password: false # No password for dev
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "4Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
model_cache: "/tmp/models" # Temp location for dev
|
||||
default_quantization: "" # No quantization for dev
|
||||
max_gpu_per_instance: 1
|
||||
max_model_len: 2048
|
||||
|
|
|
|||
|
|
@ -51,3 +51,12 @@ ssh_key: ${SSH_KEY_PATH}
|
|||
|
||||
# Config hash computation enabled (required for audit)
|
||||
# This is automatically computed by Validate()
|
||||
|
||||
# Plugin Configuration (DISABLED for HIPAA compliance)
|
||||
# Jupyter and vLLM services are disabled in HIPAA mode to ensure
|
||||
# no unauthorized network access or data processing
|
||||
plugins:
|
||||
jupyter:
|
||||
enabled: false # Disabled: HIPAA requires strict network isolation
|
||||
vllm:
|
||||
enabled: false # Disabled: External model downloads violate PHI controls
|
||||
|
|
|
|||
|
|
@ -33,3 +33,32 @@ max_artifact_total_bytes: 536870912 # 512MB
|
|||
|
||||
# Provenance (enabled)
|
||||
provenance_best_effort: true
|
||||
|
||||
# Plugin Configuration
|
||||
plugins:
|
||||
# Jupyter Notebook/Lab Service
|
||||
jupyter:
|
||||
enabled: true
|
||||
image: "quay.io/jupyter/base-notebook:latest"
|
||||
default_port: 8888
|
||||
mode: "lab"
|
||||
security:
|
||||
trusted_channels:
|
||||
- "conda-forge"
|
||||
- "defaults"
|
||||
blocked_packages:
|
||||
- "requests"
|
||||
- "urllib3"
|
||||
require_password: true
|
||||
max_gpu_per_instance: 1
|
||||
max_memory_per_instance: "8Gi"
|
||||
|
||||
# vLLM Inference Service
|
||||
vllm:
|
||||
enabled: true
|
||||
image: "vllm/vllm-openai:latest"
|
||||
default_port: 8000
|
||||
model_cache: "/models"
|
||||
default_quantization: ""
|
||||
max_gpu_per_instance: 1
|
||||
max_model_len: 4096
|
||||
|
|
|
|||
|
|
@ -33,11 +33,13 @@ type SchedulerHub struct {
|
|||
reservations map[string]*Reservation
|
||||
multiNodePending map[string]*MultiNodeJob
|
||||
pendingAcceptance map[string]*JobAssignment
|
||||
runningTasks map[string]*Task // Track assigned+accepted tasks
|
||||
state *StateStore
|
||||
starvation *StarvationTracker
|
||||
metrics *SchedulerMetrics
|
||||
auditor *audit.Logger
|
||||
tokenValidator *TokenValidator
|
||||
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
|
||||
config HubConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
|
@ -59,6 +61,7 @@ type HubConfig struct {
|
|||
AcceptanceTimeoutSecs int
|
||||
LocalMode bool
|
||||
WorkerTokens map[string]string // token -> workerID
|
||||
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
|
||||
}
|
||||
|
||||
// WorkerConn represents a connected worker
|
||||
|
|
@ -109,6 +112,7 @@ type JobAssignment struct {
|
|||
AssignedAt time.Time
|
||||
AcceptanceDeadline time.Time
|
||||
Accepted bool
|
||||
Task *Task // Reference to the task (removed from queue)
|
||||
}
|
||||
|
||||
// StarvationTracker monitors long-waiting jobs
|
||||
|
|
@ -154,6 +158,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
|||
reservations: make(map[string]*Reservation),
|
||||
multiNodePending: make(map[string]*MultiNodeJob),
|
||||
pendingAcceptance: make(map[string]*JobAssignment),
|
||||
runningTasks: make(map[string]*Task),
|
||||
state: state,
|
||||
starvation: &StarvationTracker{
|
||||
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
||||
|
|
@ -163,6 +168,7 @@ func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
|||
},
|
||||
auditor: auditor,
|
||||
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
||||
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
|
||||
config: cfg,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
|
@ -431,9 +437,6 @@ func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
|
|||
}
|
||||
|
||||
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
for _, res := range h.reservations {
|
||||
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
||||
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
||||
|
|
@ -449,7 +452,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
h.batchQueue.Remove(task.ID)
|
||||
h.serviceQueue.Remove(task.ID)
|
||||
|
||||
// Track pending acceptance
|
||||
// Track pending acceptance with task reference
|
||||
h.mu.Lock()
|
||||
h.pendingAcceptance[task.ID] = &JobAssignment{
|
||||
TaskID: task.ID,
|
||||
|
|
@ -457,6 +460,7 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
AssignedAt: time.Now(),
|
||||
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||
Accepted: false,
|
||||
Task: task, // Store reference since removed from queue
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
|
|
@ -473,12 +477,31 @@ func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
|
||||
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
||||
assignment.Accepted = true
|
||||
|
||||
// Track as running task
|
||||
task := assignment.Task
|
||||
if task != nil {
|
||||
task.Status = "running"
|
||||
task.WorkerID = workerID
|
||||
h.runningTasks[taskID] = task
|
||||
}
|
||||
|
||||
// NEW: Record quota usage for service jobs
|
||||
if task != nil && task.Spec.Type == JobTypeService {
|
||||
if h.quotaManager != nil {
|
||||
pluginName := task.Spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -486,7 +509,19 @@ func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload)
|
|||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// NEW: Release quota usage for service jobs before deleting pending acceptance
|
||||
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
|
||||
if h.quotaManager != nil {
|
||||
pluginName := task.Spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
||||
}
|
||||
}
|
||||
|
||||
delete(h.pendingAcceptance, result.TaskID)
|
||||
delete(h.runningTasks, result.TaskID)
|
||||
|
||||
eventType := EventJobCompleted
|
||||
switch result.State {
|
||||
|
|
@ -519,7 +554,10 @@ func (h *SchedulerHub) checkAcceptanceTimeouts() {
|
|||
h.mu.Lock()
|
||||
for taskID, a := range h.pendingAcceptance {
|
||||
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
||||
h.batchQueue.Add(h.getTask(taskID))
|
||||
if a.Task != nil {
|
||||
a.Task.Status = "queued"
|
||||
h.batchQueue.Add(a.Task)
|
||||
}
|
||||
delete(h.pendingAcceptance, taskID)
|
||||
if wc, ok := h.workers[a.WorkerID]; ok {
|
||||
wc.slots = SlotStatus{}
|
||||
|
|
@ -572,22 +610,31 @@ func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
|
|||
st.mu.Lock()
|
||||
defer st.mu.Unlock()
|
||||
|
||||
// First check which tasks need reservation under h.mu.RLock
|
||||
tasksToReserve := make([]*Task, 0)
|
||||
h.mu.RLock()
|
||||
for _, task := range h.batchQueue.Items() {
|
||||
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
|
||||
h.mu.Lock()
|
||||
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
|
||||
tasksToReserve = append(tasksToReserve, task)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Now acquire Lock to add reservations
|
||||
if len(tasksToReserve) > 0 {
|
||||
h.mu.Lock()
|
||||
for _, task := range tasksToReserve {
|
||||
h.reservations[task.ID] = &Reservation{
|
||||
TaskID: task.ID,
|
||||
GPUCount: task.Spec.GPUCount,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
|
||||
_, exists := h.reservations[taskID]
|
||||
return exists
|
||||
}
|
||||
|
|
@ -605,6 +652,17 @@ func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
|
|||
return fmt.Errorf("job ID is required")
|
||||
}
|
||||
|
||||
// NEW: Check plugin quotas for service jobs
|
||||
if spec.Type == JobTypeService && h.quotaManager != nil {
|
||||
pluginName := spec.Metadata["plugin_name"]
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
|
||||
return fmt.Errorf("quota exceeded: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
task := &Task{
|
||||
ID: spec.ID,
|
||||
Spec: spec,
|
||||
|
|
@ -639,7 +697,11 @@ func (h *SchedulerHub) getTask(taskID string) *Task {
|
|||
if t != nil {
|
||||
return t
|
||||
}
|
||||
return h.serviceQueue.Get(taskID)
|
||||
t = h.serviceQueue.Get(taskID)
|
||||
if t != nil {
|
||||
return t
|
||||
}
|
||||
return h.runningTasks[taskID]
|
||||
}
|
||||
|
||||
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
||||
|
|
@ -718,7 +780,7 @@ func (h *SchedulerHub) reconcileOrphans() {
|
|||
if assignment.Accepted {
|
||||
// Job was accepted but worker is gone (not in h.workers)
|
||||
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
||||
task := h.getTask(taskID)
|
||||
task := assignment.Task
|
||||
if task != nil {
|
||||
task.Status = "orphaned"
|
||||
h.batchQueue.Add(task)
|
||||
|
|
@ -776,7 +838,7 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
|||
}
|
||||
|
||||
if msg.Type == MsgMetricsRequest {
|
||||
metrics := h.getMetricsPayload()
|
||||
metrics := h.GetMetricsPayload()
|
||||
conn.WriteJSON(Message{
|
||||
Type: MsgMetricsResponse,
|
||||
Payload: mustMarshal(metrics),
|
||||
|
|
@ -785,8 +847,8 @@ func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
// getMetricsPayload returns current metrics as a map
|
||||
func (h *SchedulerHub) getMetricsPayload() map[string]any {
|
||||
// GetMetricsPayload returns current metrics as a map (public API)
|
||||
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
|
||||
h.metrics.mu.RLock()
|
||||
defer h.metrics.mu.RUnlock()
|
||||
|
||||
|
|
@ -901,7 +963,7 @@ func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
|
|||
|
||||
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
||||
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
||||
// Clone the spec and add rank info to metadata
|
||||
// Clone the spec and add rank info to metadata and env
|
||||
spec := task.Spec
|
||||
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
||||
for k, v := range task.Spec.Metadata {
|
||||
|
|
@ -910,6 +972,14 @@ func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, wo
|
|||
spec.Metadata["HEAD_ADDR"] = headAddr
|
||||
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||
|
||||
// Also set in Env for job runtime
|
||||
if spec.Env == nil {
|
||||
spec.Env = make(map[string]string)
|
||||
}
|
||||
spec.Env["HEAD_ADDR"] = headAddr
|
||||
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||
return spec
|
||||
}
|
||||
|
||||
|
|
|
|||
287
internal/scheduler/plugin_quota.go
Normal file
287
internal/scheduler/plugin_quota.go
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PluginQuotaConfig defines GPU limits for plugins.
|
||||
type PluginQuotaConfig struct {
|
||||
Enabled bool // Master switch for quota enforcement
|
||||
TotalGPUs int // Global GPU limit across all plugins
|
||||
PerUserGPUs int // Default per-user GPU limit
|
||||
PerUserServices int // Default per-user service count limit
|
||||
PerPluginLimits map[string]PluginLimit // Plugin-specific overrides
|
||||
UserOverrides map[string]UserLimit // Per-user overrides
|
||||
}
|
||||
|
||||
// PluginLimit defines limits for a specific plugin.
|
||||
type PluginLimit struct {
|
||||
MaxGPUs int
|
||||
MaxServices int
|
||||
}
|
||||
|
||||
// UserLimit defines per-user override limits.
|
||||
type UserLimit struct {
|
||||
MaxGPUs int
|
||||
MaxServices int
|
||||
AllowedPlugins []string // Empty = all plugins allowed
|
||||
}
|
||||
|
||||
// PluginUsage tracks GPU and service count for a user-plugin combination.
|
||||
type PluginUsage struct {
|
||||
GPUs int
|
||||
Services int
|
||||
}
|
||||
|
||||
// PluginQuotaManager tracks active usage and enforces quotas.
|
||||
type PluginQuotaManager struct {
|
||||
config PluginQuotaConfig
|
||||
mu sync.RWMutex
|
||||
usage map[string]map[string]PluginUsage // userID -> pluginName -> usage
|
||||
pluginTotal map[string]int // pluginName -> total GPUs in use
|
||||
totalGPUs int // global total GPUs in use
|
||||
}
|
||||
|
||||
// NewPluginQuotaManager creates a new quota manager with the given configuration.
|
||||
func NewPluginQuotaManager(config PluginQuotaConfig) *PluginQuotaManager {
|
||||
return &PluginQuotaManager{
|
||||
config: config,
|
||||
usage: make(map[string]map[string]PluginUsage),
|
||||
pluginTotal: make(map[string]int),
|
||||
totalGPUs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckQuota validates if a job can be submitted without exceeding limits.
|
||||
// Returns nil if the job is allowed, or an error describing which limit would be exceeded.
|
||||
func (m *PluginQuotaManager) CheckQuota(userID, pluginName string, gpuCount int) error {
|
||||
if !m.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Get user limits (with overrides)
|
||||
userLimit := m.getUserLimit(userID)
|
||||
|
||||
// Check if user is allowed to use this plugin
|
||||
if len(userLimit.AllowedPlugins) > 0 {
|
||||
found := false
|
||||
for _, p := range userLimit.AllowedPlugins {
|
||||
if p == pluginName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("user %s is not allowed to use plugin %s", userID, pluginName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check plugin-specific limits
|
||||
pluginLimit, hasPluginLimit := m.config.PerPluginLimits[pluginName]
|
||||
if hasPluginLimit {
|
||||
if pluginLimit.MaxGPUs > 0 && m.pluginTotal[pluginName]+gpuCount > pluginLimit.MaxGPUs {
|
||||
return fmt.Errorf("plugin %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||
pluginName, gpuCount, pluginLimit.MaxGPUs-m.pluginTotal[pluginName], pluginLimit.MaxGPUs)
|
||||
}
|
||||
if pluginLimit.MaxServices > 0 {
|
||||
// Services limit is across all users for this plugin
|
||||
totalServices := 0
|
||||
for _, u := range m.usage {
|
||||
if p, ok := u[pluginName]; ok {
|
||||
totalServices += p.Services
|
||||
}
|
||||
}
|
||||
if totalServices+1 > pluginLimit.MaxServices {
|
||||
return fmt.Errorf("plugin %s service limit exceeded", pluginName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check per-user limits
|
||||
effectiveUserGPUs := userLimit.MaxGPUs
|
||||
if effectiveUserGPUs == 0 {
|
||||
effectiveUserGPUs = m.config.PerUserGPUs
|
||||
}
|
||||
effectiveUserServices := userLimit.MaxServices
|
||||
if effectiveUserServices == 0 {
|
||||
effectiveUserServices = m.config.PerUserServices
|
||||
}
|
||||
|
||||
// Calculate total user usage across all plugins
|
||||
totalUserGPUs := 0
|
||||
totalUserServices := 0
|
||||
for _, p := range m.usage[userID] {
|
||||
totalUserGPUs += p.GPUs
|
||||
totalUserServices += p.Services
|
||||
}
|
||||
|
||||
if effectiveUserGPUs > 0 && totalUserGPUs+gpuCount > effectiveUserGPUs {
|
||||
return fmt.Errorf("user %s GPU limit exceeded: %d requested, %d available of %d total",
|
||||
userID, gpuCount, effectiveUserGPUs-totalUserGPUs, effectiveUserGPUs)
|
||||
}
|
||||
|
||||
if effectiveUserServices > 0 && totalUserServices+1 > effectiveUserServices {
|
||||
return fmt.Errorf("user %s service limit exceeded: %d services of %d allowed",
|
||||
userID, totalUserServices+1, effectiveUserServices)
|
||||
}
|
||||
|
||||
// Check global total GPU limit
|
||||
if m.config.TotalGPUs > 0 && m.totalGPUs+gpuCount > m.config.TotalGPUs {
|
||||
return fmt.Errorf("global GPU limit exceeded: %d requested, %d available of %d total",
|
||||
gpuCount, m.config.TotalGPUs-m.totalGPUs, m.config.TotalGPUs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordUsage increments usage counters when a job starts.
|
||||
func (m *PluginQuotaManager) RecordUsage(userID, pluginName string, gpuCount int) {
|
||||
if !m.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userPlugins, ok := m.usage[userID]
|
||||
if !ok {
|
||||
userPlugins = make(map[string]PluginUsage)
|
||||
m.usage[userID] = userPlugins
|
||||
}
|
||||
|
||||
usage := userPlugins[pluginName]
|
||||
usage.GPUs += gpuCount
|
||||
usage.Services++
|
||||
userPlugins[pluginName] = usage
|
||||
|
||||
m.pluginTotal[pluginName] += gpuCount
|
||||
m.totalGPUs += gpuCount
|
||||
}
|
||||
|
||||
// ReleaseUsage decrements usage counters when a job stops.
|
||||
func (m *PluginQuotaManager) ReleaseUsage(userID, pluginName string, gpuCount int) {
|
||||
if !m.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if pluginName == "" {
|
||||
pluginName = "default"
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userPlugins, ok := m.usage[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
usage := userPlugins[pluginName]
|
||||
usage.GPUs -= gpuCount
|
||||
usage.Services--
|
||||
|
||||
if usage.GPUs < 0 {
|
||||
usage.GPUs = 0
|
||||
}
|
||||
if usage.Services < 0 {
|
||||
usage.Services = 0
|
||||
}
|
||||
|
||||
if usage.GPUs == 0 && usage.Services == 0 {
|
||||
delete(userPlugins, pluginName)
|
||||
} else {
|
||||
userPlugins[pluginName] = usage
|
||||
}
|
||||
|
||||
if len(userPlugins) == 0 {
|
||||
delete(m.usage, userID)
|
||||
}
|
||||
|
||||
m.pluginTotal[pluginName] -= gpuCount
|
||||
if m.pluginTotal[pluginName] < 0 {
|
||||
m.pluginTotal[pluginName] = 0
|
||||
}
|
||||
|
||||
m.totalGPUs -= gpuCount
|
||||
if m.totalGPUs < 0 {
|
||||
m.totalGPUs = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsage returns current usage for a user across all plugins.
|
||||
func (m *PluginQuotaManager) GetUsage(userID string) (map[string]PluginUsage, int) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
result := make(map[string]PluginUsage)
|
||||
totalGPUs := 0
|
||||
|
||||
if userPlugins, ok := m.usage[userID]; ok {
|
||||
for plugin, usage := range userPlugins {
|
||||
result[plugin] = usage
|
||||
totalGPUs += usage.GPUs
|
||||
}
|
||||
}
|
||||
|
||||
return result, totalGPUs
|
||||
}
|
||||
|
||||
// GetGlobalUsage returns global GPU usage across all users and plugins.
|
||||
func (m *PluginQuotaManager) GetGlobalUsage() (int, map[string]int) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
pluginTotals := make(map[string]int, len(m.pluginTotal))
|
||||
for k, v := range m.pluginTotal {
|
||||
pluginTotals[k] = v
|
||||
}
|
||||
|
||||
return m.totalGPUs, pluginTotals
|
||||
}
|
||||
|
||||
// getUserLimit returns the effective limits for a user, applying overrides.
|
||||
func (m *PluginQuotaManager) getUserLimit(userID string) UserLimit {
|
||||
if override, ok := m.config.UserOverrides[userID]; ok {
|
||||
return override
|
||||
}
|
||||
return UserLimit{
|
||||
MaxGPUs: m.config.PerUserGPUs,
|
||||
MaxServices: m.config.PerUserServices,
|
||||
}
|
||||
}
|
||||
|
||||
// getUsageLocked returns the current usage for a user-plugin combination.
|
||||
// Must be called with read lock held.
|
||||
func (m *PluginQuotaManager) getUsageLocked(userID, pluginName string) PluginUsage {
|
||||
if userPlugins, ok := m.usage[userID]; ok {
|
||||
if usage, ok := userPlugins[pluginName]; ok {
|
||||
return usage
|
||||
}
|
||||
}
|
||||
return PluginUsage{}
|
||||
}
|
||||
|
|
@ -100,6 +100,7 @@ type JobSpec struct {
|
|||
ID string `json:"id"`
|
||||
Type JobType `json:"type"` // "batch" | "service"
|
||||
SlotPool string `json:"slot_pool"`
|
||||
UserID string `json:"user_id,omitempty"` // NEW: for per-user quota tracking
|
||||
|
||||
GPUCount int `json:"gpu_count"`
|
||||
GPUType string `json:"gpu_type,omitempty"`
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ func TestMultiNodeGangAllocation(t *testing.T) {
|
|||
ID: w.id,
|
||||
Capabilities: scheduler.WorkerCapabilities{
|
||||
GPUCount: 0,
|
||||
Hostname: "localhost",
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
|
@ -172,10 +173,8 @@ func TestServiceLifecycle(t *testing.T) {
|
|||
|
||||
// Send job accepted
|
||||
conn.WriteJSON(scheduler.Message{
|
||||
Type: scheduler.MsgJobAccepted,
|
||||
Payload: mustMarshal(map[string]string{
|
||||
"task_id": jobID,
|
||||
}),
|
||||
Type: scheduler.MsgJobAccepted,
|
||||
Payload: mustMarshal(jobID),
|
||||
})
|
||||
|
||||
// Send periodic health updates
|
||||
|
|
|
|||
|
|
@ -307,6 +307,7 @@ func startFakeWorkers(
|
|||
go func(workerID string) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
// Check for cancellation before getting next task
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
|
@ -324,6 +325,7 @@ func startFakeWorkers(
|
|||
continue
|
||||
}
|
||||
|
||||
// Process task - finish even if context is cancelled
|
||||
started := time.Now()
|
||||
completed := started.Add(10 * time.Millisecond)
|
||||
|
||||
|
|
@ -338,7 +340,16 @@ func startFakeWorkers(
|
|||
continue
|
||||
}
|
||||
|
||||
doneCh <- task.JobName
|
||||
// Only send to doneCh if we successfully completed
|
||||
select {
|
||||
case doneCh <- task.JobName:
|
||||
case <-ctx.Done():
|
||||
// Context cancelled but we still completed the task - try once more without blocking
|
||||
select {
|
||||
case doneCh <- task.JobName:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}(fmt.Sprintf("worker-%d", w))
|
||||
}
|
||||
|
|
|
|||
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
385
tests/unit/scheduler/plugin_quota_test.go
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
package scheduler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_Disabled(t *testing.T) {
|
||||
// When quota is disabled, all jobs should pass
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: false,
|
||||
TotalGPUs: 1, // Set a low limit that would fail if enabled
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
err := m.CheckQuota("user1", "plugin1", 100)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_GlobalLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 4,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// First job should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Record the usage
|
||||
m.RecordUsage("user1", "plugin1", 2)
|
||||
|
||||
// Second job should succeed (2+2=4, within limit)
|
||||
err = m.CheckQuota("user2", "plugin2", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user2", "plugin2", 2)
|
||||
|
||||
// Third job should fail (would exceed global limit)
|
||||
err = m.CheckQuota("user3", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "global GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PerUserGPULimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 3,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// User1: first job should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin1", 2)
|
||||
|
||||
// User1: second job should succeed (2+1=3, at limit)
|
||||
err = m.CheckQuota("user1", "plugin2", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin2", 1)
|
||||
|
||||
// User1: third job should fail (would exceed per-user limit)
|
||||
err = m.CheckQuota("user1", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user user1 GPU limit exceeded")
|
||||
|
||||
// User2: job should succeed (different user)
|
||||
err = m.CheckQuota("user2", "plugin1", 3)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PerUserServiceLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerUserServices: 2,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// User1: first service should succeed
|
||||
err := m.CheckQuota("user1", "plugin1", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin1", 1)
|
||||
|
||||
// User1: second service should succeed
|
||||
err = m.CheckQuota("user1", "plugin2", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "plugin2", 1)
|
||||
|
||||
// User1: third service should fail (would exceed service count limit)
|
||||
err = m.CheckQuota("user1", "plugin3", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user user1 service limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_UserOverride(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 2,
|
||||
PerUserServices: 2,
|
||||
UserOverrides: map[string]scheduler.UserLimit{
|
||||
"vip-user": {
|
||||
MaxGPUs: 5,
|
||||
MaxServices: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Regular user: limited by default
|
||||
err := m.CheckQuota("regular", "plugin1", 3)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "regular GPU limit exceeded")
|
||||
|
||||
// VIP user: has higher limit
|
||||
err = m.CheckQuota("vip-user", "plugin1", 4)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("vip-user", "plugin1", 4)
|
||||
|
||||
// VIP user: still within limit
|
||||
err = m.CheckQuota("vip-user", "plugin2", 1)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PluginSpecificLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"jupyter": {
|
||||
MaxGPUs: 3,
|
||||
MaxServices: 2,
|
||||
},
|
||||
"vllm": {
|
||||
MaxGPUs: 8,
|
||||
MaxServices: 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Jupyter: within plugin GPU limit
|
||||
err := m.CheckQuota("user1", "jupyter", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "jupyter", 2)
|
||||
|
||||
// Jupyter: exceed plugin GPU limit (but within global and user limits)
|
||||
err = m.CheckQuota("user2", "jupyter", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin jupyter GPU limit exceeded")
|
||||
|
||||
// vLLM: within its higher limit
|
||||
err = m.CheckQuota("user1", "vllm", 4)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_PluginServiceLimit(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
PerUserServices: 10,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"jupyter": {
|
||||
MaxGPUs: 10,
|
||||
MaxServices: 2, // Only 2 jupyter services total
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// First jupyter service
|
||||
err := m.CheckQuota("user1", "jupyter", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "jupyter", 1)
|
||||
|
||||
// Second jupyter service (different user)
|
||||
err = m.CheckQuota("user2", "jupyter", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user2", "jupyter", 1)
|
||||
|
||||
// Third jupyter service should fail (plugin service limit reached)
|
||||
err = m.CheckQuota("user3", "jupyter", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin jupyter service limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_AllowedPlugins(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 10,
|
||||
UserOverrides: map[string]scheduler.UserLimit{
|
||||
"restricted-user": {
|
||||
MaxGPUs: 5,
|
||||
MaxServices: 5,
|
||||
AllowedPlugins: []string{"jupyter"},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Restricted user can use allowed plugin
|
||||
err := m.CheckQuota("restricted-user", "jupyter", 2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Restricted user cannot use other plugins
|
||||
err = m.CheckQuota("restricted-user", "vllm", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not allowed to use plugin vllm")
|
||||
|
||||
// Regular user can use any plugin
|
||||
err = m.CheckQuota("regular-user", "vllm", 2)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_RecordAndReleaseUsage(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 5,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Record usage
|
||||
m.RecordUsage("user1", "jupyter", 2)
|
||||
m.RecordUsage("user1", "vllm", 1)
|
||||
m.RecordUsage("user2", "jupyter", 3)
|
||||
|
||||
// Check usage tracking
|
||||
usage, totalGPUs := m.GetUsage("user1")
|
||||
assert.Equal(t, 2, usage["jupyter"].GPUs)
|
||||
assert.Equal(t, 1, usage["jupyter"].Services)
|
||||
assert.Equal(t, 1, usage["vllm"].GPUs)
|
||||
assert.Equal(t, 1, usage["vllm"].Services)
|
||||
assert.Equal(t, 3, totalGPUs)
|
||||
|
||||
// Check global usage
|
||||
globalGPUs, pluginTotals := m.GetGlobalUsage()
|
||||
assert.Equal(t, 6, globalGPUs)
|
||||
assert.Equal(t, 5, pluginTotals["jupyter"]) // 2+3
|
||||
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||
|
||||
// Release usage
|
||||
m.ReleaseUsage("user1", "jupyter", 2)
|
||||
|
||||
// Verify release
|
||||
usage, totalGPUs = m.GetUsage("user1")
|
||||
assert.Equal(t, 0, usage["jupyter"].GPUs)
|
||||
assert.Equal(t, 0, usage["jupyter"].Services)
|
||||
assert.Equal(t, 1, usage["vllm"].GPUs) // user1 still has vllm
|
||||
assert.Equal(t, 1, totalGPUs) // only vllm remains for user1
|
||||
|
||||
// Check global usage after release
|
||||
globalGPUs, pluginTotals = m.GetGlobalUsage()
|
||||
assert.Equal(t, 4, globalGPUs)
|
||||
assert.Equal(t, 3, pluginTotals["jupyter"]) // 3 from user2
|
||||
assert.Equal(t, 1, pluginTotals["vllm"])
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_RecordUsage_Disabled(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: false,
|
||||
TotalGPUs: 10,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Recording usage when disabled should not crash
|
||||
m.RecordUsage("user1", "plugin1", 5)
|
||||
|
||||
// Usage should be empty (not tracked)
|
||||
usage, totalGPUs := m.GetUsage("user1")
|
||||
assert.Empty(t, usage)
|
||||
assert.Equal(t, 0, totalGPUs)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_ReleaseUsage_NonExistent(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Releasing non-existent usage should not crash or go negative
|
||||
m.ReleaseUsage("nonexistent", "plugin1", 5)
|
||||
|
||||
// Global usage should remain 0
|
||||
globalGPUs, _ := m.GetGlobalUsage()
|
||||
assert.Equal(t, 0, globalGPUs)
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_AnonymousUser(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 2,
|
||||
PerUserServices: 2,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Empty userID should be treated as "anonymous"
|
||||
err := m.CheckQuota("", "plugin1", 2)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("", "plugin1", 2)
|
||||
|
||||
// Second request from anonymous should fail (at limit)
|
||||
err = m.CheckQuota("", "plugin1", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user anonymous GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_CheckQuota_DefaultPlugin(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 10,
|
||||
PerUserGPUs: 5,
|
||||
PerUserServices: 5,
|
||||
PerPluginLimits: map[string]scheduler.PluginLimit{
|
||||
"default": {
|
||||
MaxGPUs: 2,
|
||||
MaxServices: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Empty plugin name should be treated as "default"
|
||||
err := m.CheckQuota("user1", "", 1)
|
||||
require.NoError(t, err)
|
||||
m.RecordUsage("user1", "", 1)
|
||||
|
||||
// Exceed default plugin limit
|
||||
err = m.CheckQuota("user2", "", 2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "plugin default GPU limit exceeded")
|
||||
}
|
||||
|
||||
func TestPluginQuotaManager_ConcurrentAccess(t *testing.T) {
|
||||
config := scheduler.PluginQuotaConfig{
|
||||
Enabled: true,
|
||||
TotalGPUs: 100,
|
||||
PerUserGPUs: 50,
|
||||
PerUserServices: 50,
|
||||
}
|
||||
m := scheduler.NewPluginQuotaManager(config)
|
||||
|
||||
// Concurrently record usage from multiple goroutines
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(idx int) {
|
||||
user := "user"
|
||||
if idx%2 == 0 {
|
||||
user = "user1"
|
||||
} else {
|
||||
user = "user2"
|
||||
}
|
||||
m.RecordUsage(user, "plugin1", 1)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify totals
|
||||
globalGPUs, _ := m.GetGlobalUsage()
|
||||
assert.Equal(t, 10, globalGPUs)
|
||||
|
||||
usage1, _ := m.GetUsage("user1")
|
||||
assert.Equal(t, 5, usage1["plugin1"].GPUs)
|
||||
assert.Equal(t, 5, usage1["plugin1"].Services)
|
||||
|
||||
usage2, _ := m.GetUsage("user2")
|
||||
assert.Equal(t, 5, usage2["plugin1"].GPUs)
|
||||
assert.Equal(t, 5, usage2["plugin1"].Services)
|
||||
}
|
||||
Loading…
Reference in a new issue