diff --git a/cmd/scheduler/main.go b/cmd/scheduler/main.go new file mode 100644 index 0000000..796af57 --- /dev/null +++ b/cmd/scheduler/main.go @@ -0,0 +1,274 @@ +package main + +import ( + "flag" + "fmt" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/scheduler" + "gopkg.in/yaml.v3" +) + +// Config represents the scheduler configuration +type Config struct { + Scheduler SchedulerConfig `yaml:"scheduler"` +} + +type SchedulerConfig struct { + BindAddr string `yaml:"bind_addr"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` + AutoGenerateCerts bool `yaml:"auto_generate_certs"` + StateDir string `yaml:"state_dir"` + DefaultBatchSlots int `yaml:"default_batch_slots"` + DefaultServiceSlots int `yaml:"default_service_slots"` + StarvationThresholdMins float64 `yaml:"starvation_threshold_mins"` + PriorityAgingRate float64 `yaml:"priority_aging_rate"` + GangAllocTimeoutSecs int `yaml:"gang_alloc_timeout_secs"` + AcceptanceTimeoutSecs int `yaml:"acceptance_timeout_secs"` + MetricsAddr string `yaml:"metrics_addr"` + WorkerTokens []WorkerToken `yaml:"worker_tokens"` +} + +type WorkerToken struct { + ID string `yaml:"id"` + Token string `yaml:"token"` +} + +func main() { + var ( + configPath string + generateToken bool + initConfig bool + numTokens int + ) + flag.StringVar(&configPath, "config", "scheduler.yaml", "Path to scheduler config file") + flag.BoolVar(&generateToken, "generate-token", false, "Generate a new worker token and exit") + flag.BoolVar(&initConfig, "init", false, "Initialize a new config file with generated tokens") + flag.IntVar(&numTokens, "tokens", 3, "Number of tokens to generate (used with -init)") + flag.Parse() + + // Handle token generation mode + if generateToken { + token := scheduler.GenerateWorkerToken() + fmt.Println(token) + os.Exit(0) + } + + // Handle init mode + if initConfig { + if err := generateConfig(configPath, numTokens); err != nil { + fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err) + os.Exit(1) + } + fmt.Printf("Config generated: %s\n", configPath) + fmt.Printf("\nGenerated %d worker tokens. Copy the appropriate token to each worker's config.\n", numTokens) + os.Exit(0) + } + + // Load config + cfg, err := loadConfig(configPath) + if err != nil { + slog.Error("failed to load config", "error", err) + os.Exit(1) + } + + // Setup logging + handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}) + logger := slog.New(handler) + slog.SetDefault(logger) + + // Create token map + tokenMap := make(map[string]string) + for _, wt := range cfg.Scheduler.WorkerTokens { + tokenMap[wt.Token] = wt.ID + } + + // Auto-generate certs if needed + if cfg.Scheduler.AutoGenerateCerts && cfg.Scheduler.CertFile != "" { + if _, err := os.Stat(cfg.Scheduler.CertFile); os.IsNotExist(err) { + keyFile := cfg.Scheduler.KeyFile + if keyFile == "" { + keyFile = cfg.Scheduler.CertFile + ".key" + } + logger.Info("generating self-signed certificate", "cert", cfg.Scheduler.CertFile) + if err := scheduler.GenerateSelfSignedCert(cfg.Scheduler.CertFile, keyFile); err != nil { + logger.Error("failed to generate certificate", "error", err) + os.Exit(1) + } + } + } + + // Create hub config + hubCfg := scheduler.HubConfig{ + BindAddr: cfg.Scheduler.BindAddr, + CertFile: cfg.Scheduler.CertFile, + KeyFile: cfg.Scheduler.KeyFile, + AutoGenerateCerts: cfg.Scheduler.AutoGenerateCerts, + StateDir: cfg.Scheduler.StateDir, + DefaultBatchSlots: cfg.Scheduler.DefaultBatchSlots, + DefaultServiceSlots: cfg.Scheduler.DefaultServiceSlots, + StarvationThresholdMins: cfg.Scheduler.StarvationThresholdMins, + PriorityAgingRate: cfg.Scheduler.PriorityAgingRate, + GangAllocTimeoutSecs: cfg.Scheduler.GangAllocTimeoutSecs, + AcceptanceTimeoutSecs: cfg.Scheduler.AcceptanceTimeoutSecs, + WorkerTokens: tokenMap, + } + + // Create auditor (optional) + var auditor *audit.Logger + + // Create hub + hub, err := scheduler.NewHub(hubCfg, auditor) + if err != nil { + logger.Error("failed to create scheduler hub", "error", err) + os.Exit(1) + } + + // Start hub + if err := hub.Start(); err != nil { + logger.Error("failed to start scheduler hub", "error", err) + os.Exit(1) + } + + logger.Info("scheduler hub started", "bind_addr", cfg.Scheduler.BindAddr) + + // Setup HTTP handlers + mux := http.NewServeMux() + mux.HandleFunc("/ws/worker", hub.HandleConnection) + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + }) + mux.HandleFunc("/metrics", hub.ServeMetrics) + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Start server + go func() { + if cfg.Scheduler.CertFile != "" { + logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr) + if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil { + logger.Error("server error", "error", err) + } + } else { + logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr) + if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil { + logger.Error("server error", "error", err) + } + } + }() + + // Wait for shutdown signal + <-sigChan + logger.Info("shutting down scheduler...") + hub.Stop() + logger.Info("scheduler stopped") +} + +func loadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config file: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + // Set defaults + if cfg.Scheduler.BindAddr == "" { + cfg.Scheduler.BindAddr = "0.0.0.0:7777" + } + if cfg.Scheduler.StateDir == "" { + cfg.Scheduler.StateDir = "/var/lib/fetch_ml" + } + if cfg.Scheduler.DefaultBatchSlots == 0 { + cfg.Scheduler.DefaultBatchSlots = 3 + } + if cfg.Scheduler.DefaultServiceSlots == 0 { + cfg.Scheduler.DefaultServiceSlots = 1 + } + if cfg.Scheduler.StarvationThresholdMins == 0 { + cfg.Scheduler.StarvationThresholdMins = 5 + } + if cfg.Scheduler.PriorityAgingRate == 0 { + cfg.Scheduler.PriorityAgingRate = 0.1 + } + if cfg.Scheduler.GangAllocTimeoutSecs == 0 { + cfg.Scheduler.GangAllocTimeoutSecs = 60 + } + if cfg.Scheduler.AcceptanceTimeoutSecs == 0 { + cfg.Scheduler.AcceptanceTimeoutSecs = 30 + } + + return &cfg, nil +} + +// generateConfig creates a new scheduler config file with generated tokens +func generateConfig(path string, numTokens int) error { + // Generate tokens + var tokens []WorkerToken + for i := 1; i <= numTokens; i++ { + tokens = append(tokens, WorkerToken{ + ID: fmt.Sprintf("worker-%02d", i), + Token: scheduler.GenerateWorkerToken(), + }) + } + + cfg := Config{ + Scheduler: SchedulerConfig{ + BindAddr: "0.0.0.0:7777", + AutoGenerateCerts: true, + CertFile: "/etc/fetch_ml/scheduler.crt", + KeyFile: "/etc/fetch_ml/scheduler.key", + StateDir: "/var/lib/fetch_ml", + DefaultBatchSlots: 3, + DefaultServiceSlots: 1, + StarvationThresholdMins: 5, + PriorityAgingRate: 0.1, + GangAllocTimeoutSecs: 60, + AcceptanceTimeoutSecs: 30, + MetricsAddr: "0.0.0.0:9090", + WorkerTokens: tokens, + }, + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("marshal config: %w", err) + } + + // Add header comment + header := `# Scheduler Configuration for fetch_ml +# Generated by: scheduler -init +# +# ⚠️ SECURITY WARNING: This file contains authentication tokens. +# - Do NOT commit to git +# - Keep the file permissions secure (chmod 600) +# - Copy the appropriate token to each worker's config +# +` + fullContent := header + string(data) + + if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil { + return fmt.Errorf("write config file: %w", err) + } + + // Print tokens to stdout for easy distribution + fmt.Print("\n=== Generated Worker Tokens ===\n") + fmt.Print("Copy these to your worker configs:\n\n") + for _, t := range tokens { + fmt.Printf("Worker: %s\nToken: %s\n\n", t.ID, t.Token) + } + + return nil +} diff --git a/internal/scheduler/auth.go b/internal/scheduler/auth.go new file mode 100644 index 0000000..0411a19 --- /dev/null +++ b/internal/scheduler/auth.go @@ -0,0 +1,157 @@ +package scheduler + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/gorilla/websocket" +) + +// GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler +func GenerateSelfSignedCert(certFile, keyFile string) error { + if err := os.MkdirAll(filepath.Dir(certFile), 0755); err != nil { + return fmt.Errorf("create cert directory: %w", err) + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("generate key: %w", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"fetch_ml"}, + CommonName: "fetch_ml_scheduler", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Add IP SANs for local development + template.IPAddresses = []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback} + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("create certificate: %w", err) + } + + certOut, err := os.Create(certFile) + if err != nil { + return fmt.Errorf("create cert file: %w", err) + } + defer certOut.Close() + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("create key file: %w", err) + } + defer keyOut.Close() + + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return fmt.Errorf("marshal private key: %w", err) + } + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return nil +} + +// DialWSS connects to the scheduler via WSS with cert pinning +func DialWSS(addr, certFile, token string) (*websocket.Conn, error) { + certPEM, err := os.ReadFile(certFile) + if err != nil { + return nil, fmt.Errorf("read cert file: %w", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(certPEM) { + return nil, fmt.Errorf("parse cert file") + } + + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{ + RootCAs: pool, + MinVersion: tls.VersionTLS12, + }, + HandshakeTimeout: 10 * time.Second, + } + + header := http.Header{} + if token != "" { + header.Set("Authorization", "Bearer "+token) + } + + url := "wss://" + addr + "/ws/worker" + conn, _, err := dialer.Dial(url, header) + if err != nil { + return nil, fmt.Errorf("dial scheduler: %w", err) + } + + return conn, nil +} + +// LocalModeDial connects without TLS for single-node mode (loopback only) +func LocalModeDial(port int, token string) (*websocket.Conn, error) { + dialer := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + + header := http.Header{} + if token != "" { + header.Set("Authorization", "Bearer "+token) + } + + url := fmt.Sprintf("ws://127.0.0.1:%d/ws/worker", port) + conn, _, err := dialer.Dial(url, header) + if err != nil { + return nil, fmt.Errorf("dial local scheduler: %w", err) + } + + return conn, nil +} + +// TokenValidator validates worker authentication tokens +type TokenValidator struct { + tokens map[string]string // token -> workerID +} + +func NewTokenValidator(tokens map[string]string) *TokenValidator { + return &TokenValidator{tokens: tokens} +} + +func (tv *TokenValidator) Validate(token string) (workerID string, ok bool) { + workerID, ok = tv.tokens[token] + return +} + +// ExtractBearerToken extracts the token from an Authorization header +func ExtractBearerToken(header string) string { + return strings.TrimPrefix(header, "Bearer ") +} + +// GenerateWorkerToken creates a cryptographically secure random token for a worker +func GenerateWorkerToken() string { + b := make([]byte, 32) + rand.Read(b) + // Use URL-safe base64 encoding for compact, URL-friendly tokens + return "wkr_" + base64.URLEncoding.EncodeToString(b) +} diff --git a/internal/scheduler/hub.go b/internal/scheduler/hub.go new file mode 100644 index 0000000..4027e6d --- /dev/null +++ b/internal/scheduler/hub.go @@ -0,0 +1,930 @@ +package scheduler + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/audit" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins (configurable in production) + }, +} + +// SchedulerHub manages worker connections and job scheduling +type SchedulerHub struct { + mu sync.RWMutex + workers map[string]*WorkerConn + readyWorkers map[string]*WorkerConn + batchQueue *PriorityQueue + serviceQueue *PriorityQueue + reservations map[string]*Reservation + multiNodePending map[string]*MultiNodeJob + pendingAcceptance map[string]*JobAssignment + state *StateStore + starvation *StarvationTracker + metrics *SchedulerMetrics + auditor *audit.Logger + tokenValidator *TokenValidator + config HubConfig + ctx context.Context + cancel context.CancelFunc + server *http.Server + listener net.Listener +} + +type HubConfig struct { + BindAddr string + CertFile string + KeyFile string + AutoGenerateCerts bool + StateDir string + DefaultBatchSlots int + DefaultServiceSlots int + StarvationThresholdMins float64 + PriorityAgingRate float64 + GangAllocTimeoutSecs int + AcceptanceTimeoutSecs int + LocalMode bool + WorkerTokens map[string]string // token -> workerID +} + +// WorkerConn represents a connected worker +type WorkerConn struct { + workerID string + conn *websocket.Conn + capabilities WorkerCapabilities + slots SlotStatus + lease *Lease + activeTasks map[string]struct{} + send chan Message + hub *SchedulerHub + mu sync.Mutex +} + +// Lease tracks job ownership +type Lease struct { + TaskID string + WorkerID string + ExpiresAt time.Time +} + +// Reservation prevents starvation of large jobs +type Reservation struct { + TaskID string + GPUCount int + CreatedAt time.Time +} + +// MultiNodeJob tracks gang allocation state +type MultiNodeJob struct { + JobID string + TotalNodes int + Assignments []*NodeAssignment + CommittedAt time.Time +} + +type NodeAssignment struct { + Worker *WorkerConn + Rank int + CommittedAt time.Time +} + +// JobAssignment tracks acceptance state +type JobAssignment struct { + TaskID string + WorkerID string + AssignedAt time.Time + AcceptanceDeadline time.Time + Accepted bool +} + +// StarvationTracker monitors long-waiting jobs +type StarvationTracker struct { + mu sync.RWMutex + threshold time.Duration +} + +// SchedulerMetrics tracks scheduler statistics +type SchedulerMetrics struct { + mu sync.RWMutex + WorkersConnected int + QueueDepthBatch int + QueueDepthService int + JobsCompleted int + JobsFailed int + JobsCancelled int + WorkerSlots map[string]SlotStatus +} + +// NewHub creates a new scheduler hub +func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) { + ctx, cancel := context.WithCancel(context.Background()) + + // Initialize state store + statePath := cfg.StateDir + "/scheduler.state" + state, err := NewStateStore(statePath) + if err != nil { + cancel() + return nil, fmt.Errorf("init state store: %w", err) + } + + agingRate := cfg.PriorityAgingRate + if agingRate == 0 { + agingRate = 0.1 + } + + hub := &SchedulerHub{ + workers: make(map[string]*WorkerConn), + readyWorkers: make(map[string]*WorkerConn), + batchQueue: NewPriorityQueue(agingRate), + serviceQueue: NewPriorityQueue(agingRate), + reservations: make(map[string]*Reservation), + multiNodePending: make(map[string]*MultiNodeJob), + pendingAcceptance: make(map[string]*JobAssignment), + state: state, + starvation: &StarvationTracker{ + threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute, + }, + metrics: &SchedulerMetrics{ + WorkerSlots: make(map[string]SlotStatus), + }, + auditor: auditor, + tokenValidator: NewTokenValidator(cfg.WorkerTokens), + config: cfg, + ctx: ctx, + cancel: cancel, + } + + return hub, nil +} + +// Start initializes the scheduler, starts the HTTP server, and replays state +func (h *SchedulerHub) Start() error { + // Replay state first + events, err := h.state.Replay() + if err != nil { + return fmt.Errorf("state replay failed: %w", err) + } + + for _, ev := range events { + switch ev.Type { + case EventJobEnqueued: + h.restoreJob(ev) + case EventJobAssigned: + h.restoreAssignment(ev) + case EventJobCompleted, EventJobFailed, EventJobCancelled: + // terminal — skip + } + } + + // Start WSS server (unified protocol) + mux := http.NewServeMux() + mux.HandleFunc("/ws/worker", h.HandleConnection) + + listener, err := net.Listen("tcp", h.config.BindAddr) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + h.listener = listener + + h.server = &http.Server{Handler: mux} + + // Auto-generate self-signed certs if requested + if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") { + certFile := h.config.StateDir + "/scheduler.crt" + keyFile := h.config.StateDir + "/scheduler.key" + if err := GenerateSelfSignedCert(certFile, keyFile); err != nil { + return fmt.Errorf("failed to generate self-signed cert: %w", err) + } + h.config.CertFile = certFile + h.config.KeyFile = keyFile + } + + // Start with TLS if certificates are configured + if h.config.CertFile != "" && h.config.KeyFile != "" { + go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile) + } else { + go h.server.Serve(listener) + } + + // Start background tasks + go h.checkAcceptanceTimeouts() + go h.checkGangTimeouts() + go h.checkStarvation() + + // Grace period: workers have 30s to reconnect before assigned jobs are orphaned + time.AfterFunc(30*time.Second, h.reconcileOrphans) + + return nil +} + +// Addr returns the listening address of the scheduler +func (h *SchedulerHub) Addr() string { + if h.listener == nil { + return "" + } + return h.listener.Addr().String() +} + +// Stop gracefully shuts down the scheduler +func (h *SchedulerHub) Stop() { + h.cancel() + h.state.Close() +} + +// HandleConnection handles WSS connections from workers and metrics clients +func (h *SchedulerHub) HandleConnection(w http.ResponseWriter, r *http.Request) { + // Validate token + token := ExtractBearerToken(r.Header.Get("Authorization")) + clientID, ok := h.tokenValidator.Validate(token) + if !ok { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Upgrade to WebSocket + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, "upgrade failed", http.StatusInternalServerError) + return + } + + // Check if this is a metrics client (special token prefix) + if strings.HasPrefix(clientID, "metrics-") { + go h.runMetricsClient(clientID, conn) + return + } + + go h.runWorker(clientID, conn) +} + +func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) { + wc := &WorkerConn{ + workerID: workerID, + conn: conn, + slots: SlotStatus{}, + activeTasks: make(map[string]struct{}), + send: make(chan Message, 10), + hub: h, + } + + h.mu.Lock() + h.workers[workerID] = wc + h.metrics.WorkersConnected++ + h.mu.Unlock() + + defer func() { + h.mu.Lock() + delete(h.workers, workerID) + delete(h.readyWorkers, workerID) + h.metrics.WorkersConnected-- + h.mu.Unlock() + conn.Close() + }() + + // Send loop + go func() { + for msg := range wc.send { + conn.WriteJSON(msg) + } + }() + + // Receive loop + for { + var msg Message + if err := conn.ReadJSON(&msg); err != nil { + return // Connection closed + } + h.handleMessage(wc, msg) + } +} + +func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) { + switch msg.Type { + case MsgRegister: + var reg WorkerRegistration + json.Unmarshal(msg.Payload, ®) + h.reconcileWorker(reg, wc) + case MsgHeartbeat: + var hb HeartbeatPayload + json.Unmarshal(msg.Payload, &hb) + wc.mu.Lock() + wc.slots = hb.Slots + wc.mu.Unlock() + h.updateWorkerMetrics(wc.workerID, hb.Slots) + case MsgReadyForWork: + var ready ReadyPayload + json.Unmarshal(msg.Payload, &ready) + wc.mu.Lock() + wc.slots = ready.Slots + wc.mu.Unlock() + h.handleReady(wc, ready.Slots) + case MsgJobAccepted: + var taskID string + json.Unmarshal(msg.Payload, &taskID) + h.handleJobAccepted(wc.workerID, taskID) + case MsgJobResult: + var result JobResultPayload + json.Unmarshal(msg.Payload, &result) + h.handleJobResult(wc.workerID, result) + case MsgServiceHealth: + // Service health updates - logged but no action needed for MVP + var health ServiceHealthPayload + json.Unmarshal(msg.Payload, &health) + slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy) + } +} + +func (h *SchedulerHub) reconcileWorker(reg WorkerRegistration, wc *WorkerConn) { + wc.capabilities = reg.Capabilities + + for _, reported := range reg.ActiveTasks { + task := h.getTask(reported.TaskID) + + switch { + case task == nil: + // Case 1: Scheduler has no record — kill it + wc.send <- Message{Type: MsgJobCancel, + Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} + + case task.Status == "orphaned": + // Case 2: Scheduler thought lost — restore, worker is running it + h.restoreLease(reported.TaskID, wc.workerID) + task.Status = "running" + + case task.Status == "queued" || task.Status == "assigned": + // Case 3: Re-queued while worker was disconnected — cancel on this worker + wc.send <- Message{Type: MsgJobCancel, + Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} + + case task.Status == "running" && task.WorkerID != reg.ID: + // Case 4: Running on two workers — cancel on reconnecting worker + wc.send <- Message{Type: MsgJobCancel, + Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})} + } + } + + // Send registration acknowledgment + wc.send <- Message{Type: MsgAck} +} + +func (h *SchedulerHub) handleReady(wc *WorkerConn, _ SlotStatus) { + // Check for multi-node jobs first + for _, task := range h.batchQueue.Items() { + if task.Spec.NodeCount > 1 && h.canAdmit(task, wc) { + if h.handleMultiNodeReady(task, wc) { + return // Multi-node job is being handled + } + } + } + + // Fall through to regular single-node matching + if task := h.findMatch(wc); task != nil { + wc.send <- h.assignTask(task, wc) + return + } + + h.starvation.CheckAndReserve(h) + h.mu.Lock() + h.readyWorkers[wc.workerID] = wc + h.mu.Unlock() + wc.send <- Message{Type: MsgNoWork} +} + +func (h *SchedulerHub) findMatch(wc *WorkerConn) *Task { + if wc.slots.ServiceAvailable() > 0 { + if task := h.scanFit(h.serviceQueue, wc); task != nil { + return task + } + } + if wc.slots.BatchAvailable() > 0 { + if task := h.scanFit(h.batchQueue, wc); task != nil { + return task + } + } + return nil +} + +func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task { + for _, task := range q.Items() { + if task.Spec.NodeCount > 1 { + continue // gang allocator handles multi-node + } + if h.canAdmit(task, wc) { + return task + } + } + return nil +} + +func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + for _, res := range h.reservations { + if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 { + if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount { + return false + } + } + } + return worker.capabilities.GPUCount >= candidate.Spec.GPUCount +} + +func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message { + // Remove from queue first (prevent double-assignment) + h.batchQueue.Remove(task.ID) + h.serviceQueue.Remove(task.ID) + + // Track pending acceptance + h.mu.Lock() + h.pendingAcceptance[task.ID] = &JobAssignment{ + TaskID: task.ID, + WorkerID: wc.workerID, + AssignedAt: time.Now(), + AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), + Accepted: false, + } + h.mu.Unlock() + + // Persist assignment + h.state.Append(StateEvent{ + Type: EventJobAssigned, + TaskID: task.ID, + WorkerID: wc.workerID, + }) + + return Message{ + Type: MsgJobAssign, + Payload: mustMarshal(task.Spec), + } +} + +func (h *SchedulerHub) handleJobAccepted(_, taskID string) { + h.mu.Lock() + defer h.mu.Unlock() + + if assignment, ok := h.pendingAcceptance[taskID]; ok { + assignment.Accepted = true + } +} + +func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.pendingAcceptance, result.TaskID) + + eventType := EventJobCompleted + switch result.State { + case "failed": + eventType = EventJobFailed + h.metrics.JobsFailed++ + case "cancelled": + eventType = EventJobCancelled + h.metrics.JobsCancelled++ + default: + h.metrics.JobsCompleted++ + } + + h.state.Append(StateEvent{ + Type: eventType, + TaskID: result.TaskID, + WorkerID: workerID, + }) +} + +// checkAcceptanceTimeouts re-queues jobs that weren't accepted +func (h *SchedulerHub) checkAcceptanceTimeouts() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-h.ctx.Done(): + return + case <-ticker.C: + h.mu.Lock() + for taskID, a := range h.pendingAcceptance { + if !a.Accepted && time.Now().After(a.AcceptanceDeadline) { + h.batchQueue.Add(h.getTask(taskID)) + delete(h.pendingAcceptance, taskID) + if wc, ok := h.workers[a.WorkerID]; ok { + wc.slots = SlotStatus{} + } + } + } + h.mu.Unlock() + } + } +} + +// checkGangTimeouts releases reserved slots for incomplete gangs +func (h *SchedulerHub) checkGangTimeouts() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-h.ctx.Done(): + return + case <-ticker.C: + h.mu.Lock() + for jobID, pending := range h.multiNodePending { + if time.Since(pending.CommittedAt) > time.Duration(h.config.GangAllocTimeoutSecs)*time.Second { + for _, a := range pending.Assignments { + a.Worker.slots = SlotStatus{} + h.readyWorkers[a.Worker.workerID] = a.Worker + } + delete(h.multiNodePending, jobID) + } + } + h.mu.Unlock() + } + } +} + +func (h *SchedulerHub) checkStarvation() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-h.ctx.Done(): + return + case <-ticker.C: + h.starvation.CheckAndReserve(h) + } + } +} + +func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) { + st.mu.Lock() + defer st.mu.Unlock() + + for _, task := range h.batchQueue.Items() { + if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) { + h.mu.Lock() + h.reservations[task.ID] = &Reservation{ + TaskID: task.ID, + GPUCount: task.Spec.GPUCount, + CreatedAt: time.Now(), + } + h.mu.Unlock() + } + } +} + +func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + _, exists := h.reservations[taskID] + return exists +} + +// Helper methods (stubs to be implemented) + +// GetTask returns a task by ID (public API) +func (h *SchedulerHub) GetTask(taskID string) *Task { + return h.getTask(taskID) +} + +// SubmitJob submits a new job to the scheduler (public API) +func (h *SchedulerHub) SubmitJob(spec JobSpec) error { + if spec.ID == "" { + return fmt.Errorf("job ID is required") + } + + task := &Task{ + ID: spec.ID, + Spec: spec, + Status: "queued", + SubmittedAt: time.Now(), + } + + // Persist to state store + h.state.Append(StateEvent{ + Type: EventJobEnqueued, + TaskID: spec.ID, + Payload: mustMarshal(spec), + Timestamp: time.Now(), + }) + + // Add to appropriate queue + if spec.Type == JobTypeService { + h.serviceQueue.Add(task) + } else { + h.batchQueue.Add(task) + } + + // Send prewarm hint if job has snapshot + h.sendPrewarmHint(task) + + slog.Info("job submitted", "task_id", spec.ID, "type", spec.Type) + return nil +} + +func (h *SchedulerHub) getTask(taskID string) *Task { + t := h.batchQueue.Get(taskID) + if t != nil { + return t + } + return h.serviceQueue.Get(taskID) +} + +func (h *SchedulerHub) restoreJob(ev StateEvent) { + // Parse job spec from event payload + var spec JobSpec + if err := json.Unmarshal(ev.Payload, &spec); err != nil { + slog.Error("failed to restore job", "task_id", ev.TaskID, "error", err) + return + } + + task := &Task{ + ID: ev.TaskID, + Spec: spec, + Status: "queued", + SubmittedAt: ev.Timestamp, + } + + // Add to appropriate queue + if spec.Type == JobTypeService { + h.serviceQueue.Add(task) + } else { + h.batchQueue.Add(task) + } + slog.Info("restored job from state", "task_id", ev.TaskID, "type", spec.Type) +} + +func (h *SchedulerHub) restoreAssignment(ev StateEvent) { + // Parse assignment from event payload + var payload struct { + WorkerID string `json:"worker_id"` + } + if err := json.Unmarshal(ev.Payload, &payload); err != nil { + slog.Error("failed to restore assignment", "task_id", ev.TaskID, "error", err) + return + } + + // Restore pending acceptance state + h.pendingAcceptance[ev.TaskID] = &JobAssignment{ + TaskID: ev.TaskID, + WorkerID: payload.WorkerID, + AssignedAt: ev.Timestamp, + AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), + Accepted: false, + } + + slog.Info("restored assignment from state", "task_id", ev.TaskID, "worker_id", payload.WorkerID) +} + +func (h *SchedulerHub) restoreLease(taskID, workerID string) { + h.mu.Lock() + defer h.mu.Unlock() + + if assignment, ok := h.pendingAcceptance[taskID]; ok { + assignment.Accepted = true + slog.Info("restored lease", "task_id", taskID, "worker_id", workerID) + } else { + // Create a new lease record + h.pendingAcceptance[taskID] = &JobAssignment{ + TaskID: taskID, + WorkerID: workerID, + AssignedAt: time.Now(), + AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second), + Accepted: true, + } + slog.Info("created new lease on reconnect", "task_id", taskID, "worker_id", workerID) + } +} + +func (h *SchedulerHub) reconcileOrphans() { + h.mu.Lock() + defer h.mu.Unlock() + + // After grace period (30s), any job still assigned to a disconnected worker + // is considered orphaned and should be re-queued + for taskID, assignment := range h.pendingAcceptance { + if assignment.Accepted { + // Job was accepted but worker is gone (not in h.workers) + if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected { + task := h.getTask(taskID) + if task != nil { + task.Status = "orphaned" + h.batchQueue.Add(task) + h.state.Append(StateEvent{ + Type: EventJobRequeued, + TaskID: taskID, + WorkerID: assignment.WorkerID, + }) + slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID) + } + delete(h.pendingAcceptance, taskID) + } + } + } +} + +func (h *SchedulerHub) updateWorkerMetrics(workerID string, slots SlotStatus) { + h.metrics.mu.Lock() + defer h.metrics.mu.Unlock() + h.metrics.WorkerSlots[workerID] = slots +} + +// sendPrewarmHint sends a prewarm hint to an idle worker when a job with snapshot is enqueued +// TODO: Call this when jobs are enqueued via scheduler API +func (h *SchedulerHub) sendPrewarmHint(task *Task) { + if task.Spec.SnapshotID == "" { + return + } + h.mu.RLock() + defer h.mu.RUnlock() + + for _, wc := range h.readyWorkers { + if h.canAdmit(task, wc) { + wc.send <- Message{ + Type: MsgPrewarmHint, + Payload: mustMarshal(PrewarmHintPayload{ + TaskID: task.ID, + SnapshotID: task.Spec.SnapshotID, + SnapshotSHA: task.Spec.SnapshotSHA, + }), + } + return // one worker prewarms — not all of them + } + } +} + +// runMetricsClient handles metrics clients over WSS +func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) { + defer conn.Close() + + for { + var msg Message + if err := conn.ReadJSON(&msg); err != nil { + return // Connection closed + } + + if msg.Type == MsgMetricsRequest { + metrics := h.getMetricsPayload() + conn.WriteJSON(Message{ + Type: MsgMetricsResponse, + Payload: mustMarshal(metrics), + }) + } + } +} + +// getMetricsPayload returns current metrics as a map +func (h *SchedulerHub) getMetricsPayload() map[string]any { + h.metrics.mu.RLock() + defer h.metrics.mu.RUnlock() + + return map[string]any{ + "workers_connected": h.metrics.WorkersConnected, + "queue_depth_batch": h.batchQueue.Len(), + "queue_depth_service": h.serviceQueue.Len(), + "jobs_completed": h.metrics.JobsCompleted, + "jobs_failed": h.metrics.JobsFailed, + "jobs_cancelled": h.metrics.JobsCancelled, + "worker_slots": h.metrics.WorkerSlots, + } +} + +// ServeMetrics serves Prometheus-formatted metrics (deprecated, use WSS) +func (h *SchedulerHub) ServeMetrics(w http.ResponseWriter, r *http.Request) { + h.metrics.mu.RLock() + defer h.metrics.mu.RUnlock() + + w.Header().Set("Content-Type", "text/plain; version=0.0.4") + + fmt.Fprintf(w, "# HELP fetch_ml_workers_connected Number of connected workers\n") + fmt.Fprintf(w, "# TYPE fetch_ml_workers_connected gauge\n") + fmt.Fprintf(w, "fetch_ml_workers_connected %d\n\n", h.metrics.WorkersConnected) + + fmt.Fprintf(w, "# HELP fetch_ml_queue_depth Current queue depth\n") + fmt.Fprintf(w, "# TYPE fetch_ml_queue_depth gauge\n") + fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"batch\"} %d\n", h.batchQueue.Len()) + fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"service\"} %d\n\n", h.serviceQueue.Len()) + + fmt.Fprintf(w, "# HELP fetch_ml_jobs_total Total jobs by result\n") + fmt.Fprintf(w, "# TYPE fetch_ml_jobs_total counter\n") + fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"completed\"} %d\n", h.metrics.JobsCompleted) + fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"failed\"} %d\n", h.metrics.JobsFailed) + fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"cancelled\"} %d\n\n", h.metrics.JobsCancelled) + + fmt.Fprintf(w, "# HELP fetch_ml_slot_utilization Slot utilization by worker\n") + fmt.Fprintf(w, "# TYPE fetch_ml_slot_utilization gauge\n") + for workerID, slots := range h.metrics.WorkerSlots { + if slots.BatchTotal > 0 { + utilization := float64(slots.BatchInUse) / float64(slots.BatchTotal) + fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"batch\"} %.2f\n", workerID, utilization) + } + if slots.ServiceTotal > 0 { + utilization := float64(slots.ServiceInUse) / float64(slots.ServiceTotal) + fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"service\"} %.2f\n", workerID, utilization) + } + } +} + +// tryGangAlloc attempts to allocate a multi-node job to a worker +// It tracks partial allocations and dispatches when all nodes are committed +func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) { + h.mu.Lock() + defer h.mu.Unlock() + + // Check if this worker can run the job + if !h.canAdmit(task, wc) { + return + } + + jobID := task.ID + pending, ok := h.multiNodePending[jobID] + if !ok { + // First worker for this job + pending = &MultiNodeJob{ + JobID: jobID, + TotalNodes: task.Spec.NodeCount, + Assignments: make([]*NodeAssignment, 0, task.Spec.NodeCount), + CommittedAt: time.Now(), + } + h.multiNodePending[jobID] = pending + } + + // Add this worker to the pending assignment + assignment := &NodeAssignment{ + Worker: wc, + Rank: len(pending.Assignments), + CommittedAt: time.Now(), + } + pending.Assignments = append(pending.Assignments, assignment) + + // Reserve slots on this worker + wc.slots.BatchInUse++ + delete(h.readyWorkers, wc.workerID) + + // Check if we have all nodes + if len(pending.Assignments) < task.Spec.NodeCount { + // Still waiting for more workers + return + } + + // All nodes committed - dispatch simultaneously + headAddr := pending.Assignments[0].Worker.capabilities.Hostname + for i, a := range pending.Assignments { + // Create rank-specific job spec + spec := h.buildRankedSpec(task, i, headAddr, task.Spec.NodeCount) + msg := Message{ + Type: MsgJobAssign, + Payload: mustMarshal(spec), + } + a.Worker.send <- msg + } + + // Clean up pending state + delete(h.multiNodePending, jobID) + slog.Info("multi-node job dispatched", + "job_id", jobID, + "nodes", task.Spec.NodeCount, + "head_addr", headAddr) +} + +// buildRankedSpec creates a job spec with rank-specific template variables resolved +func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec { + // Clone the spec and add rank info to metadata + spec := task.Spec + spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3) + for k, v := range task.Spec.Metadata { + spec.Metadata[k] = v + } + spec.Metadata["HEAD_ADDR"] = headAddr + spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize) + spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank) + return spec +} + +// handleMultiNodeReady handles a ready signal for a multi-node job +// Returns true if the job was handled (either assigned or queued for gang alloc) +func (h *SchedulerHub) handleMultiNodeReady(task *Task, wc *WorkerConn) bool { + if task.Spec.NodeCount <= 1 { + return false // Not a multi-node job + } + + h.tryGangAlloc(task, wc) + return true +} + +func mustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/internal/scheduler/pacing.go b/internal/scheduler/pacing.go new file mode 100644 index 0000000..56cefae --- /dev/null +++ b/internal/scheduler/pacing.go @@ -0,0 +1,28 @@ +package scheduler + +// AdaptivePacingController derives request pacing based on worker capacity. +type AdaptivePacingController struct { + DesiredRPSPerWorker int +} + +// NewAdaptivePacingController constructs a controller with sane defaults. +func NewAdaptivePacingController(desired int) AdaptivePacingController { + if desired < 1 { + desired = 1 + } + return AdaptivePacingController{DesiredRPSPerWorker: desired} +} + +// RequestsPerSec returns max(1, maxWorkers * desiredRPSPerWorker). +func (a AdaptivePacingController) RequestsPerSec(maxWorkers int) int { + if maxWorkers < 1 { + maxWorkers = 1 + } + + rps := maxWorkers * a.DesiredRPSPerWorker + if rps < 1 { + rps = 1 + } + + return rps +} diff --git a/internal/scheduler/port_allocator.go b/internal/scheduler/port_allocator.go new file mode 100644 index 0000000..4a65c26 --- /dev/null +++ b/internal/scheduler/port_allocator.go @@ -0,0 +1,145 @@ +package scheduler + +import ( + "fmt" + "net" + "sync" + "time" +) + +// Default port range for service jobs (Jupyter, vLLM, etc.) +const ( + DefaultServicePortStart = 8000 + DefaultServicePortEnd = 9000 +) + +// PortAllocator manages dynamic port allocation for service jobs +// It tracks which ports are in use and assigns available ports from a configured range +// This is thread-safe for concurrent allocations across multiple workers + +type PortAllocator struct { + mu sync.Mutex + start int + end int + used map[int]allocation + ttl time.Duration // How long to keep port reserved after release +} + +type allocation struct { + taskID string + allocated time.Time +} + +// NewPortAllocator creates a new port allocator for the given range +// Default port range is 10000-65535 to avoid well-known ports +func NewPortAllocator(start, end int) *PortAllocator { + if start <= 0 { + start = 10000 + } + if end <= 0 || end > 65535 { + end = 65535 + } + if start >= end { + start = 10000 + end = 65535 + } + return &PortAllocator{ + start: start, + end: end, + used: make(map[int]allocation), + ttl: 30 * time.Second, // Prevent immediate reuse + } +} + +// Allocate assigns an available port to a task +// Returns error if no ports available or port is already in use +func (pa *PortAllocator) Allocate(taskID string) (int, error) { + pa.mu.Lock() + defer pa.mu.Unlock() + + // Clean up expired allocations + pa.cleanupExpired() + + // Try to find an available port + for port := pa.start; port <= pa.end; port++ { + if _, inUse := pa.used[port]; !inUse { + // Verify port is actually available on the system + if !pa.isPortAvailable(port) { + continue + } + pa.used[port] = allocation{ + taskID: taskID, + allocated: time.Now(), + } + return port, nil + } + } + + return 0, fmt.Errorf("no ports available in range %d-%d", pa.start, pa.end) +} + +// Release frees a port for reuse (after TTL expires) +func (pa *PortAllocator) Release(port int) { + pa.mu.Lock() + defer pa.mu.Unlock() + + if alloc, exists := pa.used[port]; exists { + // Don't delete immediately - mark with release time + // so it can't be immediately reallocated + pa.used[port] = allocation{ + taskID: alloc.taskID + ":released", + allocated: time.Now().Add(-pa.ttl), // Expired + } + } +} + +// GetAllocation returns the task ID for a given port, or empty if not allocated +func (pa *PortAllocator) GetAllocation(port int) string { + pa.mu.Lock() + defer pa.mu.Unlock() + + if alloc, exists := pa.used[port]; exists && !pa.isExpired(alloc) { + return alloc.taskID + } + return "" +} + +// cleanupExpired removes expired allocations +func (pa *PortAllocator) cleanupExpired() { + for port, alloc := range pa.used { + if pa.isExpired(alloc) { + delete(pa.used, port) + } + } +} + +// isExpired checks if an allocation has expired +func (pa *PortAllocator) isExpired(alloc allocation) bool { + return time.Since(alloc.allocated) > pa.ttl +} + +// isPortAvailable checks if a port is actually available on the system +func (pa *PortAllocator) isPortAvailable(port int) bool { + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return false + } + ln.Close() + return true +} + +// AvailableCount returns the number of available ports +func (pa *PortAllocator) AvailableCount() int { + pa.mu.Lock() + defer pa.mu.Unlock() + + pa.cleanupExpired() + return (pa.end - pa.start + 1) - len(pa.used) +} + +// SetTTL changes the time-to-live for released ports (for testing) +func (pa *PortAllocator) SetTTL(ttl time.Duration) { + pa.mu.Lock() + defer pa.mu.Unlock() + pa.ttl = ttl +} diff --git a/internal/scheduler/priority_queue.go b/internal/scheduler/priority_queue.go new file mode 100644 index 0000000..4fcfd48 --- /dev/null +++ b/internal/scheduler/priority_queue.go @@ -0,0 +1,175 @@ +package scheduler + +import ( + "container/heap" + "sync" + "time" +) + +// Task represents a job in the priority queue +type Task struct { + ID string + Priority int + SubmittedAt time.Time + Spec JobSpec + Status string + WorkerID string + Metadata map[string]string // Additional task metadata (snapshot SHA, etc.) + index int // for heap interface +} + +// EffectivePriority returns the priority with aging applied +func (t *Task) EffectivePriority(agingRate float64, now time.Time) float64 { + age := now.Sub(t.SubmittedAt).Minutes() + return float64(t.Priority) + age*agingRate +} + +// taskHeap is the internal heap implementation +type taskHeap struct { + items []*Task + agingRate float64 +} + +func (h taskHeap) Len() int { return len(h.items) } + +func (h taskHeap) Less(i, j int) bool { + // Higher priority first, then older first on ties + now := time.Now() + pi := h.items[i].EffectivePriority(h.agingRate, now) + pj := h.items[j].EffectivePriority(h.agingRate, now) + if pi != pj { + return pi > pj + } + return h.items[i].SubmittedAt.Before(h.items[j].SubmittedAt) +} + +func (h taskHeap) Swap(i, j int) { + h.items[i], h.items[j] = h.items[j], h.items[i] + h.items[i].index = i + h.items[j].index = j +} + +func (h *taskHeap) Push(x any) { + n := len(h.items) + task := x.(*Task) + task.index = n + h.items = append(h.items, task) +} + +func (h *taskHeap) Pop() any { + old := h.items + n := len(old) + task := old[n-1] + old[n-1] = nil // avoid memory leak + task.index = -1 + h.items = old[:n-1] + return task +} + +// PriorityQueue implements a thread-safe priority queue for tasks +type PriorityQueue struct { + heap *taskHeap + mu sync.RWMutex + byID map[string]*Task + agingRate float64 +} + +// NewPriorityQueue creates a new priority queue +func NewPriorityQueue(agingRate float64) *PriorityQueue { + if agingRate == 0 { + agingRate = 0.1 // default: 0.1 per minute + } + return &PriorityQueue{ + heap: &taskHeap{ + items: make([]*Task, 0), + agingRate: agingRate, + }, + byID: make(map[string]*Task), + agingRate: agingRate, + } +} + +// Len returns the number of items in the queue +func (pq *PriorityQueue) Len() int { + pq.mu.RLock() + defer pq.mu.RUnlock() + return len(pq.heap.items) +} + +// Add adds a task to the queue +func (pq *PriorityQueue) Add(task *Task) { + pq.mu.Lock() + defer pq.mu.Unlock() + + if _, exists := pq.byID[task.ID]; exists { + return // already in queue + } + + pq.byID[task.ID] = task + heap.Push(pq.heap, task) +} + +// Take removes and returns the highest priority task +func (pq *PriorityQueue) Take() *Task { + pq.mu.Lock() + defer pq.mu.Unlock() + + if len(pq.heap.items) == 0 { + return nil + } + + task := heap.Pop(pq.heap).(*Task) + delete(pq.byID, task.ID) + return task +} + +// Peek returns the highest priority task without removing it +func (pq *PriorityQueue) Peek() *Task { + pq.mu.RLock() + defer pq.mu.RUnlock() + + if len(pq.heap.items) == 0 { + return nil + } + return pq.heap.items[0] +} + +// Items returns a copy of all items in priority order +func (pq *PriorityQueue) Items() []*Task { + pq.mu.RLock() + defer pq.mu.RUnlock() + + result := make([]*Task, len(pq.heap.items)) + copy(result, pq.heap.items) + return result +} + +// Get returns a task by ID +func (pq *PriorityQueue) Get(taskID string) *Task { + pq.mu.RLock() + defer pq.mu.RUnlock() + return pq.byID[taskID] +} + +// Remove removes a task from the queue +func (pq *PriorityQueue) Remove(taskID string) bool { + pq.mu.Lock() + defer pq.mu.Unlock() + + task, exists := pq.byID[taskID] + if !exists { + return false + } + + heap.Remove(pq.heap, task.index) + delete(pq.byID, task.ID) + return true +} + +// Contains checks if a task is in the queue +func (pq *PriorityQueue) Contains(taskID string) bool { + pq.mu.RLock() + defer pq.mu.RUnlock() + _, exists := pq.byID[taskID] + return exists +} diff --git a/internal/scheduler/protocol.go b/internal/scheduler/protocol.go new file mode 100644 index 0000000..2d38bfc --- /dev/null +++ b/internal/scheduler/protocol.go @@ -0,0 +1,137 @@ +package scheduler + +import ( + "encoding/json" + "time" +) + +type Message struct { + Type MessageType `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` + Error string `json:"error,omitempty"` +} + +type MessageType string + +const ( + // Worker → Scheduler + MsgRegister MessageType = "register" + MsgHeartbeat MessageType = "heartbeat" // slots only, every 10s + MsgReadyForWork MessageType = "ready_for_work" + MsgJobAccepted MessageType = "job_accepted" + MsgJobResult MessageType = "job_result" + MsgServiceHealth MessageType = "service_health" + MsgMetricsRequest MessageType = "metrics_request" // WSS metrics request + + // Scheduler → Worker + MsgJobAssign MessageType = "job_assign" + MsgNoWork MessageType = "no_work" // nothing available right now + MsgJobCancel MessageType = "job_cancel" + MsgPrewarmHint MessageType = "prewarm_hint" + MsgAck MessageType = "ack" + MsgMetricsResponse MessageType = "metrics_response" // WSS metrics response +) + +// Heartbeat — liveness and slot status combined, no CPU/mem load +type HeartbeatPayload struct { + WorkerID string `json:"worker_id"` + Slots SlotStatus `json:"slots"` +} + +type ReadyPayload struct { + WorkerID string `json:"worker_id"` + Slots SlotStatus `json:"slots"` + Reason string `json:"reason"` +} + +type JobResultPayload struct { + TaskID string `json:"task_id"` + State string `json:"state"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` +} + +type PrewarmHintPayload struct { + TaskID string `json:"task_id"` + SnapshotID string `json:"snapshot_id"` + SnapshotSHA string `json:"snapshot_sha,omitempty"` +} + +type WorkerRegistration struct { + ID string `json:"id"` + Capabilities WorkerCapabilities `json:"capabilities"` + ActiveTasks []ActiveTaskReport `json:"active_tasks"` +} + +type ActiveTaskReport struct { + TaskID string `json:"task_id"` + State string `json:"state"` + StartedAt time.Time `json:"started_at,omitempty"` +} + +type SlotStatus struct { + BatchTotal int `json:"batch_total"` + BatchInUse int `json:"batch_in_use"` + ServiceTotal int `json:"service_total"` + ServiceInUse int `json:"service_in_use"` +} + +func (s SlotStatus) BatchAvailable() int { return s.BatchTotal - s.BatchInUse } +func (s SlotStatus) ServiceAvailable() int { return s.ServiceTotal - s.ServiceInUse } + +type WorkerCapabilities struct { + GPUInfo GPUDetectionInfo `json:"gpu_info"` + GPUCount int `json:"gpu_count"` + GPUType string `json:"gpu_type"` + CPUCount int `json:"cpu_count"` + MemoryGB float64 `json:"memory_gb"` + Hostname string `json:"hostname"` +} + +type GPUDetectionInfo struct { + GPUType string `json:"gpu_type"` + Count int `json:"count"` + Devices []string `json:"devices,omitempty"` + Driver string `json:"driver,omitempty"` + MemTotal uint64 `json:"mem_total,omitempty"` +} + +type JobSpec struct { + ID string `json:"id"` + Type JobType `json:"type"` // "batch" | "service" + SlotPool string `json:"slot_pool"` + + GPUCount int `json:"gpu_count"` + GPUType string `json:"gpu_type,omitempty"` + NodeCount int `json:"node_count"` + + Command []string `json:"command"` + Env map[string]string `json:"env"` + + Prolog []string `json:"prolog,omitempty"` + Epilog []string `json:"epilog,omitempty"` + + SnapshotID string `json:"snapshot_id,omitempty"` + SnapshotSHA string `json:"snapshot_sha,omitempty"` + HealthCheck *HealthCheck `json:"health_check,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type JobType string + +const ( + JobTypeBatch JobType = "batch" + JobTypeService JobType = "service" +) + +type HealthCheck struct { + LivenessEndpoint string `json:"liveness"` + ReadinessEndpoint string `json:"readiness"` + IntervalSecs int `json:"interval_secs"` +} + +type ServiceHealthPayload struct { + TaskID string `json:"task_id"` + Healthy bool `json:"healthy"` + Message string `json:"message,omitempty"` +} diff --git a/internal/scheduler/scheduler_conn.go b/internal/scheduler/scheduler_conn.go new file mode 100644 index 0000000..19b4d2f --- /dev/null +++ b/internal/scheduler/scheduler_conn.go @@ -0,0 +1,217 @@ +package scheduler + +import ( + "encoding/json" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// SchedulerConn manages the WebSocket connection to the scheduler +type SchedulerConn struct { + addr string + certFile string + token string + conn *websocket.Conn + workerID string + capabilities WorkerCapabilities + send chan Message + activeTasks sync.Map + mu sync.RWMutex + closed bool +} + +// NewSchedulerConn creates a new scheduler connection +func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn { + return &SchedulerConn{ + addr: addr, + certFile: certFile, + token: token, + workerID: workerID, + capabilities: caps, + send: make(chan Message, 100), + } +} + +// Connect establishes the WebSocket connection +func (sc *SchedulerConn) Connect() error { + var conn *websocket.Conn + var err error + + if sc.certFile == "" { + // Local mode - no TLS, parse port from address + port := parsePortFromAddr(sc.addr) + conn, err = LocalModeDial(port, sc.token) + } else { + conn, err = DialWSS(sc.addr, sc.certFile, sc.token) + } + + if err != nil { + return err + } + + sc.mu.Lock() + sc.conn = conn + sc.closed = false + sc.mu.Unlock() + + // Send registration + sc.Send(Message{ + Type: MsgRegister, + Payload: mustMarshal(WorkerRegistration{ + ID: sc.workerID, + Capabilities: sc.capabilities, + ActiveTasks: sc.collectActiveTasks(), + }), + }) + + return nil +} + +// Send sends a message to the scheduler +func (sc *SchedulerConn) Send(msg Message) { + select { + case sc.send <- msg: + default: + // Channel full, drop message + } +} + +// Run starts the send/receive loops +func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) { + // Send loop + go func() { + for msg := range sc.send { + sc.mu.RLock() + conn := sc.conn + closed := sc.closed + sc.mu.RUnlock() + + if closed || conn == nil { + continue + } + + if err := conn.WriteJSON(msg); err != nil { + // Trigger reconnect + go sc.reconnect() + } + } + }() + + // Receive loop + for { + sc.mu.RLock() + conn := sc.conn + sc.mu.RUnlock() + + if conn == nil { + time.Sleep(100 * time.Millisecond) + continue + } + + var msg Message + if err := conn.ReadJSON(&msg); err != nil { + // Connection lost, reconnect + sc.reconnect() + continue + } + + switch msg.Type { + case MsgJobAssign: + var spec JobSpec + json.Unmarshal(msg.Payload, &spec) + onJobAssign(&spec) + case MsgJobCancel: + var taskID string + json.Unmarshal(msg.Payload, &taskID) + onJobCancel(taskID) + case MsgPrewarmHint: + var hint PrewarmHintPayload + json.Unmarshal(msg.Payload, &hint) + onPrewarmHint(hint) + case MsgNoWork: + // No action needed - worker will retry + } + } +} + +// reconnect attempts to reconnect with exponential backoff +func (sc *SchedulerConn) reconnect() { + sc.mu.Lock() + if sc.closed { + sc.mu.Unlock() + return + } + sc.conn = nil + sc.mu.Unlock() + + backoff := 1 * time.Second + maxBackoff := 30 * time.Second + + for { + time.Sleep(backoff) + + if err := sc.Connect(); err == nil { + return // Reconnected successfully + } + + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } +} + +// Close closes the connection +func (sc *SchedulerConn) Close() { + sc.mu.Lock() + defer sc.mu.Unlock() + + sc.closed = true + if sc.conn != nil { + sc.conn.Close() + } + close(sc.send) +} + +// TrackTask tracks an active task +func (sc *SchedulerConn) TrackTask(taskID string) { + sc.activeTasks.Store(taskID, time.Now()) +} + +// UntrackTask removes a task from tracking +func (sc *SchedulerConn) UntrackTask(taskID string) { + sc.activeTasks.Delete(taskID) +} + +func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport { + var reports []ActiveTaskReport + sc.activeTasks.Range(func(key, value any) bool { + taskID := key.(string) + startedAt := value.(time.Time) + reports = append(reports, ActiveTaskReport{ + TaskID: taskID, + State: "running", + StartedAt: startedAt, + }) + return true + }) + return reports +} + +// parsePortFromAddr extracts port from "host:port" address string +// Returns default port 7777 if parsing fails +func parsePortFromAddr(addr string) int { + parts := strings.Split(addr, ":") + if len(parts) != 2 { + return 7777 + } + port, err := strconv.Atoi(parts[1]) + if err != nil { + return 7777 + } + return port +} diff --git a/internal/scheduler/service_manager.go b/internal/scheduler/service_manager.go new file mode 100644 index 0000000..9d10546 --- /dev/null +++ b/internal/scheduler/service_manager.go @@ -0,0 +1,367 @@ +package scheduler + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os/exec" + "syscall" + "time" +) + +// ServiceManager handles the lifecycle of service-type jobs +// Services transition: preparing → serving → stopping → completed/failed +// Unlike batch jobs, services run indefinitely until explicitly cancelled +// and have health checks for liveness and readiness + +type ServiceManager struct { + task *Task + spec *JobSpec + port int + cmd *exec.Cmd + cancel context.CancelFunc + healthy bool + ready bool + lastHealth time.Time + stateMachine *StateMachine +} + +// StateMachine manages service state transitions +// It ensures valid transitions and notifies the scheduler of changes +type StateMachine struct { + current string + onChange func(oldState, newState string) +} + +// NewServiceManager creates a new service manager for a task +func NewServiceManager(task *Task, spec *JobSpec, port int) *ServiceManager { + return &ServiceManager{ + task: task, + spec: spec, + port: port, + healthy: false, + ready: false, + stateMachine: &StateMachine{ + current: "preparing", + onChange: nil, + }, + } +} + +// SetStateChangeCallback sets a callback for state transitions +func (sm *ServiceManager) SetStateChangeCallback(cb func(oldState, newState string)) { + if sm.stateMachine != nil { + sm.stateMachine.onChange = cb + } +} + +// Run starts and manages the service lifecycle +// It runs prolog, starts the service, waits for readiness, then health checks +func (sm *ServiceManager) Run(ctx context.Context) error { + // Create cancellable context for the service + svcCtx, cancel := context.WithCancel(ctx) + sm.cancel = cancel + defer cancel() + + // Run prolog if configured + if len(sm.spec.Prolog) > 0 { + sm.transition("preparing") + if err := sm.runProlog(svcCtx); err != nil { + sm.transition("failed") + return fmt.Errorf("prolog failed: %w", err) + } + } + + // Start the service process + if err := sm.startService(svcCtx); err != nil { + sm.transition("failed") + return fmt.Errorf("start service failed: %w", err) + } + + // Wait for readiness (if health check configured) + if sm.spec.HealthCheck != nil && sm.spec.HealthCheck.ReadinessEndpoint != "" { + sm.transition("preparing") + if err := sm.waitReady(svcCtx, 120*time.Second); err != nil { + sm.stopService() + sm.transition("failed") + return fmt.Errorf("readiness check failed: %w", err) + } + } + + // Mark as serving + sm.transition("serving") + sm.ready = true + + // Run health check loop + return sm.healthLoop(svcCtx) +} + +// Stop gracefully stops the service +// It runs epilog with a fresh context (ignores job cancellation) +func (sm *ServiceManager) Stop() error { + // Cancel service context + if sm.cancel != nil { + sm.cancel() + } + + // Run epilog with fresh context - must complete even if job cancelled + epilogCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + sm.transition("stopping") + + if len(sm.spec.Epilog) > 0 { + if err := sm.runEpilog(epilogCtx); err != nil { + slog.Warn("epilog failed", "task", sm.task.ID, "error", err) + } + } + + // Ensure service is stopped + sm.stopService() + + sm.transition("completed") + return nil +} + +// runProlog executes prolog commands before starting the service +func (sm *ServiceManager) runProlog(ctx context.Context) error { + for _, cmdStr := range sm.spec.Prolog { + cmd := sm.buildCommand(ctx, cmdStr) + if err := cmd.Run(); err != nil { + return fmt.Errorf("prolog command failed: %s, error: %w", cmdStr, err) + } + } + return nil +} + +// startService starts the main service process +func (sm *ServiceManager) startService(ctx context.Context) error { + if len(sm.spec.Command) == 0 { + return fmt.Errorf("no command specified for service") + } + + cmd := sm.buildCommand(ctx, sm.spec.Command[0], sm.spec.Command[1:]...) + + // Set up process group for clean termination (Unix-specific) + setProcessGroup(cmd) + + if err := cmd.Start(); err != nil { + return err + } + + sm.cmd = cmd + return nil +} + +// stopService stops the service process +func (sm *ServiceManager) stopService() { + if sm.cmd == nil || sm.cmd.Process == nil { + return + } + + // Try graceful termination first + sm.cmd.Process.Signal(syscall.SIGTERM) + + // Wait for graceful shutdown or timeout + done := make(chan error, 1) + go func() { + done <- sm.cmd.Wait() + }() + + select { + case <-done: + // Graceful shutdown succeeded + case <-time.After(10 * time.Second): + // Force kill process group (Unix-specific) + killProcessGroup(sm.cmd) + } +} + +// runEpilog executes epilog commands after service stops +func (sm *ServiceManager) runEpilog(ctx context.Context) error { + for _, cmdStr := range sm.spec.Epilog { + cmd := sm.buildCommand(ctx, cmdStr) + if err := cmd.Run(); err != nil { + slog.Warn("epilog command failed", "task", sm.task.ID, "cmd", cmdStr, "error", err) + // Continue with other epilog commands even if one fails + } + } + return nil +} + +// healthLoop runs health checks periodically +// Returns when context is cancelled or health check fails +func (sm *ServiceManager) healthLoop(ctx context.Context) error { + if sm.spec.HealthCheck == nil { + // No health check configured - just wait for context cancellation + <-ctx.Done() + return nil + } + + interval := time.Duration(sm.spec.HealthCheck.IntervalSecs) * time.Second + if interval < 5*time.Second { + interval = 15 * time.Second // Minimum 15s between checks + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + // Check liveness + if !sm.checkLiveness() { + sm.transition("failed") + return fmt.Errorf("liveness check failed") + } + sm.healthy = true + + // Check readiness (if configured) + if sm.spec.HealthCheck.ReadinessEndpoint != "" { + sm.ready = sm.checkReadiness() + } + sm.lastHealth = time.Now() + } + } +} + +// waitReady waits for the service to become ready +func (sm *ServiceManager) waitReady(ctx context.Context, timeout time.Duration) error { + if sm.spec.HealthCheck == nil || sm.spec.HealthCheck.ReadinessEndpoint == "" { + return nil // No readiness check configured + } + + deadline := time.Now().Add(timeout) + checkInterval := 2 * time.Second + + for time.Now().Before(deadline) { + if sm.checkReadiness() { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(checkInterval): + // Continue checking + } + } + + return fmt.Errorf("readiness check timed out after %v", timeout) +} + +// checkLiveness checks if the service process is running +func (sm *ServiceManager) checkLiveness() bool { + if sm.cmd == nil || sm.cmd.Process == nil { + return false + } + + // Check if process is still running + if !isProcessRunning(sm.cmd) { + return false + } + + // If liveness endpoint configured, also check HTTP + if sm.spec.HealthCheck != nil && sm.spec.HealthCheck.LivenessEndpoint != "" { + return sm.checkHTTPEndpoint(sm.spec.HealthCheck.LivenessEndpoint, 2*time.Second) + } + + return true +} + +// checkReadiness checks if the service is ready to receive traffic +func (sm *ServiceManager) checkReadiness() bool { + if sm.spec.HealthCheck == nil || sm.spec.HealthCheck.ReadinessEndpoint == "" { + return sm.healthy // Fall back to liveness + } + + return sm.checkHTTPEndpoint(sm.spec.HealthCheck.ReadinessEndpoint, 5*time.Second) +} + +// checkHTTPEndpoint makes an HTTP GET request to check endpoint health +func (sm *ServiceManager) checkHTTPEndpoint(endpoint string, timeout time.Duration) bool { + client := &http.Client{ + Timeout: timeout, + } + + resp, err := client.Get(endpoint) + if err != nil { + return false + } + defer resp.Body.Close() + + // Drain body to allow connection reuse + io.Copy(io.Discard, resp.Body) + + // 2xx status codes indicate success + return resp.StatusCode >= 200 && resp.StatusCode < 300 +} + +// transition changes the service state +func (sm *ServiceManager) transition(newState string) { + if sm.stateMachine == nil { + return + } + + oldState := sm.stateMachine.current + if oldState == newState { + return + } + + sm.stateMachine.current = newState + + // Update task status + sm.task.Status = newState + + // Notify callback + if sm.stateMachine.onChange != nil { + sm.stateMachine.onChange(oldState, newState) + } + + slog.Info("service state transition", + "task", sm.task.ID, + "from", oldState, + "to", newState) +} + +// buildCommand creates an exec.Cmd with environment variables +func (sm *ServiceManager) buildCommand(ctx context.Context, name string, args ...string) *exec.Cmd { + cmd := exec.CommandContext(ctx, name, args...) + + // Set environment variables + env := make([]string, 0, len(sm.spec.Env)+4) + for k, v := range sm.spec.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + + // Add service-specific variables + env = append(env, + fmt.Sprintf("SERVICE_PORT=%d", sm.port), + fmt.Sprintf("TASK_ID=%s", sm.task.ID), + ) + + cmd.Env = env + return cmd +} + +// IsHealthy returns true if the service is healthy (process running) +func (sm *ServiceManager) IsHealthy() bool { + return sm.healthy +} + +// IsReady returns true if the service is ready to receive traffic +func (sm *ServiceManager) IsReady() bool { + return sm.ready +} + +// GetState returns the current service state +func (sm *ServiceManager) GetState() string { + if sm.stateMachine == nil { + return "unknown" + } + return sm.stateMachine.current +} diff --git a/internal/scheduler/service_manager_unix.go b/internal/scheduler/service_manager_unix.go new file mode 100644 index 0000000..f5fe351 --- /dev/null +++ b/internal/scheduler/service_manager_unix.go @@ -0,0 +1,34 @@ +//go:build !windows +// +build !windows + +package scheduler + +import ( + "os/exec" + "syscall" +) + +// setProcessGroup sets up process group for clean termination on Unix systems +func setProcessGroup(cmd *exec.Cmd) { + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{} + } + cmd.SysProcAttr.Setpgid = true +} + +// killProcessGroup kills the entire process group on Unix systems +func killProcessGroup(cmd *exec.Cmd) { + if cmd != nil && cmd.Process != nil { + // Negative PID kills the entire process group + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } +} + +// isProcessRunning checks if a process is still running on Unix systems +func isProcessRunning(cmd *exec.Cmd) bool { + if cmd == nil || cmd.Process == nil { + return false + } + // Signal 0 is a no-op that just checks if process exists + return cmd.Process.Signal(syscall.Signal(0)) == nil +} diff --git a/internal/scheduler/service_manager_windows.go b/internal/scheduler/service_manager_windows.go new file mode 100644 index 0000000..1595de6 --- /dev/null +++ b/internal/scheduler/service_manager_windows.go @@ -0,0 +1,34 @@ +//go:build windows +// +build windows + +package scheduler + +import ( + "os/exec" +) + +// setProcessGroup is a no-op on Windows (process groups work differently) +func setProcessGroup(cmd *exec.Cmd) { + // Windows doesn't use Setpgid like Unix + // Process cleanup is handled differently via job objects or direct process kill +} + +// killProcessGroup kills the process on Windows +func killProcessGroup(cmd *exec.Cmd) { + if cmd != nil && cmd.Process != nil { + // On Windows, we can only kill the process directly + _ = cmd.Process.Kill() + } +} + +// isProcessRunning checks if a process is still running on Windows +func isProcessRunning(cmd *exec.Cmd) bool { + if cmd == nil || cmd.Process == nil { + return false + } + // On Windows, try to get process exit code - if it fails, process is still running + // A simpler approach: try to open the process handle + // For now, we just check if Process object exists + // A more robust implementation would use Windows API + return true +} diff --git a/internal/scheduler/service_templates.go b/internal/scheduler/service_templates.go new file mode 100644 index 0000000..40b1e83 --- /dev/null +++ b/internal/scheduler/service_templates.go @@ -0,0 +1,145 @@ +// Package scheduler provides service plugin templates for fetch_ml. +// These templates define how long-running services like Jupyter are configured. +package scheduler + +// ServiceTemplate defines a service job that runs indefinitely until stopped. +// This is used for Jupyter, vLLM, and similar interactive services. +type ServiceTemplate struct { + // JobType identifies this as a service job + JobType string `json:"job_type"` // Always "service" + + // SlotPool specifies which slot pool to use ("batch" or "service") + SlotPool string `json:"slot_pool"` + + // GPUCount is the number of GPUs required (can be 0 for CPU-only services) + GPUCount int `json:"gpu_count"` + + // Command is the service command with template variables + Command []string `json:"command"` + + // Env defines environment variables with template variables + Env map[string]string `json:"env"` + + // HealthCheck defines how to verify the service is healthy + HealthCheck ServiceHealthCheck `json:"health_check"` + + // Mounts defines volume mounts for the service + Mounts []ServiceMount `json:"mounts,omitempty"` + + // Ports to expose (if not using dynamic allocation) + Ports []int `json:"ports,omitempty"` +} + +// ServiceHealthCheck defines liveness and readiness probes +type ServiceHealthCheck struct { + // Liveness endpoint - checks if service is running + Liveness string `json:"liveness"` + + // Readiness endpoint - checks if service is ready for traffic + Readiness string `json:"readiness"` + + // Interval between health checks in seconds + Interval int `json:"interval"` + + // Timeout for each health check in seconds + Timeout int `json:"timeout"` +} + +// ServiceMount defines a volume mount +type ServiceMount struct { + Source string `json:"source"` + Destination string `json:"destination"` + ReadOnly bool `json:"readonly,omitempty"` +} + +// Template variables available in ServiceTemplate: +// {{SERVICE_PORT}} - Dynamically allocated port for the service +// {{WORKER_ID}} - ID of the worker running the service +// {{TASK_ID}} - Unique task ID for this service instance +// {{SECRET:xxx}} - Secret value from scheduler's secret store + +// JupyterLabTemplate is the default JupyterLab service configuration. +// Sysadmins can disable Jupyter by setting service_slots: 0 in worker config, +// or by not registering this template with the scheduler. +var JupyterLabTemplate = ServiceTemplate{ + JobType: "service", + SlotPool: "service", // Uses service slot pool, not batch + GPUCount: 0, // Jupyter typically runs CPU-only + + Command: []string{ + "jupyter", "lab", + "--ip=0.0.0.0", + "--port={{SERVICE_PORT}}", + "--no-browser", + "--allow-root", + "--NotebookApp.token='{{SECRET:jupyter_token}}'", + "--NotebookApp.password=''", + }, + + Env: map[string]string{ + "JUPYTER_TOKEN": "{{SECRET:jupyter_token}}", + "JUPYTER_CONFIG_DIR": "/workspace/.jupyter", + }, + + HealthCheck: ServiceHealthCheck{ + Liveness: "http://localhost:{{SERVICE_PORT}}/api", + Readiness: "http://localhost:{{SERVICE_PORT}}/api/kernels", + Interval: 15, + Timeout: 5, + }, + + Mounts: []ServiceMount{ + {Source: "{{WORKSPACE}}", Destination: "/workspace"}, + }, +} + +// JupyterNotebookTemplate is an alternative using classic Jupyter Notebook. +var JupyterNotebookTemplate = ServiceTemplate{ + JobType: "service", + SlotPool: "service", + GPUCount: 0, + + Command: []string{ + "jupyter", "notebook", + "--ip=0.0.0.0", + "--port={{SERVICE_PORT}}", + "--no-browser", + "--allow-root", + "--NotebookApp.token='{{SECRET:jupyter_token}}'", + }, + + Env: map[string]string{ + "JUPYTER_TOKEN": "{{SECRET:jupyter_token}}", + }, + + HealthCheck: ServiceHealthCheck{ + Liveness: "http://localhost:{{SERVICE_PORT}}/api", + Readiness: "http://localhost:{{SERVICE_PORT}}/api/kernels", + Interval: 15, + Timeout: 5, + }, + + Mounts: []ServiceMount{ + {Source: "{{WORKSPACE}}", Destination: "/workspace"}, + }, +} + +// VLLMTemplate is an example vLLM inference server template (future) +var VLLMTemplate = ServiceTemplate{ + JobType: "service", + SlotPool: "service", + GPUCount: 1, // Requires GPU for inference + + Command: []string{ + "python", "-m", "vllm.entrypoints.openai.api_server", + "--model", "{{MODEL_NAME}}", + "--port", "{{SERVICE_PORT}}", + }, + + HealthCheck: ServiceHealthCheck{ + Liveness: "http://localhost:{{SERVICE_PORT}}/health", + Readiness: "http://localhost:{{SERVICE_PORT}}/health", + Interval: 30, + Timeout: 10, + }, +} diff --git a/internal/scheduler/state.go b/internal/scheduler/state.go new file mode 100644 index 0000000..67a742f --- /dev/null +++ b/internal/scheduler/state.go @@ -0,0 +1,156 @@ +package scheduler + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" +) + +// StateEvent represents a state change event for persistence +type StateEvent struct { + Type StateEventType `json:"type"` + Timestamp time.Time `json:"ts"` + TaskID string `json:"task_id"` + WorkerID string `json:"worker_id,omitempty"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +type StateEventType string + +const ( + EventJobEnqueued StateEventType = "job_enqueued" + EventJobAssigned StateEventType = "job_assigned" + EventJobAccepted StateEventType = "job_accepted" + EventJobCompleted StateEventType = "job_completed" + EventJobFailed StateEventType = "job_failed" + EventJobRequeued StateEventType = "job_requeued" + EventJobCancelled StateEventType = "job_cancelled" +) + +// StateStore provides append-only persistence for scheduler state +type StateStore struct { + path string + mu sync.Mutex + file *os.File +} + +// NewStateStore creates a new state store at the given path +func NewStateStore(path string) (*StateStore, error) { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return nil, fmt.Errorf("create state directory: %w", err) + } + + file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("open state file: %w", err) + } + + return &StateStore{ + path: path, + file: file, + }, nil +} + +// Append writes a state event to the log +func (s *StateStore) Append(event StateEvent) error { + s.mu.Lock() + defer s.mu.Unlock() + + if event.Timestamp.IsZero() { + event.Timestamp = time.Now() + } + + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshal event: %w", err) + } + + if _, err := s.file.Write(data); err != nil { + return fmt.Errorf("write event: %w", err) + } + if _, err := s.file.WriteString("\n"); err != nil { + return fmt.Errorf("write newline: %w", err) + } + + return nil +} + +// Replay reads all events from the state log +func (s *StateStore) Replay() ([]StateEvent, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Close and reopen to ensure we read from the beginning + if err := s.file.Close(); err != nil { + return nil, fmt.Errorf("close state file: %w", err) + } + + file, err := os.Open(s.path) + if err != nil { + if os.IsNotExist(err) { + // Recreate the file for appending + s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + return nil, nil + } + return nil, fmt.Errorf("open state file for replay: %w", err) + } + defer file.Close() + + var events []StateEvent + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var event StateEvent + if err := json.Unmarshal(scanner.Bytes(), &event); err != nil { + // Skip malformed lines but log them + continue + } + events = append(events, event) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan state file: %w", err) + } + + // Reopen for appending + s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("reopen state file: %w", err) + } + + return events, nil +} + +// Close closes the state store +func (s *StateStore) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.file.Close() +} + +// Rotate rotates the state file (for backup/truncation) +func (s *StateStore) Rotate() (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + backupPath := s.path + "." + time.Now().Format("20060102_150405") + ".bak" + + if err := s.file.Close(); err != nil { + return "", fmt.Errorf("close state file: %w", err) + } + + if err := os.Rename(s.path, backupPath); err != nil { + return "", fmt.Errorf("rotate state file: %w", err) + } + + file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return "", fmt.Errorf("create new state file: %w", err) + } + + s.file = file + return backupPath, nil +} diff --git a/internal/scheduler/template.go b/internal/scheduler/template.go new file mode 100644 index 0000000..67dd06e --- /dev/null +++ b/internal/scheduler/template.go @@ -0,0 +1,245 @@ +package scheduler + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +// TemplateResolver handles variable substitution in job specifications +// Template variables are resolved at dispatch time on the worker +// +// Supported variables: +// {{HEAD_ADDR}} - Hostname of rank-0 worker (for multi-node) +// {{WORLD_SIZE}} - Total node count (for multi-node) +// {{NODE_RANK}} - 0-based rank of this worker (for multi-node) +// {{GPU_COUNT}} - Number of GPUs available on this worker +// {{SERVICE_PORT}} - Port assigned by PortAllocator (for service jobs) +// {{HOSTNAME}} - This worker's hostname +// {{TASK_ID}} - The task/job ID +// {{SECRET:name}} - Secret from worker's secret store + +// TemplateContext provides the values for template substitution +type TemplateContext struct { + HeadAddr string // Rank-0 worker hostname (multi-node) + WorldSize int // Total nodes (multi-node) + NodeRank int // This worker's rank (multi-node) + GPUCount int // GPUs available + ServicePort int // Assigned port (service jobs) + Hostname string // This worker's hostname + TaskID string // Task/job ID + Secrets map[string]string // Secret store +} + +var ( + // templatePattern matches {{VAR}} or {{SECRET:name}} + templatePattern = regexp.MustCompile(`\{\{(\w+)(?::([^}]+))?\}\}`) +) + +// Resolve substitutes template variables in a string +// Returns the resolved string and any error encountered +func (tc *TemplateContext) Resolve(input string) (string, error) { + if !strings.Contains(input, "{{") { + return input, nil // No templates to resolve + } + + result := templatePattern.ReplaceAllStringFunc(input, func(match string) string { + // Extract variable name and optional secret name + parts := templatePattern.FindStringSubmatch(match) + if len(parts) < 2 { + return match // Keep original if malformed + } + + varName := parts[1] + secretName := "" + if len(parts) >= 3 { + secretName = parts[2] + } + + switch varName { + case "HEAD_ADDR": + if tc.HeadAddr == "" { + return match // Keep unresolved if not set + } + return tc.HeadAddr + case "WORLD_SIZE": + if tc.WorldSize == 0 { + return match + } + return fmt.Sprintf("%d", tc.WorldSize) + case "NODE_RANK": + return fmt.Sprintf("%d", tc.NodeRank) + case "GPU_COUNT": + return fmt.Sprintf("%d", tc.GPUCount) + case "SERVICE_PORT": + if tc.ServicePort == 0 { + return match + } + return fmt.Sprintf("%d", tc.ServicePort) + case "HOSTNAME": + if tc.Hostname == "" { + tc.Hostname, _ = os.Hostname() + } + return tc.Hostname + case "TASK_ID": + return tc.TaskID + case "SECRET": + if val, ok := tc.Secrets[secretName]; ok { + return val + } + return match // Keep unresolved if secret not found + default: + return match // Unknown variable - keep as-is + } + }) + + return result, nil +} + +// ResolveCommand resolves templates in a command slice +func (tc *TemplateContext) ResolveCommand(cmd []string) ([]string, error) { + result := make([]string, len(cmd)) + for i, arg := range cmd { + resolved, err := tc.Resolve(arg) + if err != nil { + return nil, fmt.Errorf("resolve arg %d: %w", i, err) + } + result[i] = resolved + } + return result, nil +} + +// ResolveEnv resolves templates in environment variables +func (tc *TemplateContext) ResolveEnv(env map[string]string) (map[string]string, error) { + result := make(map[string]string, len(env)) + for k, v := range env { + resolved, err := tc.Resolve(v) + if err != nil { + return nil, fmt.Errorf("resolve env %s: %w", k, err) + } + result[k] = resolved + } + return result, nil +} + +// ResolveJobSpec resolves all templates in a JobSpec +// Returns a new JobSpec with all template variables substituted +func (tc *TemplateContext) ResolveJobSpec(spec *JobSpec) (*JobSpec, error) { + // Deep copy the spec + resolved := &JobSpec{ + ID: spec.ID, + Type: spec.Type, + SlotPool: spec.SlotPool, + GPUCount: spec.GPUCount, + GPUType: spec.GPUType, + NodeCount: spec.NodeCount, + SnapshotID: spec.SnapshotID, + SnapshotSHA: spec.SnapshotSHA, + Metadata: make(map[string]string, len(spec.Metadata)), + } + + // Copy metadata + for k, v := range spec.Metadata { + resolved.Metadata[k] = v + } + + // Resolve command + if len(spec.Command) > 0 { + cmd, err := tc.ResolveCommand(spec.Command) + if err != nil { + return nil, fmt.Errorf("resolve command: %w", err) + } + resolved.Command = cmd + } + + // Resolve environment + if len(spec.Env) > 0 { + env, err := tc.ResolveEnv(spec.Env) + if err != nil { + return nil, fmt.Errorf("resolve env: %w", err) + } + resolved.Env = env + } + + // Resolve prolog + if len(spec.Prolog) > 0 { + prolog, err := tc.ResolveCommand(spec.Prolog) + if err != nil { + return nil, fmt.Errorf("resolve prolog: %w", err) + } + resolved.Prolog = prolog + } + + // Resolve epilog + if len(spec.Epilog) > 0 { + epilog, err := tc.ResolveCommand(spec.Epilog) + if err != nil { + return nil, fmt.Errorf("resolve epilog: %w", err) + } + resolved.Epilog = epilog + } + + // Resolve health check endpoints + if spec.HealthCheck != nil { + hc := &HealthCheck{ + LivenessEndpoint: spec.HealthCheck.LivenessEndpoint, + ReadinessEndpoint: spec.HealthCheck.ReadinessEndpoint, + IntervalSecs: spec.HealthCheck.IntervalSecs, + } + + if hc.LivenessEndpoint != "" { + endpoint, err := tc.Resolve(hc.LivenessEndpoint) + if err != nil { + return nil, fmt.Errorf("resolve liveness endpoint: %w", err) + } + hc.LivenessEndpoint = endpoint + } + + if hc.ReadinessEndpoint != "" { + endpoint, err := tc.Resolve(hc.ReadinessEndpoint) + if err != nil { + return nil, fmt.Errorf("resolve readiness endpoint: %w", err) + } + hc.ReadinessEndpoint = endpoint + } + + resolved.HealthCheck = hc + } + + return resolved, nil +} + +// NewMultiNodeContext creates a template context for a multi-node job +func NewMultiNodeContext(taskID, headAddr string, worldSize, nodeRank, gpuCount int) *TemplateContext { + hostname, _ := os.Hostname() + return &TemplateContext{ + TaskID: taskID, + HeadAddr: headAddr, + WorldSize: worldSize, + NodeRank: nodeRank, + GPUCount: gpuCount, + Hostname: hostname, + Secrets: make(map[string]string), + } +} + +// NewServiceContext creates a template context for a service job +func NewServiceContext(taskID string, servicePort, gpuCount int) *TemplateContext { + hostname, _ := os.Hostname() + return &TemplateContext{ + TaskID: taskID, + ServicePort: servicePort, + GPUCount: gpuCount, + Hostname: hostname, + Secrets: make(map[string]string), + } +} + +// SetSecret adds a secret to the context +func (tc *TemplateContext) SetSecret(name, value string) { + if tc.Secrets == nil { + tc.Secrets = make(map[string]string) + } + tc.Secrets[name] = value +} diff --git a/tests/benchmarks/scheduler_bench_test.go b/tests/benchmarks/scheduler_bench_test.go new file mode 100644 index 0000000..fe96d88 --- /dev/null +++ b/tests/benchmarks/scheduler_bench_test.go @@ -0,0 +1,190 @@ +package benchmarks_test + +import ( + "fmt" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" +) + +// BenchmarkPriorityQueueAdd measures job enqueue performance +func BenchmarkPriorityQueueAdd(b *testing.B) { + pq := scheduler.NewPriorityQueue(0.1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + task := &scheduler.Task{ + ID: fmt.Sprintf("task-%d", i), + Priority: i % 100, + } + pq.Add(task) + } +} + +// BenchmarkPriorityQueueTake measures job dequeue performance +func BenchmarkPriorityQueueTake(b *testing.B) { + pq := scheduler.NewPriorityQueue(0.1) + + // Pre-populate queue + for i := 0; i < b.N; i++ { + task := &scheduler.Task{ + ID: fmt.Sprintf("task-%d", i), + Priority: i % 100, + } + pq.Add(task) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Take() + } +} + +// BenchmarkPortAllocator measures port allocation performance +func BenchmarkPortAllocator(b *testing.B) { + pa := scheduler.NewPortAllocator(10000, 20000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + port, _ := pa.Allocate(fmt.Sprintf("service-%d", i)) + _ = port + } +} + +// BenchmarkStateStoreAppend measures state persistence performance +func BenchmarkStateStoreAppend(b *testing.B) { + dir := b.TempDir() + store, _ := scheduler.NewStateStore(dir + "/bench.state") + + event := scheduler.StateEvent{ + Type: scheduler.EventJobEnqueued, + TaskID: "bench-task", + Timestamp: time.Now(), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + event.TaskID = fmt.Sprintf("bench-task-%d", i) + store.Append(event) + } +} + +// BenchmarkSchedulerSubmitJob measures job submission throughput +func BenchmarkSchedulerSubmitJob(b *testing.B) { + // Create scheduler directly for benchmark + cfg := scheduler.HubConfig{ + BindAddr: "localhost:0", + DefaultBatchSlots: 4, + StarvationThresholdMins: 5, + AcceptanceTimeoutSecs: 5, + } + hub, err := scheduler.NewHub(cfg, nil) + if err != nil { + b.Fatal(err) + } + defer hub.Stop() + + if err := hub.Start(); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hub.SubmitJob(scheduler.JobSpec{ + ID: fmt.Sprintf("bench-job-%d", i), + Type: scheduler.JobTypeBatch, + }) + } +} + +// BenchmarkWorkerRegistration measures worker registration throughput +func BenchmarkWorkerRegistration(b *testing.B) { + fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + workerID := fmt.Sprintf("bench-worker-%d", i) + worker := fixtures.NewMockWorker(b, fixture.Hub, workerID) + worker.Register(scheduler.WorkerCapabilities{GPUCount: 0}) + worker.Close() + } +} + +// BenchmarkHeartbeatProcessing measures heartbeat handling throughput +func BenchmarkHeartbeatProcessing(b *testing.B) { + fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + worker := fixture.CreateWorker("bench-hb-worker", scheduler.WorkerCapabilities{GPUCount: 0}) + + slots := scheduler.SlotStatus{ + BatchTotal: 4, + BatchInUse: 0, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + worker.SendHeartbeat(slots) + } +} + +// BenchmarkJobAssignment measures job scheduling latency +func BenchmarkJobAssignment(b *testing.B) { + fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create worker + worker := fixture.CreateWorker("bench-assign-worker", scheduler.WorkerCapabilities{GPUCount: 0}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Submit job + jobID := fmt.Sprintf("bench-assign-%d", i) + fixture.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + }) + + // Signal ready to trigger assignment + worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling") + + // Wait for assignment + worker.RecvTimeout(100 * time.Millisecond) + } +} + +// BenchmarkMultiWorkerScheduling measures scheduling with multiple workers +func BenchmarkMultiWorkerScheduling(b *testing.B) { + fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create multiple workers + workers := make([]*fixtures.MockWorker, 10) + for i := 0; i < 10; i++ { + workers[i] = fixture.CreateWorker( + fmt.Sprintf("bench-multi-worker-%d", i), + scheduler.WorkerCapabilities{GPUCount: 0}, + ) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Submit job + jobID := fmt.Sprintf("bench-multi-%d", i) + fixture.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + }) + + // All workers signal ready + for _, w := range workers { + w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling") + } + + // One worker gets the job + workers[i%10].RecvTimeout(100 * time.Millisecond) + } +} diff --git a/tests/fixtures/scheduler_fixture.go b/tests/fixtures/scheduler_fixture.go new file mode 100644 index 0000000..4f6a4b9 --- /dev/null +++ b/tests/fixtures/scheduler_fixture.go @@ -0,0 +1,118 @@ +// Package fixtures provides shared test utilities and fixtures for scheduler tests +package tests + +import ( + "os" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/require" +) + +// testStateDir is used for hub state storage in tests +var testStateDir string + +func init() { + var err error + testStateDir, err = os.MkdirTemp("", "fetchml-test-*") + if err != nil { + panic("failed to create test state dir: " + err.Error()) + } +} + +// SchedulerTestFixture provides a test fixture for scheduler tests +type SchedulerTestFixture struct { + T testing.TB + Hub *scheduler.SchedulerHub + Workers map[string]*MockWorker +} + +// NewSchedulerTestFixture creates a new scheduler test fixture +func NewSchedulerTestFixture(t testing.TB, cfg scheduler.HubConfig) *SchedulerTestFixture { + if cfg.BindAddr == "" { + cfg.BindAddr = "localhost:0" + } + + hub, err := scheduler.NewHub(cfg, nil) + require.NoError(t, err) + + // Start scheduler + err = hub.Start() + require.NoError(t, err) + + return &SchedulerTestFixture{ + T: t, + Hub: hub, + Workers: make(map[string]*MockWorker), + } +} + +// CreateWorker creates and registers a new mock worker +func (f *SchedulerTestFixture) CreateWorker(workerID string, caps scheduler.WorkerCapabilities) *MockWorker { + worker := NewMockWorker(f.T, f.Hub, workerID) + worker.Register(caps) + f.Workers[workerID] = worker + return worker +} + +// SubmitJob submits a job to the scheduler +func (f *SchedulerTestFixture) SubmitJob(spec scheduler.JobSpec) { + err := f.Hub.SubmitJob(spec) + require.NoError(f.T, err) +} + +// GetTask retrieves a task by ID +func (f *SchedulerTestFixture) GetTask(taskID string) *scheduler.Task { + return f.Hub.GetTask(taskID) +} + +// Cleanup stops the scheduler and closes all workers +func (f *SchedulerTestFixture) Cleanup() { + // Close all workers first + for _, worker := range f.Workers { + worker.Close() + } + // Then stop the hub + f.Hub.Stop() +} + +// DefaultHubConfig returns a default hub configuration for testing +func DefaultHubConfig() scheduler.HubConfig { + return scheduler.HubConfig{ + BindAddr: "localhost:0", + DefaultBatchSlots: 4, + StarvationThresholdMins: 5, + AcceptanceTimeoutSecs: 5, + GangAllocTimeoutSecs: 10, + StateDir: testStateDir, + WorkerTokens: map[string]string{ + "test-token-worker-restart-1": "worker-restart-1", + "test-token-mode-switch-worker": "mode-switch-worker", + "test-token-mode-switch-worker-2": "mode-switch-worker-2", + "test-token-e2e-worker-1": "e2e-worker-1", + "test-token-e2e-worker-2": "e2e-worker-2", + "test-token-worker-death-test": "worker-death-test", + "test-token-worker-split-1": "worker-split-1", + "test-token-worker-split-2": "worker-split-2", + "test-token-worker-split-3": "worker-split-3", + "test-token-worker-timeout": "worker-timeout", + "test-token-worker-gang": "worker-gang", + "test-token-bench-worker": "bench-worker", + "test-token-bench-hb-worker": "bench-hb-worker", + "test-token-bench-assign-worker": "bench-assign-worker", + }, + } +} + +// WaitForTimeout is a helper to wait for a condition with timeout +func WaitForTimeout(duration time.Duration, condition func() bool) bool { + deadline := time.Now().Add(duration) + for time.Now().Before(deadline) { + if condition() { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} diff --git a/tests/fixtures/scheduler_mock.go b/tests/fixtures/scheduler_mock.go new file mode 100644 index 0000000..980740a --- /dev/null +++ b/tests/fixtures/scheduler_mock.go @@ -0,0 +1,228 @@ +// Package fixtures provides shared test utilities for all tests +package tests + +import ( + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/require" +) + +// MockWorker simulates a worker connection for testing +type MockWorker struct { + Conn *websocket.Conn + ID string + RecvCh chan scheduler.Message + SendCh chan scheduler.Message + wg sync.WaitGroup + mu sync.RWMutex + closed bool + T testing.TB +} + +// NewMockWorker creates a new mock worker connected to the scheduler +func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *MockWorker { + addr := hub.Addr() + require.NotEmpty(t, addr, "hub not started") + + wsURL := "ws://" + addr + "/ws/worker" + + // Add test token to headers + header := http.Header{} + header.Set("Authorization", "Bearer test-token-"+workerID) + + conn, _, err := websocket.DefaultDialer.Dial(wsURL, header) + require.NoError(t, err) + + mw := &MockWorker{ + Conn: conn, + ID: workerID, + RecvCh: make(chan scheduler.Message, 100), + SendCh: make(chan scheduler.Message, 100), + T: t, + } + + // Start receive goroutine + mw.wg.Add(1) + go func() { + defer mw.wg.Done() + for { + var msg scheduler.Message + err := conn.ReadJSON(&msg) + if err != nil { + close(mw.RecvCh) + return + } + mw.RecvCh <- msg + } + }() + + // Start send goroutine + mw.wg.Add(1) + go func() { + defer mw.wg.Done() + for msg := range mw.SendCh { + if err := conn.WriteJSON(msg); err != nil { + return + } + } + }() + + return mw +} + +// Register sends worker registration message and waits for ack +func (mw *MockWorker) Register(capabilities scheduler.WorkerCapabilities) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: MustMarshal(scheduler.WorkerRegistration{ + ID: mw.ID, + Capabilities: capabilities, + }), + }) + + msg := mw.RecvTimeout(2 * time.Second) + require.Equal(mw.T, scheduler.MsgAck, msg.Type, "expected registration ack") +} + +// Send sends a message to the scheduler +func (mw *MockWorker) Send(msg scheduler.Message) { + select { + case mw.SendCh <- msg: + case <-time.After(time.Second): + mw.T.Fatal("timeout sending message") + } +} + +// Recv receives a message from the scheduler (blocks) +func (mw *MockWorker) Recv() scheduler.Message { + select { + case msg := <-mw.RecvCh: + return msg + case <-time.After(5 * time.Second): + require.Fail(mw.T, "timeout waiting for message") + return scheduler.Message{Type: "timeout"} + } +} + +// RecvTimeout receives a message with a custom timeout +func (mw *MockWorker) RecvTimeout(timeout time.Duration) scheduler.Message { + select { + case msg := <-mw.RecvCh: + return msg + case <-time.After(timeout): + require.Fail(mw.T, "timeout waiting for message") + return scheduler.Message{Type: "timeout"} + } +} + +// RecvNonBlock tries to receive without blocking +func (mw *MockWorker) RecvNonBlock() (scheduler.Message, bool) { + select { + case msg := <-mw.RecvCh: + return msg, true + default: + return scheduler.Message{}, false + } +} + +// SignalReady sends ready for work message +func (mw *MockWorker) SignalReady(slots scheduler.SlotStatus, reason string) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: MustMarshal(scheduler.ReadyPayload{ + WorkerID: mw.ID, + Slots: slots, + Reason: reason, + }), + }) +} + +// SendHeartbeat sends a heartbeat message +func (mw *MockWorker) SendHeartbeat(slots scheduler.SlotStatus) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgHeartbeat, + Payload: MustMarshal(scheduler.HeartbeatPayload{ + WorkerID: mw.ID, + Slots: slots, + }), + }) +} + +// AcceptJob accepts a job assignment +func (mw *MockWorker) AcceptJob(taskID string) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgJobAccepted, + Payload: MustMarshal(scheduler.JobResultPayload{ + TaskID: taskID, + State: "accepted", + }), + }) +} + +// CompleteJob sends job completion +func (mw *MockWorker) CompleteJob(taskID string, exitCode int, output string) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgJobResult, + Payload: MustMarshal(scheduler.JobResultPayload{ + TaskID: taskID, + State: "completed", + ExitCode: exitCode, + Error: output, + }), + }) +} + +// SendHealth sends service health update +func (mw *MockWorker) SendHealth(taskID string, healthy bool, message string) { + mw.Send(scheduler.Message{ + Type: scheduler.MsgServiceHealth, + Payload: MustMarshal(scheduler.ServiceHealthPayload{ + TaskID: taskID, + Healthy: healthy, + Message: message, + }), + }) +} + +// Close closes the worker connection +func (mw *MockWorker) Close() { + mw.mu.Lock() + if mw.closed { + mw.mu.Unlock() + return + } + mw.closed = true + mw.mu.Unlock() + + close(mw.SendCh) + mw.Conn.Close() + mw.wg.Wait() +} + +// WaitForDisconnect waits for the connection to close +func (mw *MockWorker) WaitForDisconnect(timeout time.Duration) bool { + done := make(chan struct{}) + go func() { + mw.wg.Wait() + close(done) + }() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// MustMarshal marshals a value to JSON, panicking on error +func MustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/tests/integration/scheduler/distributed_test.go b/tests/integration/scheduler/distributed_test.go new file mode 100644 index 0000000..7187f1c --- /dev/null +++ b/tests/integration/scheduler/distributed_test.go @@ -0,0 +1,248 @@ +package scheduler_test + +import ( + "encoding/json" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDistributedRoundTrip validates full job lifecycle through scheduler +func TestDistributedRoundTrip(t *testing.T) { + // Create scheduler hub with token auth configured + testToken := "test-token-123" + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 4, + AcceptanceTimeoutSecs: 5, + WorkerTokens: map[string]string{ + testToken: "test-worker", + }, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + // Start scheduler + err = hub.Start() + require.NoError(t, err) + + // Get scheduler address - use the actual listening address + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + // Create mock worker connection with auth token + workerID := "test-worker" + headers := http.Header{} + headers.Set("Authorization", "Bearer "+testToken) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + require.NoError(t, err) + defer conn.Close() + + // Start receive goroutine + recvCh := make(chan scheduler.Message, 10) + go func() { + for { + var msg scheduler.Message + err := conn.ReadJSON(&msg) + if err != nil { + close(recvCh) + return + } + recvCh <- msg + } + }() + + // Register worker + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: workerID, + Capabilities: scheduler.WorkerCapabilities{ + GPUCount: 0, + GPUType: "", + }, + }), + }) + require.NoError(t, err) + + // Wait for ack + select { + case msg := <-recvCh: + require.Equal(t, scheduler.MsgAck, msg.Type) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for registration ack") + } + + // Send heartbeat to show we're alive + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgHeartbeat, + Payload: mustMarshal(scheduler.HeartbeatPayload{ + WorkerID: workerID, + Slots: scheduler.SlotStatus{ + BatchTotal: 4, + BatchInUse: 0, + }, + }), + }) + require.NoError(t, err) + + // Signal ready for work + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: workerID, + Slots: scheduler.SlotStatus{ + BatchTotal: 4, + BatchInUse: 0, + }, + Reason: "polling", + }), + }) + require.NoError(t, err) + + // Wait a bit and verify connection is still alive + time.Sleep(200 * time.Millisecond) +} + +// TestWorkerRegistration validates worker registration flow +func TestWorkerRegistration(t *testing.T) { + testToken := "reg-test-token" + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 4, + WorkerTokens: map[string]string{ + testToken: "reg-test-worker", + }, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + // Start scheduler + err = hub.Start() + require.NoError(t, err) + + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + // Connect worker with auth token + headers := http.Header{} + headers.Set("Authorization", "Bearer "+testToken) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + require.NoError(t, err) + defer conn.Close() + + // Receive channel + recvCh := make(chan scheduler.Message, 10) + go func() { + for { + var msg scheduler.Message + err := conn.ReadJSON(&msg) + if err != nil { + close(recvCh) + return + } + recvCh <- msg + } + }() + + // Register with capabilities + workerID := "reg-test-worker" + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: workerID, + Capabilities: scheduler.WorkerCapabilities{ + GPUCount: 2, + GPUType: "nvidia", + CPUCount: 8, + MemoryGB: 32.0, + }, + }), + }) + require.NoError(t, err) + + // Expect ack + select { + case msg := <-recvCh: + assert.Equal(t, scheduler.MsgAck, msg.Type) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ack") + } +} + +// TestHeartbeat validates heartbeat slot reporting +func TestHeartbeat(t *testing.T) { + testToken := "hb-test-token" + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 4, + WorkerTokens: map[string]string{ + testToken: "hb-test-worker", + }, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + // Start scheduler + err = hub.Start() + require.NoError(t, err) + + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + headers := http.Header{} + headers.Set("Authorization", "Bearer "+testToken) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + require.NoError(t, err) + defer conn.Close() + + workerID := "hb-test-worker" + + // Register first + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: workerID, + Capabilities: scheduler.WorkerCapabilities{ + GPUCount: 0, + }, + }), + }) + require.NoError(t, err) + + // Send multiple heartbeats + slots := []scheduler.SlotStatus{ + {BatchTotal: 4, BatchInUse: 0}, + {BatchTotal: 4, BatchInUse: 1}, + {BatchTotal: 4, BatchInUse: 2}, + } + + for _, slot := range slots { + err = conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgHeartbeat, + Payload: mustMarshal(scheduler.HeartbeatPayload{ + WorkerID: workerID, + Slots: slot, + }), + }) + require.NoError(t, err) + time.Sleep(50 * time.Millisecond) + } + + // Connection should remain healthy + time.Sleep(100 * time.Millisecond) +} + +func mustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/tests/integration/scheduler/gang_service_test.go b/tests/integration/scheduler/gang_service_test.go new file mode 100644 index 0000000..35e011b --- /dev/null +++ b/tests/integration/scheduler/gang_service_test.go @@ -0,0 +1,316 @@ +package scheduler_test + +import ( + "encoding/json" + "net/http" + "net/url" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMultiNodeGangAllocation validates 2-node torchrun scenario +func TestMultiNodeGangAllocation(t *testing.T) { + // Create scheduler hub with gang timeout and auth tokens + tokens := map[string]string{ + "worker1-token": "worker-1", + "worker2-token": "worker-2", + } + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 4, + GangAllocTimeoutSecs: 10, + WorkerTokens: tokens, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + // Start scheduler + err = hub.Start() + require.NoError(t, err) + + // Get scheduler address + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + // Create two worker connections with auth + worker1, recv1 := createTestWorkerWithToken(t, wsURL, "worker-1", "worker1-token") + worker2, recv2 := createTestWorkerWithToken(t, wsURL, "worker-2", "worker2-token") + defer worker1.Close() + defer worker2.Close() + + // Register both workers + workers := []struct { + conn *websocket.Conn + recv <-chan scheduler.Message + id string + }{ + {worker1, recv1, "worker-1"}, + {worker2, recv2, "worker-2"}, + } + for _, w := range workers { + w.conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: w.id, + Capabilities: scheduler.WorkerCapabilities{ + GPUCount: 0, + }, + }), + }) + msg := <-w.recv + require.Equal(t, scheduler.MsgAck, msg.Type) + } + + // Submit multi-node job (2 nodes) + jobID := "gang-job-001" + err = hub.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeBatch, + SlotPool: "batch", + NodeCount: 2, + Command: []string{"torchrun", "--nnodes=2", "train.py"}, + }) + require.NoError(t, err) + + // Both workers signal ready + for _, w := range []struct { + conn *websocket.Conn + id string + }{{worker1, "worker-1"}, {worker2, "worker-2"}} { + w.conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: w.id, + Slots: scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, + }), + }) + } + + // Both workers should receive job assignments + msg1 := <-recv1 + msg2 := <-recv2 + + require.Equal(t, scheduler.MsgJobAssign, msg1.Type) + require.Equal(t, scheduler.MsgJobAssign, msg2.Type) + + // Verify both got the same job + var spec1, spec2 scheduler.JobSpec + json.Unmarshal(msg1.Payload, &spec1) + json.Unmarshal(msg2.Payload, &spec2) + + assert.Equal(t, jobID, spec1.ID) + assert.Equal(t, jobID, spec2.ID) + + // Verify ranks are different + assert.NotEqual(t, spec1.Env["NODE_RANK"], spec2.Env["NODE_RANK"]) + assert.Equal(t, "2", spec1.Env["WORLD_SIZE"]) + assert.Equal(t, "2", spec2.Env["WORLD_SIZE"]) +} + +// TestServiceLifecycle validates service job start, health checks, and stop +func TestServiceLifecycle(t *testing.T) { + testToken := "service-test-token" + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 4, + WorkerTokens: map[string]string{ + testToken: "service-worker", + }, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + err = hub.Start() + require.NoError(t, err) + + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + // Create worker with auth + conn, recvCh := createTestWorkerWithToken(t, wsURL, "service-worker", testToken) + defer conn.Close() + + // Register + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: "service-worker", + }), + }) + msg := <-recvCh + require.Equal(t, scheduler.MsgAck, msg.Type) + + // Submit service job + jobID := "service-001" + err = hub.SubmitJob(scheduler.JobSpec{ + ID: jobID, + Type: scheduler.JobTypeService, + SlotPool: "service", + Command: []string{"python", "-m", "http.server", "8080"}, + }) + require.NoError(t, err) + + // Signal ready + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: "service-worker", + Slots: scheduler.SlotStatus{ServiceTotal: 4, ServiceInUse: 0}, + }), + }) + + // Should receive job assignment + assignMsg := <-recvCh + require.Equal(t, scheduler.MsgJobAssign, assignMsg.Type) + + // Send job accepted + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgJobAccepted, + Payload: mustMarshal(map[string]string{ + "task_id": jobID, + }), + }) + + // Send periodic health updates + for i := 0; i < 3; i++ { + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgServiceHealth, + Payload: mustMarshal(scheduler.ServiceHealthPayload{ + TaskID: jobID, + Healthy: true, + Message: "healthy", + }), + }) + time.Sleep(50 * time.Millisecond) + } + + // Verify task exists and is running + task := hub.GetTask(jobID) + require.NotNil(t, task) + assert.Equal(t, "running", task.Status) +} + +// TestStarvationPrevention validates low-priority jobs eventually get scheduled +func TestStarvationPrevention(t *testing.T) { + testToken := "starvation-test-token" + hub, err := scheduler.NewHub(scheduler.HubConfig{ + BindAddr: "localhost:0", + StateDir: t.TempDir(), + DefaultBatchSlots: 2, + StarvationThresholdMins: 1, // 1 minute for testing + WorkerTokens: map[string]string{ + testToken: "starvation-worker", + }, + }, nil) + require.NoError(t, err) + defer hub.Stop() + + err = hub.Start() + require.NoError(t, err) + + u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} + wsURL := u.String() + + // Create worker with auth + conn, recvCh := createTestWorkerWithToken(t, wsURL, "starvation-worker", testToken) + defer conn.Close() + + // Register + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: "starvation-worker", + }), + }) + msg := <-recvCh + require.Equal(t, scheduler.MsgAck, msg.Type) + + // Submit high-priority job + err = hub.SubmitJob(scheduler.JobSpec{ + ID: "high-priority-job", + Type: scheduler.JobTypeBatch, + Env: map[string]string{"priority": "100"}, + }) + require.NoError(t, err) + + // Submit low-priority job + err = hub.SubmitJob(scheduler.JobSpec{ + ID: "low-priority-job", + Type: scheduler.JobTypeBatch, + Env: map[string]string{"priority": "1"}, + }) + require.NoError(t, err) + + // Signal ready - should get high priority job first + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: "starvation-worker", + Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0}, + }), + }) + + // First assignment should be high priority + msg1 := <-recvCh + require.Equal(t, scheduler.MsgJobAssign, msg1.Type) + + var spec1 scheduler.JobSpec + json.Unmarshal(msg1.Payload, &spec1) + assert.Equal(t, "high-priority-job", spec1.ID) + + // Complete first job + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgJobResult, + Payload: mustMarshal(scheduler.JobResultPayload{ + TaskID: "high-priority-job", + State: "completed", + }), + }) + + // Signal ready again + conn.WriteJSON(scheduler.Message{ + Type: scheduler.MsgReadyForWork, + Payload: mustMarshal(scheduler.ReadyPayload{ + WorkerID: "starvation-worker", + Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0}, + }), + }) + + // Should get low priority job + msg2 := <-recvCh + require.Equal(t, scheduler.MsgJobAssign, msg2.Type) + + var spec2 scheduler.JobSpec + json.Unmarshal(msg2.Payload, &spec2) + assert.Equal(t, "low-priority-job", spec2.ID) +} + +// Helper function to create test worker with token auth +func createTestWorkerWithToken(t *testing.T, wsURL, workerID, token string) (*websocket.Conn, <-chan scheduler.Message) { + headers := http.Header{} + headers.Set("Authorization", "Bearer "+token) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + require.NoError(t, err) + + recvCh := make(chan scheduler.Message, 10) + go func() { + for { + var msg scheduler.Message + err := conn.ReadJSON(&msg) + if err != nil { + close(recvCh) + return + } + recvCh <- msg + } + }() + + return conn, recvCh +} diff --git a/tests/unit/scheduler/failure_scenarios_test.go b/tests/unit/scheduler/failure_scenarios_test.go new file mode 100644 index 0000000..9865a09 --- /dev/null +++ b/tests/unit/scheduler/failure_scenarios_test.go @@ -0,0 +1,188 @@ +package scheduler_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + fixtures "github.com/jfraeys/fetch_ml/tests/fixtures" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustMarshal(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +// TestWorkerDeath simulates a worker dying mid-job +func TestWorkerDeath_MidJob(t *testing.T) { + // Use fixture for hub setup + fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-death-test", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Send heartbeat + worker.SendHeartbeat(scheduler.SlotStatus{ + BatchTotal: 3, + BatchInUse: 1, + }) + + // Simulate worker death + worker.Close() + + // Verify disconnect + require.True(t, worker.WaitForDisconnect(2*time.Second), "worker should disconnect") + t.Log("Worker death simulation completed") +} + +// TestSchedulerRestartRecovery simulates scheduler restart +func TestSchedulerRestart_Recovery(t *testing.T) { + dir := t.TempDir() + + // Create initial state store + ss1, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + + // Record some events + events := []scheduler.StateEvent{ + {Type: scheduler.EventJobEnqueued, TaskID: "task-1", Timestamp: time.Now()}, + {Type: scheduler.EventJobAssigned, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()}, + } + for _, e := range events { + require.NoError(t, ss1.Append(e)) + } + + // Simulate restart by creating new state store + ss2, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + + // Replay should recover state + replayed, err := ss2.Replay() + require.NoError(t, err) + assert.Len(t, replayed, 2) +} + +// TestSplitBrain_Case1: Worker reconnects with unknown task +func TestSplitBrain_UnknownTask(t *testing.T) { + fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-split-1", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Simulate reconnect with unknown task + worker.Send(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: "worker-split-1", + ActiveTasks: []scheduler.ActiveTaskReport{ + {TaskID: "unknown-task", State: "running"}, + }, + }), + }) + + // Should receive cancel for unknown task + msg := worker.RecvTimeout(2 * time.Second) + if msg.Type == scheduler.MsgJobCancel { + t.Log("Received expected cancel for unknown task") + } else { + t.Logf("Received message type: %s (may need to check split-brain handling)", msg.Type) + } +} + +// TestSplitBrain_Case2: Worker reconnects with orphaned task +func TestSplitBrain_OrphanedTask(t *testing.T) { + fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-split-2", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Simulate reconnect with orphaned task + worker.Send(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: "worker-split-2", + ActiveTasks: []scheduler.ActiveTaskReport{ + {TaskID: "orphaned-task", State: "running"}, + }, + }), + }) + + msg := worker.RecvTimeout(2 * time.Second) + t.Logf("Received message type: %s", msg.Type) +} + +// TestSplitBrain_Case3: Worker reconnects with re-queued task +func TestSplitBrain_RequeuedTask(t *testing.T) { + fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig()) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-split-3", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Simulate reconnect with re-queued task + worker.Send(scheduler.Message{ + Type: scheduler.MsgRegister, + Payload: mustMarshal(scheduler.WorkerRegistration{ + ID: "worker-split-3", + ActiveTasks: []scheduler.ActiveTaskReport{ + {TaskID: "requeued-task", State: "queued"}, + }, + }), + }) + + msg := worker.RecvTimeout(2 * time.Second) + t.Logf("Received message type: %s", msg.Type) +} + +// TestAcceptanceTimeout: Job assigned but never accepted +func TestAcceptanceTimeout(t *testing.T) { + cfg := fixtures.DefaultHubConfig() + cfg.AcceptanceTimeoutSecs = 1 + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-timeout", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Signal ready but don't accept any job + worker.SignalReady(scheduler.SlotStatus{BatchTotal: 3, BatchInUse: 0}, "polling") + + // Wait for potential assignment (but don't accept it) + msg, ok := worker.RecvNonBlock() + if ok && msg.Type == scheduler.MsgJobAssign { + t.Log("Received job assignment, not accepting to test timeout") + } + + // Wait for acceptance timeout + time.Sleep(2 * time.Second) + + t.Log("Acceptance timeout test completed") +} + +// TestGangTimeout: Multi-node job timeout during gang commit +func TestGangTimeout(t *testing.T) { + cfg := fixtures.DefaultHubConfig() + cfg.GangAllocTimeoutSecs = 1 + fixture := fixtures.NewSchedulerTestFixture(t, cfg) + defer fixture.Cleanup() + + // Create mock worker + worker := fixture.CreateWorker("worker-gang", scheduler.WorkerCapabilities{GPUCount: 0}) + + // Signal ready + worker.SignalReady(scheduler.SlotStatus{BatchTotal: 3, BatchInUse: 0}, "polling") + + // Wait for potential assignment + _, _ = worker.RecvNonBlock() + + // Wait for gang timeout + time.Sleep(2 * time.Second) + + t.Log("Gang timeout test completed") +} diff --git a/tests/unit/scheduler/port_allocator_test.go b/tests/unit/scheduler/port_allocator_test.go new file mode 100644 index 0000000..0fc193b --- /dev/null +++ b/tests/unit/scheduler/port_allocator_test.go @@ -0,0 +1,91 @@ +package scheduler_test + +import ( + "testing" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPortAllocator_BasicAllocation(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10010) + + // Allocate first port + port1, err := pa.Allocate("service-1") + require.NoError(t, err) + assert.Equal(t, 10000, port1) + + // Allocate second port + port2, err := pa.Allocate("service-2") + require.NoError(t, err) + assert.Equal(t, 10001, port2) +} + +func TestPortAllocator_Exhaustion(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10002) // Only 3 ports available + + // Allocate all ports + port1, _ := pa.Allocate("service-1") + assert.Equal(t, 10000, port1) + + port2, _ := pa.Allocate("service-2") + assert.Equal(t, 10001, port2) + + port3, _ := pa.Allocate("service-3") + assert.Equal(t, 10002, port3) + + // Should fail - no ports left + _, err := pa.Allocate("service-4") + assert.Error(t, err) +} + +func TestPortAllocator_Release(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10005) + + // Allocate and release + port, _ := pa.Allocate("service-1") + assert.Equal(t, 10000, port) + + pa.Release(port) + + // Should be able to allocate again + port2, err := pa.Allocate("service-2") + require.NoError(t, err) + assert.Equal(t, 10000, port2) // Reuses released port +} + +func TestPortAllocator_DuplicateServiceID(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10010) + + // Allocate for service + port1, err := pa.Allocate("service-1") + require.NoError(t, err) + + // Allocate again with same service ID - gets new port (current behavior) + port2, err := pa.Allocate("service-1") + require.NoError(t, err) + assert.NotEqual(t, port1, port2) // Each call returns new port +} + +func TestPortAllocator_ConcurrentAccess(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10100) // 101 ports + done := make(chan bool, 10) + + // Concurrent allocations + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 10; j++ { + pa.Allocate("service-") + } + done <- true + }(i) + } + + for i := 0; i < 10; i++ { + <-done + } + + // All 100 ports should be allocated + // (10 goroutines * 10 allocations, but only 100 unique service IDs) +} diff --git a/tests/unit/scheduler/priority_queue_test.go b/tests/unit/scheduler/priority_queue_test.go new file mode 100644 index 0000000..8680ee7 --- /dev/null +++ b/tests/unit/scheduler/priority_queue_test.go @@ -0,0 +1,214 @@ +package scheduler_test + +import ( + "fmt" + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPriorityQueue_BasicOperations(t *testing.T) { + q := scheduler.NewPriorityQueue(0.1) + + task1 := &scheduler.Task{ + ID: "task-1", + Priority: 10, + SubmittedAt: time.Now(), + Spec: scheduler.JobSpec{ID: "task-1"}, + } + + task2 := &scheduler.Task{ + ID: "task-2", + Priority: 5, + SubmittedAt: time.Now(), + Spec: scheduler.JobSpec{ID: "task-2"}, + } + + // Add tasks + q.Add(task1) + q.Add(task2) + + require.Equal(t, 2, q.Len()) + + // Should return highest priority first (task1 with priority 10) + first := q.Take() + require.NotNil(t, first) + assert.Equal(t, "task-1", first.ID) + + // Second task + second := q.Take() + require.NotNil(t, second) + assert.Equal(t, "task-2", second.ID) + + // Queue empty + third := q.Take() + assert.Nil(t, third) +} + +func TestPriorityQueue_EffectivePriority_WithAging(t *testing.T) { + now := time.Now() + + // Task with lower priority but older submission + oldTask := &scheduler.Task{ + ID: "old-task", + Priority: 5, + SubmittedAt: now.Add(-10 * time.Minute), // 10 min old + } + + // Task with higher priority but recent submission + newTask := &scheduler.Task{ + ID: "new-task", + Priority: 10, + SubmittedAt: now, // Just submitted + } + + // Calculate effective priorities + oldEffective := oldTask.EffectivePriority(0.1, now) + newEffective := newTask.EffectivePriority(0.1, now) + + // Old task should have higher effective priority due to aging + // 5 + (10 min * 0.1) = 6.0 + // 10 + (0 min * 0.1) = 10.0 + assert.Less(t, oldEffective, newEffective) + assert.InDelta(t, 6.0, oldEffective, 0.1) + assert.InDelta(t, 10.0, newEffective, 0.1) +} + +func TestPriorityQueue_FIFOOnTie(t *testing.T) { + now := time.Now() + q := scheduler.NewPriorityQueue(0.1) + + // Two tasks with same priority, submitted at different times + task1 := &scheduler.Task{ + ID: "task-1", + Priority: 10, + SubmittedAt: now.Add(-5 * time.Minute), + Spec: scheduler.JobSpec{ID: "task-1"}, + } + + task2 := &scheduler.Task{ + ID: "task-2", + Priority: 10, + SubmittedAt: now.Add(-1 * time.Minute), + Spec: scheduler.JobSpec{ID: "task-2"}, + } + + // Add in reverse order + q.Add(task2) + q.Add(task1) + + // Should return older task first (FIFO on tie) + first := q.Take() + require.NotNil(t, first) + assert.Equal(t, "task-1", first.ID) + + second := q.Take() + require.NotNil(t, second) + assert.Equal(t, "task-2", second.ID) +} + +func TestPriorityQueue_Remove(t *testing.T) { + q := scheduler.NewPriorityQueue(0.1) + + task1 := &scheduler.Task{ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}} + task2 := &scheduler.Task{ID: "task-2", Priority: 5, Spec: scheduler.JobSpec{ID: "task-2"}} + task3 := &scheduler.Task{ID: "task-3", Priority: 1, Spec: scheduler.JobSpec{ID: "task-3"}} + + q.Add(task1) + q.Add(task2) + q.Add(task3) + + // Remove middle task + removed := q.Remove("task-2") + assert.True(t, removed) + assert.Equal(t, 2, q.Len()) + + // Try to remove non-existent + removed = q.Remove("non-existent") + assert.False(t, removed) + + // Verify remaining order + first := q.Take() + assert.Equal(t, "task-1", first.ID) + second := q.Take() + assert.Equal(t, "task-3", second.ID) +} + +func TestPriorityQueue_Get(t *testing.T) { + q := scheduler.NewPriorityQueue(0.1) + + task1 := &scheduler.Task{ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}} + q.Add(task1) + + // Get existing task + found := q.Get("task-1") + assert.NotNil(t, found) + assert.Equal(t, "task-1", found.ID) + + // Get non-existent + notFound := q.Get("non-existent") + assert.Nil(t, notFound) +} + +func TestPriorityQueue_Items(t *testing.T) { + q := scheduler.NewPriorityQueue(0.1) + + tasks := []*scheduler.Task{ + {ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}}, + {ID: "task-2", Priority: 5, Spec: scheduler.JobSpec{ID: "task-2"}}, + {ID: "task-3", Priority: 1, Spec: scheduler.JobSpec{ID: "task-3"}}, + } + + for _, task := range tasks { + q.Add(task) + } + + items := q.Items() + require.Len(t, items, 3) + + // Items should be in priority order (highest first) + assert.Equal(t, "task-1", items[0].ID) + assert.Equal(t, "task-2", items[1].ID) + assert.Equal(t, "task-3", items[2].ID) +} + +func TestPriorityQueue_ConcurrentAccess(t *testing.T) { + q := scheduler.NewPriorityQueue(0.1) + done := make(chan bool, 3) + + // Concurrent adds + go func() { + for i := 0; i < 100; i++ { + q.Add(&scheduler.Task{ID: fmt.Sprintf("task-%d", i), Priority: i}) + } + done <- true + }() + + // Concurrent takes + go func() { + for i := 0; i < 50; i++ { + q.Take() + } + done <- true + }() + + // Concurrent peeks + go func() { + for i := 0; i < 100; i++ { + q.Peek() + } + done <- true + }() + + // Wait for all goroutines + for i := 0; i < 3; i++ { + <-done + } + + // Queue should be in consistent state + assert.GreaterOrEqual(t, q.Len(), 0) + assert.LessOrEqual(t, q.Len(), 100) +} diff --git a/tests/unit/scheduler/service_templates_test.go b/tests/unit/scheduler/service_templates_test.go new file mode 100644 index 0000000..20530d5 --- /dev/null +++ b/tests/unit/scheduler/service_templates_test.go @@ -0,0 +1,264 @@ +package scheduler_test + +import ( + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestJupyterLabTemplate validates the JupyterLab template configuration +func TestJupyterLabTemplate(t *testing.T) { + template := scheduler.JupyterLabTemplate + + assert.Equal(t, "service", template.JobType) + assert.Equal(t, "service", template.SlotPool) + assert.Equal(t, 0, template.GPUCount) + + // Verify command includes required flags + require.NotEmpty(t, template.Command) + assert.Contains(t, template.Command, "jupyter") + assert.Contains(t, template.Command, "lab") + assert.Contains(t, template.Command, "--ip=0.0.0.0") + assert.Contains(t, template.Command, "--port={{SERVICE_PORT}}") + assert.Contains(t, template.Command, "--no-browser") + + // Verify health checks + assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/api", template.HealthCheck.Liveness) + assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/api/kernels", template.HealthCheck.Readiness) + assert.Equal(t, 15, template.HealthCheck.Interval) + assert.Equal(t, 5, template.HealthCheck.Timeout) + + // Verify mounts + require.Len(t, template.Mounts, 1) + assert.Equal(t, "{{WORKSPACE}}", template.Mounts[0].Source) + assert.Equal(t, "/workspace", template.Mounts[0].Destination) +} + +// TestJupyterNotebookTemplate validates the classic notebook template +func TestJupyterNotebookTemplate(t *testing.T) { + template := scheduler.JupyterNotebookTemplate + + assert.Equal(t, "service", template.JobType) + assert.Equal(t, "service", template.SlotPool) + assert.Equal(t, 0, template.GPUCount) + + // Verify uses notebook subcommand + require.NotEmpty(t, template.Command) + assert.Contains(t, template.Command, "notebook") +} + +// TestVLLMTemplate validates the vLLM inference template +func TestVLLMTemplate(t *testing.T) { + template := scheduler.VLLMTemplate + + assert.Equal(t, "service", template.JobType) + assert.Equal(t, "service", template.SlotPool) + assert.Equal(t, 1, template.GPUCount) // Requires GPU + + // Verify command + require.NotEmpty(t, template.Command) + assert.Contains(t, template.Command, "vllm.entrypoints.openai.api_server") + assert.Contains(t, template.Command, "{{MODEL_NAME}}") + assert.Contains(t, template.Command, "{{SERVICE_PORT}}") +} + +// TestPortAllocatorForServices validates port allocation for service jobs +func TestPortAllocatorForServices(t *testing.T) { + pa := scheduler.NewPortAllocator(10000, 10010) + + // Allocate a port for Jupyter service + port1, err := pa.Allocate("jupyter-task-1") + require.NoError(t, err) + assert.True(t, port1 >= 10000 && port1 <= 10010) + + // Verify we can get the task for this port + taskID := pa.GetAllocation(port1) + assert.Equal(t, "jupyter-task-1", taskID) + + // Allocate another port + port2, err := pa.Allocate("jupyter-task-2") + require.NoError(t, err) + assert.NotEqual(t, port1, port2) + + // Release first port + pa.Release(port1) + + // Verify port is now available + taskID = pa.GetAllocation(port1) + assert.Equal(t, "", taskID) + + // Can reallocate the same port + port3, err := pa.Allocate("jupyter-task-3") + require.NoError(t, err) + // Should get first available (which might be port1) + assert.True(t, port3 >= 10000 && port3 <= 10010) +} + +// TestPortAllocatorExhaustion validates behavior when no ports available +func TestPortAllocatorExhaustion(t *testing.T) { + // Small range for testing + pa := scheduler.NewPortAllocator(20000, 20002) + + // Allocate all ports + _, err := pa.Allocate("task-1") + require.NoError(t, err) + _, err = pa.Allocate("task-2") + require.NoError(t, err) + _, err = pa.Allocate("task-3") + require.NoError(t, err) + + // Fourth allocation should fail + _, err = pa.Allocate("task-4") + assert.Error(t, err) + assert.Contains(t, err.Error(), "no ports available") +} + +// TestPortAllocatorTTL validates port TTL behavior +func TestPortAllocatorTTL(t *testing.T) { + pa := scheduler.NewPortAllocator(30000, 30010) + + // Set short TTL for testing + pa.SetTTL(50 * time.Millisecond) + + // Allocate a port + port1, err := pa.Allocate("test-task") + require.NoError(t, err) + + // Release it (marks with expired timestamp due to short TTL) + pa.Release(port1) + + // Immediately try to allocate - should get different port since released one is "expired" + port2, err := pa.Allocate("test-task-2") + require.NoError(t, err) + + // Could be same or different depending on cleanup timing + assert.True(t, port2 >= 30000 && port2 <= 30010) +} + +// TestServiceSlotPoolSeparation validates that service and batch use different pools +func TestServiceSlotPoolSeparation(t *testing.T) { + // This test validates the conceptual separation + // In practice, the scheduler maintains separate queues + + // Use JupyterLabTemplate which has health checks configured + serviceJob := scheduler.JupyterLabTemplate + + batchJob := scheduler.JobSpec{ + ID: "batch-1", + SlotPool: "batch", + GPUCount: 1, + } + + // Verify different slot pools + assert.Equal(t, "service", serviceJob.SlotPool) + assert.Equal(t, "batch", batchJob.SlotPool) + + // Service job has health checks + assert.NotZero(t, serviceJob.HealthCheck.Interval) + + // Batch job would typically not have health checks + // (it runs to completion) +} + +// TestHealthCheckValidation validates health check configuration +func TestHealthCheckValidation(t *testing.T) { + tests := []struct { + name string + template scheduler.ServiceTemplate + valid bool + }{ + { + name: "JupyterLab - valid", + template: scheduler.ServiceTemplate{ + JobType: "service", + SlotPool: "service", + HealthCheck: scheduler.ServiceHealthCheck{ + Liveness: "http://localhost:8888/api", + Readiness: "http://localhost:8888/api/kernels", + Interval: 15, + Timeout: 5, + }, + }, + valid: true, + }, + { + name: "Missing liveness - invalid", + template: scheduler.ServiceTemplate{ + JobType: "service", + SlotPool: "service", + HealthCheck: scheduler.ServiceHealthCheck{ + Readiness: "http://localhost:8888/api", + Interval: 15, + }, + }, + valid: false, + }, + { + name: "Zero interval - invalid", + template: scheduler.ServiceTemplate{ + JobType: "service", + SlotPool: "service", + HealthCheck: scheduler.ServiceHealthCheck{ + Liveness: "http://localhost:8888/api", + Readiness: "http://localhost:8888/api", + Interval: 0, + }, + }, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hc := tt.template.HealthCheck + isValid := hc.Liveness != "" && hc.Interval > 0 && hc.Timeout > 0 + assert.Equal(t, tt.valid, isValid) + }) + } +} + +// TestDefaultPortRange validates the default service port range +func TestDefaultPortRange(t *testing.T) { + // Default range should be large enough for typical deployments + rangeSize := scheduler.DefaultServicePortEnd - scheduler.DefaultServicePortStart + assert.True(t, rangeSize >= 1000, "Default port range should be at least 1000 ports") + assert.Equal(t, 8000, scheduler.DefaultServicePortStart) + assert.Equal(t, 9000, scheduler.DefaultServicePortEnd) +} + +// TestTemplateVariableExpansion validates template variables are present +func TestTemplateVariableExpansion(t *testing.T) { + template := scheduler.JupyterLabTemplate + + // Check command contains template variables + hasServicePort := false + for _, cmd := range template.Command { + if cmd == "--port={{SERVICE_PORT}}" { + hasServicePort = true + break + } + } + assert.True(t, hasServicePort, "Command should contain {{SERVICE_PORT}} template variable") + + // Check env contains secret template + val, ok := template.Env["JUPYTER_TOKEN"] + assert.True(t, ok, "Should have JUPYTER_TOKEN env var") + assert.Contains(t, val, "{{SECRET:", "Should use secret template") +} + +// BenchmarkPortAllocation benchmarks port allocation performance +func BenchmarkPortAllocation(b *testing.B) { + pa := scheduler.NewPortAllocator(40000, 41000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + port, err := pa.Allocate("bench-task") + if err != nil { + b.Fatal(err) + } + pa.Release(port) + } +} diff --git a/tests/unit/scheduler/state_store_test.go b/tests/unit/scheduler/state_store_test.go new file mode 100644 index 0000000..fe64556 --- /dev/null +++ b/tests/unit/scheduler/state_store_test.go @@ -0,0 +1,67 @@ +package scheduler_test + +import ( + "testing" + "time" + + "github.com/jfraeys/fetch_ml/internal/scheduler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStateStore_BasicOperations(t *testing.T) { + dir := t.TempDir() + ss, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + + // Append some events + events := []scheduler.StateEvent{ + {Type: scheduler.EventJobEnqueued, TaskID: "task-1", Timestamp: time.Now()}, + {Type: scheduler.EventJobAssigned, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()}, + {Type: scheduler.EventJobCompleted, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()}, + } + + for _, e := range events { + err := ss.Append(e) + require.NoError(t, err) + } + + // Replay events + replayed, err := ss.Replay() + require.NoError(t, err) + assert.Len(t, replayed, 3) + assert.Equal(t, "task-1", replayed[0].TaskID) +} + +func TestStateStore_Persistence(t *testing.T) { + dir := t.TempDir() + + // Create store and append events + ss1, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + event := scheduler.StateEvent{ + Type: scheduler.EventJobEnqueued, + TaskID: "persistent-task", + Timestamp: time.Now(), + } + err = ss1.Append(event) + require.NoError(t, err) + + // Create new store instance pointing to same directory + ss2, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + replayed, err := ss2.Replay() + require.NoError(t, err) + assert.Len(t, replayed, 1) + assert.Equal(t, "persistent-task", replayed[0].TaskID) +} + +func TestStateStore_ReplayEmpty(t *testing.T) { + dir := t.TempDir() + ss, err := scheduler.NewStateStore(dir + "/state.json") + require.NoError(t, err) + + replayed, err := ss.Replay() + require.NoError(t, err) + assert.Empty(t, replayed) +}