// Package scheduler provides HTTP handlers for scheduler management package scheduler import ( "encoding/json" "fmt" "net/http" "strconv" "time" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/logging" sch "github.com/jfraeys/fetch_ml/internal/scheduler" ) // APIHandler provides scheduler-related HTTP API handlers type APIHandler struct { logger *logging.Logger hub *sch.SchedulerHub streaming map[string][]chan StreamingEvent } // NewHandler creates a new scheduler API handler func NewHandler(hub *sch.SchedulerHub, logger *logging.Logger, authConfig *auth.Config) *APIHandler { return &APIHandler{ logger: logger, hub: hub, streaming: make(map[string][]chan StreamingEvent), } } // StreamingEvent represents an event for SSE streaming type StreamingEvent struct { Type string `json:"type"` Payload json.RawMessage `json:"payload"` } // WorkerInfo represents worker information for API responses type WorkerInfo struct { ID string `json:"id"` ConnectedAt time.Time `json:"connected_at"` LastHeartbeat time.Time `json:"last_heartbeat,omitempty"` Capabilities sch.WorkerCapabilities `json:"capabilities"` Slots sch.SlotStatus `json:"slots"` ActiveTasks []string `json:"active_tasks"` Status string `json:"status"` // active, draining, offline } // SchedulerStatus represents scheduler status for API responses type SchedulerStatus struct { WorkersTotal int `json:"workers_total"` WorkersActive int `json:"workers_active"` WorkersDraining int `json:"workers_draining"` BatchQueueDepth int `json:"batch_queue_depth"` ServiceQueueDepth int `json:"service_queue_depth"` TasksRunning int `json:"tasks_running"` TasksCompleted24h int `json:"tasks_completed_24h"` ReservationsActive int `json:"reservations_active"` Timestamp time.Time `json:"timestamp"` } // ReservationInfo represents reservation information for API responses type ReservationInfo struct { ID string `json:"id"` UserID string `json:"user_id"` GPUCount int `json:"gpu_count"` GPUType string `json:"gpu_type,omitempty"` NodeCount int `json:"node_count"` ExpiresAt time.Time `json:"expires_at"` Status string `json:"status"` // active, claimed, expired } // CreateReservationRequest represents a request to create a reservation type CreateReservationRequest struct { GPUCount int `json:"gpu_count"` GPUType string `json:"gpu_type,omitempty"` NodeCount int `json:"node_count,omitempty"` ExpiresMinutes int `json:"expires_minutes,omitempty"` } // NewAPIHandler creates a new scheduler API handler func NewAPIHandler(logger *logging.Logger, hub *sch.SchedulerHub) *APIHandler { handler := &APIHandler{ logger: logger, hub: hub, streaming: make(map[string][]chan StreamingEvent), } return handler } // GetV1SchedulerStatus handles GET /v1/scheduler/status func (h *APIHandler) GetV1SchedulerStatus(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } // Get metrics from hub metrics := h.hub.GetMetricsPayload() status := SchedulerStatus{ Timestamp: time.Now().UTC(), } // Extract values from metrics payload if v, ok := metrics["workers_connected"].(int); ok { status.WorkersTotal = v status.WorkersActive = v // Simplified - all connected are active } if v, ok := metrics["queue_depth_batch"].(int); ok { status.BatchQueueDepth = v } if v, ok := metrics["queue_depth_service"].(int); ok { status.ServiceQueueDepth = v } if v, ok := metrics["jobs_completed"].(int); ok { status.TasksCompleted24h = v } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } // GetV1SchedulerWorkers handles GET /v1/scheduler/workers func (h *APIHandler) GetV1SchedulerWorkers(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } // Get worker slots from metrics metrics := h.hub.GetMetricsPayload() slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus) workers := []WorkerInfo{} // Build worker list from slots data for workerID, slotStatus := range slots { worker := WorkerInfo{ ID: workerID, Slots: slotStatus, Capabilities: sch.WorkerCapabilities{}, Status: "active", ActiveTasks: []string{}, } workers = append(workers, worker) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(workers) } // GetV1SchedulerWorkersWorkerID handles GET /v1/scheduler/workers/{workerId} func (h *APIHandler) GetV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } workerID := r.PathValue("workerId") if workerID == "" { http.Error(w, `{"error":"Missing worker ID","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } // Get worker slots from metrics metrics := h.hub.GetMetricsPayload() slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus) slotStatus, ok := slots[workerID] if !ok { http.Error(w, `{"error":"Worker not found","code":"NOT_FOUND"}`, http.StatusNotFound) return } worker := WorkerInfo{ ID: workerID, Slots: slotStatus, Capabilities: sch.WorkerCapabilities{}, Status: "active", ActiveTasks: []string{}, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(worker) } // DeleteV1SchedulerWorkersWorkerID handles DELETE /v1/scheduler/workers/{workerId} func (h *APIHandler) DeleteV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:drain") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } workerID := r.PathValue("workerId") if workerID == "" { http.Error(w, `{"error":"Missing worker ID","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } h.logger.Info("draining worker", "worker_id", workerID, "user", user.Name) // Note: Actual drain implementation would involve signaling the worker // and waiting for tasks to complete. This is a simplified version. w.WriteHeader(http.StatusNoContent) } // GetV1SchedulerReservations handles GET /v1/scheduler/reservations func (h *APIHandler) GetV1SchedulerReservations(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } // Return empty list for now - reservations would be tracked in hub reservations := []ReservationInfo{} w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(reservations) } // PostV1SchedulerReservations handles POST /v1/scheduler/reservations func (h *APIHandler) PostV1SchedulerReservations(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:write") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } var req CreateReservationRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":"Invalid request body","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } if req.GPUCount <= 0 { http.Error(w, `{"error":"GPU count must be positive","code":"VALIDATION_ERROR"}`, http.StatusUnprocessableEntity) return } if req.NodeCount <= 0 { req.NodeCount = 1 } if req.ExpiresMinutes <= 0 { req.ExpiresMinutes = 30 } reservation := ReservationInfo{ ID: fmt.Sprintf("res-%d", time.Now().UnixNano()), UserID: user.Name, GPUCount: req.GPUCount, GPUType: req.GPUType, NodeCount: req.NodeCount, ExpiresAt: time.Now().Add(time.Duration(req.ExpiresMinutes) * time.Minute), Status: "active", } h.logger.Info("created reservation", "reservation_id", reservation.ID, "user", user.Name) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(reservation) } // PatchV1SchedulerJobsJobIDPriority handles PATCH /v1/scheduler/jobs/{jobId}/priority func (h *APIHandler) PatchV1SchedulerJobsJobIDPriority(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "tasks:priority") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } if h.hub == nil { http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable) return } jobID := r.PathValue("jobId") if jobID == "" { http.Error(w, `{"error":"Missing job ID","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } var req struct { Priority int `json:"priority"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":"Invalid request body","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } if req.Priority < 1 || req.Priority > 10 { http.Error(w, `{"error":"Priority must be between 1 and 10","code":"VALIDATION_ERROR"}`, http.StatusUnprocessableEntity) return } h.logger.Info("updating job priority", "job_id", jobID, "priority", req.Priority, "user", user.Name) // Note: Actual priority update would modify the task in the queue // This is a simplified version response := map[string]interface{}{ "id": jobID, "priority": req.Priority, "status": "queued", } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // GetV1SchedulerStatusStream handles GET /v1/scheduler/status/stream (SSE) func (h *APIHandler) GetV1SchedulerStatusStream(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // Create event channel for this client eventChan := make(chan StreamingEvent, 10) clientID := fmt.Sprintf("%s-%d", user.Name, time.Now().UnixNano()) h.streaming[clientID] = append(h.streaming[clientID], eventChan) // Clean up on disconnect defer func() { delete(h.streaming, clientID) close(eventChan) }() // Send initial status status := map[string]interface{}{ "type": "connected", "timestamp": time.Now().UTC(), } data, _ := json.Marshal(status) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() // Keep connection alive and send periodic updates ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-r.Context().Done(): return case <-ticker.C: heartbeat := map[string]interface{}{ "type": "heartbeat", "timestamp": time.Now().UTC(), } data, _ := json.Marshal(heartbeat) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() case event := <-eventChan: data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() } } } // GetV1SchedulerJobsJobIDStream handles GET /v1/scheduler/jobs/{jobId}/stream (SSE) func (h *APIHandler) GetV1SchedulerJobsJobIDStream(w http.ResponseWriter, r *http.Request) { user := auth.GetUserFromContext(r.Context()) if !h.checkPermission(user, "scheduler:read") { http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden) return } jobID := r.PathValue("jobId") if jobID == "" { http.Error(w, `{"error":"Missing job ID","code":"BAD_REQUEST"}`, http.StatusBadRequest) return } // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // Send initial status response := map[string]interface{}{ "type": "connected", "job_id": jobID, "timestamp": time.Now().UTC(), } data, _ := json.Marshal(response) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() // Keep connection alive ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-r.Context().Done(): return case <-ticker.C: heartbeat := map[string]interface{}{ "type": "heartbeat", "timestamp": time.Now().UTC(), } data, _ := json.Marshal(heartbeat) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() } } } // checkPermission checks if the user has the required permission func (h *APIHandler) checkPermission(user *auth.User, permission string) bool { if user == nil { return false } // Admin has all permissions if user.Admin { return true } // Check specific permission return user.HasPermission(permission) } // parseIntQueryParam parses an integer query parameter func parseIntQueryParam(r *http.Request, name string, defaultVal int) int { str := r.URL.Query().Get(name) if str == "" { return defaultVal } val, err := strconv.Atoi(str) if err != nil { return defaultVal } return val }