fetch_ml/internal/api/scheduler/handlers.go
Jeremie Fraeys 57787e1e7b
feat(scheduler): implement capability-based routing and hub v2
Add comprehensive capability routing system to scheduler hub:
- Capability-aware worker matching with requirement/offer negotiation
- Hub v2 protocol with structured message types and heartbeat management
- Worker capability advertisement and dynamic routing decisions
- Orphan recovery for disconnected workers with state reconciliation
- Template-based job scheduling with capability constraints

Add extensive test coverage:
- Unit tests for capability routing logic and heartbeat mechanics
- Unit tests for orphan recovery scenarios
- E2E tests for capability routing across multiple workers
- Hub capabilities integration tests
- Scheduler fixture helpers for test setup

Protocol improvements:
- Define structured protocol messages for hub-worker communication
- Add capability matching algorithm with scoring
- Implement graceful worker disconnection handling
2026-03-12 12:00:05 -04:00

479 lines
15 KiB
Go

// Package scheduler provides HTTP handlers for scheduler management
package scheduler
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/jfraeys/fetch_ml/internal/api/errors"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/logging"
sch "github.com/jfraeys/fetch_ml/internal/scheduler"
)
// APIHandler provides scheduler-related HTTP API handlers
type APIHandler struct {
logger *logging.Logger
hub *sch.SchedulerHub
streaming map[string][]chan StreamingEvent
}
// NewHandler creates a new scheduler API handler
func NewHandler(hub *sch.SchedulerHub, logger *logging.Logger, authConfig *auth.Config) *APIHandler {
return &APIHandler{
logger: logger,
hub: hub,
streaming: make(map[string][]chan StreamingEvent),
}
}
// StreamingEvent represents an event for SSE streaming
type StreamingEvent struct {
Type string `json:"type"`
Payload json.RawMessage `json:"payload"`
}
// WorkerInfo represents worker information for API responses
type WorkerInfo struct {
ID string `json:"id"`
ConnectedAt time.Time `json:"connected_at"`
LastHeartbeat time.Time `json:"last_heartbeat,omitempty"`
Capabilities sch.WorkerCapabilities `json:"capabilities"`
Slots sch.SlotStatus `json:"slots"`
ActiveTasks []string `json:"active_tasks"`
Status string `json:"status"` // active, draining, offline
}
// SchedulerStatus represents scheduler status for API responses
type SchedulerStatus struct {
WorkersTotal int `json:"workers_total"`
WorkersActive int `json:"workers_active"`
WorkersDraining int `json:"workers_draining"`
BatchQueueDepth int `json:"batch_queue_depth"`
ServiceQueueDepth int `json:"service_queue_depth"`
TasksRunning int `json:"tasks_running"`
TasksCompleted24h int `json:"tasks_completed_24h"`
ReservationsActive int `json:"reservations_active"`
Timestamp time.Time `json:"timestamp"`
}
// ReservationInfo represents reservation information for API responses
type ReservationInfo struct {
ID string `json:"id"`
UserID string `json:"user_id"`
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`
NodeCount int `json:"node_count"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"` // active, claimed, expired
}
// CreateReservationRequest represents a request to create a reservation
type CreateReservationRequest struct {
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`
NodeCount int `json:"node_count,omitempty"`
ExpiresMinutes int `json:"expires_minutes,omitempty"`
}
// NewAPIHandler creates a new scheduler API handler
func NewAPIHandler(logger *logging.Logger, hub *sch.SchedulerHub) *APIHandler {
handler := &APIHandler{
logger: logger,
hub: hub,
streaming: make(map[string][]chan StreamingEvent),
}
return handler
}
// GetV1SchedulerStatus handles GET /v1/scheduler/status
func (h *APIHandler) GetV1SchedulerStatus(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
// Get metrics from hub
metrics := h.hub.GetMetricsPayload()
status := SchedulerStatus{
Timestamp: time.Now().UTC(),
}
// Extract values from metrics payload
if v, ok := metrics["workers_connected"].(int); ok {
status.WorkersTotal = v
status.WorkersActive = v // Simplified - all connected are active
}
if v, ok := metrics["queue_depth_batch"].(int); ok {
status.BatchQueueDepth = v
}
if v, ok := metrics["queue_depth_service"].(int); ok {
status.ServiceQueueDepth = v
}
if v, ok := metrics["jobs_completed"].(int); ok {
status.TasksCompleted24h = v
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(status); err != nil {
h.logger.Warn("failed to encode status", "error", err)
}
}
// GetV1SchedulerWorkers handles GET /v1/scheduler/workers
func (h *APIHandler) GetV1SchedulerWorkers(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
// Get worker slots from metrics
metrics := h.hub.GetMetricsPayload()
slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus)
workers := []WorkerInfo{}
// Build worker list from slots data
for workerID, slotStatus := range slots {
worker := WorkerInfo{
ID: workerID,
Slots: slotStatus,
Capabilities: sch.WorkerCapabilities{},
Status: "active",
ActiveTasks: []string{},
}
workers = append(workers, worker)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(workers); err != nil {
h.logger.Warn("failed to encode workers", "error", err)
}
}
// GetV1SchedulerWorkersWorkerID handles GET /v1/scheduler/workers/{workerId}
func (h *APIHandler) GetV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
workerID := r.PathValue("workerId")
if workerID == "" {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Missing worker ID", "")
return
}
// Get worker slots from metrics
metrics := h.hub.GetMetricsPayload()
slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus)
slotStatus, ok := slots[workerID]
if !ok {
errors.WriteHTTPError(w, http.StatusNotFound, errors.CodeNotFound, "Worker not found", "")
return
}
worker := WorkerInfo{
ID: workerID,
Slots: slotStatus,
Capabilities: sch.WorkerCapabilities{},
Status: "active",
ActiveTasks: []string{},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(worker); err != nil {
h.logger.Warn("failed to encode worker", "error", err)
}
}
// DeleteV1SchedulerWorkersWorkerID handles DELETE /v1/scheduler/workers/{workerId}
func (h *APIHandler) DeleteV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:drain") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
workerID := r.PathValue("workerId")
if workerID == "" {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Missing worker ID", "")
return
}
h.logger.Info("draining worker", "worker_id", workerID, "user", user.Name)
// Note: Actual drain implementation would involve signaling the worker
// and waiting for tasks to complete. This is a simplified version.
w.WriteHeader(http.StatusNoContent)
}
// GetV1SchedulerReservations handles GET /v1/scheduler/reservations
func (h *APIHandler) GetV1SchedulerReservations(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
// Return empty list for now - reservations would be tracked in hub
reservations := []ReservationInfo{}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(reservations); err != nil {
h.logger.Warn("failed to encode reservations", "error", err)
}
}
// PostV1SchedulerReservations handles POST /v1/scheduler/reservations
func (h *APIHandler) PostV1SchedulerReservations(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:write") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
var req CreateReservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Invalid request body", "")
return
}
if req.GPUCount <= 0 {
errors.WriteHTTPError(w, http.StatusUnprocessableEntity, errors.CodeInvalidRequest, "GPU count must be positive", "")
return
}
if req.NodeCount <= 0 {
req.NodeCount = 1
}
if req.ExpiresMinutes <= 0 {
req.ExpiresMinutes = 30
}
reservation := ReservationInfo{
ID: fmt.Sprintf("res-%d", time.Now().UnixNano()),
UserID: user.Name,
GPUCount: req.GPUCount,
GPUType: req.GPUType,
NodeCount: req.NodeCount,
ExpiresAt: time.Now().Add(time.Duration(req.ExpiresMinutes) * time.Minute),
Status: "active",
}
h.logger.Info("created reservation", "reservation_id", reservation.ID, "user", user.Name)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(reservation); err != nil {
h.logger.Warn("failed to encode reservation", "error", err)
}
}
// PatchV1SchedulerJobsJobIDPriority handles PATCH /v1/scheduler/jobs/{jobId}/priority
func (h *APIHandler) PatchV1SchedulerJobsJobIDPriority(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "tasks:priority") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
if h.hub == nil {
errors.WriteHTTPError(w, http.StatusServiceUnavailable, errors.CodeServiceUnavailable, "Scheduler not available", "")
return
}
jobID := r.PathValue("jobId")
if jobID == "" {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Missing job ID", "")
return
}
var req struct {
Priority int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Invalid request body", "")
return
}
if req.Priority < 1 || req.Priority > 10 {
errors.WriteHTTPError(w, http.StatusUnprocessableEntity, errors.CodeInvalidRequest, "Priority must be between 1 and 10", "")
return
}
h.logger.Info("updating job priority", "job_id", jobID, "priority", req.Priority, "user", user.Name)
// Note: Actual priority update would modify the task in the queue
// This is a simplified version
response := map[string]any{
"id": jobID,
"priority": req.Priority,
"status": "queued",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Warn("failed to encode priority response", "error", err)
}
}
// GetV1SchedulerStatusStream handles GET /v1/scheduler/status/stream (SSE)
func (h *APIHandler) GetV1SchedulerStatusStream(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Create event channel for this client
eventChan := make(chan StreamingEvent, 10)
clientID := fmt.Sprintf("%s-%d", user.Name, time.Now().UnixNano())
h.streaming[clientID] = append(h.streaming[clientID], eventChan)
// Clean up on disconnect
defer func() {
delete(h.streaming, clientID)
close(eventChan)
}()
// Send initial status
status := map[string]any{
"type": "connected",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(status)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
// Keep connection alive and send periodic updates
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
heartbeat := map[string]any{
"type": "heartbeat",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(heartbeat)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
case event := <-eventChan:
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
}
}
}
// GetV1SchedulerJobsJobIDStream handles GET /v1/scheduler/jobs/{jobId}/stream (SSE)
func (h *APIHandler) GetV1SchedulerJobsJobIDStream(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
errors.WriteHTTPError(w, http.StatusForbidden, errors.CodePermissionDenied, "Insufficient permissions", "")
return
}
jobID := r.PathValue("jobId")
if jobID == "" {
errors.WriteHTTPError(w, http.StatusBadRequest, errors.CodeInvalidRequest, "Missing job ID", "")
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Send initial status
response := map[string]any{
"type": "connected",
"job_id": jobID,
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(response)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
// Keep connection alive
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
heartbeat := map[string]any{
"type": "heartbeat",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(heartbeat)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
}
}
}
// checkPermission checks if the user has the required permission
func (h *APIHandler) checkPermission(user *auth.User, permission string) bool {
if user == nil {
return false
}
// Admin has all permissions
if user.Admin {
return true
}
// Check specific permission
return user.HasPermission(permission)
}