Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
274 lines
8.1 KiB
Go
274 lines
8.1 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Config represents the scheduler configuration
|
|
type Config struct {
|
|
Scheduler SchedulerConfig `yaml:"scheduler"`
|
|
}
|
|
|
|
type SchedulerConfig struct {
|
|
BindAddr string `yaml:"bind_addr"`
|
|
CertFile string `yaml:"cert_file"`
|
|
KeyFile string `yaml:"key_file"`
|
|
AutoGenerateCerts bool `yaml:"auto_generate_certs"`
|
|
StateDir string `yaml:"state_dir"`
|
|
DefaultBatchSlots int `yaml:"default_batch_slots"`
|
|
DefaultServiceSlots int `yaml:"default_service_slots"`
|
|
StarvationThresholdMins float64 `yaml:"starvation_threshold_mins"`
|
|
PriorityAgingRate float64 `yaml:"priority_aging_rate"`
|
|
GangAllocTimeoutSecs int `yaml:"gang_alloc_timeout_secs"`
|
|
AcceptanceTimeoutSecs int `yaml:"acceptance_timeout_secs"`
|
|
MetricsAddr string `yaml:"metrics_addr"`
|
|
WorkerTokens []WorkerToken `yaml:"worker_tokens"`
|
|
}
|
|
|
|
type WorkerToken struct {
|
|
ID string `yaml:"id"`
|
|
Token string `yaml:"token"`
|
|
}
|
|
|
|
func main() {
|
|
var (
|
|
configPath string
|
|
generateToken bool
|
|
initConfig bool
|
|
numTokens int
|
|
)
|
|
flag.StringVar(&configPath, "config", "scheduler.yaml", "Path to scheduler config file")
|
|
flag.BoolVar(&generateToken, "generate-token", false, "Generate a new worker token and exit")
|
|
flag.BoolVar(&initConfig, "init", false, "Initialize a new config file with generated tokens")
|
|
flag.IntVar(&numTokens, "tokens", 3, "Number of tokens to generate (used with -init)")
|
|
flag.Parse()
|
|
|
|
// Handle token generation mode
|
|
if generateToken {
|
|
token := scheduler.GenerateWorkerToken()
|
|
fmt.Println(token)
|
|
os.Exit(0)
|
|
}
|
|
|
|
// Handle init mode
|
|
if initConfig {
|
|
if err := generateConfig(configPath, numTokens); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
fmt.Printf("Config generated: %s\n", configPath)
|
|
fmt.Printf("\nGenerated %d worker tokens. Copy the appropriate token to each worker's config.\n", numTokens)
|
|
os.Exit(0)
|
|
}
|
|
|
|
// Load config
|
|
cfg, err := loadConfig(configPath)
|
|
if err != nil {
|
|
slog.Error("failed to load config", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Setup logging
|
|
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})
|
|
logger := slog.New(handler)
|
|
slog.SetDefault(logger)
|
|
|
|
// Create token map
|
|
tokenMap := make(map[string]string)
|
|
for _, wt := range cfg.Scheduler.WorkerTokens {
|
|
tokenMap[wt.Token] = wt.ID
|
|
}
|
|
|
|
// Auto-generate certs if needed
|
|
if cfg.Scheduler.AutoGenerateCerts && cfg.Scheduler.CertFile != "" {
|
|
if _, err := os.Stat(cfg.Scheduler.CertFile); os.IsNotExist(err) {
|
|
keyFile := cfg.Scheduler.KeyFile
|
|
if keyFile == "" {
|
|
keyFile = cfg.Scheduler.CertFile + ".key"
|
|
}
|
|
logger.Info("generating self-signed certificate", "cert", cfg.Scheduler.CertFile)
|
|
if err := scheduler.GenerateSelfSignedCert(cfg.Scheduler.CertFile, keyFile); err != nil {
|
|
logger.Error("failed to generate certificate", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create hub config
|
|
hubCfg := scheduler.HubConfig{
|
|
BindAddr: cfg.Scheduler.BindAddr,
|
|
CertFile: cfg.Scheduler.CertFile,
|
|
KeyFile: cfg.Scheduler.KeyFile,
|
|
AutoGenerateCerts: cfg.Scheduler.AutoGenerateCerts,
|
|
StateDir: cfg.Scheduler.StateDir,
|
|
DefaultBatchSlots: cfg.Scheduler.DefaultBatchSlots,
|
|
DefaultServiceSlots: cfg.Scheduler.DefaultServiceSlots,
|
|
StarvationThresholdMins: cfg.Scheduler.StarvationThresholdMins,
|
|
PriorityAgingRate: cfg.Scheduler.PriorityAgingRate,
|
|
GangAllocTimeoutSecs: cfg.Scheduler.GangAllocTimeoutSecs,
|
|
AcceptanceTimeoutSecs: cfg.Scheduler.AcceptanceTimeoutSecs,
|
|
WorkerTokens: tokenMap,
|
|
}
|
|
|
|
// Create auditor (optional)
|
|
var auditor *audit.Logger
|
|
|
|
// Create hub
|
|
hub, err := scheduler.NewHub(hubCfg, auditor)
|
|
if err != nil {
|
|
logger.Error("failed to create scheduler hub", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Start hub
|
|
if err := hub.Start(); err != nil {
|
|
logger.Error("failed to start scheduler hub", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
logger.Info("scheduler hub started", "bind_addr", cfg.Scheduler.BindAddr)
|
|
|
|
// Setup HTTP handlers
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/ws/worker", hub.HandleConnection)
|
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"status":"ok"}`))
|
|
})
|
|
mux.HandleFunc("/metrics", hub.ServeMetrics)
|
|
|
|
// Setup graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
// Start server
|
|
go func() {
|
|
if cfg.Scheduler.CertFile != "" {
|
|
logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr)
|
|
if err := http.ListenAndServeTLS(cfg.Scheduler.BindAddr, cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile, mux); err != nil {
|
|
logger.Error("server error", "error", err)
|
|
}
|
|
} else {
|
|
logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr)
|
|
if err := http.ListenAndServe(cfg.Scheduler.BindAddr, mux); err != nil {
|
|
logger.Error("server error", "error", err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Wait for shutdown signal
|
|
<-sigChan
|
|
logger.Info("shutting down scheduler...")
|
|
hub.Stop()
|
|
logger.Info("scheduler stopped")
|
|
}
|
|
|
|
func loadConfig(path string) (*Config, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read config file: %w", err)
|
|
}
|
|
|
|
var cfg Config
|
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
|
|
// Set defaults
|
|
if cfg.Scheduler.BindAddr == "" {
|
|
cfg.Scheduler.BindAddr = "0.0.0.0:7777"
|
|
}
|
|
if cfg.Scheduler.StateDir == "" {
|
|
cfg.Scheduler.StateDir = "/var/lib/fetch_ml"
|
|
}
|
|
if cfg.Scheduler.DefaultBatchSlots == 0 {
|
|
cfg.Scheduler.DefaultBatchSlots = 3
|
|
}
|
|
if cfg.Scheduler.DefaultServiceSlots == 0 {
|
|
cfg.Scheduler.DefaultServiceSlots = 1
|
|
}
|
|
if cfg.Scheduler.StarvationThresholdMins == 0 {
|
|
cfg.Scheduler.StarvationThresholdMins = 5
|
|
}
|
|
if cfg.Scheduler.PriorityAgingRate == 0 {
|
|
cfg.Scheduler.PriorityAgingRate = 0.1
|
|
}
|
|
if cfg.Scheduler.GangAllocTimeoutSecs == 0 {
|
|
cfg.Scheduler.GangAllocTimeoutSecs = 60
|
|
}
|
|
if cfg.Scheduler.AcceptanceTimeoutSecs == 0 {
|
|
cfg.Scheduler.AcceptanceTimeoutSecs = 30
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
// generateConfig creates a new scheduler config file with generated tokens
|
|
func generateConfig(path string, numTokens int) error {
|
|
// Generate tokens
|
|
var tokens []WorkerToken
|
|
for i := 1; i <= numTokens; i++ {
|
|
tokens = append(tokens, WorkerToken{
|
|
ID: fmt.Sprintf("worker-%02d", i),
|
|
Token: scheduler.GenerateWorkerToken(),
|
|
})
|
|
}
|
|
|
|
cfg := Config{
|
|
Scheduler: SchedulerConfig{
|
|
BindAddr: "0.0.0.0:7777",
|
|
AutoGenerateCerts: true,
|
|
CertFile: "/etc/fetch_ml/scheduler.crt",
|
|
KeyFile: "/etc/fetch_ml/scheduler.key",
|
|
StateDir: "/var/lib/fetch_ml",
|
|
DefaultBatchSlots: 3,
|
|
DefaultServiceSlots: 1,
|
|
StarvationThresholdMins: 5,
|
|
PriorityAgingRate: 0.1,
|
|
GangAllocTimeoutSecs: 60,
|
|
AcceptanceTimeoutSecs: 30,
|
|
MetricsAddr: "0.0.0.0:9090",
|
|
WorkerTokens: tokens,
|
|
},
|
|
}
|
|
|
|
data, err := yaml.Marshal(cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal config: %w", err)
|
|
}
|
|
|
|
// Add header comment
|
|
header := `# Scheduler Configuration for fetch_ml
|
|
# Generated by: scheduler -init
|
|
#
|
|
# ⚠️ SECURITY WARNING: This file contains authentication tokens.
|
|
# - Do NOT commit to git
|
|
# - Keep the file permissions secure (chmod 600)
|
|
# - Copy the appropriate token to each worker's config
|
|
#
|
|
`
|
|
fullContent := header + string(data)
|
|
|
|
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
|
|
return fmt.Errorf("write config file: %w", err)
|
|
}
|
|
|
|
// Print tokens to stdout for easy distribution
|
|
fmt.Print("\n=== Generated Worker Tokens ===\n")
|
|
fmt.Print("Copy these to your worker configs:\n\n")
|
|
for _, t := range tokens {
|
|
fmt.Printf("Worker: %s\nToken: %s\n\n", t.ID, t.Token)
|
|
}
|
|
|
|
return nil
|
|
}
|