fetch_ml/internal/scheduler/hub.go
Jeremie Fraeys da104367d6
Some checks failed
Build Pipeline / Build Binaries (push) Failing after 1m59s
Build Pipeline / Build Docker Images (push) Has been skipped
Build Pipeline / Sign HIPAA Config (push) Has been skipped
Build Pipeline / Generate SLSA Provenance (push) Has been skipped
Checkout test / test (push) Successful in 5s
CI Pipeline / Test (ubuntu-latest on self-hosted) (push) Failing after 1s
CI Pipeline / Dev Compose Smoke Test (push) Has been skipped
CI Pipeline / Security Scan (push) Has been skipped
CI Pipeline / Test Scripts (push) Has been skipped
CI Pipeline / Test Native Libraries (push) Has been skipped
CI Pipeline / Native Library Build Matrix (push) Has been skipped
Documentation / build-and-publish (push) Failing after 35s
CI Pipeline / Trigger Build Workflow (push) Failing after 0s
Security Scan / Security Analysis (push) Has been cancelled
Security Scan / Native Library Security (push) Has been cancelled
Verification & Maintenance / V.1 - Schema Drift Detection (push) Has been cancelled
Verification & Maintenance / V.4 - Custom Go Vet Analyzers (push) Has been cancelled
Verification & Maintenance / V.7 - Audit Chain Integrity (push) Has been cancelled
Verification & Maintenance / V.6 - Extended Security Scanning (push) Has been cancelled
Verification & Maintenance / V.10 - OpenSSF Scorecard (push) Has been cancelled
Verification & Maintenance / Verification Summary (push) Has been cancelled
feat: add Plugin GPU Quota implementation and tests
- Add plugin_quota.go with GPU quota management for scheduler

- Update scheduler hub and protocol for plugin support

- Add comprehensive plugin quota unit tests

- Update gang service and WebSocket queue integration tests
2026-02-26 14:35:05 -05:00

1000 lines
27 KiB
Go

package scheduler
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins (configurable in production)
},
}
// SchedulerHub manages worker connections and job scheduling
type SchedulerHub struct {
mu sync.RWMutex
workers map[string]*WorkerConn
readyWorkers map[string]*WorkerConn
batchQueue *PriorityQueue
serviceQueue *PriorityQueue
reservations map[string]*Reservation
multiNodePending map[string]*MultiNodeJob
pendingAcceptance map[string]*JobAssignment
runningTasks map[string]*Task // Track assigned+accepted tasks
state *StateStore
starvation *StarvationTracker
metrics *SchedulerMetrics
auditor *audit.Logger
tokenValidator *TokenValidator
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
config HubConfig
ctx context.Context
cancel context.CancelFunc
server *http.Server
listener net.Listener
}
type HubConfig struct {
BindAddr string
CertFile string
KeyFile string
AutoGenerateCerts bool
StateDir string
DefaultBatchSlots int
DefaultServiceSlots int
StarvationThresholdMins float64
PriorityAgingRate float64
GangAllocTimeoutSecs int
AcceptanceTimeoutSecs int
LocalMode bool
WorkerTokens map[string]string // token -> workerID
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
}
// WorkerConn represents a connected worker
type WorkerConn struct {
workerID string
conn *websocket.Conn
capabilities WorkerCapabilities
slots SlotStatus
lease *Lease
activeTasks map[string]struct{}
send chan Message
hub *SchedulerHub
mu sync.Mutex
}
// Lease tracks job ownership
type Lease struct {
TaskID string
WorkerID string
ExpiresAt time.Time
}
// Reservation prevents starvation of large jobs
type Reservation struct {
TaskID string
GPUCount int
CreatedAt time.Time
}
// MultiNodeJob tracks gang allocation state
type MultiNodeJob struct {
JobID string
TotalNodes int
Assignments []*NodeAssignment
CommittedAt time.Time
}
type NodeAssignment struct {
Worker *WorkerConn
Rank int
CommittedAt time.Time
}
// JobAssignment tracks acceptance state
type JobAssignment struct {
TaskID string
WorkerID string
AssignedAt time.Time
AcceptanceDeadline time.Time
Accepted bool
Task *Task // Reference to the task (removed from queue)
}
// StarvationTracker monitors long-waiting jobs
type StarvationTracker struct {
mu sync.RWMutex
threshold time.Duration
}
// SchedulerMetrics tracks scheduler statistics
type SchedulerMetrics struct {
mu sync.RWMutex
WorkersConnected int
QueueDepthBatch int
QueueDepthService int
JobsCompleted int
JobsFailed int
JobsCancelled int
WorkerSlots map[string]SlotStatus
}
// NewHub creates a new scheduler hub
func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
ctx, cancel := context.WithCancel(context.Background())
// Initialize state store
statePath := cfg.StateDir + "/scheduler.state"
state, err := NewStateStore(statePath)
if err != nil {
cancel()
return nil, fmt.Errorf("init state store: %w", err)
}
agingRate := cfg.PriorityAgingRate
if agingRate == 0 {
agingRate = 0.1
}
hub := &SchedulerHub{
workers: make(map[string]*WorkerConn),
readyWorkers: make(map[string]*WorkerConn),
batchQueue: NewPriorityQueue(agingRate),
serviceQueue: NewPriorityQueue(agingRate),
reservations: make(map[string]*Reservation),
multiNodePending: make(map[string]*MultiNodeJob),
pendingAcceptance: make(map[string]*JobAssignment),
runningTasks: make(map[string]*Task),
state: state,
starvation: &StarvationTracker{
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
},
metrics: &SchedulerMetrics{
WorkerSlots: make(map[string]SlotStatus),
},
auditor: auditor,
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
config: cfg,
ctx: ctx,
cancel: cancel,
}
return hub, nil
}
// Start initializes the scheduler, starts the HTTP server, and replays state
func (h *SchedulerHub) Start() error {
// Replay state first
events, err := h.state.Replay()
if err != nil {
return fmt.Errorf("state replay failed: %w", err)
}
for _, ev := range events {
switch ev.Type {
case EventJobEnqueued:
h.restoreJob(ev)
case EventJobAssigned:
h.restoreAssignment(ev)
case EventJobCompleted, EventJobFailed, EventJobCancelled:
// terminal — skip
}
}
// Start WSS server (unified protocol)
mux := http.NewServeMux()
mux.HandleFunc("/ws/worker", h.HandleConnection)
listener, err := net.Listen("tcp", h.config.BindAddr)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
h.listener = listener
h.server = &http.Server{Handler: mux}
// Auto-generate self-signed certs if requested
if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") {
certFile := h.config.StateDir + "/scheduler.crt"
keyFile := h.config.StateDir + "/scheduler.key"
if err := GenerateSelfSignedCert(certFile, keyFile); err != nil {
return fmt.Errorf("failed to generate self-signed cert: %w", err)
}
h.config.CertFile = certFile
h.config.KeyFile = keyFile
}
// Start with TLS if certificates are configured
if h.config.CertFile != "" && h.config.KeyFile != "" {
go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile)
} else {
go h.server.Serve(listener)
}
// Start background tasks
go h.checkAcceptanceTimeouts()
go h.checkGangTimeouts()
go h.checkStarvation()
// Grace period: workers have 30s to reconnect before assigned jobs are orphaned
time.AfterFunc(30*time.Second, h.reconcileOrphans)
return nil
}
// Addr returns the listening address of the scheduler
func (h *SchedulerHub) Addr() string {
if h.listener == nil {
return ""
}
return h.listener.Addr().String()
}
// Stop gracefully shuts down the scheduler
func (h *SchedulerHub) Stop() {
h.cancel()
h.state.Close()
}
// HandleConnection handles WSS connections from workers and metrics clients
func (h *SchedulerHub) HandleConnection(w http.ResponseWriter, r *http.Request) {
// Validate token
token := ExtractBearerToken(r.Header.Get("Authorization"))
clientID, ok := h.tokenValidator.Validate(token)
if !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
// Upgrade to WebSocket
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, "upgrade failed", http.StatusInternalServerError)
return
}
// Check if this is a metrics client (special token prefix)
if strings.HasPrefix(clientID, "metrics-") {
go h.runMetricsClient(clientID, conn)
return
}
go h.runWorker(clientID, conn)
}
func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) {
wc := &WorkerConn{
workerID: workerID,
conn: conn,
slots: SlotStatus{},
activeTasks: make(map[string]struct{}),
send: make(chan Message, 10),
hub: h,
}
h.mu.Lock()
h.workers[workerID] = wc
h.metrics.WorkersConnected++
h.mu.Unlock()
defer func() {
h.mu.Lock()
delete(h.workers, workerID)
delete(h.readyWorkers, workerID)
h.metrics.WorkersConnected--
h.mu.Unlock()
conn.Close()
}()
// Send loop
go func() {
for msg := range wc.send {
conn.WriteJSON(msg)
}
}()
// Receive loop
for {
var msg Message
if err := conn.ReadJSON(&msg); err != nil {
return // Connection closed
}
h.handleMessage(wc, msg)
}
}
func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) {
switch msg.Type {
case MsgRegister:
var reg WorkerRegistration
json.Unmarshal(msg.Payload, &reg)
h.reconcileWorker(reg, wc)
case MsgHeartbeat:
var hb HeartbeatPayload
json.Unmarshal(msg.Payload, &hb)
wc.mu.Lock()
wc.slots = hb.Slots
wc.mu.Unlock()
h.updateWorkerMetrics(wc.workerID, hb.Slots)
case MsgReadyForWork:
var ready ReadyPayload
json.Unmarshal(msg.Payload, &ready)
wc.mu.Lock()
wc.slots = ready.Slots
wc.mu.Unlock()
h.handleReady(wc, ready.Slots)
case MsgJobAccepted:
var taskID string
json.Unmarshal(msg.Payload, &taskID)
h.handleJobAccepted(wc.workerID, taskID)
case MsgJobResult:
var result JobResultPayload
json.Unmarshal(msg.Payload, &result)
h.handleJobResult(wc.workerID, result)
case MsgServiceHealth:
// Service health updates - logged but no action needed for MVP
var health ServiceHealthPayload
json.Unmarshal(msg.Payload, &health)
slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy)
}
}
func (h *SchedulerHub) reconcileWorker(reg WorkerRegistration, wc *WorkerConn) {
wc.capabilities = reg.Capabilities
for _, reported := range reg.ActiveTasks {
task := h.getTask(reported.TaskID)
switch {
case task == nil:
// Case 1: Scheduler has no record — kill it
wc.send <- Message{Type: MsgJobCancel,
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
case task.Status == "orphaned":
// Case 2: Scheduler thought lost — restore, worker is running it
h.restoreLease(reported.TaskID, wc.workerID)
task.Status = "running"
case task.Status == "queued" || task.Status == "assigned":
// Case 3: Re-queued while worker was disconnected — cancel on this worker
wc.send <- Message{Type: MsgJobCancel,
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
case task.Status == "running" && task.WorkerID != reg.ID:
// Case 4: Running on two workers — cancel on reconnecting worker
wc.send <- Message{Type: MsgJobCancel,
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
}
}
// Send registration acknowledgment
wc.send <- Message{Type: MsgAck}
}
func (h *SchedulerHub) handleReady(wc *WorkerConn, _ SlotStatus) {
// Check for multi-node jobs first
for _, task := range h.batchQueue.Items() {
if task.Spec.NodeCount > 1 && h.canAdmit(task, wc) {
if h.handleMultiNodeReady(task, wc) {
return // Multi-node job is being handled
}
}
}
// Fall through to regular single-node matching
if task := h.findMatch(wc); task != nil {
wc.send <- h.assignTask(task, wc)
return
}
h.starvation.CheckAndReserve(h)
h.mu.Lock()
h.readyWorkers[wc.workerID] = wc
h.mu.Unlock()
wc.send <- Message{Type: MsgNoWork}
}
func (h *SchedulerHub) findMatch(wc *WorkerConn) *Task {
if wc.slots.ServiceAvailable() > 0 {
if task := h.scanFit(h.serviceQueue, wc); task != nil {
return task
}
}
if wc.slots.BatchAvailable() > 0 {
if task := h.scanFit(h.batchQueue, wc); task != nil {
return task
}
}
return nil
}
func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
for _, task := range q.Items() {
if task.Spec.NodeCount > 1 {
continue // gang allocator handles multi-node
}
if h.canAdmit(task, wc) {
return task
}
}
return nil
}
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
for _, res := range h.reservations {
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
return false
}
}
}
return worker.capabilities.GPUCount >= candidate.Spec.GPUCount
}
func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
// Remove from queue first (prevent double-assignment)
h.batchQueue.Remove(task.ID)
h.serviceQueue.Remove(task.ID)
// Track pending acceptance with task reference
h.mu.Lock()
h.pendingAcceptance[task.ID] = &JobAssignment{
TaskID: task.ID,
WorkerID: wc.workerID,
AssignedAt: time.Now(),
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
Accepted: false,
Task: task, // Store reference since removed from queue
}
h.mu.Unlock()
// Persist assignment
h.state.Append(StateEvent{
Type: EventJobAssigned,
TaskID: task.ID,
WorkerID: wc.workerID,
})
return Message{
Type: MsgJobAssign,
Payload: mustMarshal(task.Spec),
}
}
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
h.mu.Lock()
defer h.mu.Unlock()
if assignment, ok := h.pendingAcceptance[taskID]; ok {
assignment.Accepted = true
// Track as running task
task := assignment.Task
if task != nil {
task.Status = "running"
task.WorkerID = workerID
h.runningTasks[taskID] = task
}
// NEW: Record quota usage for service jobs
if task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
}
}
func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) {
h.mu.Lock()
defer h.mu.Unlock()
// NEW: Release quota usage for service jobs before deleting pending acceptance
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
if h.quotaManager != nil {
pluginName := task.Spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
}
}
delete(h.pendingAcceptance, result.TaskID)
delete(h.runningTasks, result.TaskID)
eventType := EventJobCompleted
switch result.State {
case "failed":
eventType = EventJobFailed
h.metrics.JobsFailed++
case "cancelled":
eventType = EventJobCancelled
h.metrics.JobsCancelled++
default:
h.metrics.JobsCompleted++
}
h.state.Append(StateEvent{
Type: eventType,
TaskID: result.TaskID,
WorkerID: workerID,
})
}
// checkAcceptanceTimeouts re-queues jobs that weren't accepted
func (h *SchedulerHub) checkAcceptanceTimeouts() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
h.mu.Lock()
for taskID, a := range h.pendingAcceptance {
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
if a.Task != nil {
a.Task.Status = "queued"
h.batchQueue.Add(a.Task)
}
delete(h.pendingAcceptance, taskID)
if wc, ok := h.workers[a.WorkerID]; ok {
wc.slots = SlotStatus{}
}
}
}
h.mu.Unlock()
}
}
}
// checkGangTimeouts releases reserved slots for incomplete gangs
func (h *SchedulerHub) checkGangTimeouts() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
h.mu.Lock()
for jobID, pending := range h.multiNodePending {
if time.Since(pending.CommittedAt) > time.Duration(h.config.GangAllocTimeoutSecs)*time.Second {
for _, a := range pending.Assignments {
a.Worker.slots = SlotStatus{}
h.readyWorkers[a.Worker.workerID] = a.Worker
}
delete(h.multiNodePending, jobID)
}
}
h.mu.Unlock()
}
}
}
func (h *SchedulerHub) checkStarvation() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
h.starvation.CheckAndReserve(h)
}
}
}
func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
st.mu.Lock()
defer st.mu.Unlock()
// First check which tasks need reservation under h.mu.RLock
tasksToReserve := make([]*Task, 0)
h.mu.RLock()
for _, task := range h.batchQueue.Items() {
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
tasksToReserve = append(tasksToReserve, task)
}
}
h.mu.RUnlock()
// Now acquire Lock to add reservations
if len(tasksToReserve) > 0 {
h.mu.Lock()
for _, task := range tasksToReserve {
h.reservations[task.ID] = &Reservation{
TaskID: task.ID,
GPUCount: task.Spec.GPUCount,
CreatedAt: time.Now(),
}
}
h.mu.Unlock()
}
}
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
_, exists := h.reservations[taskID]
return exists
}
// Helper methods (stubs to be implemented)
// GetTask returns a task by ID (public API)
func (h *SchedulerHub) GetTask(taskID string) *Task {
return h.getTask(taskID)
}
// SubmitJob submits a new job to the scheduler (public API)
func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
if spec.ID == "" {
return fmt.Errorf("job ID is required")
}
// NEW: Check plugin quotas for service jobs
if spec.Type == JobTypeService && h.quotaManager != nil {
pluginName := spec.Metadata["plugin_name"]
if pluginName == "" {
pluginName = "default"
}
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
return fmt.Errorf("quota exceeded: %w", err)
}
}
task := &Task{
ID: spec.ID,
Spec: spec,
Status: "queued",
SubmittedAt: time.Now(),
}
// Persist to state store
h.state.Append(StateEvent{
Type: EventJobEnqueued,
TaskID: spec.ID,
Payload: mustMarshal(spec),
Timestamp: time.Now(),
})
// Add to appropriate queue
if spec.Type == JobTypeService {
h.serviceQueue.Add(task)
} else {
h.batchQueue.Add(task)
}
// Send prewarm hint if job has snapshot
h.sendPrewarmHint(task)
slog.Info("job submitted", "task_id", spec.ID, "type", spec.Type)
return nil
}
func (h *SchedulerHub) getTask(taskID string) *Task {
t := h.batchQueue.Get(taskID)
if t != nil {
return t
}
t = h.serviceQueue.Get(taskID)
if t != nil {
return t
}
return h.runningTasks[taskID]
}
func (h *SchedulerHub) restoreJob(ev StateEvent) {
// Parse job spec from event payload
var spec JobSpec
if err := json.Unmarshal(ev.Payload, &spec); err != nil {
slog.Error("failed to restore job", "task_id", ev.TaskID, "error", err)
return
}
task := &Task{
ID: ev.TaskID,
Spec: spec,
Status: "queued",
SubmittedAt: ev.Timestamp,
}
// Add to appropriate queue
if spec.Type == JobTypeService {
h.serviceQueue.Add(task)
} else {
h.batchQueue.Add(task)
}
slog.Info("restored job from state", "task_id", ev.TaskID, "type", spec.Type)
}
func (h *SchedulerHub) restoreAssignment(ev StateEvent) {
// Parse assignment from event payload
var payload struct {
WorkerID string `json:"worker_id"`
}
if err := json.Unmarshal(ev.Payload, &payload); err != nil {
slog.Error("failed to restore assignment", "task_id", ev.TaskID, "error", err)
return
}
// Restore pending acceptance state
h.pendingAcceptance[ev.TaskID] = &JobAssignment{
TaskID: ev.TaskID,
WorkerID: payload.WorkerID,
AssignedAt: ev.Timestamp,
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
Accepted: false,
}
slog.Info("restored assignment from state", "task_id", ev.TaskID, "worker_id", payload.WorkerID)
}
func (h *SchedulerHub) restoreLease(taskID, workerID string) {
h.mu.Lock()
defer h.mu.Unlock()
if assignment, ok := h.pendingAcceptance[taskID]; ok {
assignment.Accepted = true
slog.Info("restored lease", "task_id", taskID, "worker_id", workerID)
} else {
// Create a new lease record
h.pendingAcceptance[taskID] = &JobAssignment{
TaskID: taskID,
WorkerID: workerID,
AssignedAt: time.Now(),
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
Accepted: true,
}
slog.Info("created new lease on reconnect", "task_id", taskID, "worker_id", workerID)
}
}
func (h *SchedulerHub) reconcileOrphans() {
h.mu.Lock()
defer h.mu.Unlock()
// After grace period (30s), any job still assigned to a disconnected worker
// is considered orphaned and should be re-queued
for taskID, assignment := range h.pendingAcceptance {
if assignment.Accepted {
// Job was accepted but worker is gone (not in h.workers)
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
task := assignment.Task
if task != nil {
task.Status = "orphaned"
h.batchQueue.Add(task)
h.state.Append(StateEvent{
Type: EventJobRequeued,
TaskID: taskID,
WorkerID: assignment.WorkerID,
})
slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID)
}
delete(h.pendingAcceptance, taskID)
}
}
}
}
func (h *SchedulerHub) updateWorkerMetrics(workerID string, slots SlotStatus) {
h.metrics.mu.Lock()
defer h.metrics.mu.Unlock()
h.metrics.WorkerSlots[workerID] = slots
}
// sendPrewarmHint sends a prewarm hint to an idle worker when a job with snapshot is enqueued
// TODO: Call this when jobs are enqueued via scheduler API
func (h *SchedulerHub) sendPrewarmHint(task *Task) {
if task.Spec.SnapshotID == "" {
return
}
h.mu.RLock()
defer h.mu.RUnlock()
for _, wc := range h.readyWorkers {
if h.canAdmit(task, wc) {
wc.send <- Message{
Type: MsgPrewarmHint,
Payload: mustMarshal(PrewarmHintPayload{
TaskID: task.ID,
SnapshotID: task.Spec.SnapshotID,
SnapshotSHA: task.Spec.SnapshotSHA,
}),
}
return // one worker prewarms — not all of them
}
}
}
// runMetricsClient handles metrics clients over WSS
func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
defer conn.Close()
for {
var msg Message
if err := conn.ReadJSON(&msg); err != nil {
return // Connection closed
}
if msg.Type == MsgMetricsRequest {
metrics := h.GetMetricsPayload()
conn.WriteJSON(Message{
Type: MsgMetricsResponse,
Payload: mustMarshal(metrics),
})
}
}
}
// GetMetricsPayload returns current metrics as a map (public API)
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
h.metrics.mu.RLock()
defer h.metrics.mu.RUnlock()
return map[string]any{
"workers_connected": h.metrics.WorkersConnected,
"queue_depth_batch": h.batchQueue.Len(),
"queue_depth_service": h.serviceQueue.Len(),
"jobs_completed": h.metrics.JobsCompleted,
"jobs_failed": h.metrics.JobsFailed,
"jobs_cancelled": h.metrics.JobsCancelled,
"worker_slots": h.metrics.WorkerSlots,
}
}
// ServeMetrics serves Prometheus-formatted metrics (deprecated, use WSS)
func (h *SchedulerHub) ServeMetrics(w http.ResponseWriter, r *http.Request) {
h.metrics.mu.RLock()
defer h.metrics.mu.RUnlock()
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
fmt.Fprintf(w, "# HELP fetch_ml_workers_connected Number of connected workers\n")
fmt.Fprintf(w, "# TYPE fetch_ml_workers_connected gauge\n")
fmt.Fprintf(w, "fetch_ml_workers_connected %d\n\n", h.metrics.WorkersConnected)
fmt.Fprintf(w, "# HELP fetch_ml_queue_depth Current queue depth\n")
fmt.Fprintf(w, "# TYPE fetch_ml_queue_depth gauge\n")
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"batch\"} %d\n", h.batchQueue.Len())
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"service\"} %d\n\n", h.serviceQueue.Len())
fmt.Fprintf(w, "# HELP fetch_ml_jobs_total Total jobs by result\n")
fmt.Fprintf(w, "# TYPE fetch_ml_jobs_total counter\n")
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"completed\"} %d\n", h.metrics.JobsCompleted)
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"failed\"} %d\n", h.metrics.JobsFailed)
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"cancelled\"} %d\n\n", h.metrics.JobsCancelled)
fmt.Fprintf(w, "# HELP fetch_ml_slot_utilization Slot utilization by worker\n")
fmt.Fprintf(w, "# TYPE fetch_ml_slot_utilization gauge\n")
for workerID, slots := range h.metrics.WorkerSlots {
if slots.BatchTotal > 0 {
utilization := float64(slots.BatchInUse) / float64(slots.BatchTotal)
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"batch\"} %.2f\n", workerID, utilization)
}
if slots.ServiceTotal > 0 {
utilization := float64(slots.ServiceInUse) / float64(slots.ServiceTotal)
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"service\"} %.2f\n", workerID, utilization)
}
}
}
// tryGangAlloc attempts to allocate a multi-node job to a worker
// It tracks partial allocations and dispatches when all nodes are committed
func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
h.mu.Lock()
defer h.mu.Unlock()
// Check if this worker can run the job
if !h.canAdmit(task, wc) {
return
}
jobID := task.ID
pending, ok := h.multiNodePending[jobID]
if !ok {
// First worker for this job
pending = &MultiNodeJob{
JobID: jobID,
TotalNodes: task.Spec.NodeCount,
Assignments: make([]*NodeAssignment, 0, task.Spec.NodeCount),
CommittedAt: time.Now(),
}
h.multiNodePending[jobID] = pending
}
// Add this worker to the pending assignment
assignment := &NodeAssignment{
Worker: wc,
Rank: len(pending.Assignments),
CommittedAt: time.Now(),
}
pending.Assignments = append(pending.Assignments, assignment)
// Reserve slots on this worker
wc.slots.BatchInUse++
delete(h.readyWorkers, wc.workerID)
// Check if we have all nodes
if len(pending.Assignments) < task.Spec.NodeCount {
// Still waiting for more workers
return
}
// All nodes committed - dispatch simultaneously
headAddr := pending.Assignments[0].Worker.capabilities.Hostname
for i, a := range pending.Assignments {
// Create rank-specific job spec
spec := h.buildRankedSpec(task, i, headAddr, task.Spec.NodeCount)
msg := Message{
Type: MsgJobAssign,
Payload: mustMarshal(spec),
}
a.Worker.send <- msg
}
// Clean up pending state
delete(h.multiNodePending, jobID)
slog.Info("multi-node job dispatched",
"job_id", jobID,
"nodes", task.Spec.NodeCount,
"head_addr", headAddr)
}
// buildRankedSpec creates a job spec with rank-specific template variables resolved
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
// Clone the spec and add rank info to metadata and env
spec := task.Spec
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
for k, v := range task.Spec.Metadata {
spec.Metadata[k] = v
}
spec.Metadata["HEAD_ADDR"] = headAddr
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
// Also set in Env for job runtime
if spec.Env == nil {
spec.Env = make(map[string]string)
}
spec.Env["HEAD_ADDR"] = headAddr
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
return spec
}
// handleMultiNodeReady handles a ready signal for a multi-node job
// Returns true if the job was handled (either assigned or queued for gang alloc)
func (h *SchedulerHub) handleMultiNodeReady(task *Task, wc *WorkerConn) bool {
if task.Spec.NodeCount <= 1 {
return false // Not a multi-node job
}
h.tryGangAlloc(task, wc)
return true
}
func mustMarshal(v any) []byte {
b, _ := json.Marshal(v)
return b
}