package main import ( "flag" "fmt" "log/slog" "net/http" "os" "os/signal" "syscall" "time" "github.com/jfraeys/fetch_ml/internal/audit" "github.com/jfraeys/fetch_ml/internal/scheduler" "gopkg.in/yaml.v3" ) // Config represents the scheduler configuration type Config struct { Scheduler SchedulerConfig `yaml:"scheduler"` } type SchedulerConfig struct { BindAddr string `yaml:"bind_addr"` CertFile string `yaml:"cert_file"` KeyFile string `yaml:"key_file"` AutoGenerateCerts bool `yaml:"auto_generate_certs"` StateDir string `yaml:"state_dir"` DefaultBatchSlots int `yaml:"default_batch_slots"` DefaultServiceSlots int `yaml:"default_service_slots"` StarvationThresholdMins float64 `yaml:"starvation_threshold_mins"` PriorityAgingRate float64 `yaml:"priority_aging_rate"` GangAllocTimeoutSecs int `yaml:"gang_alloc_timeout_secs"` AcceptanceTimeoutSecs int `yaml:"acceptance_timeout_secs"` MetricsAddr string `yaml:"metrics_addr"` WorkerTokens []WorkerToken `yaml:"worker_tokens"` } type WorkerToken struct { ID string `yaml:"id"` Token string `yaml:"token"` } func main() { var ( configPath string generateToken bool initConfig bool numTokens int ) flag.StringVar(&configPath, "config", "scheduler.yaml", "Path to scheduler config file") flag.BoolVar(&generateToken, "generate-token", false, "Generate a new worker token and exit") flag.BoolVar(&initConfig, "init", false, "Initialize a new config file with generated tokens") flag.IntVar(&numTokens, "tokens", 3, "Number of tokens to generate (used with -init)") flag.Parse() // Handle token generation mode if generateToken { token := scheduler.GenerateWorkerToken() fmt.Println(token) os.Exit(0) } // Handle init mode if initConfig { if err := generateConfig(configPath, numTokens); err != nil { fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err) os.Exit(1) } fmt.Printf("Config generated: %s\n", configPath) fmt.Printf("\nGenerated %d worker tokens. Copy the appropriate token to each worker's config.\n", numTokens) os.Exit(0) } // Load config cfg, err := loadConfig(configPath) if err != nil { slog.Error("failed to load config", "error", err) os.Exit(1) } // Setup logging handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}) logger := slog.New(handler) slog.SetDefault(logger) // Create token map tokenMap := make(map[string]string) for _, wt := range cfg.Scheduler.WorkerTokens { tokenMap[wt.Token] = wt.ID } // Auto-generate certs if needed if cfg.Scheduler.AutoGenerateCerts && cfg.Scheduler.CertFile != "" { if _, err := os.Stat(cfg.Scheduler.CertFile); os.IsNotExist(err) { keyFile := cfg.Scheduler.KeyFile if keyFile == "" { keyFile = cfg.Scheduler.CertFile + ".key" } logger.Info("generating self-signed certificate", "cert", cfg.Scheduler.CertFile) if err := scheduler.GenerateSelfSignedCert(cfg.Scheduler.CertFile, keyFile); err != nil { logger.Error("failed to generate certificate", "error", err) os.Exit(1) } } } // Create hub config hubCfg := scheduler.HubConfig{ BindAddr: cfg.Scheduler.BindAddr, CertFile: cfg.Scheduler.CertFile, KeyFile: cfg.Scheduler.KeyFile, AutoGenerateCerts: cfg.Scheduler.AutoGenerateCerts, StateDir: cfg.Scheduler.StateDir, DefaultBatchSlots: cfg.Scheduler.DefaultBatchSlots, DefaultServiceSlots: cfg.Scheduler.DefaultServiceSlots, StarvationThresholdMins: cfg.Scheduler.StarvationThresholdMins, PriorityAgingRate: cfg.Scheduler.PriorityAgingRate, GangAllocTimeoutSecs: cfg.Scheduler.GangAllocTimeoutSecs, AcceptanceTimeoutSecs: cfg.Scheduler.AcceptanceTimeoutSecs, WorkerTokens: tokenMap, } // Create auditor (optional) var auditor *audit.Logger // Create hub hub, err := scheduler.NewHub(hubCfg, auditor) if err != nil { logger.Error("failed to create scheduler hub", "error", err) os.Exit(1) } // Start hub if err := hub.Start(); err != nil { logger.Error("failed to start scheduler hub", "error", err) os.Exit(1) } logger.Info("scheduler hub started", "bind_addr", cfg.Scheduler.BindAddr) // Setup HTTP handlers mux := http.NewServeMux() mux.HandleFunc("/ws/worker", hub.HandleConnection) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if _, err := w.Write([]byte(`{"status":"ok"}`)); err != nil { logger.Warn("health endpoint write failed", "error", err) } }) mux.HandleFunc("/metrics", hub.ServeMetrics) // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Start server with proper timeouts server := &http.Server{ Addr: cfg.Scheduler.BindAddr, Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, } go func() { if cfg.Scheduler.CertFile != "" { logger.Info("starting HTTPS server", "addr", cfg.Scheduler.BindAddr) if err := server.ListenAndServeTLS(cfg.Scheduler.CertFile, cfg.Scheduler.KeyFile); err != nil && err != http.ErrServerClosed { logger.Error("server error", "error", err) } } else { logger.Info("starting HTTP server", "addr", cfg.Scheduler.BindAddr) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Error("server error", "error", err) } } }() // Wait for shutdown signal <-sigChan logger.Info("shutting down scheduler...") hub.Stop() logger.Info("scheduler stopped") } func loadConfig(path string) (*Config, error) { // #nosec G304 -- Config path is provided by admin, not arbitrary user input data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read config file: %w", err) } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("parse config: %w", err) } // Set defaults if cfg.Scheduler.BindAddr == "" { cfg.Scheduler.BindAddr = "0.0.0.0:7777" } if cfg.Scheduler.StateDir == "" { cfg.Scheduler.StateDir = "/var/lib/fetch_ml" } if cfg.Scheduler.DefaultBatchSlots == 0 { cfg.Scheduler.DefaultBatchSlots = 3 } if cfg.Scheduler.DefaultServiceSlots == 0 { cfg.Scheduler.DefaultServiceSlots = 1 } if cfg.Scheduler.StarvationThresholdMins == 0 { cfg.Scheduler.StarvationThresholdMins = 5 } if cfg.Scheduler.PriorityAgingRate == 0 { cfg.Scheduler.PriorityAgingRate = 0.1 } if cfg.Scheduler.GangAllocTimeoutSecs == 0 { cfg.Scheduler.GangAllocTimeoutSecs = 60 } if cfg.Scheduler.AcceptanceTimeoutSecs == 0 { cfg.Scheduler.AcceptanceTimeoutSecs = 30 } return &cfg, nil } // generateConfig creates a new scheduler config file with generated tokens func generateConfig(path string, numTokens int) error { // Generate tokens var tokens []WorkerToken for i := 1; i <= numTokens; i++ { tokens = append(tokens, WorkerToken{ ID: fmt.Sprintf("worker-%02d", i), Token: scheduler.GenerateWorkerToken(), }) } cfg := Config{ Scheduler: SchedulerConfig{ BindAddr: "0.0.0.0:7777", AutoGenerateCerts: true, CertFile: "/etc/fetch_ml/scheduler.crt", KeyFile: "/etc/fetch_ml/scheduler.key", StateDir: "/var/lib/fetch_ml", DefaultBatchSlots: 3, DefaultServiceSlots: 1, StarvationThresholdMins: 5, PriorityAgingRate: 0.1, GangAllocTimeoutSecs: 60, AcceptanceTimeoutSecs: 30, MetricsAddr: "0.0.0.0:9090", WorkerTokens: tokens, }, } data, err := yaml.Marshal(cfg) if err != nil { return fmt.Errorf("marshal config: %w", err) } // Add header comment header := `# Scheduler Configuration for fetch_ml # Generated by: scheduler -init # # ⚠️ SECURITY WARNING: This file contains authentication tokens. # - Do NOT commit to git # - Keep the file permissions secure (chmod 600) # - Copy the appropriate token to each worker's config # ` fullContent := header + string(data) if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil { return fmt.Errorf("write config file: %w", err) } // Print tokens to stdout for easy distribution fmt.Print("\n=== Generated Worker Tokens ===\n") fmt.Print("Copy these to your worker configs:\n\n") for _, t := range tokens { fmt.Printf("Worker: %s\nToken: %s\n\n", t.ID, t.Token) } return nil }