fetch_ml/internal/api/scheduler/handlers.go
Jeremie Fraeys c52179dcbe
feat(auth): add token-based access and structured logging
Add comprehensive authentication and authorization enhancements:

- tokens.go: New token management system for public task access and cloning
  * SHA-256 hashed token storage for security
  * Token generation, validation, and automatic cleanup
  * Support for public access and clone permissions

- api_key.go: Extend User struct with Groups field
  * Lab group membership (ml-lab, nlp-group)
  * Integration with permission system for group-based access

- flags.go: Security hardening - migrate to structured logging
  * Replace log.Printf with log/slog to prevent log injection attacks
  * Consistent structured output for all auth warnings
  * Safe handling of file paths and errors in logs

- permissions.go: Add task sharing permission constants
  * PermissionTasksReadOwn: Access own tasks
  * PermissionTasksReadLab: Access lab group tasks
  * PermissionTasksReadAll: Admin/institution-wide access
  * PermissionTasksShare: Grant access to other users
  * PermissionTasksClone: Create copies of shared tasks
  * CanAccessTask() method with visibility checks

- database.go: Improve error handling
  * Add structured error logging on row close failures
2026-03-08 12:51:07 -04:00

478 lines
15 KiB
Go

// Package scheduler provides HTTP handlers for scheduler management
package scheduler
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/logging"
sch "github.com/jfraeys/fetch_ml/internal/scheduler"
)
// APIHandler provides scheduler-related HTTP API handlers
type APIHandler struct {
logger *logging.Logger
hub *sch.SchedulerHub
streaming map[string][]chan StreamingEvent
}
// NewHandler creates a new scheduler API handler
func NewHandler(hub *sch.SchedulerHub, logger *logging.Logger, authConfig *auth.Config) *APIHandler {
return &APIHandler{
logger: logger,
hub: hub,
streaming: make(map[string][]chan StreamingEvent),
}
}
// StreamingEvent represents an event for SSE streaming
type StreamingEvent struct {
Type string `json:"type"`
Payload json.RawMessage `json:"payload"`
}
// WorkerInfo represents worker information for API responses
type WorkerInfo struct {
ID string `json:"id"`
ConnectedAt time.Time `json:"connected_at"`
LastHeartbeat time.Time `json:"last_heartbeat,omitempty"`
Capabilities sch.WorkerCapabilities `json:"capabilities"`
Slots sch.SlotStatus `json:"slots"`
ActiveTasks []string `json:"active_tasks"`
Status string `json:"status"` // active, draining, offline
}
// SchedulerStatus represents scheduler status for API responses
type SchedulerStatus struct {
WorkersTotal int `json:"workers_total"`
WorkersActive int `json:"workers_active"`
WorkersDraining int `json:"workers_draining"`
BatchQueueDepth int `json:"batch_queue_depth"`
ServiceQueueDepth int `json:"service_queue_depth"`
TasksRunning int `json:"tasks_running"`
TasksCompleted24h int `json:"tasks_completed_24h"`
ReservationsActive int `json:"reservations_active"`
Timestamp time.Time `json:"timestamp"`
}
// ReservationInfo represents reservation information for API responses
type ReservationInfo struct {
ID string `json:"id"`
UserID string `json:"user_id"`
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`
NodeCount int `json:"node_count"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"` // active, claimed, expired
}
// CreateReservationRequest represents a request to create a reservation
type CreateReservationRequest struct {
GPUCount int `json:"gpu_count"`
GPUType string `json:"gpu_type,omitempty"`
NodeCount int `json:"node_count,omitempty"`
ExpiresMinutes int `json:"expires_minutes,omitempty"`
}
// NewAPIHandler creates a new scheduler API handler
func NewAPIHandler(logger *logging.Logger, hub *sch.SchedulerHub) *APIHandler {
handler := &APIHandler{
logger: logger,
hub: hub,
streaming: make(map[string][]chan StreamingEvent),
}
return handler
}
// GetV1SchedulerStatus handles GET /v1/scheduler/status
func (h *APIHandler) GetV1SchedulerStatus(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
// Get metrics from hub
metrics := h.hub.GetMetricsPayload()
status := SchedulerStatus{
Timestamp: time.Now().UTC(),
}
// Extract values from metrics payload
if v, ok := metrics["workers_connected"].(int); ok {
status.WorkersTotal = v
status.WorkersActive = v // Simplified - all connected are active
}
if v, ok := metrics["queue_depth_batch"].(int); ok {
status.BatchQueueDepth = v
}
if v, ok := metrics["queue_depth_service"].(int); ok {
status.ServiceQueueDepth = v
}
if v, ok := metrics["jobs_completed"].(int); ok {
status.TasksCompleted24h = v
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(status); err != nil {
h.logger.Warn("failed to encode status", "error", err)
}
}
// GetV1SchedulerWorkers handles GET /v1/scheduler/workers
func (h *APIHandler) GetV1SchedulerWorkers(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
// Get worker slots from metrics
metrics := h.hub.GetMetricsPayload()
slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus)
workers := []WorkerInfo{}
// Build worker list from slots data
for workerID, slotStatus := range slots {
worker := WorkerInfo{
ID: workerID,
Slots: slotStatus,
Capabilities: sch.WorkerCapabilities{},
Status: "active",
ActiveTasks: []string{},
}
workers = append(workers, worker)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(workers); err != nil {
h.logger.Warn("failed to encode workers", "error", err)
}
}
// GetV1SchedulerWorkersWorkerID handles GET /v1/scheduler/workers/{workerId}
func (h *APIHandler) GetV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
workerID := r.PathValue("workerId")
if workerID == "" {
http.Error(w, `{"error":"Missing worker ID","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
// Get worker slots from metrics
metrics := h.hub.GetMetricsPayload()
slots, _ := metrics["worker_slots"].(map[string]sch.SlotStatus)
slotStatus, ok := slots[workerID]
if !ok {
http.Error(w, `{"error":"Worker not found","code":"NOT_FOUND"}`, http.StatusNotFound)
return
}
worker := WorkerInfo{
ID: workerID,
Slots: slotStatus,
Capabilities: sch.WorkerCapabilities{},
Status: "active",
ActiveTasks: []string{},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(worker); err != nil {
h.logger.Warn("failed to encode worker", "error", err)
}
}
// DeleteV1SchedulerWorkersWorkerID handles DELETE /v1/scheduler/workers/{workerId}
func (h *APIHandler) DeleteV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:drain") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
workerID := r.PathValue("workerId")
if workerID == "" {
http.Error(w, `{"error":"Missing worker ID","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
h.logger.Info("draining worker", "worker_id", workerID, "user", user.Name)
// Note: Actual drain implementation would involve signaling the worker
// and waiting for tasks to complete. This is a simplified version.
w.WriteHeader(http.StatusNoContent)
}
// GetV1SchedulerReservations handles GET /v1/scheduler/reservations
func (h *APIHandler) GetV1SchedulerReservations(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
// Return empty list for now - reservations would be tracked in hub
reservations := []ReservationInfo{}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(reservations); err != nil {
h.logger.Warn("failed to encode reservations", "error", err)
}
}
// PostV1SchedulerReservations handles POST /v1/scheduler/reservations
func (h *APIHandler) PostV1SchedulerReservations(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:write") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
var req CreateReservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"Invalid request body","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
if req.GPUCount <= 0 {
http.Error(w, `{"error":"GPU count must be positive","code":"VALIDATION_ERROR"}`, http.StatusUnprocessableEntity)
return
}
if req.NodeCount <= 0 {
req.NodeCount = 1
}
if req.ExpiresMinutes <= 0 {
req.ExpiresMinutes = 30
}
reservation := ReservationInfo{
ID: fmt.Sprintf("res-%d", time.Now().UnixNano()),
UserID: user.Name,
GPUCount: req.GPUCount,
GPUType: req.GPUType,
NodeCount: req.NodeCount,
ExpiresAt: time.Now().Add(time.Duration(req.ExpiresMinutes) * time.Minute),
Status: "active",
}
h.logger.Info("created reservation", "reservation_id", reservation.ID, "user", user.Name)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(reservation); err != nil {
h.logger.Warn("failed to encode reservation", "error", err)
}
}
// PatchV1SchedulerJobsJobIDPriority handles PATCH /v1/scheduler/jobs/{jobId}/priority
func (h *APIHandler) PatchV1SchedulerJobsJobIDPriority(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "tasks:priority") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
if h.hub == nil {
http.Error(w, `{"error":"Scheduler not available","code":"SERVICE_UNAVAILABLE"}`, http.StatusServiceUnavailable)
return
}
jobID := r.PathValue("jobId")
if jobID == "" {
http.Error(w, `{"error":"Missing job ID","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
var req struct {
Priority int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":"Invalid request body","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
if req.Priority < 1 || req.Priority > 10 {
http.Error(w, `{"error":"Priority must be between 1 and 10","code":"VALIDATION_ERROR"}`, http.StatusUnprocessableEntity)
return
}
h.logger.Info("updating job priority", "job_id", jobID, "priority", req.Priority, "user", user.Name)
// Note: Actual priority update would modify the task in the queue
// This is a simplified version
response := map[string]interface{}{
"id": jobID,
"priority": req.Priority,
"status": "queued",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Warn("failed to encode priority response", "error", err)
}
}
// GetV1SchedulerStatusStream handles GET /v1/scheduler/status/stream (SSE)
func (h *APIHandler) GetV1SchedulerStatusStream(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Create event channel for this client
eventChan := make(chan StreamingEvent, 10)
clientID := fmt.Sprintf("%s-%d", user.Name, time.Now().UnixNano())
h.streaming[clientID] = append(h.streaming[clientID], eventChan)
// Clean up on disconnect
defer func() {
delete(h.streaming, clientID)
close(eventChan)
}()
// Send initial status
status := map[string]interface{}{
"type": "connected",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(status)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
// Keep connection alive and send periodic updates
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
heartbeat := map[string]interface{}{
"type": "heartbeat",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(heartbeat)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
case event := <-eventChan:
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
}
}
}
// GetV1SchedulerJobsJobIDStream handles GET /v1/scheduler/jobs/{jobId}/stream (SSE)
func (h *APIHandler) GetV1SchedulerJobsJobIDStream(w http.ResponseWriter, r *http.Request) {
user := auth.GetUserFromContext(r.Context())
if !h.checkPermission(user, "scheduler:read") {
http.Error(w, `{"error":"Insufficient permissions","code":"FORBIDDEN"}`, http.StatusForbidden)
return
}
jobID := r.PathValue("jobId")
if jobID == "" {
http.Error(w, `{"error":"Missing job ID","code":"BAD_REQUEST"}`, http.StatusBadRequest)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Send initial status
response := map[string]interface{}{
"type": "connected",
"job_id": jobID,
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(response)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
// Keep connection alive
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
heartbeat := map[string]interface{}{
"type": "heartbeat",
"timestamp": time.Now().UTC(),
}
data, _ := json.Marshal(heartbeat)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
}
}
}
// checkPermission checks if the user has the required permission
func (h *APIHandler) checkPermission(user *auth.User, permission string) bool {
if user == nil {
return false
}
// Admin has all permissions
if user.Admin {
return true
}
// Check specific permission
return user.HasPermission(permission)
}