fetch_ml/cmd/scheduler/main.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

274 lines
8.1 KiB
Go

package main
import (
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"gopkg.in/yaml.v3"
)
// Config represents the scheduler configuration
type Config struct {
Scheduler SchedulerConfig `yaml:"scheduler"`
}
type SchedulerConfig struct {
BindAddr string `yaml:"bind_addr"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
AutoGenerateCerts bool `yaml:"auto_generate_certs"`
StateDir string `yaml:"state_dir"`
DefaultBatchSlots int `yaml:"default_batch_slots"`
DefaultServiceSlots int `yaml:"default_service_slots"`
StarvationThresholdMins float64 `yaml:"starvation_threshold_mins"`
PriorityAgingRate float64 `yaml:"priority_aging_rate"`
GangAllocTimeoutSecs int `yaml:"gang_alloc_timeout_secs"`
AcceptanceTimeoutSecs int `yaml:"acceptance_timeout_secs"`
MetricsAddr string `yaml:"metrics_addr"`
WorkerTokens []WorkerToken `yaml:"worker_tokens"`
}
type WorkerToken struct {
ID string `yaml:"id"`
Token string `yaml:"token"`
}
func main() {
var (
configPath string
generateToken bool
initConfig bool
numTokens int
)
flag.StringVar(&configPath, "config", "scheduler.yaml", "Path to scheduler config file")
flag.BoolVar(&generateToken, "generate-token", false, "Generate a new worker token and exit")
flag.BoolVar(&initConfig, "init", false, "Initialize a new config file with generated tokens")
flag.IntVar(&numTokens, "tokens", 3, "Number of tokens to generate (used with -init)")
flag.Parse()
// Handle token generation mode
if generateToken {
token := scheduler.GenerateWorkerToken()
fmt.Println(token)
os.Exit(0)
}
// Handle init mode
if initConfig {
if err := generateConfig(configPath, numTokens); err != nil {
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
os.Exit(1)
}
fmt.Printf("Config generated: %s\n", configPath)
fmt.Printf("\nGenerated %d worker tokens. Copy the appropriate token to each worker's config.\n", numTokens)
os.Exit(0)
}
// Load config
cfg, err := loadConfig(configPath)
if err != nil {
slog.Error("failed to load config", "error", err)
os.Exit(1)
}
// Setup logging
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})
logger := slog.New(handler)
slog.SetDefault(logger)
// Create token map
tokenMap := make(map[string]string)
for _, wt := range cfg.Scheduler.WorkerTokens {
tokenMap[wt.Token] = wt.ID
}
// Auto-generate certs if needed
if cfg.Scheduler.AutoGenerateCerts && cfg.Scheduler.CertFile != "" {
if _, err := os.Stat(cfg.Scheduler.CertFile); os.IsNotExist(err) {
keyFile := cfg.Scheduler.KeyFile
if keyFile == "" {
keyFile = cfg.Scheduler.CertFile + ".key"
}
logger.Info("generating self-signed certificate", "cert", cfg.Scheduler.CertFile)
if err := scheduler.GenerateSelfSignedCert(cfg.Scheduler.CertFile, keyFile); err != nil {
logger.Error("failed to generate certificate", "error", err)
os.Exit(1)
}
}
}
// Create hub config
hubCfg := scheduler.HubConfig{
BindAddr: cfg.Scheduler.BindAddr,
CertFile: cfg.Scheduler.CertFile,
KeyFile: cfg.Scheduler.KeyFile,
AutoGenerateCerts: cfg.Scheduler.AutoGenerateCerts,
StateDir: cfg.Scheduler.StateDir,
DefaultBatchSlots: cfg.Scheduler.DefaultBatchSlots,
DefaultServiceSlots: cfg.Scheduler.DefaultServiceSlots,
StarvationThresholdMins: cfg.Scheduler.StarvationThresholdMins,
PriorityAgingRate: cfg.Scheduler.PriorityAgingRate,
GangAllocTimeoutSecs: cfg.Scheduler.GangAllocTimeoutSecs,
AcceptanceTimeoutSecs: cfg.Scheduler.AcceptanceTimeoutSecs,
WorkerTokens: tokenMap,
}
// Create auditor (optional)
var auditor *audit.Logger
// Create hub
hub, err := scheduler.NewHub(hubCfg, auditor)
if err != nil {
logger.Error("failed to create scheduler hub", "error", err)
os.Exit(1)
}
// Start hub
if err := hub.Start(); err != nil {
logger.Error("failed to start scheduler hub", "error", err)
os.Exit(1)
}
logger.Info("scheduler hub started", "bind_addr", cfg.Scheduler.BindAddr)
// Setup HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/ws/worker", hub.HandleConnection)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok"}`))
})
mux.HandleFunc("/metrics", hub.ServeMetrics)
// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start server
go func() {
if cfg.Scheduler.CertFile != "" {
logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr)
if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil {
logger.Error("server error", "error", err)
}
} else {
logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr)
if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil {
logger.Error("server error", "error", err)
}
}
}()
// Wait for shutdown signal
<-sigChan
logger.Info("shutting down scheduler...")
hub.Stop()
logger.Info("scheduler stopped")
}
func loadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config file: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Set defaults
if cfg.Scheduler.BindAddr == "" {
cfg.Scheduler.BindAddr = "0.0.0.0:7777"
}
if cfg.Scheduler.StateDir == "" {
cfg.Scheduler.StateDir = "/var/lib/fetch_ml"
}
if cfg.Scheduler.DefaultBatchSlots == 0 {
cfg.Scheduler.DefaultBatchSlots = 3
}
if cfg.Scheduler.DefaultServiceSlots == 0 {
cfg.Scheduler.DefaultServiceSlots = 1
}
if cfg.Scheduler.StarvationThresholdMins == 0 {
cfg.Scheduler.StarvationThresholdMins = 5
}
if cfg.Scheduler.PriorityAgingRate == 0 {
cfg.Scheduler.PriorityAgingRate = 0.1
}
if cfg.Scheduler.GangAllocTimeoutSecs == 0 {
cfg.Scheduler.GangAllocTimeoutSecs = 60
}
if cfg.Scheduler.AcceptanceTimeoutSecs == 0 {
cfg.Scheduler.AcceptanceTimeoutSecs = 30
}
return &cfg, nil
}
// generateConfig creates a new scheduler config file with generated tokens
func generateConfig(path string, numTokens int) error {
// Generate tokens
var tokens []WorkerToken
for i := 1; i <= numTokens; i++ {
tokens = append(tokens, WorkerToken{
ID: fmt.Sprintf("worker-%02d", i),
Token: scheduler.GenerateWorkerToken(),
})
}
cfg := Config{
Scheduler: SchedulerConfig{
BindAddr: "0.0.0.0:7777",
AutoGenerateCerts: true,
CertFile: "/etc/fetch_ml/scheduler.crt",
KeyFile: "/etc/fetch_ml/scheduler.key",
StateDir: "/var/lib/fetch_ml",
DefaultBatchSlots: 3,
DefaultServiceSlots: 1,
StarvationThresholdMins: 5,
PriorityAgingRate: 0.1,
GangAllocTimeoutSecs: 60,
AcceptanceTimeoutSecs: 30,
MetricsAddr: "0.0.0.0:9090",
WorkerTokens: tokens,
},
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
// Add header comment
header := `# Scheduler Configuration for fetch_ml
# Generated by: scheduler -init
#
# ⚠️ SECURITY WARNING: This file contains authentication tokens.
# - Do NOT commit to git
# - Keep the file permissions secure (chmod 600)
# - Copy the appropriate token to each worker's config
#
`
fullContent := header + string(data)
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
return fmt.Errorf("write config file: %w", err)
}
// Print tokens to stdout for easy distribution
fmt.Print("\n=== Generated Worker Tokens ===\n")
fmt.Print("Copy these to your worker configs:\n\n")
for _, t := range tokens {
fmt.Printf("Worker: %s\nToken: %s\n\n", t.ID, t.Token)
}
return nil
}