package scheduler import ( "context" "encoding/json" "fmt" "log/slog" "net" "net/http" "strings" "sync" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // Allow all origins (configurable in production) }, } // SchedulerHub manages worker connections and job scheduling type SchedulerHub struct { mu sync.RWMutex workers map[string]*WorkerConn readyWorkers map[string]*WorkerConn batchQueue *PriorityQueue serviceQueue *PriorityQueue reservations map[string]*Reservation multiNodePending map[string]*MultiNodeJob pendingAcceptance map[string]*JobAssignment runningTasks map[string]*Task // Track assigned+accepted tasks state *StateStore starvation *StarvationTracker metrics *SchedulerMetrics auditor *audit.Logger tokenValidator *TokenValidator quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager config HubConfig ctx context.Context cancel context.CancelFunc server *http.Server listener net.Listener } type HubConfig struct { BindAddr string CertFile string KeyFile string AutoGenerateCerts bool StateDir string DefaultBatchSlots int DefaultServiceSlots int StarvationThresholdMins float64 PriorityAgingRate float64 GangAllocTimeoutSecs int AcceptanceTimeoutSecs int LocalMode bool WorkerTokens map[string]string // token -> workerID PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration } // WorkerConn represents a connected worker type WorkerConn struct { workerID string conn *websocket.Conn capabilities WorkerCapabilities slots SlotStatus lease *Lease activeTasks map[string]struct{} send chan Message hub *SchedulerHub mu sync.Mutex } // Lease tracks job ownership type Lease struct { TaskID string WorkerID string ExpiresAt time.Time } // Reservation prevents starvation of large jobs type Reservation struct { TaskID string GPUCount int CreatedAt time.Time } // MultiNodeJob tracks gang allocation state type MultiNodeJob struct { JobID string TotalNodes int Assignments []*NodeAssignment CommittedAt time.Time } type NodeAssignment struct { Worker *WorkerConn Rank int CommittedAt time.Time } // JobAssignment tracks acceptance state type JobAssignment struct { TaskID string WorkerID string AssignedAt time.Time AcceptanceDeadline time.Time Accepted bool Task *Task // Reference to the task (removed from queue) } // StarvationTracker monitors long-waiting jobs type StarvationTracker struct { mu sync.RWMutex threshold time.Duration } // SchedulerMetrics tracks scheduler statistics type SchedulerMetrics struct { mu sync.RWMutex WorkersConnected int QueueDepthBatch int QueueDepthService int JobsCompleted int JobsFailed int JobsCancelled int WorkerSlots map[string]SlotStatus } // NewHub creates a new scheduler hub func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) { ctx, cancel := context.WithCancel(context.Background()) // Initialize state store statePath := cfg.StateDir + "/scheduler.state" state, err := NewStateStore(statePath) if err != nil { cancel() return nil, fmt.Errorf("init state store: %w", err) } agingRate := cfg.PriorityAgingRate if agingRate == 0 { agingRate = 0.1 } hub := &SchedulerHub{ workers: make(map[string]*WorkerConn), readyWorkers: make(map[string]*WorkerConn), batchQueue: NewPriorityQueue(agingRate), serviceQueue: NewPriorityQueue(agingRate), reservations: make(map[string]*Reservation), multiNodePending: make(map[string]*MultiNodeJob), pendingAcceptance: make(map[string]*JobAssignment), runningTasks: make(map[string]*Task), state: state, starvation: &StarvationTracker{ threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute, }, metrics: &SchedulerMetrics{ WorkerSlots: make(map[string]SlotStatus), }, auditor: auditor, tokenValidator: NewTokenValidator(cfg.WorkerTokens), quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager config: cfg, ctx: ctx, cancel: cancel, } return hub, nil } // Start initializes the scheduler, starts the HTTP server, and replays state func (h *SchedulerHub) Start() error { // Replay state first events, err := h.state.Replay() if err != nil { return fmt.Errorf("state replay failed: %w", err) } for _, ev := range events { switch ev.Type { case EventJobEnqueued: h.restoreJob(ev) case EventJobAssigned: h.restoreAssignment(ev) case EventJobCompleted, EventJobFailed, EventJobCancelled: // terminal — skip } } // Start WSS server (unified protocol) mux := http.NewServeMux() mux.HandleFunc("/ws/worker", h.HandleConnection) listener, err := net.Listen("tcp", h.config.BindAddr) if err != nil { return fmt.Errorf("failed to listen: %w", err) } h.listener = listener h.server = &http.Server{Handler: mux} // Auto-generate self-signed certs if requested if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") { certFile := h.config.StateDir + "/scheduler.crt" keyFile := h.config.StateDir + "/scheduler.key" if err := GenerateSelfSignedCert(certFile, keyFile); err != nil { return fmt.Errorf("failed to generate self-signed cert: %w", err) } h.config.CertFile = certFile h.config.KeyFile = keyFile } // Start with TLS if certificates are configured if h.config.CertFile != "" && h.config.KeyFile != "" { go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile) } else { go h.server.Serve(listener) } // Start background tasks go h.checkAcceptanceTimeouts() go h.checkGangTimeouts() go h.checkStarvation() // Grace period: workers have 30s to reconnect before assigned jobs are orphaned time.AfterFunc(30*time.Second, h.reconcileOrphans) return nil } // Addr returns the listening address of the scheduler func (h *SchedulerHub) Addr() string { if h.listener == nil { return "" } return h.listener.Addr().String() } // Stop gracefully shuts down the scheduler func (h *SchedulerHub) Stop() { h.cancel() h.state.Close() } // HandleConnection handles WSS connections from workers and metrics clients func (h *SchedulerHub) HandleConnection(w http.ResponseWriter, r *http.Request) { // Validate token token := ExtractBearerToken(r.Header.Get("Authorization")) clientID, ok := h.tokenValidator.Validate(token) if !ok { http.Error(w, "unauthorized", http.StatusUnauthorized) return } // Upgrade to WebSocket conn, err := upgrader.Upgrade(w, r, nil) if err != nil { http.Error(w, "upgrade failed", http.StatusInternalServerError) return } // Check if this is a metrics client (special token prefix) if strings.HasPrefix(clientID, "metrics-") { go h.runMetricsClient(clientID, conn) return } go h.runWorker(clientID, conn) } func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) { wc := &WorkerConn{ workerID: workerID, conn: conn, slots: SlotStatus{}, activeTasks: make(map[string]struct{}), send: make(chan Message, 10), hub: h, } h.mu.Lock() h.workers[workerID] = wc h.metrics.WorkersConnected++ h.mu.Unlock() defer func() { h.mu.Lock() delete(h.workers, workerID) delete(h.readyWorkers, workerID) h.metrics.WorkersConnected-- h.mu.Unlock() conn.Close() }() // Send loop go func() { for msg := range wc.send { conn.WriteJSON(msg) } }() // Receive loop for { var msg Message if err := conn.ReadJSON(&msg); err != nil { return // Connection closed } h.handleMessage(wc, msg) } } func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) { switch msg.Type { case MsgRegister: var reg WorkerRegistration json.Unmarshal(msg.Payload, ®) h.reconcileWorker(reg, wc) case MsgHeartbeat: var hb HeartbeatPayload json.Unmarshal(msg.Payload, &hb) wc.mu.Lock() wc.slots = hb.Slots wc.mu.Unlock() h.updateWorkerMetrics(wc.workerID, hb.Slots) case MsgReadyForWork: var ready ReadyPayload json.Unmarshal(msg.Payload, &ready) wc.mu.Lock() wc.slots = ready.Slots wc.mu.Unlock() h.handleReady(wc, ready.Slots) case MsgJobAccepted: var taskID string json.Unmarshal(msg.Payload, &taskID) h.handleJobAccepted(wc.workerID, taskID) case MsgJobResult: var result JobResultPayload json.Unmarshal(msg.Payload, &result) h.handleJobResult(wc.workerID, result) case MsgServiceHealth: // Service health updates - logged but no action needed for MVP var health ServiceHealthPayload json.Unmarshal(msg.Payload, &health) slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy) } } func (h *SchedulerHub) reconcileWorker(reg WorkerRegistration, wc *WorkerConn) { wc.capabilities = reg.Capabilities for _, reported := range reg.ActiveTasks { task := h.getTask(reported.TaskID) switch { case task == nil: // Case 1: Scheduler has no record — kill it wc.send <- Message{Type: MsgJobCancel, Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} case task.Status == "orphaned": // Case 2: Scheduler thought lost — restore, worker is running it h.restoreLease(reported.TaskID, wc.workerID) task.Status = "running" case task.Status == "queued" || task.Status == "assigned": // Case 3: Re-queued while worker was disconnected — cancel on this worker wc.send <- Message{Type: MsgJobCancel, Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} case task.Status == "running" && task.WorkerID != reg.ID: // Case 4: Running on two workers — cancel on reconnecting worker wc.send <- Message{Type: MsgJobCancel, Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} } } // Send registration acknowledgment wc.send <- Message{Type: MsgAck} } func (h *SchedulerHub) handleReady(wc *WorkerConn, _ SlotStatus) { // Check for multi-node jobs first for _, task := range h.batchQueue.Items() { if task.Spec.NodeCount > 1 && h.canAdmit(task, wc) { if h.handleMultiNodeReady(task, wc) { return // Multi-node job is being handled } } } // Fall through to regular single-node matching if task := h.findMatch(wc); task != nil { wc.send <- h.assignTask(task, wc) return } h.starvation.CheckAndReserve(h) h.mu.Lock() h.readyWorkers[wc.workerID] = wc h.mu.Unlock() wc.send <- Message{Type: MsgNoWork} } func (h *SchedulerHub) findMatch(wc *WorkerConn) *Task { if wc.slots.ServiceAvailable() > 0 { if task := h.scanFit(h.serviceQueue, wc); task != nil { return task } } if wc.slots.BatchAvailable() > 0 { if task := h.scanFit(h.batchQueue, wc); task != nil { return task } } return nil } func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task { for _, task := range q.Items() { if task.Spec.NodeCount > 1 { continue // gang allocator handles multi-node } if h.canAdmit(task, wc) { return task } } return nil } func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool { for _, res := range h.reservations { if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 { if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount { return false } } } return worker.capabilities.GPUCount >= candidate.Spec.GPUCount } // canRequeue checks if a task can be re-queued based on wall-clock elapsed time. // Returns false if the task has exceeded its MaxRuntime budget. func (h *SchedulerHub) canRequeue(task *Task) bool { if task.FirstAssignedAt.IsZero() { return true // Never assigned, can always re-queue } elapsed := time.Since(task.FirstAssignedAt) maxRuntime := task.MaxRuntime if maxRuntime == 0 { maxRuntime = 24 * time.Hour // Default 24h } if elapsed > maxRuntime { // Task exceeded wall-clock budget - fail it slog.Info("task exceeded max runtime, failing", "task_id", task.ID, "elapsed", elapsed, "max_runtime", maxRuntime) return false } return true } func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { // Remove from queue first (prevent double-assignment) h.batchQueue.Remove(task.ID) h.serviceQueue.Remove(task.ID) // Set FirstAssignedAt if this is the first assignment if task.FirstAssignedAt.IsZero() { task.FirstAssignedAt = time.Now() } // Cache MaxRuntime from spec maxHours := task.Spec.MaxRuntimeHours if maxHours <= 0 { maxHours = 24 // Default 24h } if maxHours > 168 { maxHours = 168 // Hard cap at 7d } task.MaxRuntime = time.Duration(maxHours) * time.Hour // Calculate remaining time budget elapsed := time.Since(task.FirstAssignedAt) remaining := task.MaxRuntime - elapsed if remaining < 0 { remaining = 0 } // Track pending acceptance with task reference h.mu.Lock() h.pendingAcceptance[task.ID] = &JobAssignment{ TaskID: task.ID, WorkerID: wc.workerID, AssignedAt: time.Now(), AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), Accepted: false, Task: task, // Store reference since removed from queue } h.mu.Unlock() // Persist assignment h.state.Append(StateEvent{ Type: EventJobAssigned, TaskID: task.ID, WorkerID: wc.workerID, }) // Send job assignment with remaining time budget payload := JobAssignPayload{ Spec: task.Spec, RemainingTime: remaining, } return Message{ Type: MsgJobAssign, Payload: mustMarshal(payload), } } func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) { h.mu.Lock() defer h.mu.Unlock() if assignment, ok := h.pendingAcceptance[taskID]; ok { assignment.Accepted = true // Track as running task task := assignment.Task if task != nil { task.Status = "running" task.WorkerID = workerID h.runningTasks[taskID] = task } // NEW: Record quota usage for service jobs if task != nil && task.Spec.Type == JobTypeService { if h.quotaManager != nil { pluginName := task.Spec.Metadata["plugin_name"] if pluginName == "" { pluginName = "default" } h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount) } } } } func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) { h.mu.Lock() defer h.mu.Unlock() // NEW: Release quota usage for service jobs before deleting pending acceptance if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService { if h.quotaManager != nil { pluginName := task.Spec.Metadata["plugin_name"] if pluginName == "" { pluginName = "default" } h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount) } } delete(h.pendingAcceptance, result.TaskID) delete(h.runningTasks, result.TaskID) eventType := EventJobCompleted switch result.State { case "failed": eventType = EventJobFailed h.metrics.JobsFailed++ case "cancelled": eventType = EventJobCancelled h.metrics.JobsCancelled++ default: h.metrics.JobsCompleted++ } h.state.Append(StateEvent{ Type: eventType, TaskID: result.TaskID, WorkerID: workerID, }) } // checkAcceptanceTimeouts re-queues jobs that weren't accepted func (h *SchedulerHub) checkAcceptanceTimeouts() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-h.ctx.Done(): return case <-ticker.C: h.mu.Lock() for taskID, a := range h.pendingAcceptance { if !a.Accepted && time.Now().After(a.AcceptanceDeadline) { if a.Task != nil { a.Task.Status = "queued" h.batchQueue.Add(a.Task) } delete(h.pendingAcceptance, taskID) if wc, ok := h.workers[a.WorkerID]; ok { wc.slots = SlotStatus{} } } } h.mu.Unlock() } } } // checkGangTimeouts releases reserved slots for incomplete gangs func (h *SchedulerHub) checkGangTimeouts() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case <-h.ctx.Done(): return case <-ticker.C: h.mu.Lock() for jobID, pending := range h.multiNodePending { if time.Since(pending.CommittedAt) > time.Duration(h.config.GangAllocTimeoutSecs)*time.Second { for _, a := range pending.Assignments { a.Worker.slots = SlotStatus{} h.readyWorkers[a.Worker.workerID] = a.Worker } delete(h.multiNodePending, jobID) } } h.mu.Unlock() } } } func (h *SchedulerHub) checkStarvation() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-h.ctx.Done(): return case <-ticker.C: h.starvation.CheckAndReserve(h) } } } func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) { st.mu.Lock() defer st.mu.Unlock() // First check which tasks need reservation under h.mu.RLock tasksToReserve := make([]*Task, 0) h.mu.RLock() for _, task := range h.batchQueue.Items() { if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) { tasksToReserve = append(tasksToReserve, task) } } h.mu.RUnlock() // Now acquire Lock to add reservations if len(tasksToReserve) > 0 { h.mu.Lock() for _, task := range tasksToReserve { h.reservations[task.ID] = &Reservation{ TaskID: task.ID, GPUCount: task.Spec.GPUCount, CreatedAt: time.Now(), } } h.mu.Unlock() } } func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool { _, exists := h.reservations[taskID] return exists } // Helper methods (stubs to be implemented) // GetTask returns a task by ID (public API) func (h *SchedulerHub) GetTask(taskID string) *Task { return h.getTask(taskID) } // SubmitJob submits a new job to the scheduler (public API) func (h *SchedulerHub) SubmitJob(spec JobSpec) error { if spec.ID == "" { return fmt.Errorf("job ID is required") } // NEW: Check plugin quotas for service jobs if spec.Type == JobTypeService && h.quotaManager != nil { pluginName := spec.Metadata["plugin_name"] if pluginName == "" { pluginName = "default" } if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil { return fmt.Errorf("quota exceeded: %w", err) } } task := &Task{ ID: spec.ID, Spec: spec, Status: "queued", SubmittedAt: time.Now(), } // Persist to state store h.state.Append(StateEvent{ Type: EventJobEnqueued, TaskID: spec.ID, Payload: mustMarshal(spec), Timestamp: time.Now(), }) // Add to appropriate queue if spec.Type == JobTypeService { h.serviceQueue.Add(task) } else { h.batchQueue.Add(task) } // Send prewarm hint if job has snapshot h.sendPrewarmHint(task) slog.Info("job submitted", "task_id", spec.ID, "type", spec.Type) return nil } func (h *SchedulerHub) getTask(taskID string) *Task { t := h.batchQueue.Get(taskID) if t != nil { return t } t = h.serviceQueue.Get(taskID) if t != nil { return t } return h.runningTasks[taskID] } func (h *SchedulerHub) restoreJob(ev StateEvent) { // Parse job spec from event payload var spec JobSpec if err := json.Unmarshal(ev.Payload, &spec); err != nil { slog.Error("failed to restore job", "task_id", ev.TaskID, "error", err) return } task := &Task{ ID: ev.TaskID, Spec: spec, Status: "queued", SubmittedAt: ev.Timestamp, } // Add to appropriate queue if spec.Type == JobTypeService { h.serviceQueue.Add(task) } else { h.batchQueue.Add(task) } slog.Info("restored job from state", "task_id", ev.TaskID, "type", spec.Type) } func (h *SchedulerHub) restoreAssignment(ev StateEvent) { // Parse assignment from event payload var payload struct { WorkerID string `json:"worker_id"` } if err := json.Unmarshal(ev.Payload, &payload); err != nil { slog.Error("failed to restore assignment", "task_id", ev.TaskID, "error", err) return } // Restore pending acceptance state h.pendingAcceptance[ev.TaskID] = &JobAssignment{ TaskID: ev.TaskID, WorkerID: payload.WorkerID, AssignedAt: ev.Timestamp, AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), Accepted: false, } slog.Info("restored assignment from state", "task_id", ev.TaskID, "worker_id", payload.WorkerID) } func (h *SchedulerHub) restoreLease(taskID, workerID string) { h.mu.Lock() defer h.mu.Unlock() if assignment, ok := h.pendingAcceptance[taskID]; ok { assignment.Accepted = true slog.Info("restored lease", "task_id", taskID, "worker_id", workerID) } else { // Create a new lease record h.pendingAcceptance[taskID] = &JobAssignment{ TaskID: taskID, WorkerID: workerID, AssignedAt: time.Now(), AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), Accepted: true, } slog.Info("created new lease on reconnect", "task_id", taskID, "worker_id", workerID) } } func (h *SchedulerHub) reconcileOrphans() { h.mu.Lock() defer h.mu.Unlock() // After grace period (30s), any job still assigned to a disconnected worker // is considered orphaned and should be re-queued for taskID, assignment := range h.pendingAcceptance { if assignment.Accepted { // Job was accepted but worker is gone (not in h.workers) if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected { task := assignment.Task if task != nil { task.Status = "orphaned" h.batchQueue.Add(task) h.state.Append(StateEvent{ Type: EventJobRequeued, TaskID: taskID, WorkerID: assignment.WorkerID, }) slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID) } delete(h.pendingAcceptance, taskID) } } } } func (h *SchedulerHub) updateWorkerMetrics(workerID string, slots SlotStatus) { h.metrics.mu.Lock() defer h.metrics.mu.Unlock() h.metrics.WorkerSlots[workerID] = slots } // sendPrewarmHint sends a prewarm hint to an idle worker when a job with snapshot is enqueued // TODO: Call this when jobs are enqueued via scheduler API func (h *SchedulerHub) sendPrewarmHint(task *Task) { if task.Spec.SnapshotID == "" { return } h.mu.RLock() defer h.mu.RUnlock() for _, wc := range h.readyWorkers { if h.canAdmit(task, wc) { wc.send <- Message{ Type: MsgPrewarmHint, Payload: mustMarshal(PrewarmHintPayload{ TaskID: task.ID, SnapshotID: task.Spec.SnapshotID, SnapshotSHA: task.Spec.SnapshotSHA, }), } return // one worker prewarms — not all of them } } } // runMetricsClient handles metrics clients over WSS func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) { defer conn.Close() for { var msg Message if err := conn.ReadJSON(&msg); err != nil { return // Connection closed } if msg.Type == MsgMetricsRequest { metrics := h.GetMetricsPayload() conn.WriteJSON(Message{ Type: MsgMetricsResponse, Payload: mustMarshal(metrics), }) } } } // GetMetricsPayload returns current metrics as a map (public API) func (h *SchedulerHub) GetMetricsPayload() map[string]any { h.metrics.mu.RLock() defer h.metrics.mu.RUnlock() return map[string]any{ "workers_connected": h.metrics.WorkersConnected, "queue_depth_batch": h.batchQueue.Len(), "queue_depth_service": h.serviceQueue.Len(), "jobs_completed": h.metrics.JobsCompleted, "jobs_failed": h.metrics.JobsFailed, "jobs_cancelled": h.metrics.JobsCancelled, "worker_slots": h.metrics.WorkerSlots, } } // ServeMetrics serves Prometheus-formatted metrics (deprecated, use WSS) func (h *SchedulerHub) ServeMetrics(w http.ResponseWriter, r *http.Request) { h.metrics.mu.RLock() defer h.metrics.mu.RUnlock() w.Header().Set("Content-Type", "text/plain; version=0.0.4") fmt.Fprintf(w, "# HELP fetch_ml_workers_connected Number of connected workers\n") fmt.Fprintf(w, "# TYPE fetch_ml_workers_connected gauge\n") fmt.Fprintf(w, "fetch_ml_workers_connected %d\n\n", h.metrics.WorkersConnected) fmt.Fprintf(w, "# HELP fetch_ml_queue_depth Current queue depth\n") fmt.Fprintf(w, "# TYPE fetch_ml_queue_depth gauge\n") fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"batch\"} %d\n", h.batchQueue.Len()) fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"service\"} %d\n\n", h.serviceQueue.Len()) fmt.Fprintf(w, "# HELP fetch_ml_jobs_total Total jobs by result\n") fmt.Fprintf(w, "# TYPE fetch_ml_jobs_total counter\n") fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"completed\"} %d\n", h.metrics.JobsCompleted) fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"failed\"} %d\n", h.metrics.JobsFailed) fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"cancelled\"} %d\n\n", h.metrics.JobsCancelled) fmt.Fprintf(w, "# HELP fetch_ml_slot_utilization Slot utilization by worker\n") fmt.Fprintf(w, "# TYPE fetch_ml_slot_utilization gauge\n") for workerID, slots := range h.metrics.WorkerSlots { if slots.BatchTotal > 0 { utilization := float64(slots.BatchInUse) / float64(slots.BatchTotal) fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"batch\"} %.2f\n", workerID, utilization) } if slots.ServiceTotal > 0 { utilization := float64(slots.ServiceInUse) / float64(slots.ServiceTotal) fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"service\"} %.2f\n", workerID, utilization) } } } // tryGangAlloc attempts to allocate a multi-node job to a worker // It tracks partial allocations and dispatches when all nodes are committed func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) { h.mu.Lock() defer h.mu.Unlock() // Check if this worker can run the job if !h.canAdmit(task, wc) { return } jobID := task.ID pending, ok := h.multiNodePending[jobID] if !ok { // First worker for this job pending = &MultiNodeJob{ JobID: jobID, TotalNodes: task.Spec.NodeCount, Assignments: make([]*NodeAssignment, 0, task.Spec.NodeCount), CommittedAt: time.Now(), } h.multiNodePending[jobID] = pending } // Add this worker to the pending assignment assignment := &NodeAssignment{ Worker: wc, Rank: len(pending.Assignments), CommittedAt: time.Now(), } pending.Assignments = append(pending.Assignments, assignment) // Reserve slots on this worker wc.slots.BatchInUse++ delete(h.readyWorkers, wc.workerID) // Check if we have all nodes if len(pending.Assignments) < task.Spec.NodeCount { // Still waiting for more workers return } // All nodes committed - dispatch simultaneously headAddr := pending.Assignments[0].Worker.capabilities.Hostname for i, a := range pending.Assignments { // Create rank-specific job spec spec := h.buildRankedSpec(task, i, headAddr, task.Spec.NodeCount) msg := Message{ Type: MsgJobAssign, Payload: mustMarshal(spec), } a.Worker.send <- msg } // Clean up pending state delete(h.multiNodePending, jobID) slog.Info("multi-node job dispatched", "job_id", jobID, "nodes", task.Spec.NodeCount, "head_addr", headAddr) } // buildRankedSpec creates a job spec with rank-specific template variables resolved func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec { // Clone the spec and add rank info to metadata and env spec := task.Spec spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3) for k, v := range task.Spec.Metadata { spec.Metadata[k] = v } spec.Metadata["HEAD_ADDR"] = headAddr spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize) spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank) // Also set in Env for job runtime if spec.Env == nil { spec.Env = make(map[string]string) } spec.Env["HEAD_ADDR"] = headAddr spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize) spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank) return spec } // handleMultiNodeReady handles a ready signal for a multi-node job // Returns true if the job was handled (either assigned or queued for gang alloc) func (h *SchedulerHub) handleMultiNodeReady(task *Task, wc *WorkerConn) bool { if task.Spec.NodeCount <= 1 { return false // Not a multi-node job } h.tryGangAlloc(task, wc) return true } func mustMarshal(v any) []byte { b, _ := json.Marshal(v) return b }