- Update task domain model - Improve scheduler hub and priority queue - Enhance protocol definitions - Update manifest schema and run handling
1052 lines
28 KiB
Go
1052 lines
28 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // Allow all origins (configurable in production)
|
|
},
|
|
}
|
|
|
|
// SchedulerHub manages worker connections and job scheduling
|
|
type SchedulerHub struct {
|
|
mu sync.RWMutex
|
|
workers map[string]*WorkerConn
|
|
readyWorkers map[string]*WorkerConn
|
|
batchQueue *PriorityQueue
|
|
serviceQueue *PriorityQueue
|
|
reservations map[string]*Reservation
|
|
multiNodePending map[string]*MultiNodeJob
|
|
pendingAcceptance map[string]*JobAssignment
|
|
runningTasks map[string]*Task // Track assigned+accepted tasks
|
|
state *StateStore
|
|
starvation *StarvationTracker
|
|
metrics *SchedulerMetrics
|
|
auditor *audit.Logger
|
|
tokenValidator *TokenValidator
|
|
quotaManager *PluginQuotaManager // NEW: plugin GPU quota manager
|
|
config HubConfig
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
server *http.Server
|
|
listener net.Listener
|
|
}
|
|
|
|
type HubConfig struct {
|
|
BindAddr string
|
|
CertFile string
|
|
KeyFile string
|
|
AutoGenerateCerts bool
|
|
StateDir string
|
|
DefaultBatchSlots int
|
|
DefaultServiceSlots int
|
|
StarvationThresholdMins float64
|
|
PriorityAgingRate float64
|
|
GangAllocTimeoutSecs int
|
|
AcceptanceTimeoutSecs int
|
|
LocalMode bool
|
|
WorkerTokens map[string]string // token -> workerID
|
|
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
|
|
}
|
|
|
|
// WorkerConn represents a connected worker
|
|
type WorkerConn struct {
|
|
workerID string
|
|
conn *websocket.Conn
|
|
capabilities WorkerCapabilities
|
|
slots SlotStatus
|
|
lease *Lease
|
|
activeTasks map[string]struct{}
|
|
send chan Message
|
|
hub *SchedulerHub
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Lease tracks job ownership
|
|
type Lease struct {
|
|
TaskID string
|
|
WorkerID string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// Reservation prevents starvation of large jobs
|
|
type Reservation struct {
|
|
TaskID string
|
|
GPUCount int
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// MultiNodeJob tracks gang allocation state
|
|
type MultiNodeJob struct {
|
|
JobID string
|
|
TotalNodes int
|
|
Assignments []*NodeAssignment
|
|
CommittedAt time.Time
|
|
}
|
|
|
|
type NodeAssignment struct {
|
|
Worker *WorkerConn
|
|
Rank int
|
|
CommittedAt time.Time
|
|
}
|
|
|
|
// JobAssignment tracks acceptance state
|
|
type JobAssignment struct {
|
|
TaskID string
|
|
WorkerID string
|
|
AssignedAt time.Time
|
|
AcceptanceDeadline time.Time
|
|
Accepted bool
|
|
Task *Task // Reference to the task (removed from queue)
|
|
}
|
|
|
|
// StarvationTracker monitors long-waiting jobs
|
|
type StarvationTracker struct {
|
|
mu sync.RWMutex
|
|
threshold time.Duration
|
|
}
|
|
|
|
// SchedulerMetrics tracks scheduler statistics
|
|
type SchedulerMetrics struct {
|
|
mu sync.RWMutex
|
|
WorkersConnected int
|
|
QueueDepthBatch int
|
|
QueueDepthService int
|
|
JobsCompleted int
|
|
JobsFailed int
|
|
JobsCancelled int
|
|
WorkerSlots map[string]SlotStatus
|
|
}
|
|
|
|
// NewHub creates a new scheduler hub
|
|
func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Initialize state store
|
|
statePath := cfg.StateDir + "/scheduler.state"
|
|
state, err := NewStateStore(statePath)
|
|
if err != nil {
|
|
cancel()
|
|
return nil, fmt.Errorf("init state store: %w", err)
|
|
}
|
|
|
|
agingRate := cfg.PriorityAgingRate
|
|
if agingRate == 0 {
|
|
agingRate = 0.1
|
|
}
|
|
|
|
hub := &SchedulerHub{
|
|
workers: make(map[string]*WorkerConn),
|
|
readyWorkers: make(map[string]*WorkerConn),
|
|
batchQueue: NewPriorityQueue(agingRate),
|
|
serviceQueue: NewPriorityQueue(agingRate),
|
|
reservations: make(map[string]*Reservation),
|
|
multiNodePending: make(map[string]*MultiNodeJob),
|
|
pendingAcceptance: make(map[string]*JobAssignment),
|
|
runningTasks: make(map[string]*Task),
|
|
state: state,
|
|
starvation: &StarvationTracker{
|
|
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
|
},
|
|
metrics: &SchedulerMetrics{
|
|
WorkerSlots: make(map[string]SlotStatus),
|
|
},
|
|
auditor: auditor,
|
|
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
|
quotaManager: NewPluginQuotaManager(cfg.PluginQuota), // NEW: initialize quota manager
|
|
config: cfg,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
return hub, nil
|
|
}
|
|
|
|
// Start initializes the scheduler, starts the HTTP server, and replays state
|
|
func (h *SchedulerHub) Start() error {
|
|
// Replay state first
|
|
events, err := h.state.Replay()
|
|
if err != nil {
|
|
return fmt.Errorf("state replay failed: %w", err)
|
|
}
|
|
|
|
for _, ev := range events {
|
|
switch ev.Type {
|
|
case EventJobEnqueued:
|
|
h.restoreJob(ev)
|
|
case EventJobAssigned:
|
|
h.restoreAssignment(ev)
|
|
case EventJobCompleted, EventJobFailed, EventJobCancelled:
|
|
// terminal — skip
|
|
}
|
|
}
|
|
|
|
// Start WSS server (unified protocol)
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/ws/worker", h.HandleConnection)
|
|
|
|
listener, err := net.Listen("tcp", h.config.BindAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
h.listener = listener
|
|
|
|
h.server = &http.Server{Handler: mux}
|
|
|
|
// Auto-generate self-signed certs if requested
|
|
if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") {
|
|
certFile := h.config.StateDir + "/scheduler.crt"
|
|
keyFile := h.config.StateDir + "/scheduler.key"
|
|
if err := GenerateSelfSignedCert(certFile, keyFile); err != nil {
|
|
return fmt.Errorf("failed to generate self-signed cert: %w", err)
|
|
}
|
|
h.config.CertFile = certFile
|
|
h.config.KeyFile = keyFile
|
|
}
|
|
|
|
// Start with TLS if certificates are configured
|
|
if h.config.CertFile != "" && h.config.KeyFile != "" {
|
|
go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile)
|
|
} else {
|
|
go h.server.Serve(listener)
|
|
}
|
|
|
|
// Start background tasks
|
|
go h.checkAcceptanceTimeouts()
|
|
go h.checkGangTimeouts()
|
|
go h.checkStarvation()
|
|
|
|
// Grace period: workers have 30s to reconnect before assigned jobs are orphaned
|
|
time.AfterFunc(30*time.Second, h.reconcileOrphans)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Addr returns the listening address of the scheduler
|
|
func (h *SchedulerHub) Addr() string {
|
|
if h.listener == nil {
|
|
return ""
|
|
}
|
|
return h.listener.Addr().String()
|
|
}
|
|
|
|
// Stop gracefully shuts down the scheduler
|
|
func (h *SchedulerHub) Stop() {
|
|
h.cancel()
|
|
h.state.Close()
|
|
}
|
|
|
|
// HandleConnection handles WSS connections from workers and metrics clients
|
|
func (h *SchedulerHub) HandleConnection(w http.ResponseWriter, r *http.Request) {
|
|
// Validate token
|
|
token := ExtractBearerToken(r.Header.Get("Authorization"))
|
|
clientID, ok := h.tokenValidator.Validate(token)
|
|
if !ok {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Upgrade to WebSocket
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
http.Error(w, "upgrade failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Check if this is a metrics client (special token prefix)
|
|
if strings.HasPrefix(clientID, "metrics-") {
|
|
go h.runMetricsClient(clientID, conn)
|
|
return
|
|
}
|
|
|
|
go h.runWorker(clientID, conn)
|
|
}
|
|
|
|
func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) {
|
|
wc := &WorkerConn{
|
|
workerID: workerID,
|
|
conn: conn,
|
|
slots: SlotStatus{},
|
|
activeTasks: make(map[string]struct{}),
|
|
send: make(chan Message, 10),
|
|
hub: h,
|
|
}
|
|
|
|
h.mu.Lock()
|
|
h.workers[workerID] = wc
|
|
h.metrics.WorkersConnected++
|
|
h.mu.Unlock()
|
|
|
|
defer func() {
|
|
h.mu.Lock()
|
|
delete(h.workers, workerID)
|
|
delete(h.readyWorkers, workerID)
|
|
h.metrics.WorkersConnected--
|
|
h.mu.Unlock()
|
|
conn.Close()
|
|
}()
|
|
|
|
// Send loop
|
|
go func() {
|
|
for msg := range wc.send {
|
|
conn.WriteJSON(msg)
|
|
}
|
|
}()
|
|
|
|
// Receive loop
|
|
for {
|
|
var msg Message
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
return // Connection closed
|
|
}
|
|
h.handleMessage(wc, msg)
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) {
|
|
switch msg.Type {
|
|
case MsgRegister:
|
|
var reg WorkerRegistration
|
|
json.Unmarshal(msg.Payload, ®)
|
|
h.reconcileWorker(reg, wc)
|
|
case MsgHeartbeat:
|
|
var hb HeartbeatPayload
|
|
json.Unmarshal(msg.Payload, &hb)
|
|
wc.mu.Lock()
|
|
wc.slots = hb.Slots
|
|
wc.mu.Unlock()
|
|
h.updateWorkerMetrics(wc.workerID, hb.Slots)
|
|
case MsgReadyForWork:
|
|
var ready ReadyPayload
|
|
json.Unmarshal(msg.Payload, &ready)
|
|
wc.mu.Lock()
|
|
wc.slots = ready.Slots
|
|
wc.mu.Unlock()
|
|
h.handleReady(wc, ready.Slots)
|
|
case MsgJobAccepted:
|
|
var taskID string
|
|
json.Unmarshal(msg.Payload, &taskID)
|
|
h.handleJobAccepted(wc.workerID, taskID)
|
|
case MsgJobResult:
|
|
var result JobResultPayload
|
|
json.Unmarshal(msg.Payload, &result)
|
|
h.handleJobResult(wc.workerID, result)
|
|
case MsgServiceHealth:
|
|
// Service health updates - logged but no action needed for MVP
|
|
var health ServiceHealthPayload
|
|
json.Unmarshal(msg.Payload, &health)
|
|
slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy)
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) reconcileWorker(reg WorkerRegistration, wc *WorkerConn) {
|
|
wc.capabilities = reg.Capabilities
|
|
|
|
for _, reported := range reg.ActiveTasks {
|
|
task := h.getTask(reported.TaskID)
|
|
|
|
switch {
|
|
case task == nil:
|
|
// Case 1: Scheduler has no record — kill it
|
|
wc.send <- Message{Type: MsgJobCancel,
|
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
|
|
|
case task.Status == "orphaned":
|
|
// Case 2: Scheduler thought lost — restore, worker is running it
|
|
h.restoreLease(reported.TaskID, wc.workerID)
|
|
task.Status = "running"
|
|
|
|
case task.Status == "queued" || task.Status == "assigned":
|
|
// Case 3: Re-queued while worker was disconnected — cancel on this worker
|
|
wc.send <- Message{Type: MsgJobCancel,
|
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
|
|
|
case task.Status == "running" && task.WorkerID != reg.ID:
|
|
// Case 4: Running on two workers — cancel on reconnecting worker
|
|
wc.send <- Message{Type: MsgJobCancel,
|
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
|
}
|
|
}
|
|
|
|
// Send registration acknowledgment
|
|
wc.send <- Message{Type: MsgAck}
|
|
}
|
|
|
|
func (h *SchedulerHub) handleReady(wc *WorkerConn, _ SlotStatus) {
|
|
// Check for multi-node jobs first
|
|
for _, task := range h.batchQueue.Items() {
|
|
if task.Spec.NodeCount > 1 && h.canAdmit(task, wc) {
|
|
if h.handleMultiNodeReady(task, wc) {
|
|
return // Multi-node job is being handled
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall through to regular single-node matching
|
|
if task := h.findMatch(wc); task != nil {
|
|
wc.send <- h.assignTask(task, wc)
|
|
return
|
|
}
|
|
|
|
h.starvation.CheckAndReserve(h)
|
|
h.mu.Lock()
|
|
h.readyWorkers[wc.workerID] = wc
|
|
h.mu.Unlock()
|
|
wc.send <- Message{Type: MsgNoWork}
|
|
}
|
|
|
|
func (h *SchedulerHub) findMatch(wc *WorkerConn) *Task {
|
|
if wc.slots.ServiceAvailable() > 0 {
|
|
if task := h.scanFit(h.serviceQueue, wc); task != nil {
|
|
return task
|
|
}
|
|
}
|
|
if wc.slots.BatchAvailable() > 0 {
|
|
if task := h.scanFit(h.batchQueue, wc); task != nil {
|
|
return task
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
|
|
for _, task := range q.Items() {
|
|
if task.Spec.NodeCount > 1 {
|
|
continue // gang allocator handles multi-node
|
|
}
|
|
if h.canAdmit(task, wc) {
|
|
return task
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
|
for _, res := range h.reservations {
|
|
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
|
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return worker.capabilities.GPUCount >= candidate.Spec.GPUCount
|
|
}
|
|
|
|
// canRequeue checks if a task can be re-queued based on wall-clock elapsed time.
|
|
// Returns false if the task has exceeded its MaxRuntime budget.
|
|
func (h *SchedulerHub) canRequeue(task *Task) bool {
|
|
if task.FirstAssignedAt.IsZero() {
|
|
return true // Never assigned, can always re-queue
|
|
}
|
|
|
|
elapsed := time.Since(task.FirstAssignedAt)
|
|
maxRuntime := task.MaxRuntime
|
|
if maxRuntime == 0 {
|
|
maxRuntime = 24 * time.Hour // Default 24h
|
|
}
|
|
|
|
if elapsed > maxRuntime {
|
|
// Task exceeded wall-clock budget - fail it
|
|
slog.Info("task exceeded max runtime, failing",
|
|
"task_id", task.ID,
|
|
"elapsed", elapsed,
|
|
"max_runtime", maxRuntime)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
|
// Remove from queue first (prevent double-assignment)
|
|
h.batchQueue.Remove(task.ID)
|
|
h.serviceQueue.Remove(task.ID)
|
|
|
|
// Set FirstAssignedAt if this is the first assignment
|
|
if task.FirstAssignedAt.IsZero() {
|
|
task.FirstAssignedAt = time.Now()
|
|
}
|
|
|
|
// Cache MaxRuntime from spec
|
|
maxHours := task.Spec.MaxRuntimeHours
|
|
if maxHours <= 0 {
|
|
maxHours = 24 // Default 24h
|
|
}
|
|
if maxHours > 168 {
|
|
maxHours = 168 // Hard cap at 7d
|
|
}
|
|
task.MaxRuntime = time.Duration(maxHours) * time.Hour
|
|
|
|
// Calculate remaining time budget
|
|
elapsed := time.Since(task.FirstAssignedAt)
|
|
remaining := task.MaxRuntime - elapsed
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
// Track pending acceptance with task reference
|
|
h.mu.Lock()
|
|
h.pendingAcceptance[task.ID] = &JobAssignment{
|
|
TaskID: task.ID,
|
|
WorkerID: wc.workerID,
|
|
AssignedAt: time.Now(),
|
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
|
Accepted: false,
|
|
Task: task, // Store reference since removed from queue
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
// Persist assignment
|
|
h.state.Append(StateEvent{
|
|
Type: EventJobAssigned,
|
|
TaskID: task.ID,
|
|
WorkerID: wc.workerID,
|
|
})
|
|
|
|
// Send job assignment with remaining time budget
|
|
payload := JobAssignPayload{
|
|
Spec: task.Spec,
|
|
RemainingTime: remaining,
|
|
}
|
|
|
|
return Message{
|
|
Type: MsgJobAssign,
|
|
Payload: mustMarshal(payload),
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) handleJobAccepted(workerID, taskID string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
|
assignment.Accepted = true
|
|
|
|
// Track as running task
|
|
task := assignment.Task
|
|
if task != nil {
|
|
task.Status = "running"
|
|
task.WorkerID = workerID
|
|
h.runningTasks[taskID] = task
|
|
}
|
|
|
|
// NEW: Record quota usage for service jobs
|
|
if task != nil && task.Spec.Type == JobTypeService {
|
|
if h.quotaManager != nil {
|
|
pluginName := task.Spec.Metadata["plugin_name"]
|
|
if pluginName == "" {
|
|
pluginName = "default"
|
|
}
|
|
h.quotaManager.RecordUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
// NEW: Release quota usage for service jobs before deleting pending acceptance
|
|
if task := h.runningTasks[result.TaskID]; task != nil && task.Spec.Type == JobTypeService {
|
|
if h.quotaManager != nil {
|
|
pluginName := task.Spec.Metadata["plugin_name"]
|
|
if pluginName == "" {
|
|
pluginName = "default"
|
|
}
|
|
h.quotaManager.ReleaseUsage(task.Spec.UserID, pluginName, task.Spec.GPUCount)
|
|
}
|
|
}
|
|
|
|
delete(h.pendingAcceptance, result.TaskID)
|
|
delete(h.runningTasks, result.TaskID)
|
|
|
|
eventType := EventJobCompleted
|
|
switch result.State {
|
|
case "failed":
|
|
eventType = EventJobFailed
|
|
h.metrics.JobsFailed++
|
|
case "cancelled":
|
|
eventType = EventJobCancelled
|
|
h.metrics.JobsCancelled++
|
|
default:
|
|
h.metrics.JobsCompleted++
|
|
}
|
|
|
|
h.state.Append(StateEvent{
|
|
Type: eventType,
|
|
TaskID: result.TaskID,
|
|
WorkerID: workerID,
|
|
})
|
|
}
|
|
|
|
// checkAcceptanceTimeouts re-queues jobs that weren't accepted
|
|
func (h *SchedulerHub) checkAcceptanceTimeouts() {
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-h.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
h.mu.Lock()
|
|
for taskID, a := range h.pendingAcceptance {
|
|
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
|
if a.Task != nil {
|
|
a.Task.Status = "queued"
|
|
h.batchQueue.Add(a.Task)
|
|
}
|
|
delete(h.pendingAcceptance, taskID)
|
|
if wc, ok := h.workers[a.WorkerID]; ok {
|
|
wc.slots = SlotStatus{}
|
|
}
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkGangTimeouts releases reserved slots for incomplete gangs
|
|
func (h *SchedulerHub) checkGangTimeouts() {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-h.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
h.mu.Lock()
|
|
for jobID, pending := range h.multiNodePending {
|
|
if time.Since(pending.CommittedAt) > time.Duration(h.config.GangAllocTimeoutSecs)*time.Second {
|
|
for _, a := range pending.Assignments {
|
|
a.Worker.slots = SlotStatus{}
|
|
h.readyWorkers[a.Worker.workerID] = a.Worker
|
|
}
|
|
delete(h.multiNodePending, jobID)
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) checkStarvation() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-h.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
h.starvation.CheckAndReserve(h)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
|
|
st.mu.Lock()
|
|
defer st.mu.Unlock()
|
|
|
|
// First check which tasks need reservation under h.mu.RLock
|
|
tasksToReserve := make([]*Task, 0)
|
|
h.mu.RLock()
|
|
for _, task := range h.batchQueue.Items() {
|
|
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservationLocked(h, task.ID) {
|
|
tasksToReserve = append(tasksToReserve, task)
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
|
|
// Now acquire Lock to add reservations
|
|
if len(tasksToReserve) > 0 {
|
|
h.mu.Lock()
|
|
for _, task := range tasksToReserve {
|
|
h.reservations[task.ID] = &Reservation{
|
|
TaskID: task.ID,
|
|
GPUCount: task.Spec.GPUCount,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (st *StarvationTracker) hasReservationLocked(h *SchedulerHub, taskID string) bool {
|
|
_, exists := h.reservations[taskID]
|
|
return exists
|
|
}
|
|
|
|
// Helper methods (stubs to be implemented)
|
|
|
|
// GetTask returns a task by ID (public API)
|
|
func (h *SchedulerHub) GetTask(taskID string) *Task {
|
|
return h.getTask(taskID)
|
|
}
|
|
|
|
// SubmitJob submits a new job to the scheduler (public API)
|
|
func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
|
|
if spec.ID == "" {
|
|
return fmt.Errorf("job ID is required")
|
|
}
|
|
|
|
// NEW: Check plugin quotas for service jobs
|
|
if spec.Type == JobTypeService && h.quotaManager != nil {
|
|
pluginName := spec.Metadata["plugin_name"]
|
|
if pluginName == "" {
|
|
pluginName = "default"
|
|
}
|
|
if err := h.quotaManager.CheckQuota(spec.UserID, pluginName, spec.GPUCount); err != nil {
|
|
return fmt.Errorf("quota exceeded: %w", err)
|
|
}
|
|
}
|
|
|
|
task := &Task{
|
|
ID: spec.ID,
|
|
Spec: spec,
|
|
Status: "queued",
|
|
SubmittedAt: time.Now(),
|
|
}
|
|
|
|
// Persist to state store
|
|
h.state.Append(StateEvent{
|
|
Type: EventJobEnqueued,
|
|
TaskID: spec.ID,
|
|
Payload: mustMarshal(spec),
|
|
Timestamp: time.Now(),
|
|
})
|
|
|
|
// Add to appropriate queue
|
|
if spec.Type == JobTypeService {
|
|
h.serviceQueue.Add(task)
|
|
} else {
|
|
h.batchQueue.Add(task)
|
|
}
|
|
|
|
// Send prewarm hint if job has snapshot
|
|
h.sendPrewarmHint(task)
|
|
|
|
slog.Info("job submitted", "task_id", spec.ID, "type", spec.Type)
|
|
return nil
|
|
}
|
|
|
|
func (h *SchedulerHub) getTask(taskID string) *Task {
|
|
t := h.batchQueue.Get(taskID)
|
|
if t != nil {
|
|
return t
|
|
}
|
|
t = h.serviceQueue.Get(taskID)
|
|
if t != nil {
|
|
return t
|
|
}
|
|
return h.runningTasks[taskID]
|
|
}
|
|
|
|
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
|
// Parse job spec from event payload
|
|
var spec JobSpec
|
|
if err := json.Unmarshal(ev.Payload, &spec); err != nil {
|
|
slog.Error("failed to restore job", "task_id", ev.TaskID, "error", err)
|
|
return
|
|
}
|
|
|
|
task := &Task{
|
|
ID: ev.TaskID,
|
|
Spec: spec,
|
|
Status: "queued",
|
|
SubmittedAt: ev.Timestamp,
|
|
}
|
|
|
|
// Add to appropriate queue
|
|
if spec.Type == JobTypeService {
|
|
h.serviceQueue.Add(task)
|
|
} else {
|
|
h.batchQueue.Add(task)
|
|
}
|
|
slog.Info("restored job from state", "task_id", ev.TaskID, "type", spec.Type)
|
|
}
|
|
|
|
func (h *SchedulerHub) restoreAssignment(ev StateEvent) {
|
|
// Parse assignment from event payload
|
|
var payload struct {
|
|
WorkerID string `json:"worker_id"`
|
|
}
|
|
if err := json.Unmarshal(ev.Payload, &payload); err != nil {
|
|
slog.Error("failed to restore assignment", "task_id", ev.TaskID, "error", err)
|
|
return
|
|
}
|
|
|
|
// Restore pending acceptance state
|
|
h.pendingAcceptance[ev.TaskID] = &JobAssignment{
|
|
TaskID: ev.TaskID,
|
|
WorkerID: payload.WorkerID,
|
|
AssignedAt: ev.Timestamp,
|
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
|
Accepted: false,
|
|
}
|
|
|
|
slog.Info("restored assignment from state", "task_id", ev.TaskID, "worker_id", payload.WorkerID)
|
|
}
|
|
|
|
func (h *SchedulerHub) restoreLease(taskID, workerID string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
|
assignment.Accepted = true
|
|
slog.Info("restored lease", "task_id", taskID, "worker_id", workerID)
|
|
} else {
|
|
// Create a new lease record
|
|
h.pendingAcceptance[taskID] = &JobAssignment{
|
|
TaskID: taskID,
|
|
WorkerID: workerID,
|
|
AssignedAt: time.Now(),
|
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
|
Accepted: true,
|
|
}
|
|
slog.Info("created new lease on reconnect", "task_id", taskID, "worker_id", workerID)
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) reconcileOrphans() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
// After grace period (30s), any job still assigned to a disconnected worker
|
|
// is considered orphaned and should be re-queued
|
|
for taskID, assignment := range h.pendingAcceptance {
|
|
if assignment.Accepted {
|
|
// Job was accepted but worker is gone (not in h.workers)
|
|
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
|
task := assignment.Task
|
|
if task != nil {
|
|
task.Status = "orphaned"
|
|
h.batchQueue.Add(task)
|
|
h.state.Append(StateEvent{
|
|
Type: EventJobRequeued,
|
|
TaskID: taskID,
|
|
WorkerID: assignment.WorkerID,
|
|
})
|
|
slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID)
|
|
}
|
|
delete(h.pendingAcceptance, taskID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *SchedulerHub) updateWorkerMetrics(workerID string, slots SlotStatus) {
|
|
h.metrics.mu.Lock()
|
|
defer h.metrics.mu.Unlock()
|
|
h.metrics.WorkerSlots[workerID] = slots
|
|
}
|
|
|
|
// sendPrewarmHint sends a prewarm hint to an idle worker when a job with snapshot is enqueued
|
|
// TODO: Call this when jobs are enqueued via scheduler API
|
|
func (h *SchedulerHub) sendPrewarmHint(task *Task) {
|
|
if task.Spec.SnapshotID == "" {
|
|
return
|
|
}
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
for _, wc := range h.readyWorkers {
|
|
if h.canAdmit(task, wc) {
|
|
wc.send <- Message{
|
|
Type: MsgPrewarmHint,
|
|
Payload: mustMarshal(PrewarmHintPayload{
|
|
TaskID: task.ID,
|
|
SnapshotID: task.Spec.SnapshotID,
|
|
SnapshotSHA: task.Spec.SnapshotSHA,
|
|
}),
|
|
}
|
|
return // one worker prewarms — not all of them
|
|
}
|
|
}
|
|
}
|
|
|
|
// runMetricsClient handles metrics clients over WSS
|
|
func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
|
defer conn.Close()
|
|
|
|
for {
|
|
var msg Message
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
return // Connection closed
|
|
}
|
|
|
|
if msg.Type == MsgMetricsRequest {
|
|
metrics := h.GetMetricsPayload()
|
|
conn.WriteJSON(Message{
|
|
Type: MsgMetricsResponse,
|
|
Payload: mustMarshal(metrics),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetMetricsPayload returns current metrics as a map (public API)
|
|
func (h *SchedulerHub) GetMetricsPayload() map[string]any {
|
|
h.metrics.mu.RLock()
|
|
defer h.metrics.mu.RUnlock()
|
|
|
|
return map[string]any{
|
|
"workers_connected": h.metrics.WorkersConnected,
|
|
"queue_depth_batch": h.batchQueue.Len(),
|
|
"queue_depth_service": h.serviceQueue.Len(),
|
|
"jobs_completed": h.metrics.JobsCompleted,
|
|
"jobs_failed": h.metrics.JobsFailed,
|
|
"jobs_cancelled": h.metrics.JobsCancelled,
|
|
"worker_slots": h.metrics.WorkerSlots,
|
|
}
|
|
}
|
|
|
|
// ServeMetrics serves Prometheus-formatted metrics (deprecated, use WSS)
|
|
func (h *SchedulerHub) ServeMetrics(w http.ResponseWriter, r *http.Request) {
|
|
h.metrics.mu.RLock()
|
|
defer h.metrics.mu.RUnlock()
|
|
|
|
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
|
|
|
|
fmt.Fprintf(w, "# HELP fetch_ml_workers_connected Number of connected workers\n")
|
|
fmt.Fprintf(w, "# TYPE fetch_ml_workers_connected gauge\n")
|
|
fmt.Fprintf(w, "fetch_ml_workers_connected %d\n\n", h.metrics.WorkersConnected)
|
|
|
|
fmt.Fprintf(w, "# HELP fetch_ml_queue_depth Current queue depth\n")
|
|
fmt.Fprintf(w, "# TYPE fetch_ml_queue_depth gauge\n")
|
|
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"batch\"} %d\n", h.batchQueue.Len())
|
|
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"service\"} %d\n\n", h.serviceQueue.Len())
|
|
|
|
fmt.Fprintf(w, "# HELP fetch_ml_jobs_total Total jobs by result\n")
|
|
fmt.Fprintf(w, "# TYPE fetch_ml_jobs_total counter\n")
|
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"completed\"} %d\n", h.metrics.JobsCompleted)
|
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"failed\"} %d\n", h.metrics.JobsFailed)
|
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"cancelled\"} %d\n\n", h.metrics.JobsCancelled)
|
|
|
|
fmt.Fprintf(w, "# HELP fetch_ml_slot_utilization Slot utilization by worker\n")
|
|
fmt.Fprintf(w, "# TYPE fetch_ml_slot_utilization gauge\n")
|
|
for workerID, slots := range h.metrics.WorkerSlots {
|
|
if slots.BatchTotal > 0 {
|
|
utilization := float64(slots.BatchInUse) / float64(slots.BatchTotal)
|
|
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"batch\"} %.2f\n", workerID, utilization)
|
|
}
|
|
if slots.ServiceTotal > 0 {
|
|
utilization := float64(slots.ServiceInUse) / float64(slots.ServiceTotal)
|
|
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"service\"} %.2f\n", workerID, utilization)
|
|
}
|
|
}
|
|
}
|
|
|
|
// tryGangAlloc attempts to allocate a multi-node job to a worker
|
|
// It tracks partial allocations and dispatches when all nodes are committed
|
|
func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
// Check if this worker can run the job
|
|
if !h.canAdmit(task, wc) {
|
|
return
|
|
}
|
|
|
|
jobID := task.ID
|
|
pending, ok := h.multiNodePending[jobID]
|
|
if !ok {
|
|
// First worker for this job
|
|
pending = &MultiNodeJob{
|
|
JobID: jobID,
|
|
TotalNodes: task.Spec.NodeCount,
|
|
Assignments: make([]*NodeAssignment, 0, task.Spec.NodeCount),
|
|
CommittedAt: time.Now(),
|
|
}
|
|
h.multiNodePending[jobID] = pending
|
|
}
|
|
|
|
// Add this worker to the pending assignment
|
|
assignment := &NodeAssignment{
|
|
Worker: wc,
|
|
Rank: len(pending.Assignments),
|
|
CommittedAt: time.Now(),
|
|
}
|
|
pending.Assignments = append(pending.Assignments, assignment)
|
|
|
|
// Reserve slots on this worker
|
|
wc.slots.BatchInUse++
|
|
delete(h.readyWorkers, wc.workerID)
|
|
|
|
// Check if we have all nodes
|
|
if len(pending.Assignments) < task.Spec.NodeCount {
|
|
// Still waiting for more workers
|
|
return
|
|
}
|
|
|
|
// All nodes committed - dispatch simultaneously
|
|
headAddr := pending.Assignments[0].Worker.capabilities.Hostname
|
|
for i, a := range pending.Assignments {
|
|
// Create rank-specific job spec
|
|
spec := h.buildRankedSpec(task, i, headAddr, task.Spec.NodeCount)
|
|
msg := Message{
|
|
Type: MsgJobAssign,
|
|
Payload: mustMarshal(spec),
|
|
}
|
|
a.Worker.send <- msg
|
|
}
|
|
|
|
// Clean up pending state
|
|
delete(h.multiNodePending, jobID)
|
|
slog.Info("multi-node job dispatched",
|
|
"job_id", jobID,
|
|
"nodes", task.Spec.NodeCount,
|
|
"head_addr", headAddr)
|
|
}
|
|
|
|
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
|
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
|
// Clone the spec and add rank info to metadata and env
|
|
spec := task.Spec
|
|
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
|
for k, v := range task.Spec.Metadata {
|
|
spec.Metadata[k] = v
|
|
}
|
|
spec.Metadata["HEAD_ADDR"] = headAddr
|
|
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
|
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
|
|
|
// Also set in Env for job runtime
|
|
if spec.Env == nil {
|
|
spec.Env = make(map[string]string)
|
|
}
|
|
spec.Env["HEAD_ADDR"] = headAddr
|
|
spec.Env["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
|
spec.Env["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
|
return spec
|
|
}
|
|
|
|
// handleMultiNodeReady handles a ready signal for a multi-node job
|
|
// Returns true if the job was handled (either assigned or queued for gang alloc)
|
|
func (h *SchedulerHub) handleMultiNodeReady(task *Task, wc *WorkerConn) bool {
|
|
if task.Spec.NodeCount <= 1 {
|
|
return false // Not a multi-node job
|
|
}
|
|
|
|
h.tryGangAlloc(task, wc)
|
|
return true
|
|
}
|
|
|
|
func mustMarshal(v any) []byte {
|
|
b, _ := json.Marshal(v)
|
|
return b
|
|
}
|