feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
This commit is contained in:
parent
6e0e7d9d2e
commit
43e6446587
24 changed files with 4968 additions and 0 deletions
274
cmd/scheduler/main.go
Normal file
274
cmd/scheduler/main.go
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents the scheduler configuration
|
||||||
|
type Config struct {
|
||||||
|
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchedulerConfig struct {
|
||||||
|
BindAddr string `yaml:"bind_addr"`
|
||||||
|
CertFile string `yaml:"cert_file"`
|
||||||
|
KeyFile string `yaml:"key_file"`
|
||||||
|
AutoGenerateCerts bool `yaml:"auto_generate_certs"`
|
||||||
|
StateDir string `yaml:"state_dir"`
|
||||||
|
DefaultBatchSlots int `yaml:"default_batch_slots"`
|
||||||
|
DefaultServiceSlots int `yaml:"default_service_slots"`
|
||||||
|
StarvationThresholdMins float64 `yaml:"starvation_threshold_mins"`
|
||||||
|
PriorityAgingRate float64 `yaml:"priority_aging_rate"`
|
||||||
|
GangAllocTimeoutSecs int `yaml:"gang_alloc_timeout_secs"`
|
||||||
|
AcceptanceTimeoutSecs int `yaml:"acceptance_timeout_secs"`
|
||||||
|
MetricsAddr string `yaml:"metrics_addr"`
|
||||||
|
WorkerTokens []WorkerToken `yaml:"worker_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WorkerToken struct {
|
||||||
|
ID string `yaml:"id"`
|
||||||
|
Token string `yaml:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
configPath string
|
||||||
|
generateToken bool
|
||||||
|
initConfig bool
|
||||||
|
numTokens int
|
||||||
|
)
|
||||||
|
flag.StringVar(&configPath, "config", "scheduler.yaml", "Path to scheduler config file")
|
||||||
|
flag.BoolVar(&generateToken, "generate-token", false, "Generate a new worker token and exit")
|
||||||
|
flag.BoolVar(&initConfig, "init", false, "Initialize a new config file with generated tokens")
|
||||||
|
flag.IntVar(&numTokens, "tokens", 3, "Number of tokens to generate (used with -init)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Handle token generation mode
|
||||||
|
if generateToken {
|
||||||
|
token := scheduler.GenerateWorkerToken()
|
||||||
|
fmt.Println(token)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle init mode
|
||||||
|
if initConfig {
|
||||||
|
if err := generateConfig(configPath, numTokens); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Config generated: %s\n", configPath)
|
||||||
|
fmt.Printf("\nGenerated %d worker tokens. Copy the appropriate token to each worker's config.\n", numTokens)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
cfg, err := loadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to load config", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup logging
|
||||||
|
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})
|
||||||
|
logger := slog.New(handler)
|
||||||
|
slog.SetDefault(logger)
|
||||||
|
|
||||||
|
// Create token map
|
||||||
|
tokenMap := make(map[string]string)
|
||||||
|
for _, wt := range cfg.Scheduler.WorkerTokens {
|
||||||
|
tokenMap[wt.Token] = wt.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-generate certs if needed
|
||||||
|
if cfg.Scheduler.AutoGenerateCerts && cfg.Scheduler.CertFile != "" {
|
||||||
|
if _, err := os.Stat(cfg.Scheduler.CertFile); os.IsNotExist(err) {
|
||||||
|
keyFile := cfg.Scheduler.KeyFile
|
||||||
|
if keyFile == "" {
|
||||||
|
keyFile = cfg.Scheduler.CertFile + ".key"
|
||||||
|
}
|
||||||
|
logger.Info("generating self-signed certificate", "cert", cfg.Scheduler.CertFile)
|
||||||
|
if err := scheduler.GenerateSelfSignedCert(cfg.Scheduler.CertFile, keyFile); err != nil {
|
||||||
|
logger.Error("failed to generate certificate", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hub config
|
||||||
|
hubCfg := scheduler.HubConfig{
|
||||||
|
BindAddr: cfg.Scheduler.BindAddr,
|
||||||
|
CertFile: cfg.Scheduler.CertFile,
|
||||||
|
KeyFile: cfg.Scheduler.KeyFile,
|
||||||
|
AutoGenerateCerts: cfg.Scheduler.AutoGenerateCerts,
|
||||||
|
StateDir: cfg.Scheduler.StateDir,
|
||||||
|
DefaultBatchSlots: cfg.Scheduler.DefaultBatchSlots,
|
||||||
|
DefaultServiceSlots: cfg.Scheduler.DefaultServiceSlots,
|
||||||
|
StarvationThresholdMins: cfg.Scheduler.StarvationThresholdMins,
|
||||||
|
PriorityAgingRate: cfg.Scheduler.PriorityAgingRate,
|
||||||
|
GangAllocTimeoutSecs: cfg.Scheduler.GangAllocTimeoutSecs,
|
||||||
|
AcceptanceTimeoutSecs: cfg.Scheduler.AcceptanceTimeoutSecs,
|
||||||
|
WorkerTokens: tokenMap,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create auditor (optional)
|
||||||
|
var auditor *audit.Logger
|
||||||
|
|
||||||
|
// Create hub
|
||||||
|
hub, err := scheduler.NewHub(hubCfg, auditor)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create scheduler hub", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start hub
|
||||||
|
if err := hub.Start(); err != nil {
|
||||||
|
logger.Error("failed to start scheduler hub", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("scheduler hub started", "bind_addr", cfg.Scheduler.BindAddr)
|
||||||
|
|
||||||
|
// Setup HTTP handlers
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/ws/worker", hub.HandleConnection)
|
||||||
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/metrics", hub.ServeMetrics)
|
||||||
|
|
||||||
|
// Setup graceful shutdown
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
go func() {
|
||||||
|
if cfg.Scheduler.CertFile != "" {
|
||||||
|
logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr)
|
||||||
|
if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil {
|
||||||
|
logger.Error("server error", "error", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr)
|
||||||
|
if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil {
|
||||||
|
logger.Error("server error", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for shutdown signal
|
||||||
|
<-sigChan
|
||||||
|
logger.Info("shutting down scheduler...")
|
||||||
|
hub.Stop()
|
||||||
|
logger.Info("scheduler stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfig(path string) (*Config, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if cfg.Scheduler.BindAddr == "" {
|
||||||
|
cfg.Scheduler.BindAddr = "0.0.0.0:7777"
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.StateDir == "" {
|
||||||
|
cfg.Scheduler.StateDir = "/var/lib/fetch_ml"
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.DefaultBatchSlots == 0 {
|
||||||
|
cfg.Scheduler.DefaultBatchSlots = 3
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.DefaultServiceSlots == 0 {
|
||||||
|
cfg.Scheduler.DefaultServiceSlots = 1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.StarvationThresholdMins == 0 {
|
||||||
|
cfg.Scheduler.StarvationThresholdMins = 5
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.PriorityAgingRate == 0 {
|
||||||
|
cfg.Scheduler.PriorityAgingRate = 0.1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.GangAllocTimeoutSecs == 0 {
|
||||||
|
cfg.Scheduler.GangAllocTimeoutSecs = 60
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.AcceptanceTimeoutSecs == 0 {
|
||||||
|
cfg.Scheduler.AcceptanceTimeoutSecs = 30
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateConfig creates a new scheduler config file with generated tokens
|
||||||
|
func generateConfig(path string, numTokens int) error {
|
||||||
|
// Generate tokens
|
||||||
|
var tokens []WorkerToken
|
||||||
|
for i := 1; i <= numTokens; i++ {
|
||||||
|
tokens = append(tokens, WorkerToken{
|
||||||
|
ID: fmt.Sprintf("worker-%02d", i),
|
||||||
|
Token: scheduler.GenerateWorkerToken(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := Config{
|
||||||
|
Scheduler: SchedulerConfig{
|
||||||
|
BindAddr: "0.0.0.0:7777",
|
||||||
|
AutoGenerateCerts: true,
|
||||||
|
CertFile: "/etc/fetch_ml/scheduler.crt",
|
||||||
|
KeyFile: "/etc/fetch_ml/scheduler.key",
|
||||||
|
StateDir: "/var/lib/fetch_ml",
|
||||||
|
DefaultBatchSlots: 3,
|
||||||
|
DefaultServiceSlots: 1,
|
||||||
|
StarvationThresholdMins: 5,
|
||||||
|
PriorityAgingRate: 0.1,
|
||||||
|
GangAllocTimeoutSecs: 60,
|
||||||
|
AcceptanceTimeoutSecs: 30,
|
||||||
|
MetricsAddr: "0.0.0.0:9090",
|
||||||
|
WorkerTokens: tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add header comment
|
||||||
|
header := `# Scheduler Configuration for fetch_ml
|
||||||
|
# Generated by: scheduler -init
|
||||||
|
#
|
||||||
|
# ⚠️ SECURITY WARNING: This file contains authentication tokens.
|
||||||
|
# - Do NOT commit to git
|
||||||
|
# - Keep the file permissions secure (chmod 600)
|
||||||
|
# - Copy the appropriate token to each worker's config
|
||||||
|
#
|
||||||
|
`
|
||||||
|
fullContent := header + string(data)
|
||||||
|
|
||||||
|
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
|
||||||
|
return fmt.Errorf("write config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print tokens to stdout for easy distribution
|
||||||
|
fmt.Print("\n=== Generated Worker Tokens ===\n")
|
||||||
|
fmt.Print("Copy these to your worker configs:\n\n")
|
||||||
|
for _, t := range tokens {
|
||||||
|
fmt.Printf("Worker: %s\nToken: %s\n\n", t.ID, t.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
157
internal/scheduler/auth.go
Normal file
157
internal/scheduler/auth.go
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler
|
||||||
|
func GenerateSelfSignedCert(certFile, keyFile string) error {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(certFile), 0755); err != nil {
|
||||||
|
return fmt.Errorf("create cert directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"fetch_ml"},
|
||||||
|
CommonName: "fetch_ml_scheduler",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IP SANs for local development
|
||||||
|
template.IPAddresses = []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certOut, err := os.Create(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create cert file: %w", err)
|
||||||
|
}
|
||||||
|
defer certOut.Close()
|
||||||
|
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
|
||||||
|
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create key file: %w", err)
|
||||||
|
}
|
||||||
|
defer keyOut.Close()
|
||||||
|
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal private key: %w", err)
|
||||||
|
}
|
||||||
|
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialWSS connects to the scheduler via WSS with cert pinning
|
||||||
|
func DialWSS(addr, certFile, token string) (*websocket.Conn, error) {
|
||||||
|
certPEM, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read cert file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(certPEM) {
|
||||||
|
return nil, fmt.Errorf("parse cert file")
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: pool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
},
|
||||||
|
HandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
header := http.Header{}
|
||||||
|
if token != "" {
|
||||||
|
header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "wss://" + addr + "/ws/worker"
|
||||||
|
conn, _, err := dialer.Dial(url, header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dial scheduler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalModeDial connects without TLS for single-node mode (loopback only)
|
||||||
|
func LocalModeDial(port int, token string) (*websocket.Conn, error) {
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
header := http.Header{}
|
||||||
|
if token != "" {
|
||||||
|
header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("ws://127.0.0.1:%d/ws/worker", port)
|
||||||
|
conn, _, err := dialer.Dial(url, header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dial local scheduler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenValidator validates worker authentication tokens
|
||||||
|
type TokenValidator struct {
|
||||||
|
tokens map[string]string // token -> workerID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTokenValidator(tokens map[string]string) *TokenValidator {
|
||||||
|
return &TokenValidator{tokens: tokens}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tv *TokenValidator) Validate(token string) (workerID string, ok bool) {
|
||||||
|
workerID, ok = tv.tokens[token]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractBearerToken extracts the token from an Authorization header
|
||||||
|
func ExtractBearerToken(header string) string {
|
||||||
|
return strings.TrimPrefix(header, "Bearer ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateWorkerToken creates a cryptographically secure random token for a worker
|
||||||
|
func GenerateWorkerToken() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
rand.Read(b)
|
||||||
|
// Use URL-safe base64 encoding for compact, URL-friendly tokens
|
||||||
|
return "wkr_" + base64.URLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
930
internal/scheduler/hub.go
Normal file
930
internal/scheduler/hub.go
Normal file
|
|
@ -0,0 +1,930 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||||
|
)
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true // Allow all origins (configurable in production)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// SchedulerHub manages worker connections and job scheduling
|
||||||
|
type SchedulerHub struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
workers map[string]*WorkerConn
|
||||||
|
readyWorkers map[string]*WorkerConn
|
||||||
|
batchQueue *PriorityQueue
|
||||||
|
serviceQueue *PriorityQueue
|
||||||
|
reservations map[string]*Reservation
|
||||||
|
multiNodePending map[string]*MultiNodeJob
|
||||||
|
pendingAcceptance map[string]*JobAssignment
|
||||||
|
state *StateStore
|
||||||
|
starvation *StarvationTracker
|
||||||
|
metrics *SchedulerMetrics
|
||||||
|
auditor *audit.Logger
|
||||||
|
tokenValidator *TokenValidator
|
||||||
|
config HubConfig
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
server *http.Server
|
||||||
|
listener net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
type HubConfig struct {
|
||||||
|
BindAddr string
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
AutoGenerateCerts bool
|
||||||
|
StateDir string
|
||||||
|
DefaultBatchSlots int
|
||||||
|
DefaultServiceSlots int
|
||||||
|
StarvationThresholdMins float64
|
||||||
|
PriorityAgingRate float64
|
||||||
|
GangAllocTimeoutSecs int
|
||||||
|
AcceptanceTimeoutSecs int
|
||||||
|
LocalMode bool
|
||||||
|
WorkerTokens map[string]string // token -> workerID
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkerConn represents a connected worker
|
||||||
|
type WorkerConn struct {
|
||||||
|
workerID string
|
||||||
|
conn *websocket.Conn
|
||||||
|
capabilities WorkerCapabilities
|
||||||
|
slots SlotStatus
|
||||||
|
lease *Lease
|
||||||
|
activeTasks map[string]struct{}
|
||||||
|
send chan Message
|
||||||
|
hub *SchedulerHub
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lease tracks job ownership
|
||||||
|
type Lease struct {
|
||||||
|
TaskID string
|
||||||
|
WorkerID string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reservation prevents starvation of large jobs
|
||||||
|
type Reservation struct {
|
||||||
|
TaskID string
|
||||||
|
GPUCount int
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiNodeJob tracks gang allocation state
|
||||||
|
type MultiNodeJob struct {
|
||||||
|
JobID string
|
||||||
|
TotalNodes int
|
||||||
|
Assignments []*NodeAssignment
|
||||||
|
CommittedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type NodeAssignment struct {
|
||||||
|
Worker *WorkerConn
|
||||||
|
Rank int
|
||||||
|
CommittedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// JobAssignment tracks acceptance state
|
||||||
|
type JobAssignment struct {
|
||||||
|
TaskID string
|
||||||
|
WorkerID string
|
||||||
|
AssignedAt time.Time
|
||||||
|
AcceptanceDeadline time.Time
|
||||||
|
Accepted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarvationTracker monitors long-waiting jobs
|
||||||
|
type StarvationTracker struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
threshold time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// SchedulerMetrics tracks scheduler statistics
|
||||||
|
type SchedulerMetrics struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
WorkersConnected int
|
||||||
|
QueueDepthBatch int
|
||||||
|
QueueDepthService int
|
||||||
|
JobsCompleted int
|
||||||
|
JobsFailed int
|
||||||
|
JobsCancelled int
|
||||||
|
WorkerSlots map[string]SlotStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHub creates a new scheduler hub
|
||||||
|
func NewHub(cfg HubConfig, auditor *audit.Logger) (*SchedulerHub, error) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Initialize state store
|
||||||
|
statePath := cfg.StateDir + "/scheduler.state"
|
||||||
|
state, err := NewStateStore(statePath)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil, fmt.Errorf("init state store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agingRate := cfg.PriorityAgingRate
|
||||||
|
if agingRate == 0 {
|
||||||
|
agingRate = 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
hub := &SchedulerHub{
|
||||||
|
workers: make(map[string]*WorkerConn),
|
||||||
|
readyWorkers: make(map[string]*WorkerConn),
|
||||||
|
batchQueue: NewPriorityQueue(agingRate),
|
||||||
|
serviceQueue: NewPriorityQueue(agingRate),
|
||||||
|
reservations: make(map[string]*Reservation),
|
||||||
|
multiNodePending: make(map[string]*MultiNodeJob),
|
||||||
|
pendingAcceptance: make(map[string]*JobAssignment),
|
||||||
|
state: state,
|
||||||
|
starvation: &StarvationTracker{
|
||||||
|
threshold: time.Duration(cfg.StarvationThresholdMins) * time.Minute,
|
||||||
|
},
|
||||||
|
metrics: &SchedulerMetrics{
|
||||||
|
WorkerSlots: make(map[string]SlotStatus),
|
||||||
|
},
|
||||||
|
auditor: auditor,
|
||||||
|
tokenValidator: NewTokenValidator(cfg.WorkerTokens),
|
||||||
|
config: cfg,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
return hub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the scheduler, starts the HTTP server, and replays state
|
||||||
|
func (h *SchedulerHub) Start() error {
|
||||||
|
// Replay state first
|
||||||
|
events, err := h.state.Replay()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("state replay failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range events {
|
||||||
|
switch ev.Type {
|
||||||
|
case EventJobEnqueued:
|
||||||
|
h.restoreJob(ev)
|
||||||
|
case EventJobAssigned:
|
||||||
|
h.restoreAssignment(ev)
|
||||||
|
case EventJobCompleted, EventJobFailed, EventJobCancelled:
|
||||||
|
// terminal — skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start WSS server (unified protocol)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/ws/worker", h.HandleConnection)
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", h.config.BindAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen: %w", err)
|
||||||
|
}
|
||||||
|
h.listener = listener
|
||||||
|
|
||||||
|
h.server = &http.Server{Handler: mux}
|
||||||
|
|
||||||
|
// Auto-generate self-signed certs if requested
|
||||||
|
if h.config.AutoGenerateCerts && (h.config.CertFile == "" || h.config.KeyFile == "") {
|
||||||
|
certFile := h.config.StateDir + "/scheduler.crt"
|
||||||
|
keyFile := h.config.StateDir + "/scheduler.key"
|
||||||
|
if err := GenerateSelfSignedCert(certFile, keyFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate self-signed cert: %w", err)
|
||||||
|
}
|
||||||
|
h.config.CertFile = certFile
|
||||||
|
h.config.KeyFile = keyFile
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with TLS if certificates are configured
|
||||||
|
if h.config.CertFile != "" && h.config.KeyFile != "" {
|
||||||
|
go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile)
|
||||||
|
} else {
|
||||||
|
go h.server.Serve(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start background tasks
|
||||||
|
go h.checkAcceptanceTimeouts()
|
||||||
|
go h.checkGangTimeouts()
|
||||||
|
go h.checkStarvation()
|
||||||
|
|
||||||
|
// Grace period: workers have 30s to reconnect before assigned jobs are orphaned
|
||||||
|
time.AfterFunc(30*time.Second, h.reconcileOrphans)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the listening address of the scheduler
|
||||||
|
func (h *SchedulerHub) Addr() string {
|
||||||
|
if h.listener == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h.listener.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the scheduler
|
||||||
|
func (h *SchedulerHub) Stop() {
|
||||||
|
h.cancel()
|
||||||
|
h.state.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleConnection handles WSS connections from workers and metrics clients
|
||||||
|
func (h *SchedulerHub) HandleConnection(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Validate token
|
||||||
|
token := ExtractBearerToken(r.Header.Get("Authorization"))
|
||||||
|
clientID, ok := h.tokenValidator.Validate(token)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade to WebSocket
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "upgrade failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a metrics client (special token prefix)
|
||||||
|
if strings.HasPrefix(clientID, "metrics-") {
|
||||||
|
go h.runMetricsClient(clientID, conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.runWorker(clientID, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) runWorker(workerID string, conn *websocket.Conn) {
|
||||||
|
wc := &WorkerConn{
|
||||||
|
workerID: workerID,
|
||||||
|
conn: conn,
|
||||||
|
slots: SlotStatus{},
|
||||||
|
activeTasks: make(map[string]struct{}),
|
||||||
|
send: make(chan Message, 10),
|
||||||
|
hub: h,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
h.workers[workerID] = wc
|
||||||
|
h.metrics.WorkersConnected++
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.workers, workerID)
|
||||||
|
delete(h.readyWorkers, workerID)
|
||||||
|
h.metrics.WorkersConnected--
|
||||||
|
h.mu.Unlock()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send loop
|
||||||
|
go func() {
|
||||||
|
for msg := range wc.send {
|
||||||
|
conn.WriteJSON(msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Receive loop
|
||||||
|
for {
|
||||||
|
var msg Message
|
||||||
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
|
return // Connection closed
|
||||||
|
}
|
||||||
|
h.handleMessage(wc, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) handleMessage(wc *WorkerConn, msg Message) {
|
||||||
|
switch msg.Type {
|
||||||
|
case MsgRegister:
|
||||||
|
var reg WorkerRegistration
|
||||||
|
json.Unmarshal(msg.Payload, ®)
|
||||||
|
h.reconcileWorker(reg, wc)
|
||||||
|
case MsgHeartbeat:
|
||||||
|
var hb HeartbeatPayload
|
||||||
|
json.Unmarshal(msg.Payload, &hb)
|
||||||
|
wc.mu.Lock()
|
||||||
|
wc.slots = hb.Slots
|
||||||
|
wc.mu.Unlock()
|
||||||
|
h.updateWorkerMetrics(wc.workerID, hb.Slots)
|
||||||
|
case MsgReadyForWork:
|
||||||
|
var ready ReadyPayload
|
||||||
|
json.Unmarshal(msg.Payload, &ready)
|
||||||
|
wc.mu.Lock()
|
||||||
|
wc.slots = ready.Slots
|
||||||
|
wc.mu.Unlock()
|
||||||
|
h.handleReady(wc, ready.Slots)
|
||||||
|
case MsgJobAccepted:
|
||||||
|
var taskID string
|
||||||
|
json.Unmarshal(msg.Payload, &taskID)
|
||||||
|
h.handleJobAccepted(wc.workerID, taskID)
|
||||||
|
case MsgJobResult:
|
||||||
|
var result JobResultPayload
|
||||||
|
json.Unmarshal(msg.Payload, &result)
|
||||||
|
h.handleJobResult(wc.workerID, result)
|
||||||
|
case MsgServiceHealth:
|
||||||
|
// Service health updates - logged but no action needed for MVP
|
||||||
|
var health ServiceHealthPayload
|
||||||
|
json.Unmarshal(msg.Payload, &health)
|
||||||
|
slog.Debug("service health update", "worker", wc.workerID, "task", health.TaskID, "healthy", health.Healthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) reconcileWorker(reg WorkerRegistration, wc *WorkerConn) {
|
||||||
|
wc.capabilities = reg.Capabilities
|
||||||
|
|
||||||
|
for _, reported := range reg.ActiveTasks {
|
||||||
|
task := h.getTask(reported.TaskID)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case task == nil:
|
||||||
|
// Case 1: Scheduler has no record — kill it
|
||||||
|
wc.send <- Message{Type: MsgJobCancel,
|
||||||
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
||||||
|
|
||||||
|
case task.Status == "orphaned":
|
||||||
|
// Case 2: Scheduler thought lost — restore, worker is running it
|
||||||
|
h.restoreLease(reported.TaskID, wc.workerID)
|
||||||
|
task.Status = "running"
|
||||||
|
|
||||||
|
case task.Status == "queued" || task.Status == "assigned":
|
||||||
|
// Case 3: Re-queued while worker was disconnected — cancel on this worker
|
||||||
|
wc.send <- Message{Type: MsgJobCancel,
|
||||||
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
||||||
|
|
||||||
|
case task.Status == "running" && task.WorkerID != reg.ID:
|
||||||
|
// Case 4: Running on two workers — cancel on reconnecting worker
|
||||||
|
wc.send <- Message{Type: MsgJobCancel,
|
||||||
|
Payload: mustMarshal(map[string]string{"task_id": reported.TaskID})}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send registration acknowledgment
|
||||||
|
wc.send <- Message{Type: MsgAck}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) handleReady(wc *WorkerConn, _ SlotStatus) {
|
||||||
|
// Check for multi-node jobs first
|
||||||
|
for _, task := range h.batchQueue.Items() {
|
||||||
|
if task.Spec.NodeCount > 1 && h.canAdmit(task, wc) {
|
||||||
|
if h.handleMultiNodeReady(task, wc) {
|
||||||
|
return // Multi-node job is being handled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall through to regular single-node matching
|
||||||
|
if task := h.findMatch(wc); task != nil {
|
||||||
|
wc.send <- h.assignTask(task, wc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.starvation.CheckAndReserve(h)
|
||||||
|
h.mu.Lock()
|
||||||
|
h.readyWorkers[wc.workerID] = wc
|
||||||
|
h.mu.Unlock()
|
||||||
|
wc.send <- Message{Type: MsgNoWork}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) findMatch(wc *WorkerConn) *Task {
|
||||||
|
if wc.slots.ServiceAvailable() > 0 {
|
||||||
|
if task := h.scanFit(h.serviceQueue, wc); task != nil {
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if wc.slots.BatchAvailable() > 0 {
|
||||||
|
if task := h.scanFit(h.batchQueue, wc); task != nil {
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) scanFit(q *PriorityQueue, wc *WorkerConn) *Task {
|
||||||
|
for _, task := range q.Items() {
|
||||||
|
if task.Spec.NodeCount > 1 {
|
||||||
|
continue // gang allocator handles multi-node
|
||||||
|
}
|
||||||
|
if h.canAdmit(task, wc) {
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) canAdmit(candidate *Task, worker *WorkerConn) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, res := range h.reservations {
|
||||||
|
if candidate.Spec.GPUCount > 0 && res.GPUCount > 0 {
|
||||||
|
if worker.capabilities.GPUCount < res.GPUCount+candidate.Spec.GPUCount {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return worker.capabilities.GPUCount >= candidate.Spec.GPUCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) assignTask(task *Task, wc *WorkerConn) Message {
|
||||||
|
// Remove from queue first (prevent double-assignment)
|
||||||
|
h.batchQueue.Remove(task.ID)
|
||||||
|
h.serviceQueue.Remove(task.ID)
|
||||||
|
|
||||||
|
// Track pending acceptance
|
||||||
|
h.mu.Lock()
|
||||||
|
h.pendingAcceptance[task.ID] = &JobAssignment{
|
||||||
|
TaskID: task.ID,
|
||||||
|
WorkerID: wc.workerID,
|
||||||
|
AssignedAt: time.Now(),
|
||||||
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||||
|
Accepted: false,
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
// Persist assignment
|
||||||
|
h.state.Append(StateEvent{
|
||||||
|
Type: EventJobAssigned,
|
||||||
|
TaskID: task.ID,
|
||||||
|
WorkerID: wc.workerID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return Message{
|
||||||
|
Type: MsgJobAssign,
|
||||||
|
Payload: mustMarshal(task.Spec),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) handleJobAccepted(_, taskID string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
||||||
|
assignment.Accepted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) handleJobResult(workerID string, result JobResultPayload) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
delete(h.pendingAcceptance, result.TaskID)
|
||||||
|
|
||||||
|
eventType := EventJobCompleted
|
||||||
|
switch result.State {
|
||||||
|
case "failed":
|
||||||
|
eventType = EventJobFailed
|
||||||
|
h.metrics.JobsFailed++
|
||||||
|
case "cancelled":
|
||||||
|
eventType = EventJobCancelled
|
||||||
|
h.metrics.JobsCancelled++
|
||||||
|
default:
|
||||||
|
h.metrics.JobsCompleted++
|
||||||
|
}
|
||||||
|
|
||||||
|
h.state.Append(StateEvent{
|
||||||
|
Type: eventType,
|
||||||
|
TaskID: result.TaskID,
|
||||||
|
WorkerID: workerID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAcceptanceTimeouts re-queues jobs that weren't accepted
|
||||||
|
func (h *SchedulerHub) checkAcceptanceTimeouts() {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
h.mu.Lock()
|
||||||
|
for taskID, a := range h.pendingAcceptance {
|
||||||
|
if !a.Accepted && time.Now().After(a.AcceptanceDeadline) {
|
||||||
|
h.batchQueue.Add(h.getTask(taskID))
|
||||||
|
delete(h.pendingAcceptance, taskID)
|
||||||
|
if wc, ok := h.workers[a.WorkerID]; ok {
|
||||||
|
wc.slots = SlotStatus{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkGangTimeouts releases reserved slots for incomplete gangs
|
||||||
|
func (h *SchedulerHub) checkGangTimeouts() {
|
||||||
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
h.mu.Lock()
|
||||||
|
for jobID, pending := range h.multiNodePending {
|
||||||
|
if time.Since(pending.CommittedAt) > time.Duration(h.config.GangAllocTimeoutSecs)*time.Second {
|
||||||
|
for _, a := range pending.Assignments {
|
||||||
|
a.Worker.slots = SlotStatus{}
|
||||||
|
h.readyWorkers[a.Worker.workerID] = a.Worker
|
||||||
|
}
|
||||||
|
delete(h.multiNodePending, jobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) checkStarvation() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
h.starvation.CheckAndReserve(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *StarvationTracker) CheckAndReserve(h *SchedulerHub) {
|
||||||
|
st.mu.Lock()
|
||||||
|
defer st.mu.Unlock()
|
||||||
|
|
||||||
|
for _, task := range h.batchQueue.Items() {
|
||||||
|
if time.Since(task.SubmittedAt) > st.threshold && !st.hasReservation(h, task.ID) {
|
||||||
|
h.mu.Lock()
|
||||||
|
h.reservations[task.ID] = &Reservation{
|
||||||
|
TaskID: task.ID,
|
||||||
|
GPUCount: task.Spec.GPUCount,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *StarvationTracker) hasReservation(h *SchedulerHub, taskID string) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
_, exists := h.reservations[taskID]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods (stubs to be implemented)
|
||||||
|
|
||||||
|
// GetTask returns a task by ID (public API)
|
||||||
|
func (h *SchedulerHub) GetTask(taskID string) *Task {
|
||||||
|
return h.getTask(taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitJob submits a new job to the scheduler (public API)
|
||||||
|
func (h *SchedulerHub) SubmitJob(spec JobSpec) error {
|
||||||
|
if spec.ID == "" {
|
||||||
|
return fmt.Errorf("job ID is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
task := &Task{
|
||||||
|
ID: spec.ID,
|
||||||
|
Spec: spec,
|
||||||
|
Status: "queued",
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist to state store
|
||||||
|
h.state.Append(StateEvent{
|
||||||
|
Type: EventJobEnqueued,
|
||||||
|
TaskID: spec.ID,
|
||||||
|
Payload: mustMarshal(spec),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add to appropriate queue
|
||||||
|
if spec.Type == JobTypeService {
|
||||||
|
h.serviceQueue.Add(task)
|
||||||
|
} else {
|
||||||
|
h.batchQueue.Add(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send prewarm hint if job has snapshot
|
||||||
|
h.sendPrewarmHint(task)
|
||||||
|
|
||||||
|
slog.Info("job submitted", "task_id", spec.ID, "type", spec.Type)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) getTask(taskID string) *Task {
|
||||||
|
t := h.batchQueue.Get(taskID)
|
||||||
|
if t != nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return h.serviceQueue.Get(taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) restoreJob(ev StateEvent) {
|
||||||
|
// Parse job spec from event payload
|
||||||
|
var spec JobSpec
|
||||||
|
if err := json.Unmarshal(ev.Payload, &spec); err != nil {
|
||||||
|
slog.Error("failed to restore job", "task_id", ev.TaskID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
task := &Task{
|
||||||
|
ID: ev.TaskID,
|
||||||
|
Spec: spec,
|
||||||
|
Status: "queued",
|
||||||
|
SubmittedAt: ev.Timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to appropriate queue
|
||||||
|
if spec.Type == JobTypeService {
|
||||||
|
h.serviceQueue.Add(task)
|
||||||
|
} else {
|
||||||
|
h.batchQueue.Add(task)
|
||||||
|
}
|
||||||
|
slog.Info("restored job from state", "task_id", ev.TaskID, "type", spec.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) restoreAssignment(ev StateEvent) {
|
||||||
|
// Parse assignment from event payload
|
||||||
|
var payload struct {
|
||||||
|
WorkerID string `json:"worker_id"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(ev.Payload, &payload); err != nil {
|
||||||
|
slog.Error("failed to restore assignment", "task_id", ev.TaskID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore pending acceptance state
|
||||||
|
h.pendingAcceptance[ev.TaskID] = &JobAssignment{
|
||||||
|
TaskID: ev.TaskID,
|
||||||
|
WorkerID: payload.WorkerID,
|
||||||
|
AssignedAt: ev.Timestamp,
|
||||||
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||||
|
Accepted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("restored assignment from state", "task_id", ev.TaskID, "worker_id", payload.WorkerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) restoreLease(taskID, workerID string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if assignment, ok := h.pendingAcceptance[taskID]; ok {
|
||||||
|
assignment.Accepted = true
|
||||||
|
slog.Info("restored lease", "task_id", taskID, "worker_id", workerID)
|
||||||
|
} else {
|
||||||
|
// Create a new lease record
|
||||||
|
h.pendingAcceptance[taskID] = &JobAssignment{
|
||||||
|
TaskID: taskID,
|
||||||
|
WorkerID: workerID,
|
||||||
|
AssignedAt: time.Now(),
|
||||||
|
AcceptanceDeadline: time.Now().Add(time.Duration(h.config.AcceptanceTimeoutSecs) * time.Second),
|
||||||
|
Accepted: true,
|
||||||
|
}
|
||||||
|
slog.Info("created new lease on reconnect", "task_id", taskID, "worker_id", workerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) reconcileOrphans() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// After grace period (30s), any job still assigned to a disconnected worker
|
||||||
|
// is considered orphaned and should be re-queued
|
||||||
|
for taskID, assignment := range h.pendingAcceptance {
|
||||||
|
if assignment.Accepted {
|
||||||
|
// Job was accepted but worker is gone (not in h.workers)
|
||||||
|
if _, stillConnected := h.workers[assignment.WorkerID]; !stillConnected {
|
||||||
|
task := h.getTask(taskID)
|
||||||
|
if task != nil {
|
||||||
|
task.Status = "orphaned"
|
||||||
|
h.batchQueue.Add(task)
|
||||||
|
h.state.Append(StateEvent{
|
||||||
|
Type: EventJobRequeued,
|
||||||
|
TaskID: taskID,
|
||||||
|
WorkerID: assignment.WorkerID,
|
||||||
|
})
|
||||||
|
slog.Info("orphaned job re-queued", "task_id", taskID, "worker_id", assignment.WorkerID)
|
||||||
|
}
|
||||||
|
delete(h.pendingAcceptance, taskID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SchedulerHub) updateWorkerMetrics(workerID string, slots SlotStatus) {
|
||||||
|
h.metrics.mu.Lock()
|
||||||
|
defer h.metrics.mu.Unlock()
|
||||||
|
h.metrics.WorkerSlots[workerID] = slots
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendPrewarmHint sends a prewarm hint to an idle worker when a job with snapshot is enqueued
|
||||||
|
// TODO: Call this when jobs are enqueued via scheduler API
|
||||||
|
func (h *SchedulerHub) sendPrewarmHint(task *Task) {
|
||||||
|
if task.Spec.SnapshotID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, wc := range h.readyWorkers {
|
||||||
|
if h.canAdmit(task, wc) {
|
||||||
|
wc.send <- Message{
|
||||||
|
Type: MsgPrewarmHint,
|
||||||
|
Payload: mustMarshal(PrewarmHintPayload{
|
||||||
|
TaskID: task.ID,
|
||||||
|
SnapshotID: task.Spec.SnapshotID,
|
||||||
|
SnapshotSHA: task.Spec.SnapshotSHA,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
return // one worker prewarms — not all of them
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMetricsClient handles metrics clients over WSS
|
||||||
|
func (h *SchedulerHub) runMetricsClient(clientID string, conn *websocket.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg Message
|
||||||
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
|
return // Connection closed
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Type == MsgMetricsRequest {
|
||||||
|
metrics := h.getMetricsPayload()
|
||||||
|
conn.WriteJSON(Message{
|
||||||
|
Type: MsgMetricsResponse,
|
||||||
|
Payload: mustMarshal(metrics),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsPayload returns current metrics as a map
|
||||||
|
func (h *SchedulerHub) getMetricsPayload() map[string]any {
|
||||||
|
h.metrics.mu.RLock()
|
||||||
|
defer h.metrics.mu.RUnlock()
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"workers_connected": h.metrics.WorkersConnected,
|
||||||
|
"queue_depth_batch": h.batchQueue.Len(),
|
||||||
|
"queue_depth_service": h.serviceQueue.Len(),
|
||||||
|
"jobs_completed": h.metrics.JobsCompleted,
|
||||||
|
"jobs_failed": h.metrics.JobsFailed,
|
||||||
|
"jobs_cancelled": h.metrics.JobsCancelled,
|
||||||
|
"worker_slots": h.metrics.WorkerSlots,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeMetrics serves Prometheus-formatted metrics (deprecated, use WSS)
|
||||||
|
func (h *SchedulerHub) ServeMetrics(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.metrics.mu.RLock()
|
||||||
|
defer h.metrics.mu.RUnlock()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "# HELP fetch_ml_workers_connected Number of connected workers\n")
|
||||||
|
fmt.Fprintf(w, "# TYPE fetch_ml_workers_connected gauge\n")
|
||||||
|
fmt.Fprintf(w, "fetch_ml_workers_connected %d\n\n", h.metrics.WorkersConnected)
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "# HELP fetch_ml_queue_depth Current queue depth\n")
|
||||||
|
fmt.Fprintf(w, "# TYPE fetch_ml_queue_depth gauge\n")
|
||||||
|
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"batch\"} %d\n", h.batchQueue.Len())
|
||||||
|
fmt.Fprintf(w, "fetch_ml_queue_depth{pool=\"service\"} %d\n\n", h.serviceQueue.Len())
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "# HELP fetch_ml_jobs_total Total jobs by result\n")
|
||||||
|
fmt.Fprintf(w, "# TYPE fetch_ml_jobs_total counter\n")
|
||||||
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"completed\"} %d\n", h.metrics.JobsCompleted)
|
||||||
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"failed\"} %d\n", h.metrics.JobsFailed)
|
||||||
|
fmt.Fprintf(w, "fetch_ml_jobs_total{result=\"cancelled\"} %d\n\n", h.metrics.JobsCancelled)
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "# HELP fetch_ml_slot_utilization Slot utilization by worker\n")
|
||||||
|
fmt.Fprintf(w, "# TYPE fetch_ml_slot_utilization gauge\n")
|
||||||
|
for workerID, slots := range h.metrics.WorkerSlots {
|
||||||
|
if slots.BatchTotal > 0 {
|
||||||
|
utilization := float64(slots.BatchInUse) / float64(slots.BatchTotal)
|
||||||
|
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"batch\"} %.2f\n", workerID, utilization)
|
||||||
|
}
|
||||||
|
if slots.ServiceTotal > 0 {
|
||||||
|
utilization := float64(slots.ServiceInUse) / float64(slots.ServiceTotal)
|
||||||
|
fmt.Fprintf(w, "fetch_ml_slot_utilization{worker=\"%s\",pool=\"service\"} %.2f\n", workerID, utilization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryGangAlloc attempts to allocate a multi-node job to a worker
|
||||||
|
// It tracks partial allocations and dispatches when all nodes are committed
|
||||||
|
func (h *SchedulerHub) tryGangAlloc(task *Task, wc *WorkerConn) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if this worker can run the job
|
||||||
|
if !h.canAdmit(task, wc) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jobID := task.ID
|
||||||
|
pending, ok := h.multiNodePending[jobID]
|
||||||
|
if !ok {
|
||||||
|
// First worker for this job
|
||||||
|
pending = &MultiNodeJob{
|
||||||
|
JobID: jobID,
|
||||||
|
TotalNodes: task.Spec.NodeCount,
|
||||||
|
Assignments: make([]*NodeAssignment, 0, task.Spec.NodeCount),
|
||||||
|
CommittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
h.multiNodePending[jobID] = pending
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add this worker to the pending assignment
|
||||||
|
assignment := &NodeAssignment{
|
||||||
|
Worker: wc,
|
||||||
|
Rank: len(pending.Assignments),
|
||||||
|
CommittedAt: time.Now(),
|
||||||
|
}
|
||||||
|
pending.Assignments = append(pending.Assignments, assignment)
|
||||||
|
|
||||||
|
// Reserve slots on this worker
|
||||||
|
wc.slots.BatchInUse++
|
||||||
|
delete(h.readyWorkers, wc.workerID)
|
||||||
|
|
||||||
|
// Check if we have all nodes
|
||||||
|
if len(pending.Assignments) < task.Spec.NodeCount {
|
||||||
|
// Still waiting for more workers
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// All nodes committed - dispatch simultaneously
|
||||||
|
headAddr := pending.Assignments[0].Worker.capabilities.Hostname
|
||||||
|
for i, a := range pending.Assignments {
|
||||||
|
// Create rank-specific job spec
|
||||||
|
spec := h.buildRankedSpec(task, i, headAddr, task.Spec.NodeCount)
|
||||||
|
msg := Message{
|
||||||
|
Type: MsgJobAssign,
|
||||||
|
Payload: mustMarshal(spec),
|
||||||
|
}
|
||||||
|
a.Worker.send <- msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up pending state
|
||||||
|
delete(h.multiNodePending, jobID)
|
||||||
|
slog.Info("multi-node job dispatched",
|
||||||
|
"job_id", jobID,
|
||||||
|
"nodes", task.Spec.NodeCount,
|
||||||
|
"head_addr", headAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRankedSpec creates a job spec with rank-specific template variables resolved
|
||||||
|
func (h *SchedulerHub) buildRankedSpec(task *Task, rank int, headAddr string, worldSize int) JobSpec {
|
||||||
|
// Clone the spec and add rank info to metadata
|
||||||
|
spec := task.Spec
|
||||||
|
spec.Metadata = make(map[string]string, len(task.Spec.Metadata)+3)
|
||||||
|
for k, v := range task.Spec.Metadata {
|
||||||
|
spec.Metadata[k] = v
|
||||||
|
}
|
||||||
|
spec.Metadata["HEAD_ADDR"] = headAddr
|
||||||
|
spec.Metadata["WORLD_SIZE"] = fmt.Sprintf("%d", worldSize)
|
||||||
|
spec.Metadata["NODE_RANK"] = fmt.Sprintf("%d", rank)
|
||||||
|
return spec
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMultiNodeReady handles a ready signal for a multi-node job
|
||||||
|
// Returns true if the job was handled (either assigned or queued for gang alloc)
|
||||||
|
func (h *SchedulerHub) handleMultiNodeReady(task *Task, wc *WorkerConn) bool {
|
||||||
|
if task.Spec.NodeCount <= 1 {
|
||||||
|
return false // Not a multi-node job
|
||||||
|
}
|
||||||
|
|
||||||
|
h.tryGangAlloc(task, wc)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshal(v any) []byte {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return b
|
||||||
|
}
|
||||||
28
internal/scheduler/pacing.go
Normal file
28
internal/scheduler/pacing.go
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
// AdaptivePacingController derives request pacing based on worker capacity.
|
||||||
|
type AdaptivePacingController struct {
|
||||||
|
DesiredRPSPerWorker int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAdaptivePacingController constructs a controller with sane defaults.
|
||||||
|
func NewAdaptivePacingController(desired int) AdaptivePacingController {
|
||||||
|
if desired < 1 {
|
||||||
|
desired = 1
|
||||||
|
}
|
||||||
|
return AdaptivePacingController{DesiredRPSPerWorker: desired}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestsPerSec returns max(1, maxWorkers * desiredRPSPerWorker).
|
||||||
|
func (a AdaptivePacingController) RequestsPerSec(maxWorkers int) int {
|
||||||
|
if maxWorkers < 1 {
|
||||||
|
maxWorkers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
rps := maxWorkers * a.DesiredRPSPerWorker
|
||||||
|
if rps < 1 {
|
||||||
|
rps = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return rps
|
||||||
|
}
|
||||||
145
internal/scheduler/port_allocator.go
Normal file
145
internal/scheduler/port_allocator.go
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default port range for service jobs (Jupyter, vLLM, etc.)
|
||||||
|
const (
|
||||||
|
DefaultServicePortStart = 8000
|
||||||
|
DefaultServicePortEnd = 9000
|
||||||
|
)
|
||||||
|
|
||||||
|
// PortAllocator manages dynamic port allocation for service jobs
|
||||||
|
// It tracks which ports are in use and assigns available ports from a configured range
|
||||||
|
// This is thread-safe for concurrent allocations across multiple workers
|
||||||
|
|
||||||
|
type PortAllocator struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
used map[int]allocation
|
||||||
|
ttl time.Duration // How long to keep port reserved after release
|
||||||
|
}
|
||||||
|
|
||||||
|
type allocation struct {
|
||||||
|
taskID string
|
||||||
|
allocated time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPortAllocator creates a new port allocator for the given range
|
||||||
|
// Default port range is 10000-65535 to avoid well-known ports
|
||||||
|
func NewPortAllocator(start, end int) *PortAllocator {
|
||||||
|
if start <= 0 {
|
||||||
|
start = 10000
|
||||||
|
}
|
||||||
|
if end <= 0 || end > 65535 {
|
||||||
|
end = 65535
|
||||||
|
}
|
||||||
|
if start >= end {
|
||||||
|
start = 10000
|
||||||
|
end = 65535
|
||||||
|
}
|
||||||
|
return &PortAllocator{
|
||||||
|
start: start,
|
||||||
|
end: end,
|
||||||
|
used: make(map[int]allocation),
|
||||||
|
ttl: 30 * time.Second, // Prevent immediate reuse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate assigns an available port to a task
|
||||||
|
// Returns error if no ports available or port is already in use
|
||||||
|
func (pa *PortAllocator) Allocate(taskID string) (int, error) {
|
||||||
|
pa.mu.Lock()
|
||||||
|
defer pa.mu.Unlock()
|
||||||
|
|
||||||
|
// Clean up expired allocations
|
||||||
|
pa.cleanupExpired()
|
||||||
|
|
||||||
|
// Try to find an available port
|
||||||
|
for port := pa.start; port <= pa.end; port++ {
|
||||||
|
if _, inUse := pa.used[port]; !inUse {
|
||||||
|
// Verify port is actually available on the system
|
||||||
|
if !pa.isPortAvailable(port) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pa.used[port] = allocation{
|
||||||
|
taskID: taskID,
|
||||||
|
allocated: time.Now(),
|
||||||
|
}
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("no ports available in range %d-%d", pa.start, pa.end)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release frees a port for reuse (after TTL expires)
|
||||||
|
func (pa *PortAllocator) Release(port int) {
|
||||||
|
pa.mu.Lock()
|
||||||
|
defer pa.mu.Unlock()
|
||||||
|
|
||||||
|
if alloc, exists := pa.used[port]; exists {
|
||||||
|
// Don't delete immediately - mark with release time
|
||||||
|
// so it can't be immediately reallocated
|
||||||
|
pa.used[port] = allocation{
|
||||||
|
taskID: alloc.taskID + ":released",
|
||||||
|
allocated: time.Now().Add(-pa.ttl), // Expired
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllocation returns the task ID for a given port, or empty if not allocated
|
||||||
|
func (pa *PortAllocator) GetAllocation(port int) string {
|
||||||
|
pa.mu.Lock()
|
||||||
|
defer pa.mu.Unlock()
|
||||||
|
|
||||||
|
if alloc, exists := pa.used[port]; exists && !pa.isExpired(alloc) {
|
||||||
|
return alloc.taskID
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpired removes expired allocations
|
||||||
|
func (pa *PortAllocator) cleanupExpired() {
|
||||||
|
for port, alloc := range pa.used {
|
||||||
|
if pa.isExpired(alloc) {
|
||||||
|
delete(pa.used, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isExpired checks if an allocation has expired
|
||||||
|
func (pa *PortAllocator) isExpired(alloc allocation) bool {
|
||||||
|
return time.Since(alloc.allocated) > pa.ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortAvailable checks if a port is actually available on the system
|
||||||
|
func (pa *PortAllocator) isPortAvailable(port int) bool {
|
||||||
|
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ln.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableCount returns the number of available ports
|
||||||
|
func (pa *PortAllocator) AvailableCount() int {
|
||||||
|
pa.mu.Lock()
|
||||||
|
defer pa.mu.Unlock()
|
||||||
|
|
||||||
|
pa.cleanupExpired()
|
||||||
|
return (pa.end - pa.start + 1) - len(pa.used)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTTL changes the time-to-live for released ports (for testing)
|
||||||
|
func (pa *PortAllocator) SetTTL(ttl time.Duration) {
|
||||||
|
pa.mu.Lock()
|
||||||
|
defer pa.mu.Unlock()
|
||||||
|
pa.ttl = ttl
|
||||||
|
}
|
||||||
175
internal/scheduler/priority_queue.go
Normal file
175
internal/scheduler/priority_queue.go
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Task represents a job in the priority queue
|
||||||
|
type Task struct {
|
||||||
|
ID string
|
||||||
|
Priority int
|
||||||
|
SubmittedAt time.Time
|
||||||
|
Spec JobSpec
|
||||||
|
Status string
|
||||||
|
WorkerID string
|
||||||
|
Metadata map[string]string // Additional task metadata (snapshot SHA, etc.)
|
||||||
|
index int // for heap interface
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectivePriority returns the priority with aging applied
|
||||||
|
func (t *Task) EffectivePriority(agingRate float64, now time.Time) float64 {
|
||||||
|
age := now.Sub(t.SubmittedAt).Minutes()
|
||||||
|
return float64(t.Priority) + age*agingRate
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskHeap is the internal heap implementation
|
||||||
|
type taskHeap struct {
|
||||||
|
items []*Task
|
||||||
|
agingRate float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h taskHeap) Len() int { return len(h.items) }
|
||||||
|
|
||||||
|
func (h taskHeap) Less(i, j int) bool {
|
||||||
|
// Higher priority first, then older first on ties
|
||||||
|
now := time.Now()
|
||||||
|
pi := h.items[i].EffectivePriority(h.agingRate, now)
|
||||||
|
pj := h.items[j].EffectivePriority(h.agingRate, now)
|
||||||
|
if pi != pj {
|
||||||
|
return pi > pj
|
||||||
|
}
|
||||||
|
return h.items[i].SubmittedAt.Before(h.items[j].SubmittedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h taskHeap) Swap(i, j int) {
|
||||||
|
h.items[i], h.items[j] = h.items[j], h.items[i]
|
||||||
|
h.items[i].index = i
|
||||||
|
h.items[j].index = j
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *taskHeap) Push(x any) {
|
||||||
|
n := len(h.items)
|
||||||
|
task := x.(*Task)
|
||||||
|
task.index = n
|
||||||
|
h.items = append(h.items, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *taskHeap) Pop() any {
|
||||||
|
old := h.items
|
||||||
|
n := len(old)
|
||||||
|
task := old[n-1]
|
||||||
|
old[n-1] = nil // avoid memory leak
|
||||||
|
task.index = -1
|
||||||
|
h.items = old[:n-1]
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
|
||||||
|
// PriorityQueue implements a thread-safe priority queue for tasks
|
||||||
|
type PriorityQueue struct {
|
||||||
|
heap *taskHeap
|
||||||
|
mu sync.RWMutex
|
||||||
|
byID map[string]*Task
|
||||||
|
agingRate float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPriorityQueue creates a new priority queue
|
||||||
|
func NewPriorityQueue(agingRate float64) *PriorityQueue {
|
||||||
|
if agingRate == 0 {
|
||||||
|
agingRate = 0.1 // default: 0.1 per minute
|
||||||
|
}
|
||||||
|
return &PriorityQueue{
|
||||||
|
heap: &taskHeap{
|
||||||
|
items: make([]*Task, 0),
|
||||||
|
agingRate: agingRate,
|
||||||
|
},
|
||||||
|
byID: make(map[string]*Task),
|
||||||
|
agingRate: agingRate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of items in the queue
|
||||||
|
func (pq *PriorityQueue) Len() int {
|
||||||
|
pq.mu.RLock()
|
||||||
|
defer pq.mu.RUnlock()
|
||||||
|
return len(pq.heap.items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a task to the queue
|
||||||
|
func (pq *PriorityQueue) Add(task *Task) {
|
||||||
|
pq.mu.Lock()
|
||||||
|
defer pq.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := pq.byID[task.ID]; exists {
|
||||||
|
return // already in queue
|
||||||
|
}
|
||||||
|
|
||||||
|
pq.byID[task.ID] = task
|
||||||
|
heap.Push(pq.heap, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take removes and returns the highest priority task
|
||||||
|
func (pq *PriorityQueue) Take() *Task {
|
||||||
|
pq.mu.Lock()
|
||||||
|
defer pq.mu.Unlock()
|
||||||
|
|
||||||
|
if len(pq.heap.items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
task := heap.Pop(pq.heap).(*Task)
|
||||||
|
delete(pq.byID, task.ID)
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek returns the highest priority task without removing it
|
||||||
|
func (pq *PriorityQueue) Peek() *Task {
|
||||||
|
pq.mu.RLock()
|
||||||
|
defer pq.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(pq.heap.items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return pq.heap.items[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Items returns a copy of all items in priority order
|
||||||
|
func (pq *PriorityQueue) Items() []*Task {
|
||||||
|
pq.mu.RLock()
|
||||||
|
defer pq.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]*Task, len(pq.heap.items))
|
||||||
|
copy(result, pq.heap.items)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a task by ID
|
||||||
|
func (pq *PriorityQueue) Get(taskID string) *Task {
|
||||||
|
pq.mu.RLock()
|
||||||
|
defer pq.mu.RUnlock()
|
||||||
|
return pq.byID[taskID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes a task from the queue
|
||||||
|
func (pq *PriorityQueue) Remove(taskID string) bool {
|
||||||
|
pq.mu.Lock()
|
||||||
|
defer pq.mu.Unlock()
|
||||||
|
|
||||||
|
task, exists := pq.byID[taskID]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
heap.Remove(pq.heap, task.index)
|
||||||
|
delete(pq.byID, task.ID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks if a task is in the queue
|
||||||
|
func (pq *PriorityQueue) Contains(taskID string) bool {
|
||||||
|
pq.mu.RLock()
|
||||||
|
defer pq.mu.RUnlock()
|
||||||
|
_, exists := pq.byID[taskID]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
137
internal/scheduler/protocol.go
Normal file
137
internal/scheduler/protocol.go
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Payload json.RawMessage `json:"payload,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Worker → Scheduler
|
||||||
|
MsgRegister MessageType = "register"
|
||||||
|
MsgHeartbeat MessageType = "heartbeat" // slots only, every 10s
|
||||||
|
MsgReadyForWork MessageType = "ready_for_work"
|
||||||
|
MsgJobAccepted MessageType = "job_accepted"
|
||||||
|
MsgJobResult MessageType = "job_result"
|
||||||
|
MsgServiceHealth MessageType = "service_health"
|
||||||
|
MsgMetricsRequest MessageType = "metrics_request" // WSS metrics request
|
||||||
|
|
||||||
|
// Scheduler → Worker
|
||||||
|
MsgJobAssign MessageType = "job_assign"
|
||||||
|
MsgNoWork MessageType = "no_work" // nothing available right now
|
||||||
|
MsgJobCancel MessageType = "job_cancel"
|
||||||
|
MsgPrewarmHint MessageType = "prewarm_hint"
|
||||||
|
MsgAck MessageType = "ack"
|
||||||
|
MsgMetricsResponse MessageType = "metrics_response" // WSS metrics response
|
||||||
|
)
|
||||||
|
|
||||||
|
// Heartbeat — liveness and slot status combined, no CPU/mem load
|
||||||
|
type HeartbeatPayload struct {
|
||||||
|
WorkerID string `json:"worker_id"`
|
||||||
|
Slots SlotStatus `json:"slots"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReadyPayload struct {
|
||||||
|
WorkerID string `json:"worker_id"`
|
||||||
|
Slots SlotStatus `json:"slots"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobResultPayload struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
ExitCode int `json:"exit_code"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PrewarmHintPayload struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
SnapshotID string `json:"snapshot_id"`
|
||||||
|
SnapshotSHA string `json:"snapshot_sha,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WorkerRegistration struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Capabilities WorkerCapabilities `json:"capabilities"`
|
||||||
|
ActiveTasks []ActiveTaskReport `json:"active_tasks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActiveTaskReport struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
State string `json:"state"`
|
||||||
|
StartedAt time.Time `json:"started_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SlotStatus struct {
|
||||||
|
BatchTotal int `json:"batch_total"`
|
||||||
|
BatchInUse int `json:"batch_in_use"`
|
||||||
|
ServiceTotal int `json:"service_total"`
|
||||||
|
ServiceInUse int `json:"service_in_use"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SlotStatus) BatchAvailable() int { return s.BatchTotal - s.BatchInUse }
|
||||||
|
func (s SlotStatus) ServiceAvailable() int { return s.ServiceTotal - s.ServiceInUse }
|
||||||
|
|
||||||
|
type WorkerCapabilities struct {
|
||||||
|
GPUInfo GPUDetectionInfo `json:"gpu_info"`
|
||||||
|
GPUCount int `json:"gpu_count"`
|
||||||
|
GPUType string `json:"gpu_type"`
|
||||||
|
CPUCount int `json:"cpu_count"`
|
||||||
|
MemoryGB float64 `json:"memory_gb"`
|
||||||
|
Hostname string `json:"hostname"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GPUDetectionInfo struct {
|
||||||
|
GPUType string `json:"gpu_type"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Devices []string `json:"devices,omitempty"`
|
||||||
|
Driver string `json:"driver,omitempty"`
|
||||||
|
MemTotal uint64 `json:"mem_total,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobSpec struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type JobType `json:"type"` // "batch" | "service"
|
||||||
|
SlotPool string `json:"slot_pool"`
|
||||||
|
|
||||||
|
GPUCount int `json:"gpu_count"`
|
||||||
|
GPUType string `json:"gpu_type,omitempty"`
|
||||||
|
NodeCount int `json:"node_count"`
|
||||||
|
|
||||||
|
Command []string `json:"command"`
|
||||||
|
Env map[string]string `json:"env"`
|
||||||
|
|
||||||
|
Prolog []string `json:"prolog,omitempty"`
|
||||||
|
Epilog []string `json:"epilog,omitempty"`
|
||||||
|
|
||||||
|
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||||
|
SnapshotSHA string `json:"snapshot_sha,omitempty"`
|
||||||
|
HealthCheck *HealthCheck `json:"health_check,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
JobTypeBatch JobType = "batch"
|
||||||
|
JobTypeService JobType = "service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HealthCheck struct {
|
||||||
|
LivenessEndpoint string `json:"liveness"`
|
||||||
|
ReadinessEndpoint string `json:"readiness"`
|
||||||
|
IntervalSecs int `json:"interval_secs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceHealthPayload struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
Healthy bool `json:"healthy"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
217
internal/scheduler/scheduler_conn.go
Normal file
217
internal/scheduler/scheduler_conn.go
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SchedulerConn manages the WebSocket connection to the scheduler
|
||||||
|
type SchedulerConn struct {
|
||||||
|
addr string
|
||||||
|
certFile string
|
||||||
|
token string
|
||||||
|
conn *websocket.Conn
|
||||||
|
workerID string
|
||||||
|
capabilities WorkerCapabilities
|
||||||
|
send chan Message
|
||||||
|
activeTasks sync.Map
|
||||||
|
mu sync.RWMutex
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSchedulerConn creates a new scheduler connection
|
||||||
|
func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn {
|
||||||
|
return &SchedulerConn{
|
||||||
|
addr: addr,
|
||||||
|
certFile: certFile,
|
||||||
|
token: token,
|
||||||
|
workerID: workerID,
|
||||||
|
capabilities: caps,
|
||||||
|
send: make(chan Message, 100),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes the WebSocket connection
|
||||||
|
func (sc *SchedulerConn) Connect() error {
|
||||||
|
var conn *websocket.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if sc.certFile == "" {
|
||||||
|
// Local mode - no TLS, parse port from address
|
||||||
|
port := parsePortFromAddr(sc.addr)
|
||||||
|
conn, err = LocalModeDial(port, sc.token)
|
||||||
|
} else {
|
||||||
|
conn, err = DialWSS(sc.addr, sc.certFile, sc.token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.mu.Lock()
|
||||||
|
sc.conn = conn
|
||||||
|
sc.closed = false
|
||||||
|
sc.mu.Unlock()
|
||||||
|
|
||||||
|
// Send registration
|
||||||
|
sc.Send(Message{
|
||||||
|
Type: MsgRegister,
|
||||||
|
Payload: mustMarshal(WorkerRegistration{
|
||||||
|
ID: sc.workerID,
|
||||||
|
Capabilities: sc.capabilities,
|
||||||
|
ActiveTasks: sc.collectActiveTasks(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the scheduler
|
||||||
|
func (sc *SchedulerConn) Send(msg Message) {
|
||||||
|
select {
|
||||||
|
case sc.send <- msg:
|
||||||
|
default:
|
||||||
|
// Channel full, drop message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the send/receive loops
|
||||||
|
func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) {
|
||||||
|
// Send loop
|
||||||
|
go func() {
|
||||||
|
for msg := range sc.send {
|
||||||
|
sc.mu.RLock()
|
||||||
|
conn := sc.conn
|
||||||
|
closed := sc.closed
|
||||||
|
sc.mu.RUnlock()
|
||||||
|
|
||||||
|
if closed || conn == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.WriteJSON(msg); err != nil {
|
||||||
|
// Trigger reconnect
|
||||||
|
go sc.reconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Receive loop
|
||||||
|
for {
|
||||||
|
sc.mu.RLock()
|
||||||
|
conn := sc.conn
|
||||||
|
sc.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg Message
|
||||||
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
|
// Connection lost, reconnect
|
||||||
|
sc.reconnect()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
case MsgJobAssign:
|
||||||
|
var spec JobSpec
|
||||||
|
json.Unmarshal(msg.Payload, &spec)
|
||||||
|
onJobAssign(&spec)
|
||||||
|
case MsgJobCancel:
|
||||||
|
var taskID string
|
||||||
|
json.Unmarshal(msg.Payload, &taskID)
|
||||||
|
onJobCancel(taskID)
|
||||||
|
case MsgPrewarmHint:
|
||||||
|
var hint PrewarmHintPayload
|
||||||
|
json.Unmarshal(msg.Payload, &hint)
|
||||||
|
onPrewarmHint(hint)
|
||||||
|
case MsgNoWork:
|
||||||
|
// No action needed - worker will retry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect attempts to reconnect with exponential backoff
|
||||||
|
func (sc *SchedulerConn) reconnect() {
|
||||||
|
sc.mu.Lock()
|
||||||
|
if sc.closed {
|
||||||
|
sc.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sc.conn = nil
|
||||||
|
sc.mu.Unlock()
|
||||||
|
|
||||||
|
backoff := 1 * time.Second
|
||||||
|
maxBackoff := 30 * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
time.Sleep(backoff)
|
||||||
|
|
||||||
|
if err := sc.Connect(); err == nil {
|
||||||
|
return // Reconnected successfully
|
||||||
|
}
|
||||||
|
|
||||||
|
backoff *= 2
|
||||||
|
if backoff > maxBackoff {
|
||||||
|
backoff = maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection
|
||||||
|
func (sc *SchedulerConn) Close() {
|
||||||
|
sc.mu.Lock()
|
||||||
|
defer sc.mu.Unlock()
|
||||||
|
|
||||||
|
sc.closed = true
|
||||||
|
if sc.conn != nil {
|
||||||
|
sc.conn.Close()
|
||||||
|
}
|
||||||
|
close(sc.send)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackTask tracks an active task
|
||||||
|
func (sc *SchedulerConn) TrackTask(taskID string) {
|
||||||
|
sc.activeTasks.Store(taskID, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrackTask removes a task from tracking
|
||||||
|
func (sc *SchedulerConn) UntrackTask(taskID string) {
|
||||||
|
sc.activeTasks.Delete(taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport {
|
||||||
|
var reports []ActiveTaskReport
|
||||||
|
sc.activeTasks.Range(func(key, value any) bool {
|
||||||
|
taskID := key.(string)
|
||||||
|
startedAt := value.(time.Time)
|
||||||
|
reports = append(reports, ActiveTaskReport{
|
||||||
|
TaskID: taskID,
|
||||||
|
State: "running",
|
||||||
|
StartedAt: startedAt,
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return reports
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePortFromAddr extracts port from "host:port" address string
|
||||||
|
// Returns default port 7777 if parsing fails
|
||||||
|
func parsePortFromAddr(addr string) int {
|
||||||
|
parts := strings.Split(addr, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 7777
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return 7777
|
||||||
|
}
|
||||||
|
return port
|
||||||
|
}
|
||||||
367
internal/scheduler/service_manager.go
Normal file
367
internal/scheduler/service_manager.go
Normal file
|
|
@ -0,0 +1,367 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServiceManager handles the lifecycle of service-type jobs
|
||||||
|
// Services transition: preparing → serving → stopping → completed/failed
|
||||||
|
// Unlike batch jobs, services run indefinitely until explicitly cancelled
|
||||||
|
// and have health checks for liveness and readiness
|
||||||
|
|
||||||
|
type ServiceManager struct {
|
||||||
|
task *Task
|
||||||
|
spec *JobSpec
|
||||||
|
port int
|
||||||
|
cmd *exec.Cmd
|
||||||
|
cancel context.CancelFunc
|
||||||
|
healthy bool
|
||||||
|
ready bool
|
||||||
|
lastHealth time.Time
|
||||||
|
stateMachine *StateMachine
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateMachine manages service state transitions
|
||||||
|
// It ensures valid transitions and notifies the scheduler of changes
|
||||||
|
type StateMachine struct {
|
||||||
|
current string
|
||||||
|
onChange func(oldState, newState string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServiceManager creates a new service manager for a task
|
||||||
|
func NewServiceManager(task *Task, spec *JobSpec, port int) *ServiceManager {
|
||||||
|
return &ServiceManager{
|
||||||
|
task: task,
|
||||||
|
spec: spec,
|
||||||
|
port: port,
|
||||||
|
healthy: false,
|
||||||
|
ready: false,
|
||||||
|
stateMachine: &StateMachine{
|
||||||
|
current: "preparing",
|
||||||
|
onChange: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStateChangeCallback sets a callback for state transitions
|
||||||
|
func (sm *ServiceManager) SetStateChangeCallback(cb func(oldState, newState string)) {
|
||||||
|
if sm.stateMachine != nil {
|
||||||
|
sm.stateMachine.onChange = cb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts and manages the service lifecycle
|
||||||
|
// It runs prolog, starts the service, waits for readiness, then health checks
|
||||||
|
func (sm *ServiceManager) Run(ctx context.Context) error {
|
||||||
|
// Create cancellable context for the service
|
||||||
|
svcCtx, cancel := context.WithCancel(ctx)
|
||||||
|
sm.cancel = cancel
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Run prolog if configured
|
||||||
|
if len(sm.spec.Prolog) > 0 {
|
||||||
|
sm.transition("preparing")
|
||||||
|
if err := sm.runProlog(svcCtx); err != nil {
|
||||||
|
sm.transition("failed")
|
||||||
|
return fmt.Errorf("prolog failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the service process
|
||||||
|
if err := sm.startService(svcCtx); err != nil {
|
||||||
|
sm.transition("failed")
|
||||||
|
return fmt.Errorf("start service failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for readiness (if health check configured)
|
||||||
|
if sm.spec.HealthCheck != nil && sm.spec.HealthCheck.ReadinessEndpoint != "" {
|
||||||
|
sm.transition("preparing")
|
||||||
|
if err := sm.waitReady(svcCtx, 120*time.Second); err != nil {
|
||||||
|
sm.stopService()
|
||||||
|
sm.transition("failed")
|
||||||
|
return fmt.Errorf("readiness check failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as serving
|
||||||
|
sm.transition("serving")
|
||||||
|
sm.ready = true
|
||||||
|
|
||||||
|
// Run health check loop
|
||||||
|
return sm.healthLoop(svcCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully stops the service
|
||||||
|
// It runs epilog with a fresh context (ignores job cancellation)
|
||||||
|
func (sm *ServiceManager) Stop() error {
|
||||||
|
// Cancel service context
|
||||||
|
if sm.cancel != nil {
|
||||||
|
sm.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run epilog with fresh context - must complete even if job cancelled
|
||||||
|
epilogCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sm.transition("stopping")
|
||||||
|
|
||||||
|
if len(sm.spec.Epilog) > 0 {
|
||||||
|
if err := sm.runEpilog(epilogCtx); err != nil {
|
||||||
|
slog.Warn("epilog failed", "task", sm.task.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure service is stopped
|
||||||
|
sm.stopService()
|
||||||
|
|
||||||
|
sm.transition("completed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runProlog executes prolog commands before starting the service
|
||||||
|
func (sm *ServiceManager) runProlog(ctx context.Context) error {
|
||||||
|
for _, cmdStr := range sm.spec.Prolog {
|
||||||
|
cmd := sm.buildCommand(ctx, cmdStr)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("prolog command failed: %s, error: %w", cmdStr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startService starts the main service process
|
||||||
|
func (sm *ServiceManager) startService(ctx context.Context) error {
|
||||||
|
if len(sm.spec.Command) == 0 {
|
||||||
|
return fmt.Errorf("no command specified for service")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := sm.buildCommand(ctx, sm.spec.Command[0], sm.spec.Command[1:]...)
|
||||||
|
|
||||||
|
// Set up process group for clean termination (Unix-specific)
|
||||||
|
setProcessGroup(cmd)
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.cmd = cmd
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopService stops the service process
|
||||||
|
func (sm *ServiceManager) stopService() {
|
||||||
|
if sm.cmd == nil || sm.cmd.Process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try graceful termination first
|
||||||
|
sm.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Wait for graceful shutdown or timeout
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- sm.cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Graceful shutdown succeeded
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
// Force kill process group (Unix-specific)
|
||||||
|
killProcessGroup(sm.cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runEpilog executes epilog commands after service stops
|
||||||
|
func (sm *ServiceManager) runEpilog(ctx context.Context) error {
|
||||||
|
for _, cmdStr := range sm.spec.Epilog {
|
||||||
|
cmd := sm.buildCommand(ctx, cmdStr)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
slog.Warn("epilog command failed", "task", sm.task.ID, "cmd", cmdStr, "error", err)
|
||||||
|
// Continue with other epilog commands even if one fails
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthLoop runs health checks periodically
|
||||||
|
// Returns when context is cancelled or health check fails
|
||||||
|
func (sm *ServiceManager) healthLoop(ctx context.Context) error {
|
||||||
|
if sm.spec.HealthCheck == nil {
|
||||||
|
// No health check configured - just wait for context cancellation
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(sm.spec.HealthCheck.IntervalSecs) * time.Second
|
||||||
|
if interval < 5*time.Second {
|
||||||
|
interval = 15 * time.Second // Minimum 15s between checks
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
// Check liveness
|
||||||
|
if !sm.checkLiveness() {
|
||||||
|
sm.transition("failed")
|
||||||
|
return fmt.Errorf("liveness check failed")
|
||||||
|
}
|
||||||
|
sm.healthy = true
|
||||||
|
|
||||||
|
// Check readiness (if configured)
|
||||||
|
if sm.spec.HealthCheck.ReadinessEndpoint != "" {
|
||||||
|
sm.ready = sm.checkReadiness()
|
||||||
|
}
|
||||||
|
sm.lastHealth = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitReady waits for the service to become ready
|
||||||
|
func (sm *ServiceManager) waitReady(ctx context.Context, timeout time.Duration) error {
|
||||||
|
if sm.spec.HealthCheck == nil || sm.spec.HealthCheck.ReadinessEndpoint == "" {
|
||||||
|
return nil // No readiness check configured
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
checkInterval := 2 * time.Second
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if sm.checkReadiness() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(checkInterval):
|
||||||
|
// Continue checking
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("readiness check timed out after %v", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkLiveness checks if the service process is running
|
||||||
|
func (sm *ServiceManager) checkLiveness() bool {
|
||||||
|
if sm.cmd == nil || sm.cmd.Process == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if process is still running
|
||||||
|
if !isProcessRunning(sm.cmd) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If liveness endpoint configured, also check HTTP
|
||||||
|
if sm.spec.HealthCheck != nil && sm.spec.HealthCheck.LivenessEndpoint != "" {
|
||||||
|
return sm.checkHTTPEndpoint(sm.spec.HealthCheck.LivenessEndpoint, 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkReadiness checks if the service is ready to receive traffic
|
||||||
|
func (sm *ServiceManager) checkReadiness() bool {
|
||||||
|
if sm.spec.HealthCheck == nil || sm.spec.HealthCheck.ReadinessEndpoint == "" {
|
||||||
|
return sm.healthy // Fall back to liveness
|
||||||
|
}
|
||||||
|
|
||||||
|
return sm.checkHTTPEndpoint(sm.spec.HealthCheck.ReadinessEndpoint, 5*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkHTTPEndpoint makes an HTTP GET request to check endpoint health
|
||||||
|
func (sm *ServiceManager) checkHTTPEndpoint(endpoint string, timeout time.Duration) bool {
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Get(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Drain body to allow connection reuse
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
|
||||||
|
// 2xx status codes indicate success
|
||||||
|
return resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||||
|
}
|
||||||
|
|
||||||
|
// transition changes the service state
|
||||||
|
func (sm *ServiceManager) transition(newState string) {
|
||||||
|
if sm.stateMachine == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldState := sm.stateMachine.current
|
||||||
|
if oldState == newState {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.stateMachine.current = newState
|
||||||
|
|
||||||
|
// Update task status
|
||||||
|
sm.task.Status = newState
|
||||||
|
|
||||||
|
// Notify callback
|
||||||
|
if sm.stateMachine.onChange != nil {
|
||||||
|
sm.stateMachine.onChange(oldState, newState)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("service state transition",
|
||||||
|
"task", sm.task.ID,
|
||||||
|
"from", oldState,
|
||||||
|
"to", newState)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCommand creates an exec.Cmd with environment variables
|
||||||
|
func (sm *ServiceManager) buildCommand(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||||
|
cmd := exec.CommandContext(ctx, name, args...)
|
||||||
|
|
||||||
|
// Set environment variables
|
||||||
|
env := make([]string, 0, len(sm.spec.Env)+4)
|
||||||
|
for k, v := range sm.spec.Env {
|
||||||
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add service-specific variables
|
||||||
|
env = append(env,
|
||||||
|
fmt.Sprintf("SERVICE_PORT=%d", sm.port),
|
||||||
|
fmt.Sprintf("TASK_ID=%s", sm.task.ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd.Env = env
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthy returns true if the service is healthy (process running)
|
||||||
|
func (sm *ServiceManager) IsHealthy() bool {
|
||||||
|
return sm.healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReady returns true if the service is ready to receive traffic
|
||||||
|
func (sm *ServiceManager) IsReady() bool {
|
||||||
|
return sm.ready
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState returns the current service state
|
||||||
|
func (sm *ServiceManager) GetState() string {
|
||||||
|
if sm.stateMachine == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return sm.stateMachine.current
|
||||||
|
}
|
||||||
34
internal/scheduler/service_manager_unix.go
Normal file
34
internal/scheduler/service_manager_unix.go
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcessGroup sets up process group for clean termination on Unix systems
|
||||||
|
func setProcessGroup(cmd *exec.Cmd) {
|
||||||
|
if cmd.SysProcAttr == nil {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||||
|
}
|
||||||
|
cmd.SysProcAttr.Setpgid = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessGroup kills the entire process group on Unix systems
|
||||||
|
func killProcessGroup(cmd *exec.Cmd) {
|
||||||
|
if cmd != nil && cmd.Process != nil {
|
||||||
|
// Negative PID kills the entire process group
|
||||||
|
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProcessRunning checks if a process is still running on Unix systems
|
||||||
|
func isProcessRunning(cmd *exec.Cmd) bool {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Signal 0 is a no-op that just checks if process exists
|
||||||
|
return cmd.Process.Signal(syscall.Signal(0)) == nil
|
||||||
|
}
|
||||||
34
internal/scheduler/service_manager_windows.go
Normal file
34
internal/scheduler/service_manager_windows.go
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcessGroup is a no-op on Windows (process groups work differently)
|
||||||
|
func setProcessGroup(cmd *exec.Cmd) {
|
||||||
|
// Windows doesn't use Setpgid like Unix
|
||||||
|
// Process cleanup is handled differently via job objects or direct process kill
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessGroup kills the process on Windows
|
||||||
|
func killProcessGroup(cmd *exec.Cmd) {
|
||||||
|
if cmd != nil && cmd.Process != nil {
|
||||||
|
// On Windows, we can only kill the process directly
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProcessRunning checks if a process is still running on Windows
|
||||||
|
func isProcessRunning(cmd *exec.Cmd) bool {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// On Windows, try to get process exit code - if it fails, process is still running
|
||||||
|
// A simpler approach: try to open the process handle
|
||||||
|
// For now, we just check if Process object exists
|
||||||
|
// A more robust implementation would use Windows API
|
||||||
|
return true
|
||||||
|
}
|
||||||
145
internal/scheduler/service_templates.go
Normal file
145
internal/scheduler/service_templates.go
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
// Package scheduler provides service plugin templates for fetch_ml.
|
||||||
|
// These templates define how long-running services like Jupyter are configured.
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
// ServiceTemplate defines a service job that runs indefinitely until stopped.
|
||||||
|
// This is used for Jupyter, vLLM, and similar interactive services.
|
||||||
|
type ServiceTemplate struct {
|
||||||
|
// JobType identifies this as a service job
|
||||||
|
JobType string `json:"job_type"` // Always "service"
|
||||||
|
|
||||||
|
// SlotPool specifies which slot pool to use ("batch" or "service")
|
||||||
|
SlotPool string `json:"slot_pool"`
|
||||||
|
|
||||||
|
// GPUCount is the number of GPUs required (can be 0 for CPU-only services)
|
||||||
|
GPUCount int `json:"gpu_count"`
|
||||||
|
|
||||||
|
// Command is the service command with template variables
|
||||||
|
Command []string `json:"command"`
|
||||||
|
|
||||||
|
// Env defines environment variables with template variables
|
||||||
|
Env map[string]string `json:"env"`
|
||||||
|
|
||||||
|
// HealthCheck defines how to verify the service is healthy
|
||||||
|
HealthCheck ServiceHealthCheck `json:"health_check"`
|
||||||
|
|
||||||
|
// Mounts defines volume mounts for the service
|
||||||
|
Mounts []ServiceMount `json:"mounts,omitempty"`
|
||||||
|
|
||||||
|
// Ports to expose (if not using dynamic allocation)
|
||||||
|
Ports []int `json:"ports,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServiceHealthCheck defines liveness and readiness probes
|
||||||
|
type ServiceHealthCheck struct {
|
||||||
|
// Liveness endpoint - checks if service is running
|
||||||
|
Liveness string `json:"liveness"`
|
||||||
|
|
||||||
|
// Readiness endpoint - checks if service is ready for traffic
|
||||||
|
Readiness string `json:"readiness"`
|
||||||
|
|
||||||
|
// Interval between health checks in seconds
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
|
||||||
|
// Timeout for each health check in seconds
|
||||||
|
Timeout int `json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServiceMount defines a volume mount
|
||||||
|
type ServiceMount struct {
|
||||||
|
Source string `json:"source"`
|
||||||
|
Destination string `json:"destination"`
|
||||||
|
ReadOnly bool `json:"readonly,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template variables available in ServiceTemplate:
|
||||||
|
// {{SERVICE_PORT}} - Dynamically allocated port for the service
|
||||||
|
// {{WORKER_ID}} - ID of the worker running the service
|
||||||
|
// {{TASK_ID}} - Unique task ID for this service instance
|
||||||
|
// {{SECRET:xxx}} - Secret value from scheduler's secret store
|
||||||
|
|
||||||
|
// JupyterLabTemplate is the default JupyterLab service configuration.
|
||||||
|
// Sysadmins can disable Jupyter by setting service_slots: 0 in worker config,
|
||||||
|
// or by not registering this template with the scheduler.
|
||||||
|
var JupyterLabTemplate = ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service", // Uses service slot pool, not batch
|
||||||
|
GPUCount: 0, // Jupyter typically runs CPU-only
|
||||||
|
|
||||||
|
Command: []string{
|
||||||
|
"jupyter", "lab",
|
||||||
|
"--ip=0.0.0.0",
|
||||||
|
"--port={{SERVICE_PORT}}",
|
||||||
|
"--no-browser",
|
||||||
|
"--allow-root",
|
||||||
|
"--NotebookApp.token='{{SECRET:jupyter_token}}'",
|
||||||
|
"--NotebookApp.password=''",
|
||||||
|
},
|
||||||
|
|
||||||
|
Env: map[string]string{
|
||||||
|
"JUPYTER_TOKEN": "{{SECRET:jupyter_token}}",
|
||||||
|
"JUPYTER_CONFIG_DIR": "/workspace/.jupyter",
|
||||||
|
},
|
||||||
|
|
||||||
|
HealthCheck: ServiceHealthCheck{
|
||||||
|
Liveness: "http://localhost:{{SERVICE_PORT}}/api",
|
||||||
|
Readiness: "http://localhost:{{SERVICE_PORT}}/api/kernels",
|
||||||
|
Interval: 15,
|
||||||
|
Timeout: 5,
|
||||||
|
},
|
||||||
|
|
||||||
|
Mounts: []ServiceMount{
|
||||||
|
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// JupyterNotebookTemplate is an alternative using classic Jupyter Notebook.
|
||||||
|
var JupyterNotebookTemplate = ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service",
|
||||||
|
GPUCount: 0,
|
||||||
|
|
||||||
|
Command: []string{
|
||||||
|
"jupyter", "notebook",
|
||||||
|
"--ip=0.0.0.0",
|
||||||
|
"--port={{SERVICE_PORT}}",
|
||||||
|
"--no-browser",
|
||||||
|
"--allow-root",
|
||||||
|
"--NotebookApp.token='{{SECRET:jupyter_token}}'",
|
||||||
|
},
|
||||||
|
|
||||||
|
Env: map[string]string{
|
||||||
|
"JUPYTER_TOKEN": "{{SECRET:jupyter_token}}",
|
||||||
|
},
|
||||||
|
|
||||||
|
HealthCheck: ServiceHealthCheck{
|
||||||
|
Liveness: "http://localhost:{{SERVICE_PORT}}/api",
|
||||||
|
Readiness: "http://localhost:{{SERVICE_PORT}}/api/kernels",
|
||||||
|
Interval: 15,
|
||||||
|
Timeout: 5,
|
||||||
|
},
|
||||||
|
|
||||||
|
Mounts: []ServiceMount{
|
||||||
|
{Source: "{{WORKSPACE}}", Destination: "/workspace"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// VLLMTemplate is an example vLLM inference server template (future)
|
||||||
|
var VLLMTemplate = ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service",
|
||||||
|
GPUCount: 1, // Requires GPU for inference
|
||||||
|
|
||||||
|
Command: []string{
|
||||||
|
"python", "-m", "vllm.entrypoints.openai.api_server",
|
||||||
|
"--model", "{{MODEL_NAME}}",
|
||||||
|
"--port", "{{SERVICE_PORT}}",
|
||||||
|
},
|
||||||
|
|
||||||
|
HealthCheck: ServiceHealthCheck{
|
||||||
|
Liveness: "http://localhost:{{SERVICE_PORT}}/health",
|
||||||
|
Readiness: "http://localhost:{{SERVICE_PORT}}/health",
|
||||||
|
Interval: 30,
|
||||||
|
Timeout: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
156
internal/scheduler/state.go
Normal file
156
internal/scheduler/state.go
Normal file
|
|
@ -0,0 +1,156 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateEvent represents a state change event for persistence
|
||||||
|
type StateEvent struct {
|
||||||
|
Type StateEventType `json:"type"`
|
||||||
|
Timestamp time.Time `json:"ts"`
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
WorkerID string `json:"worker_id,omitempty"`
|
||||||
|
Payload json.RawMessage `json:"payload,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StateEventType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventJobEnqueued StateEventType = "job_enqueued"
|
||||||
|
EventJobAssigned StateEventType = "job_assigned"
|
||||||
|
EventJobAccepted StateEventType = "job_accepted"
|
||||||
|
EventJobCompleted StateEventType = "job_completed"
|
||||||
|
EventJobFailed StateEventType = "job_failed"
|
||||||
|
EventJobRequeued StateEventType = "job_requeued"
|
||||||
|
EventJobCancelled StateEventType = "job_cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateStore provides append-only persistence for scheduler state
|
||||||
|
type StateStore struct {
|
||||||
|
path string
|
||||||
|
mu sync.Mutex
|
||||||
|
file *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStateStore creates a new state store at the given path
|
||||||
|
func NewStateStore(path string) (*StateStore, error) {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("create state directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StateStore{
|
||||||
|
path: path,
|
||||||
|
file: file,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append writes a state event to the log
|
||||||
|
func (s *StateStore) Append(event StateEvent) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if event.Timestamp.IsZero() {
|
||||||
|
event.Timestamp = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.file.Write(data); err != nil {
|
||||||
|
return fmt.Errorf("write event: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := s.file.WriteString("\n"); err != nil {
|
||||||
|
return fmt.Errorf("write newline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replay reads all events from the state log
|
||||||
|
func (s *StateStore) Replay() ([]StateEvent, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Close and reopen to ensure we read from the beginning
|
||||||
|
if err := s.file.Close(); err != nil {
|
||||||
|
return nil, fmt.Errorf("close state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(s.path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// Recreate the file for appending
|
||||||
|
s.file, _ = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("open state file for replay: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var events []StateEvent
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
var event StateEvent
|
||||||
|
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
|
||||||
|
// Skip malformed lines but log them
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reopen for appending
|
||||||
|
s.file, err = os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reopen state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the state store
|
||||||
|
func (s *StateStore) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate rotates the state file (for backup/truncation)
|
||||||
|
func (s *StateStore) Rotate() (string, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
backupPath := s.path + "." + time.Now().Format("20060102_150405") + ".bak"
|
||||||
|
|
||||||
|
if err := s.file.Close(); err != nil {
|
||||||
|
return "", fmt.Errorf("close state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(s.path, backupPath); err != nil {
|
||||||
|
return "", fmt.Errorf("rotate state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(s.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create new state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.file = file
|
||||||
|
return backupPath, nil
|
||||||
|
}
|
||||||
245
internal/scheduler/template.go
Normal file
245
internal/scheduler/template.go
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TemplateResolver handles variable substitution in job specifications
|
||||||
|
// Template variables are resolved at dispatch time on the worker
|
||||||
|
//
|
||||||
|
// Supported variables:
|
||||||
|
// {{HEAD_ADDR}} - Hostname of rank-0 worker (for multi-node)
|
||||||
|
// {{WORLD_SIZE}} - Total node count (for multi-node)
|
||||||
|
// {{NODE_RANK}} - 0-based rank of this worker (for multi-node)
|
||||||
|
// {{GPU_COUNT}} - Number of GPUs available on this worker
|
||||||
|
// {{SERVICE_PORT}} - Port assigned by PortAllocator (for service jobs)
|
||||||
|
// {{HOSTNAME}} - This worker's hostname
|
||||||
|
// {{TASK_ID}} - The task/job ID
|
||||||
|
// {{SECRET:name}} - Secret from worker's secret store
|
||||||
|
|
||||||
|
// TemplateContext provides the values for template substitution
|
||||||
|
type TemplateContext struct {
|
||||||
|
HeadAddr string // Rank-0 worker hostname (multi-node)
|
||||||
|
WorldSize int // Total nodes (multi-node)
|
||||||
|
NodeRank int // This worker's rank (multi-node)
|
||||||
|
GPUCount int // GPUs available
|
||||||
|
ServicePort int // Assigned port (service jobs)
|
||||||
|
Hostname string // This worker's hostname
|
||||||
|
TaskID string // Task/job ID
|
||||||
|
Secrets map[string]string // Secret store
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// templatePattern matches {{VAR}} or {{SECRET:name}}
|
||||||
|
templatePattern = regexp.MustCompile(`\{\{(\w+)(?::([^}]+))?\}\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Resolve substitutes template variables in a string
|
||||||
|
// Returns the resolved string and any error encountered
|
||||||
|
func (tc *TemplateContext) Resolve(input string) (string, error) {
|
||||||
|
if !strings.Contains(input, "{{") {
|
||||||
|
return input, nil // No templates to resolve
|
||||||
|
}
|
||||||
|
|
||||||
|
result := templatePattern.ReplaceAllStringFunc(input, func(match string) string {
|
||||||
|
// Extract variable name and optional secret name
|
||||||
|
parts := templatePattern.FindStringSubmatch(match)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return match // Keep original if malformed
|
||||||
|
}
|
||||||
|
|
||||||
|
varName := parts[1]
|
||||||
|
secretName := ""
|
||||||
|
if len(parts) >= 3 {
|
||||||
|
secretName = parts[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
switch varName {
|
||||||
|
case "HEAD_ADDR":
|
||||||
|
if tc.HeadAddr == "" {
|
||||||
|
return match // Keep unresolved if not set
|
||||||
|
}
|
||||||
|
return tc.HeadAddr
|
||||||
|
case "WORLD_SIZE":
|
||||||
|
if tc.WorldSize == 0 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", tc.WorldSize)
|
||||||
|
case "NODE_RANK":
|
||||||
|
return fmt.Sprintf("%d", tc.NodeRank)
|
||||||
|
case "GPU_COUNT":
|
||||||
|
return fmt.Sprintf("%d", tc.GPUCount)
|
||||||
|
case "SERVICE_PORT":
|
||||||
|
if tc.ServicePort == 0 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", tc.ServicePort)
|
||||||
|
case "HOSTNAME":
|
||||||
|
if tc.Hostname == "" {
|
||||||
|
tc.Hostname, _ = os.Hostname()
|
||||||
|
}
|
||||||
|
return tc.Hostname
|
||||||
|
case "TASK_ID":
|
||||||
|
return tc.TaskID
|
||||||
|
case "SECRET":
|
||||||
|
if val, ok := tc.Secrets[secretName]; ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return match // Keep unresolved if secret not found
|
||||||
|
default:
|
||||||
|
return match // Unknown variable - keep as-is
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveCommand resolves templates in a command slice
|
||||||
|
func (tc *TemplateContext) ResolveCommand(cmd []string) ([]string, error) {
|
||||||
|
result := make([]string, len(cmd))
|
||||||
|
for i, arg := range cmd {
|
||||||
|
resolved, err := tc.Resolve(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve arg %d: %w", i, err)
|
||||||
|
}
|
||||||
|
result[i] = resolved
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveEnv resolves templates in environment variables
|
||||||
|
func (tc *TemplateContext) ResolveEnv(env map[string]string) (map[string]string, error) {
|
||||||
|
result := make(map[string]string, len(env))
|
||||||
|
for k, v := range env {
|
||||||
|
resolved, err := tc.Resolve(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve env %s: %w", k, err)
|
||||||
|
}
|
||||||
|
result[k] = resolved
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveJobSpec resolves all templates in a JobSpec
|
||||||
|
// Returns a new JobSpec with all template variables substituted
|
||||||
|
func (tc *TemplateContext) ResolveJobSpec(spec *JobSpec) (*JobSpec, error) {
|
||||||
|
// Deep copy the spec
|
||||||
|
resolved := &JobSpec{
|
||||||
|
ID: spec.ID,
|
||||||
|
Type: spec.Type,
|
||||||
|
SlotPool: spec.SlotPool,
|
||||||
|
GPUCount: spec.GPUCount,
|
||||||
|
GPUType: spec.GPUType,
|
||||||
|
NodeCount: spec.NodeCount,
|
||||||
|
SnapshotID: spec.SnapshotID,
|
||||||
|
SnapshotSHA: spec.SnapshotSHA,
|
||||||
|
Metadata: make(map[string]string, len(spec.Metadata)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy metadata
|
||||||
|
for k, v := range spec.Metadata {
|
||||||
|
resolved.Metadata[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve command
|
||||||
|
if len(spec.Command) > 0 {
|
||||||
|
cmd, err := tc.ResolveCommand(spec.Command)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve command: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Command = cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve environment
|
||||||
|
if len(spec.Env) > 0 {
|
||||||
|
env, err := tc.ResolveEnv(spec.Env)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve env: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Env = env
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve prolog
|
||||||
|
if len(spec.Prolog) > 0 {
|
||||||
|
prolog, err := tc.ResolveCommand(spec.Prolog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve prolog: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Prolog = prolog
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve epilog
|
||||||
|
if len(spec.Epilog) > 0 {
|
||||||
|
epilog, err := tc.ResolveCommand(spec.Epilog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve epilog: %w", err)
|
||||||
|
}
|
||||||
|
resolved.Epilog = epilog
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve health check endpoints
|
||||||
|
if spec.HealthCheck != nil {
|
||||||
|
hc := &HealthCheck{
|
||||||
|
LivenessEndpoint: spec.HealthCheck.LivenessEndpoint,
|
||||||
|
ReadinessEndpoint: spec.HealthCheck.ReadinessEndpoint,
|
||||||
|
IntervalSecs: spec.HealthCheck.IntervalSecs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if hc.LivenessEndpoint != "" {
|
||||||
|
endpoint, err := tc.Resolve(hc.LivenessEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve liveness endpoint: %w", err)
|
||||||
|
}
|
||||||
|
hc.LivenessEndpoint = endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
if hc.ReadinessEndpoint != "" {
|
||||||
|
endpoint, err := tc.Resolve(hc.ReadinessEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve readiness endpoint: %w", err)
|
||||||
|
}
|
||||||
|
hc.ReadinessEndpoint = endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
resolved.HealthCheck = hc
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolved, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiNodeContext creates a template context for a multi-node job
|
||||||
|
func NewMultiNodeContext(taskID, headAddr string, worldSize, nodeRank, gpuCount int) *TemplateContext {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
return &TemplateContext{
|
||||||
|
TaskID: taskID,
|
||||||
|
HeadAddr: headAddr,
|
||||||
|
WorldSize: worldSize,
|
||||||
|
NodeRank: nodeRank,
|
||||||
|
GPUCount: gpuCount,
|
||||||
|
Hostname: hostname,
|
||||||
|
Secrets: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServiceContext creates a template context for a service job
|
||||||
|
func NewServiceContext(taskID string, servicePort, gpuCount int) *TemplateContext {
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
return &TemplateContext{
|
||||||
|
TaskID: taskID,
|
||||||
|
ServicePort: servicePort,
|
||||||
|
GPUCount: gpuCount,
|
||||||
|
Hostname: hostname,
|
||||||
|
Secrets: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSecret adds a secret to the context
|
||||||
|
func (tc *TemplateContext) SetSecret(name, value string) {
|
||||||
|
if tc.Secrets == nil {
|
||||||
|
tc.Secrets = make(map[string]string)
|
||||||
|
}
|
||||||
|
tc.Secrets[name] = value
|
||||||
|
}
|
||||||
190
tests/benchmarks/scheduler_bench_test.go
Normal file
190
tests/benchmarks/scheduler_bench_test.go
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
package benchmarks_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkPriorityQueueAdd measures job enqueue performance
|
||||||
|
func BenchmarkPriorityQueueAdd(b *testing.B) {
|
||||||
|
pq := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
task := &scheduler.Task{
|
||||||
|
ID: fmt.Sprintf("task-%d", i),
|
||||||
|
Priority: i % 100,
|
||||||
|
}
|
||||||
|
pq.Add(task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPriorityQueueTake measures job dequeue performance
|
||||||
|
func BenchmarkPriorityQueueTake(b *testing.B) {
|
||||||
|
pq := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
// Pre-populate queue
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
task := &scheduler.Task{
|
||||||
|
ID: fmt.Sprintf("task-%d", i),
|
||||||
|
Priority: i % 100,
|
||||||
|
}
|
||||||
|
pq.Add(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pq.Take()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPortAllocator measures port allocation performance
|
||||||
|
func BenchmarkPortAllocator(b *testing.B) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 20000)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
port, _ := pa.Allocate(fmt.Sprintf("service-%d", i))
|
||||||
|
_ = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkStateStoreAppend measures state persistence performance
|
||||||
|
func BenchmarkStateStoreAppend(b *testing.B) {
|
||||||
|
dir := b.TempDir()
|
||||||
|
store, _ := scheduler.NewStateStore(dir + "/bench.state")
|
||||||
|
|
||||||
|
event := scheduler.StateEvent{
|
||||||
|
Type: scheduler.EventJobEnqueued,
|
||||||
|
TaskID: "bench-task",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
event.TaskID = fmt.Sprintf("bench-task-%d", i)
|
||||||
|
store.Append(event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkSchedulerSubmitJob measures job submission throughput
|
||||||
|
func BenchmarkSchedulerSubmitJob(b *testing.B) {
|
||||||
|
// Create scheduler directly for benchmark
|
||||||
|
cfg := scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
StarvationThresholdMins: 5,
|
||||||
|
AcceptanceTimeoutSecs: 5,
|
||||||
|
}
|
||||||
|
hub, err := scheduler.NewHub(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
if err := hub.Start(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hub.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: fmt.Sprintf("bench-job-%d", i),
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWorkerRegistration measures worker registration throughput
|
||||||
|
func BenchmarkWorkerRegistration(b *testing.B) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
workerID := fmt.Sprintf("bench-worker-%d", i)
|
||||||
|
worker := fixtures.NewMockWorker(b, fixture.Hub, workerID)
|
||||||
|
worker.Register(scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
worker.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkHeartbeatProcessing measures heartbeat handling throughput
|
||||||
|
func BenchmarkHeartbeatProcessing(b *testing.B) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
worker := fixture.CreateWorker("bench-hb-worker", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
slots := scheduler.SlotStatus{
|
||||||
|
BatchTotal: 4,
|
||||||
|
BatchInUse: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
worker.SendHeartbeat(slots)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkJobAssignment measures job scheduling latency
|
||||||
|
func BenchmarkJobAssignment(b *testing.B) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create worker
|
||||||
|
worker := fixture.CreateWorker("bench-assign-worker", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Submit job
|
||||||
|
jobID := fmt.Sprintf("bench-assign-%d", i)
|
||||||
|
fixture.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: jobID,
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Signal ready to trigger assignment
|
||||||
|
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||||
|
|
||||||
|
// Wait for assignment
|
||||||
|
worker.RecvTimeout(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkMultiWorkerScheduling measures scheduling with multiple workers
|
||||||
|
func BenchmarkMultiWorkerScheduling(b *testing.B) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(b, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create multiple workers
|
||||||
|
workers := make([]*fixtures.MockWorker, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
workers[i] = fixture.CreateWorker(
|
||||||
|
fmt.Sprintf("bench-multi-worker-%d", i),
|
||||||
|
scheduler.WorkerCapabilities{GPUCount: 0},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Submit job
|
||||||
|
jobID := fmt.Sprintf("bench-multi-%d", i)
|
||||||
|
fixture.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: jobID,
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
})
|
||||||
|
|
||||||
|
// All workers signal ready
|
||||||
|
for _, w := range workers {
|
||||||
|
w.SignalReady(scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0}, "polling")
|
||||||
|
}
|
||||||
|
|
||||||
|
// One worker gets the job
|
||||||
|
workers[i%10].RecvTimeout(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
118
tests/fixtures/scheduler_fixture.go
vendored
Normal file
118
tests/fixtures/scheduler_fixture.go
vendored
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Package fixtures provides shared test utilities and fixtures for scheduler tests
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testStateDir is used for hub state storage in tests
|
||||||
|
var testStateDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
testStateDir, err = os.MkdirTemp("", "fetchml-test-*")
|
||||||
|
if err != nil {
|
||||||
|
panic("failed to create test state dir: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SchedulerTestFixture provides a test fixture for scheduler tests
|
||||||
|
type SchedulerTestFixture struct {
|
||||||
|
T testing.TB
|
||||||
|
Hub *scheduler.SchedulerHub
|
||||||
|
Workers map[string]*MockWorker
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSchedulerTestFixture creates a new scheduler test fixture
|
||||||
|
func NewSchedulerTestFixture(t testing.TB, cfg scheduler.HubConfig) *SchedulerTestFixture {
|
||||||
|
if cfg.BindAddr == "" {
|
||||||
|
cfg.BindAddr = "localhost:0"
|
||||||
|
}
|
||||||
|
|
||||||
|
hub, err := scheduler.NewHub(cfg, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Start scheduler
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return &SchedulerTestFixture{
|
||||||
|
T: t,
|
||||||
|
Hub: hub,
|
||||||
|
Workers: make(map[string]*MockWorker),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWorker creates and registers a new mock worker
|
||||||
|
func (f *SchedulerTestFixture) CreateWorker(workerID string, caps scheduler.WorkerCapabilities) *MockWorker {
|
||||||
|
worker := NewMockWorker(f.T, f.Hub, workerID)
|
||||||
|
worker.Register(caps)
|
||||||
|
f.Workers[workerID] = worker
|
||||||
|
return worker
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitJob submits a job to the scheduler
|
||||||
|
func (f *SchedulerTestFixture) SubmitJob(spec scheduler.JobSpec) {
|
||||||
|
err := f.Hub.SubmitJob(spec)
|
||||||
|
require.NoError(f.T, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTask retrieves a task by ID
|
||||||
|
func (f *SchedulerTestFixture) GetTask(taskID string) *scheduler.Task {
|
||||||
|
return f.Hub.GetTask(taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup stops the scheduler and closes all workers
|
||||||
|
func (f *SchedulerTestFixture) Cleanup() {
|
||||||
|
// Close all workers first
|
||||||
|
for _, worker := range f.Workers {
|
||||||
|
worker.Close()
|
||||||
|
}
|
||||||
|
// Then stop the hub
|
||||||
|
f.Hub.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHubConfig returns a default hub configuration for testing
|
||||||
|
func DefaultHubConfig() scheduler.HubConfig {
|
||||||
|
return scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
StarvationThresholdMins: 5,
|
||||||
|
AcceptanceTimeoutSecs: 5,
|
||||||
|
GangAllocTimeoutSecs: 10,
|
||||||
|
StateDir: testStateDir,
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
"test-token-worker-restart-1": "worker-restart-1",
|
||||||
|
"test-token-mode-switch-worker": "mode-switch-worker",
|
||||||
|
"test-token-mode-switch-worker-2": "mode-switch-worker-2",
|
||||||
|
"test-token-e2e-worker-1": "e2e-worker-1",
|
||||||
|
"test-token-e2e-worker-2": "e2e-worker-2",
|
||||||
|
"test-token-worker-death-test": "worker-death-test",
|
||||||
|
"test-token-worker-split-1": "worker-split-1",
|
||||||
|
"test-token-worker-split-2": "worker-split-2",
|
||||||
|
"test-token-worker-split-3": "worker-split-3",
|
||||||
|
"test-token-worker-timeout": "worker-timeout",
|
||||||
|
"test-token-worker-gang": "worker-gang",
|
||||||
|
"test-token-bench-worker": "bench-worker",
|
||||||
|
"test-token-bench-hb-worker": "bench-hb-worker",
|
||||||
|
"test-token-bench-assign-worker": "bench-assign-worker",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForTimeout is a helper to wait for a condition with timeout
|
||||||
|
func WaitForTimeout(duration time.Duration, condition func() bool) bool {
|
||||||
|
deadline := time.Now().Add(duration)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if condition() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
228
tests/fixtures/scheduler_mock.go
vendored
Normal file
228
tests/fixtures/scheduler_mock.go
vendored
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
// Package fixtures provides shared test utilities for all tests
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockWorker simulates a worker connection for testing
|
||||||
|
type MockWorker struct {
|
||||||
|
Conn *websocket.Conn
|
||||||
|
ID string
|
||||||
|
RecvCh chan scheduler.Message
|
||||||
|
SendCh chan scheduler.Message
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mu sync.RWMutex
|
||||||
|
closed bool
|
||||||
|
T testing.TB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockWorker creates a new mock worker connected to the scheduler
|
||||||
|
func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *MockWorker {
|
||||||
|
addr := hub.Addr()
|
||||||
|
require.NotEmpty(t, addr, "hub not started")
|
||||||
|
|
||||||
|
wsURL := "ws://" + addr + "/ws/worker"
|
||||||
|
|
||||||
|
// Add test token to headers
|
||||||
|
header := http.Header{}
|
||||||
|
header.Set("Authorization", "Bearer test-token-"+workerID)
|
||||||
|
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mw := &MockWorker{
|
||||||
|
Conn: conn,
|
||||||
|
ID: workerID,
|
||||||
|
RecvCh: make(chan scheduler.Message, 100),
|
||||||
|
SendCh: make(chan scheduler.Message, 100),
|
||||||
|
T: t,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start receive goroutine
|
||||||
|
mw.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer mw.wg.Done()
|
||||||
|
for {
|
||||||
|
var msg scheduler.Message
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
close(mw.RecvCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mw.RecvCh <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start send goroutine
|
||||||
|
mw.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer mw.wg.Done()
|
||||||
|
for msg := range mw.SendCh {
|
||||||
|
if err := conn.WriteJSON(msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return mw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register sends worker registration message and waits for ack
|
||||||
|
func (mw *MockWorker) Register(capabilities scheduler.WorkerCapabilities) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: MustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: mw.ID,
|
||||||
|
Capabilities: capabilities,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
msg := mw.RecvTimeout(2 * time.Second)
|
||||||
|
require.Equal(mw.T, scheduler.MsgAck, msg.Type, "expected registration ack")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the scheduler
|
||||||
|
func (mw *MockWorker) Send(msg scheduler.Message) {
|
||||||
|
select {
|
||||||
|
case mw.SendCh <- msg:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
mw.T.Fatal("timeout sending message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv receives a message from the scheduler (blocks)
|
||||||
|
func (mw *MockWorker) Recv() scheduler.Message {
|
||||||
|
select {
|
||||||
|
case msg := <-mw.RecvCh:
|
||||||
|
return msg
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
require.Fail(mw.T, "timeout waiting for message")
|
||||||
|
return scheduler.Message{Type: "timeout"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecvTimeout receives a message with a custom timeout
|
||||||
|
func (mw *MockWorker) RecvTimeout(timeout time.Duration) scheduler.Message {
|
||||||
|
select {
|
||||||
|
case msg := <-mw.RecvCh:
|
||||||
|
return msg
|
||||||
|
case <-time.After(timeout):
|
||||||
|
require.Fail(mw.T, "timeout waiting for message")
|
||||||
|
return scheduler.Message{Type: "timeout"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecvNonBlock tries to receive without blocking
|
||||||
|
func (mw *MockWorker) RecvNonBlock() (scheduler.Message, bool) {
|
||||||
|
select {
|
||||||
|
case msg := <-mw.RecvCh:
|
||||||
|
return msg, true
|
||||||
|
default:
|
||||||
|
return scheduler.Message{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignalReady sends ready for work message
|
||||||
|
func (mw *MockWorker) SignalReady(slots scheduler.SlotStatus, reason string) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: MustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: mw.ID,
|
||||||
|
Slots: slots,
|
||||||
|
Reason: reason,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendHeartbeat sends a heartbeat message
|
||||||
|
func (mw *MockWorker) SendHeartbeat(slots scheduler.SlotStatus) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgHeartbeat,
|
||||||
|
Payload: MustMarshal(scheduler.HeartbeatPayload{
|
||||||
|
WorkerID: mw.ID,
|
||||||
|
Slots: slots,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptJob accepts a job assignment
|
||||||
|
func (mw *MockWorker) AcceptJob(taskID string) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgJobAccepted,
|
||||||
|
Payload: MustMarshal(scheduler.JobResultPayload{
|
||||||
|
TaskID: taskID,
|
||||||
|
State: "accepted",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteJob sends job completion
|
||||||
|
func (mw *MockWorker) CompleteJob(taskID string, exitCode int, output string) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgJobResult,
|
||||||
|
Payload: MustMarshal(scheduler.JobResultPayload{
|
||||||
|
TaskID: taskID,
|
||||||
|
State: "completed",
|
||||||
|
ExitCode: exitCode,
|
||||||
|
Error: output,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendHealth sends service health update
|
||||||
|
func (mw *MockWorker) SendHealth(taskID string, healthy bool, message string) {
|
||||||
|
mw.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgServiceHealth,
|
||||||
|
Payload: MustMarshal(scheduler.ServiceHealthPayload{
|
||||||
|
TaskID: taskID,
|
||||||
|
Healthy: healthy,
|
||||||
|
Message: message,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the worker connection
|
||||||
|
func (mw *MockWorker) Close() {
|
||||||
|
mw.mu.Lock()
|
||||||
|
if mw.closed {
|
||||||
|
mw.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mw.closed = true
|
||||||
|
mw.mu.Unlock()
|
||||||
|
|
||||||
|
close(mw.SendCh)
|
||||||
|
mw.Conn.Close()
|
||||||
|
mw.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForDisconnect waits for the connection to close
|
||||||
|
func (mw *MockWorker) WaitForDisconnect(timeout time.Duration) bool {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
mw.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustMarshal marshals a value to JSON, panicking on error
|
||||||
|
func MustMarshal(v any) []byte {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return b
|
||||||
|
}
|
||||||
248
tests/integration/scheduler/distributed_test.go
Normal file
248
tests/integration/scheduler/distributed_test.go
Normal file
|
|
@ -0,0 +1,248 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDistributedRoundTrip validates full job lifecycle through scheduler
|
||||||
|
func TestDistributedRoundTrip(t *testing.T) {
|
||||||
|
// Create scheduler hub with token auth configured
|
||||||
|
testToken := "test-token-123"
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
AcceptanceTimeoutSecs: 5,
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
testToken: "test-worker",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
// Start scheduler
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get scheduler address - use the actual listening address
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
// Create mock worker connection with auth token
|
||||||
|
workerID := "test-worker"
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Authorization", "Bearer "+testToken)
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Start receive goroutine
|
||||||
|
recvCh := make(chan scheduler.Message, 10)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var msg scheduler.Message
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
close(recvCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recvCh <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Register worker
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: workerID,
|
||||||
|
Capabilities: scheduler.WorkerCapabilities{
|
||||||
|
GPUCount: 0,
|
||||||
|
GPUType: "",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for ack
|
||||||
|
select {
|
||||||
|
case msg := <-recvCh:
|
||||||
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for registration ack")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send heartbeat to show we're alive
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgHeartbeat,
|
||||||
|
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
||||||
|
WorkerID: workerID,
|
||||||
|
Slots: scheduler.SlotStatus{
|
||||||
|
BatchTotal: 4,
|
||||||
|
BatchInUse: 0,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Signal ready for work
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: workerID,
|
||||||
|
Slots: scheduler.SlotStatus{
|
||||||
|
BatchTotal: 4,
|
||||||
|
BatchInUse: 0,
|
||||||
|
},
|
||||||
|
Reason: "polling",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait a bit and verify connection is still alive
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWorkerRegistration validates worker registration flow
|
||||||
|
func TestWorkerRegistration(t *testing.T) {
|
||||||
|
testToken := "reg-test-token"
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
testToken: "reg-test-worker",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
// Start scheduler
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
// Connect worker with auth token
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Authorization", "Bearer "+testToken)
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Receive channel
|
||||||
|
recvCh := make(chan scheduler.Message, 10)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var msg scheduler.Message
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
close(recvCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recvCh <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Register with capabilities
|
||||||
|
workerID := "reg-test-worker"
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: workerID,
|
||||||
|
Capabilities: scheduler.WorkerCapabilities{
|
||||||
|
GPUCount: 2,
|
||||||
|
GPUType: "nvidia",
|
||||||
|
CPUCount: 8,
|
||||||
|
MemoryGB: 32.0,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Expect ack
|
||||||
|
select {
|
||||||
|
case msg := <-recvCh:
|
||||||
|
assert.Equal(t, scheduler.MsgAck, msg.Type)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHeartbeat validates heartbeat slot reporting
|
||||||
|
func TestHeartbeat(t *testing.T) {
|
||||||
|
testToken := "hb-test-token"
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
testToken: "hb-test-worker",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
// Start scheduler
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Authorization", "Bearer "+testToken)
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
workerID := "hb-test-worker"
|
||||||
|
|
||||||
|
// Register first
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: workerID,
|
||||||
|
Capabilities: scheduler.WorkerCapabilities{
|
||||||
|
GPUCount: 0,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Send multiple heartbeats
|
||||||
|
slots := []scheduler.SlotStatus{
|
||||||
|
{BatchTotal: 4, BatchInUse: 0},
|
||||||
|
{BatchTotal: 4, BatchInUse: 1},
|
||||||
|
{BatchTotal: 4, BatchInUse: 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, slot := range slots {
|
||||||
|
err = conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgHeartbeat,
|
||||||
|
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
||||||
|
WorkerID: workerID,
|
||||||
|
Slots: slot,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection should remain healthy
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshal(v any) []byte {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return b
|
||||||
|
}
|
||||||
316
tests/integration/scheduler/gang_service_test.go
Normal file
316
tests/integration/scheduler/gang_service_test.go
Normal file
|
|
@ -0,0 +1,316 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMultiNodeGangAllocation validates 2-node torchrun scenario
|
||||||
|
func TestMultiNodeGangAllocation(t *testing.T) {
|
||||||
|
// Create scheduler hub with gang timeout and auth tokens
|
||||||
|
tokens := map[string]string{
|
||||||
|
"worker1-token": "worker-1",
|
||||||
|
"worker2-token": "worker-2",
|
||||||
|
}
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
GangAllocTimeoutSecs: 10,
|
||||||
|
WorkerTokens: tokens,
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
// Start scheduler
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get scheduler address
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
// Create two worker connections with auth
|
||||||
|
worker1, recv1 := createTestWorkerWithToken(t, wsURL, "worker-1", "worker1-token")
|
||||||
|
worker2, recv2 := createTestWorkerWithToken(t, wsURL, "worker-2", "worker2-token")
|
||||||
|
defer worker1.Close()
|
||||||
|
defer worker2.Close()
|
||||||
|
|
||||||
|
// Register both workers
|
||||||
|
workers := []struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
recv <-chan scheduler.Message
|
||||||
|
id string
|
||||||
|
}{
|
||||||
|
{worker1, recv1, "worker-1"},
|
||||||
|
{worker2, recv2, "worker-2"},
|
||||||
|
}
|
||||||
|
for _, w := range workers {
|
||||||
|
w.conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: w.id,
|
||||||
|
Capabilities: scheduler.WorkerCapabilities{
|
||||||
|
GPUCount: 0,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
msg := <-w.recv
|
||||||
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit multi-node job (2 nodes)
|
||||||
|
jobID := "gang-job-001"
|
||||||
|
err = hub.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: jobID,
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
SlotPool: "batch",
|
||||||
|
NodeCount: 2,
|
||||||
|
Command: []string{"torchrun", "--nnodes=2", "train.py"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Both workers signal ready
|
||||||
|
for _, w := range []struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
id string
|
||||||
|
}{{worker1, "worker-1"}, {worker2, "worker-2"}} {
|
||||||
|
w.conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: w.id,
|
||||||
|
Slots: scheduler.SlotStatus{BatchTotal: 4, BatchInUse: 0},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both workers should receive job assignments
|
||||||
|
msg1 := <-recv1
|
||||||
|
msg2 := <-recv2
|
||||||
|
|
||||||
|
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
|
||||||
|
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
|
||||||
|
|
||||||
|
// Verify both got the same job
|
||||||
|
var spec1, spec2 scheduler.JobSpec
|
||||||
|
json.Unmarshal(msg1.Payload, &spec1)
|
||||||
|
json.Unmarshal(msg2.Payload, &spec2)
|
||||||
|
|
||||||
|
assert.Equal(t, jobID, spec1.ID)
|
||||||
|
assert.Equal(t, jobID, spec2.ID)
|
||||||
|
|
||||||
|
// Verify ranks are different
|
||||||
|
assert.NotEqual(t, spec1.Env["NODE_RANK"], spec2.Env["NODE_RANK"])
|
||||||
|
assert.Equal(t, "2", spec1.Env["WORLD_SIZE"])
|
||||||
|
assert.Equal(t, "2", spec2.Env["WORLD_SIZE"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceLifecycle validates service job start, health checks, and stop
|
||||||
|
func TestServiceLifecycle(t *testing.T) {
|
||||||
|
testToken := "service-test-token"
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 4,
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
testToken: "service-worker",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
// Create worker with auth
|
||||||
|
conn, recvCh := createTestWorkerWithToken(t, wsURL, "service-worker", testToken)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Register
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: "service-worker",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
msg := <-recvCh
|
||||||
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
||||||
|
|
||||||
|
// Submit service job
|
||||||
|
jobID := "service-001"
|
||||||
|
err = hub.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: jobID,
|
||||||
|
Type: scheduler.JobTypeService,
|
||||||
|
SlotPool: "service",
|
||||||
|
Command: []string{"python", "-m", "http.server", "8080"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Signal ready
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: "service-worker",
|
||||||
|
Slots: scheduler.SlotStatus{ServiceTotal: 4, ServiceInUse: 0},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should receive job assignment
|
||||||
|
assignMsg := <-recvCh
|
||||||
|
require.Equal(t, scheduler.MsgJobAssign, assignMsg.Type)
|
||||||
|
|
||||||
|
// Send job accepted
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgJobAccepted,
|
||||||
|
Payload: mustMarshal(map[string]string{
|
||||||
|
"task_id": jobID,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send periodic health updates
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgServiceHealth,
|
||||||
|
Payload: mustMarshal(scheduler.ServiceHealthPayload{
|
||||||
|
TaskID: jobID,
|
||||||
|
Healthy: true,
|
||||||
|
Message: "healthy",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify task exists and is running
|
||||||
|
task := hub.GetTask(jobID)
|
||||||
|
require.NotNil(t, task)
|
||||||
|
assert.Equal(t, "running", task.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStarvationPrevention validates low-priority jobs eventually get scheduled
|
||||||
|
func TestStarvationPrevention(t *testing.T) {
|
||||||
|
testToken := "starvation-test-token"
|
||||||
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
||||||
|
BindAddr: "localhost:0",
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
DefaultBatchSlots: 2,
|
||||||
|
StarvationThresholdMins: 1, // 1 minute for testing
|
||||||
|
WorkerTokens: map[string]string{
|
||||||
|
testToken: "starvation-worker",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer hub.Stop()
|
||||||
|
|
||||||
|
err = hub.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
||||||
|
wsURL := u.String()
|
||||||
|
|
||||||
|
// Create worker with auth
|
||||||
|
conn, recvCh := createTestWorkerWithToken(t, wsURL, "starvation-worker", testToken)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Register
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: "starvation-worker",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
msg := <-recvCh
|
||||||
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
||||||
|
|
||||||
|
// Submit high-priority job
|
||||||
|
err = hub.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: "high-priority-job",
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
Env: map[string]string{"priority": "100"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Submit low-priority job
|
||||||
|
err = hub.SubmitJob(scheduler.JobSpec{
|
||||||
|
ID: "low-priority-job",
|
||||||
|
Type: scheduler.JobTypeBatch,
|
||||||
|
Env: map[string]string{"priority": "1"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Signal ready - should get high priority job first
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: "starvation-worker",
|
||||||
|
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// First assignment should be high priority
|
||||||
|
msg1 := <-recvCh
|
||||||
|
require.Equal(t, scheduler.MsgJobAssign, msg1.Type)
|
||||||
|
|
||||||
|
var spec1 scheduler.JobSpec
|
||||||
|
json.Unmarshal(msg1.Payload, &spec1)
|
||||||
|
assert.Equal(t, "high-priority-job", spec1.ID)
|
||||||
|
|
||||||
|
// Complete first job
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgJobResult,
|
||||||
|
Payload: mustMarshal(scheduler.JobResultPayload{
|
||||||
|
TaskID: "high-priority-job",
|
||||||
|
State: "completed",
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Signal ready again
|
||||||
|
conn.WriteJSON(scheduler.Message{
|
||||||
|
Type: scheduler.MsgReadyForWork,
|
||||||
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
||||||
|
WorkerID: "starvation-worker",
|
||||||
|
Slots: scheduler.SlotStatus{BatchTotal: 2, BatchInUse: 0},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should get low priority job
|
||||||
|
msg2 := <-recvCh
|
||||||
|
require.Equal(t, scheduler.MsgJobAssign, msg2.Type)
|
||||||
|
|
||||||
|
var spec2 scheduler.JobSpec
|
||||||
|
json.Unmarshal(msg2.Payload, &spec2)
|
||||||
|
assert.Equal(t, "low-priority-job", spec2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create test worker with token auth
|
||||||
|
func createTestWorkerWithToken(t *testing.T, wsURL, workerID, token string) (*websocket.Conn, <-chan scheduler.Message) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Authorization", "Bearer "+token)
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recvCh := make(chan scheduler.Message, 10)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var msg scheduler.Message
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
close(recvCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recvCh <- msg
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return conn, recvCh
|
||||||
|
}
|
||||||
188
tests/unit/scheduler/failure_scenarios_test.go
Normal file
188
tests/unit/scheduler/failure_scenarios_test.go
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
fixtures "github.com/jfraeys/fetch_ml/tests/fixtures"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustMarshal(v any) []byte {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWorkerDeath simulates a worker dying mid-job
|
||||||
|
func TestWorkerDeath_MidJob(t *testing.T) {
|
||||||
|
// Use fixture for hub setup
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-death-test", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Send heartbeat
|
||||||
|
worker.SendHeartbeat(scheduler.SlotStatus{
|
||||||
|
BatchTotal: 3,
|
||||||
|
BatchInUse: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate worker death
|
||||||
|
worker.Close()
|
||||||
|
|
||||||
|
// Verify disconnect
|
||||||
|
require.True(t, worker.WaitForDisconnect(2*time.Second), "worker should disconnect")
|
||||||
|
t.Log("Worker death simulation completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSchedulerRestartRecovery simulates scheduler restart
|
||||||
|
func TestSchedulerRestart_Recovery(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Create initial state store
|
||||||
|
ss1, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Record some events
|
||||||
|
events := []scheduler.StateEvent{
|
||||||
|
{Type: scheduler.EventJobEnqueued, TaskID: "task-1", Timestamp: time.Now()},
|
||||||
|
{Type: scheduler.EventJobAssigned, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()},
|
||||||
|
}
|
||||||
|
for _, e := range events {
|
||||||
|
require.NoError(t, ss1.Append(e))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate restart by creating new state store
|
||||||
|
ss2, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Replay should recover state
|
||||||
|
replayed, err := ss2.Replay()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, replayed, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSplitBrain_Case1: Worker reconnects with unknown task
|
||||||
|
func TestSplitBrain_UnknownTask(t *testing.T) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-split-1", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Simulate reconnect with unknown task
|
||||||
|
worker.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: "worker-split-1",
|
||||||
|
ActiveTasks: []scheduler.ActiveTaskReport{
|
||||||
|
{TaskID: "unknown-task", State: "running"},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should receive cancel for unknown task
|
||||||
|
msg := worker.RecvTimeout(2 * time.Second)
|
||||||
|
if msg.Type == scheduler.MsgJobCancel {
|
||||||
|
t.Log("Received expected cancel for unknown task")
|
||||||
|
} else {
|
||||||
|
t.Logf("Received message type: %s (may need to check split-brain handling)", msg.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSplitBrain_Case2: Worker reconnects with orphaned task
|
||||||
|
func TestSplitBrain_OrphanedTask(t *testing.T) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-split-2", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Simulate reconnect with orphaned task
|
||||||
|
worker.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: "worker-split-2",
|
||||||
|
ActiveTasks: []scheduler.ActiveTaskReport{
|
||||||
|
{TaskID: "orphaned-task", State: "running"},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
msg := worker.RecvTimeout(2 * time.Second)
|
||||||
|
t.Logf("Received message type: %s", msg.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSplitBrain_Case3: Worker reconnects with re-queued task
|
||||||
|
func TestSplitBrain_RequeuedTask(t *testing.T) {
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, fixtures.DefaultHubConfig())
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-split-3", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Simulate reconnect with re-queued task
|
||||||
|
worker.Send(scheduler.Message{
|
||||||
|
Type: scheduler.MsgRegister,
|
||||||
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
||||||
|
ID: "worker-split-3",
|
||||||
|
ActiveTasks: []scheduler.ActiveTaskReport{
|
||||||
|
{TaskID: "requeued-task", State: "queued"},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
msg := worker.RecvTimeout(2 * time.Second)
|
||||||
|
t.Logf("Received message type: %s", msg.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAcceptanceTimeout: Job assigned but never accepted
|
||||||
|
func TestAcceptanceTimeout(t *testing.T) {
|
||||||
|
cfg := fixtures.DefaultHubConfig()
|
||||||
|
cfg.AcceptanceTimeoutSecs = 1
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-timeout", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Signal ready but don't accept any job
|
||||||
|
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 3, BatchInUse: 0}, "polling")
|
||||||
|
|
||||||
|
// Wait for potential assignment (but don't accept it)
|
||||||
|
msg, ok := worker.RecvNonBlock()
|
||||||
|
if ok && msg.Type == scheduler.MsgJobAssign {
|
||||||
|
t.Log("Received job assignment, not accepting to test timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for acceptance timeout
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
t.Log("Acceptance timeout test completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGangTimeout: Multi-node job timeout during gang commit
|
||||||
|
func TestGangTimeout(t *testing.T) {
|
||||||
|
cfg := fixtures.DefaultHubConfig()
|
||||||
|
cfg.GangAllocTimeoutSecs = 1
|
||||||
|
fixture := fixtures.NewSchedulerTestFixture(t, cfg)
|
||||||
|
defer fixture.Cleanup()
|
||||||
|
|
||||||
|
// Create mock worker
|
||||||
|
worker := fixture.CreateWorker("worker-gang", scheduler.WorkerCapabilities{GPUCount: 0})
|
||||||
|
|
||||||
|
// Signal ready
|
||||||
|
worker.SignalReady(scheduler.SlotStatus{BatchTotal: 3, BatchInUse: 0}, "polling")
|
||||||
|
|
||||||
|
// Wait for potential assignment
|
||||||
|
_, _ = worker.RecvNonBlock()
|
||||||
|
|
||||||
|
// Wait for gang timeout
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
t.Log("Gang timeout test completed")
|
||||||
|
}
|
||||||
91
tests/unit/scheduler/port_allocator_test.go
Normal file
91
tests/unit/scheduler/port_allocator_test.go
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPortAllocator_BasicAllocation(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10010)
|
||||||
|
|
||||||
|
// Allocate first port
|
||||||
|
port1, err := pa.Allocate("service-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10000, port1)
|
||||||
|
|
||||||
|
// Allocate second port
|
||||||
|
port2, err := pa.Allocate("service-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10001, port2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortAllocator_Exhaustion(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10002) // Only 3 ports available
|
||||||
|
|
||||||
|
// Allocate all ports
|
||||||
|
port1, _ := pa.Allocate("service-1")
|
||||||
|
assert.Equal(t, 10000, port1)
|
||||||
|
|
||||||
|
port2, _ := pa.Allocate("service-2")
|
||||||
|
assert.Equal(t, 10001, port2)
|
||||||
|
|
||||||
|
port3, _ := pa.Allocate("service-3")
|
||||||
|
assert.Equal(t, 10002, port3)
|
||||||
|
|
||||||
|
// Should fail - no ports left
|
||||||
|
_, err := pa.Allocate("service-4")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortAllocator_Release(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10005)
|
||||||
|
|
||||||
|
// Allocate and release
|
||||||
|
port, _ := pa.Allocate("service-1")
|
||||||
|
assert.Equal(t, 10000, port)
|
||||||
|
|
||||||
|
pa.Release(port)
|
||||||
|
|
||||||
|
// Should be able to allocate again
|
||||||
|
port2, err := pa.Allocate("service-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10000, port2) // Reuses released port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortAllocator_DuplicateServiceID(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10010)
|
||||||
|
|
||||||
|
// Allocate for service
|
||||||
|
port1, err := pa.Allocate("service-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Allocate again with same service ID - gets new port (current behavior)
|
||||||
|
port2, err := pa.Allocate("service-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, port1, port2) // Each call returns new port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortAllocator_ConcurrentAccess(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10100) // 101 ports
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
|
||||||
|
// Concurrent allocations
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
pa.Allocate("service-")
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// All 100 ports should be allocated
|
||||||
|
// (10 goroutines * 10 allocations, but only 100 unique service IDs)
|
||||||
|
}
|
||||||
214
tests/unit/scheduler/priority_queue_test.go
Normal file
214
tests/unit/scheduler/priority_queue_test.go
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPriorityQueue_BasicOperations(t *testing.T) {
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
task1 := &scheduler.Task{
|
||||||
|
ID: "task-1",
|
||||||
|
Priority: 10,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
Spec: scheduler.JobSpec{ID: "task-1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
task2 := &scheduler.Task{
|
||||||
|
ID: "task-2",
|
||||||
|
Priority: 5,
|
||||||
|
SubmittedAt: time.Now(),
|
||||||
|
Spec: scheduler.JobSpec{ID: "task-2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tasks
|
||||||
|
q.Add(task1)
|
||||||
|
q.Add(task2)
|
||||||
|
|
||||||
|
require.Equal(t, 2, q.Len())
|
||||||
|
|
||||||
|
// Should return highest priority first (task1 with priority 10)
|
||||||
|
first := q.Take()
|
||||||
|
require.NotNil(t, first)
|
||||||
|
assert.Equal(t, "task-1", first.ID)
|
||||||
|
|
||||||
|
// Second task
|
||||||
|
second := q.Take()
|
||||||
|
require.NotNil(t, second)
|
||||||
|
assert.Equal(t, "task-2", second.ID)
|
||||||
|
|
||||||
|
// Queue empty
|
||||||
|
third := q.Take()
|
||||||
|
assert.Nil(t, third)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_EffectivePriority_WithAging(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Task with lower priority but older submission
|
||||||
|
oldTask := &scheduler.Task{
|
||||||
|
ID: "old-task",
|
||||||
|
Priority: 5,
|
||||||
|
SubmittedAt: now.Add(-10 * time.Minute), // 10 min old
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task with higher priority but recent submission
|
||||||
|
newTask := &scheduler.Task{
|
||||||
|
ID: "new-task",
|
||||||
|
Priority: 10,
|
||||||
|
SubmittedAt: now, // Just submitted
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate effective priorities
|
||||||
|
oldEffective := oldTask.EffectivePriority(0.1, now)
|
||||||
|
newEffective := newTask.EffectivePriority(0.1, now)
|
||||||
|
|
||||||
|
// Old task should have higher effective priority due to aging
|
||||||
|
// 5 + (10 min * 0.1) = 6.0
|
||||||
|
// 10 + (0 min * 0.1) = 10.0
|
||||||
|
assert.Less(t, oldEffective, newEffective)
|
||||||
|
assert.InDelta(t, 6.0, oldEffective, 0.1)
|
||||||
|
assert.InDelta(t, 10.0, newEffective, 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_FIFOOnTie(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
// Two tasks with same priority, submitted at different times
|
||||||
|
task1 := &scheduler.Task{
|
||||||
|
ID: "task-1",
|
||||||
|
Priority: 10,
|
||||||
|
SubmittedAt: now.Add(-5 * time.Minute),
|
||||||
|
Spec: scheduler.JobSpec{ID: "task-1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
task2 := &scheduler.Task{
|
||||||
|
ID: "task-2",
|
||||||
|
Priority: 10,
|
||||||
|
SubmittedAt: now.Add(-1 * time.Minute),
|
||||||
|
Spec: scheduler.JobSpec{ID: "task-2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add in reverse order
|
||||||
|
q.Add(task2)
|
||||||
|
q.Add(task1)
|
||||||
|
|
||||||
|
// Should return older task first (FIFO on tie)
|
||||||
|
first := q.Take()
|
||||||
|
require.NotNil(t, first)
|
||||||
|
assert.Equal(t, "task-1", first.ID)
|
||||||
|
|
||||||
|
second := q.Take()
|
||||||
|
require.NotNil(t, second)
|
||||||
|
assert.Equal(t, "task-2", second.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_Remove(t *testing.T) {
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
task1 := &scheduler.Task{ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}}
|
||||||
|
task2 := &scheduler.Task{ID: "task-2", Priority: 5, Spec: scheduler.JobSpec{ID: "task-2"}}
|
||||||
|
task3 := &scheduler.Task{ID: "task-3", Priority: 1, Spec: scheduler.JobSpec{ID: "task-3"}}
|
||||||
|
|
||||||
|
q.Add(task1)
|
||||||
|
q.Add(task2)
|
||||||
|
q.Add(task3)
|
||||||
|
|
||||||
|
// Remove middle task
|
||||||
|
removed := q.Remove("task-2")
|
||||||
|
assert.True(t, removed)
|
||||||
|
assert.Equal(t, 2, q.Len())
|
||||||
|
|
||||||
|
// Try to remove non-existent
|
||||||
|
removed = q.Remove("non-existent")
|
||||||
|
assert.False(t, removed)
|
||||||
|
|
||||||
|
// Verify remaining order
|
||||||
|
first := q.Take()
|
||||||
|
assert.Equal(t, "task-1", first.ID)
|
||||||
|
second := q.Take()
|
||||||
|
assert.Equal(t, "task-3", second.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_Get(t *testing.T) {
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
task1 := &scheduler.Task{ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}}
|
||||||
|
q.Add(task1)
|
||||||
|
|
||||||
|
// Get existing task
|
||||||
|
found := q.Get("task-1")
|
||||||
|
assert.NotNil(t, found)
|
||||||
|
assert.Equal(t, "task-1", found.ID)
|
||||||
|
|
||||||
|
// Get non-existent
|
||||||
|
notFound := q.Get("non-existent")
|
||||||
|
assert.Nil(t, notFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_Items(t *testing.T) {
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
|
||||||
|
tasks := []*scheduler.Task{
|
||||||
|
{ID: "task-1", Priority: 10, Spec: scheduler.JobSpec{ID: "task-1"}},
|
||||||
|
{ID: "task-2", Priority: 5, Spec: scheduler.JobSpec{ID: "task-2"}},
|
||||||
|
{ID: "task-3", Priority: 1, Spec: scheduler.JobSpec{ID: "task-3"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, task := range tasks {
|
||||||
|
q.Add(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
items := q.Items()
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
|
||||||
|
// Items should be in priority order (highest first)
|
||||||
|
assert.Equal(t, "task-1", items[0].ID)
|
||||||
|
assert.Equal(t, "task-2", items[1].ID)
|
||||||
|
assert.Equal(t, "task-3", items[2].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPriorityQueue_ConcurrentAccess(t *testing.T) {
|
||||||
|
q := scheduler.NewPriorityQueue(0.1)
|
||||||
|
done := make(chan bool, 3)
|
||||||
|
|
||||||
|
// Concurrent adds
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
q.Add(&scheduler.Task{ID: fmt.Sprintf("task-%d", i), Priority: i})
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Concurrent takes
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
q.Take()
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Concurrent peeks
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
q.Peek()
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queue should be in consistent state
|
||||||
|
assert.GreaterOrEqual(t, q.Len(), 0)
|
||||||
|
assert.LessOrEqual(t, q.Len(), 100)
|
||||||
|
}
|
||||||
264
tests/unit/scheduler/service_templates_test.go
Normal file
264
tests/unit/scheduler/service_templates_test.go
Normal file
|
|
@ -0,0 +1,264 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestJupyterLabTemplate validates the JupyterLab template configuration
|
||||||
|
func TestJupyterLabTemplate(t *testing.T) {
|
||||||
|
template := scheduler.JupyterLabTemplate
|
||||||
|
|
||||||
|
assert.Equal(t, "service", template.JobType)
|
||||||
|
assert.Equal(t, "service", template.SlotPool)
|
||||||
|
assert.Equal(t, 0, template.GPUCount)
|
||||||
|
|
||||||
|
// Verify command includes required flags
|
||||||
|
require.NotEmpty(t, template.Command)
|
||||||
|
assert.Contains(t, template.Command, "jupyter")
|
||||||
|
assert.Contains(t, template.Command, "lab")
|
||||||
|
assert.Contains(t, template.Command, "--ip=0.0.0.0")
|
||||||
|
assert.Contains(t, template.Command, "--port={{SERVICE_PORT}}")
|
||||||
|
assert.Contains(t, template.Command, "--no-browser")
|
||||||
|
|
||||||
|
// Verify health checks
|
||||||
|
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/api", template.HealthCheck.Liveness)
|
||||||
|
assert.Equal(t, "http://localhost:{{SERVICE_PORT}}/api/kernels", template.HealthCheck.Readiness)
|
||||||
|
assert.Equal(t, 15, template.HealthCheck.Interval)
|
||||||
|
assert.Equal(t, 5, template.HealthCheck.Timeout)
|
||||||
|
|
||||||
|
// Verify mounts
|
||||||
|
require.Len(t, template.Mounts, 1)
|
||||||
|
assert.Equal(t, "{{WORKSPACE}}", template.Mounts[0].Source)
|
||||||
|
assert.Equal(t, "/workspace", template.Mounts[0].Destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJupyterNotebookTemplate validates the classic notebook template
|
||||||
|
func TestJupyterNotebookTemplate(t *testing.T) {
|
||||||
|
template := scheduler.JupyterNotebookTemplate
|
||||||
|
|
||||||
|
assert.Equal(t, "service", template.JobType)
|
||||||
|
assert.Equal(t, "service", template.SlotPool)
|
||||||
|
assert.Equal(t, 0, template.GPUCount)
|
||||||
|
|
||||||
|
// Verify uses notebook subcommand
|
||||||
|
require.NotEmpty(t, template.Command)
|
||||||
|
assert.Contains(t, template.Command, "notebook")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVLLMTemplate validates the vLLM inference template
|
||||||
|
func TestVLLMTemplate(t *testing.T) {
|
||||||
|
template := scheduler.VLLMTemplate
|
||||||
|
|
||||||
|
assert.Equal(t, "service", template.JobType)
|
||||||
|
assert.Equal(t, "service", template.SlotPool)
|
||||||
|
assert.Equal(t, 1, template.GPUCount) // Requires GPU
|
||||||
|
|
||||||
|
// Verify command
|
||||||
|
require.NotEmpty(t, template.Command)
|
||||||
|
assert.Contains(t, template.Command, "vllm.entrypoints.openai.api_server")
|
||||||
|
assert.Contains(t, template.Command, "{{MODEL_NAME}}")
|
||||||
|
assert.Contains(t, template.Command, "{{SERVICE_PORT}}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPortAllocatorForServices validates port allocation for service jobs
|
||||||
|
func TestPortAllocatorForServices(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(10000, 10010)
|
||||||
|
|
||||||
|
// Allocate a port for Jupyter service
|
||||||
|
port1, err := pa.Allocate("jupyter-task-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, port1 >= 10000 && port1 <= 10010)
|
||||||
|
|
||||||
|
// Verify we can get the task for this port
|
||||||
|
taskID := pa.GetAllocation(port1)
|
||||||
|
assert.Equal(t, "jupyter-task-1", taskID)
|
||||||
|
|
||||||
|
// Allocate another port
|
||||||
|
port2, err := pa.Allocate("jupyter-task-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, port1, port2)
|
||||||
|
|
||||||
|
// Release first port
|
||||||
|
pa.Release(port1)
|
||||||
|
|
||||||
|
// Verify port is now available
|
||||||
|
taskID = pa.GetAllocation(port1)
|
||||||
|
assert.Equal(t, "", taskID)
|
||||||
|
|
||||||
|
// Can reallocate the same port
|
||||||
|
port3, err := pa.Allocate("jupyter-task-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Should get first available (which might be port1)
|
||||||
|
assert.True(t, port3 >= 10000 && port3 <= 10010)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPortAllocatorExhaustion validates behavior when no ports available
|
||||||
|
func TestPortAllocatorExhaustion(t *testing.T) {
|
||||||
|
// Small range for testing
|
||||||
|
pa := scheduler.NewPortAllocator(20000, 20002)
|
||||||
|
|
||||||
|
// Allocate all ports
|
||||||
|
_, err := pa.Allocate("task-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = pa.Allocate("task-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = pa.Allocate("task-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Fourth allocation should fail
|
||||||
|
_, err = pa.Allocate("task-4")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no ports available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPortAllocatorTTL validates port TTL behavior
|
||||||
|
func TestPortAllocatorTTL(t *testing.T) {
|
||||||
|
pa := scheduler.NewPortAllocator(30000, 30010)
|
||||||
|
|
||||||
|
// Set short TTL for testing
|
||||||
|
pa.SetTTL(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Allocate a port
|
||||||
|
port1, err := pa.Allocate("test-task")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Release it (marks with expired timestamp due to short TTL)
|
||||||
|
pa.Release(port1)
|
||||||
|
|
||||||
|
// Immediately try to allocate - should get different port since released one is "expired"
|
||||||
|
port2, err := pa.Allocate("test-task-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Could be same or different depending on cleanup timing
|
||||||
|
assert.True(t, port2 >= 30000 && port2 <= 30010)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceSlotPoolSeparation validates that service and batch use different pools
|
||||||
|
func TestServiceSlotPoolSeparation(t *testing.T) {
|
||||||
|
// This test validates the conceptual separation
|
||||||
|
// In practice, the scheduler maintains separate queues
|
||||||
|
|
||||||
|
// Use JupyterLabTemplate which has health checks configured
|
||||||
|
serviceJob := scheduler.JupyterLabTemplate
|
||||||
|
|
||||||
|
batchJob := scheduler.JobSpec{
|
||||||
|
ID: "batch-1",
|
||||||
|
SlotPool: "batch",
|
||||||
|
GPUCount: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify different slot pools
|
||||||
|
assert.Equal(t, "service", serviceJob.SlotPool)
|
||||||
|
assert.Equal(t, "batch", batchJob.SlotPool)
|
||||||
|
|
||||||
|
// Service job has health checks
|
||||||
|
assert.NotZero(t, serviceJob.HealthCheck.Interval)
|
||||||
|
|
||||||
|
// Batch job would typically not have health checks
|
||||||
|
// (it runs to completion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHealthCheckValidation validates health check configuration
|
||||||
|
func TestHealthCheckValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template scheduler.ServiceTemplate
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "JupyterLab - valid",
|
||||||
|
template: scheduler.ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service",
|
||||||
|
HealthCheck: scheduler.ServiceHealthCheck{
|
||||||
|
Liveness: "http://localhost:8888/api",
|
||||||
|
Readiness: "http://localhost:8888/api/kernels",
|
||||||
|
Interval: 15,
|
||||||
|
Timeout: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing liveness - invalid",
|
||||||
|
template: scheduler.ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service",
|
||||||
|
HealthCheck: scheduler.ServiceHealthCheck{
|
||||||
|
Readiness: "http://localhost:8888/api",
|
||||||
|
Interval: 15,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Zero interval - invalid",
|
||||||
|
template: scheduler.ServiceTemplate{
|
||||||
|
JobType: "service",
|
||||||
|
SlotPool: "service",
|
||||||
|
HealthCheck: scheduler.ServiceHealthCheck{
|
||||||
|
Liveness: "http://localhost:8888/api",
|
||||||
|
Readiness: "http://localhost:8888/api",
|
||||||
|
Interval: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hc := tt.template.HealthCheck
|
||||||
|
isValid := hc.Liveness != "" && hc.Interval > 0 && hc.Timeout > 0
|
||||||
|
assert.Equal(t, tt.valid, isValid)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultPortRange validates the default service port range
|
||||||
|
func TestDefaultPortRange(t *testing.T) {
|
||||||
|
// Default range should be large enough for typical deployments
|
||||||
|
rangeSize := scheduler.DefaultServicePortEnd - scheduler.DefaultServicePortStart
|
||||||
|
assert.True(t, rangeSize >= 1000, "Default port range should be at least 1000 ports")
|
||||||
|
assert.Equal(t, 8000, scheduler.DefaultServicePortStart)
|
||||||
|
assert.Equal(t, 9000, scheduler.DefaultServicePortEnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTemplateVariableExpansion validates template variables are present
|
||||||
|
func TestTemplateVariableExpansion(t *testing.T) {
|
||||||
|
template := scheduler.JupyterLabTemplate
|
||||||
|
|
||||||
|
// Check command contains template variables
|
||||||
|
hasServicePort := false
|
||||||
|
for _, cmd := range template.Command {
|
||||||
|
if cmd == "--port={{SERVICE_PORT}}" {
|
||||||
|
hasServicePort = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasServicePort, "Command should contain {{SERVICE_PORT}} template variable")
|
||||||
|
|
||||||
|
// Check env contains secret template
|
||||||
|
val, ok := template.Env["JUPYTER_TOKEN"]
|
||||||
|
assert.True(t, ok, "Should have JUPYTER_TOKEN env var")
|
||||||
|
assert.Contains(t, val, "{{SECRET:", "Should use secret template")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPortAllocation benchmarks port allocation performance
|
||||||
|
func BenchmarkPortAllocation(b *testing.B) {
|
||||||
|
pa := scheduler.NewPortAllocator(40000, 41000)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
port, err := pa.Allocate("bench-task")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
pa.Release(port)
|
||||||
|
}
|
||||||
|
}
|
||||||
67
tests/unit/scheduler/state_store_test.go
Normal file
67
tests/unit/scheduler/state_store_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
package scheduler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStateStore_BasicOperations(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
ss, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Append some events
|
||||||
|
events := []scheduler.StateEvent{
|
||||||
|
{Type: scheduler.EventJobEnqueued, TaskID: "task-1", Timestamp: time.Now()},
|
||||||
|
{Type: scheduler.EventJobAssigned, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()},
|
||||||
|
{Type: scheduler.EventJobCompleted, TaskID: "task-1", WorkerID: "worker-1", Timestamp: time.Now()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range events {
|
||||||
|
err := ss.Append(e)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replay events
|
||||||
|
replayed, err := ss.Replay()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, replayed, 3)
|
||||||
|
assert.Equal(t, "task-1", replayed[0].TaskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateStore_Persistence(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Create store and append events
|
||||||
|
ss1, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
event := scheduler.StateEvent{
|
||||||
|
Type: scheduler.EventJobEnqueued,
|
||||||
|
TaskID: "persistent-task",
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
err = ss1.Append(event)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create new store instance pointing to same directory
|
||||||
|
ss2, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
replayed, err := ss2.Replay()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, replayed, 1)
|
||||||
|
assert.Equal(t, "persistent-task", replayed[0].TaskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateStore_ReplayEmpty(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
ss, err := scheduler.NewStateStore(dir + "/state.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
replayed, err := ss.Replay()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, replayed)
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue