// Package main implements the ML task worker package main import ( "flag" "fmt" "log" "os" "os/signal" "strings" "syscall" "github.com/invopop/yaml" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/worker" ) const ( defaultConfigPath = "config-local.yaml" ) func resolveWorkerConfigPath(flags *auth.Flags) string { if flags != nil { p := strings.TrimSpace(flags.ConfigFile) if p != "" { return p } } if _, err := os.Stat("/app/configs/worker.yaml"); err == nil { return "/app/configs/worker.yaml" } return defaultConfigPath } func main() { var ( configPath string initConfig bool mode string schedulerAddr string token string ) flag.StringVar(&configPath, "config", "worker.yaml", "Path to worker config file") flag.BoolVar(&initConfig, "init", false, "Initialize a new worker config file") flag.StringVar(&mode, "mode", "distributed", "Worker mode: standalone or distributed") flag.StringVar(&schedulerAddr, "scheduler", "", "Scheduler address (for distributed mode)") flag.StringVar(&token, "token", "", "Worker token (copy from scheduler -init output)") flag.Parse() // Handle init mode if initConfig { if err := generateWorkerConfig(configPath, mode, schedulerAddr, token); err != nil { fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err) os.Exit(1) } fmt.Printf("Config generated: %s\n", configPath) fmt.Println("\nNext steps:") if mode == "distributed" { fmt.Println("1. Copy the token from your scheduler's -init output") fmt.Println("2. Edit the config to set scheduler.address and scheduler.token") fmt.Println("3. Copy the scheduler's TLS cert to the worker") } os.Exit(0) } // Normal worker startup... // Parse authentication flags authFlags := auth.ParseAuthFlags() if err := auth.ValidateFlags(authFlags); err != nil { log.Fatalf("Authentication flag error: %v", err) } // Get API key from various sources apiKey := auth.GetAPIKeyFromSources(authFlags) // Load configuration resolvedConfig, err := config.ResolveConfigPath(resolveWorkerConfigPath(authFlags)) if err != nil { log.Fatalf("%v", err) } cfg, err := worker.LoadConfig(resolvedConfig) if err != nil { log.Fatalf("Failed to load config: %v", err) } // Validate authentication configuration if err := cfg.Auth.ValidateAuthConfig(); err != nil { log.Fatalf("Invalid authentication configuration: %v", err) } // Validate configuration if err := cfg.Validate(); err != nil { log.Fatalf("Invalid configuration: %v", err) } // Test authentication if enabled if cfg.Auth.Enabled && apiKey != "" { user, err := cfg.Auth.ValidateAPIKey(apiKey) if err != nil { log.Fatalf("Authentication failed: %v", err) } log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin) } else if cfg.Auth.Enabled { log.Fatal("Authentication required but no API key provided") } wrk, err := worker.NewWorker(cfg, apiKey) if err != nil { log.Fatalf("Failed to create worker: %v", err) } sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go wrk.Start() sig := <-sigChan log.Printf("Received signal: %v", sig) // Use graceful shutdown if err := wrk.Shutdown(); err != nil { log.Printf("Graceful shutdown error: %v", err) wrk.Stop() // Fallback to force stop } else { log.Println("Worker shut down gracefully") } } // generateWorkerConfig creates a new worker config file func generateWorkerConfig(path, mode, schedulerAddr, token string) error { cfg := map[string]any{ "node": map[string]any{ "role": "worker", "id": "", }, "worker": map[string]any{ "mode": mode, "max_workers": 3, }, } if mode == "distributed" { cfg["scheduler"] = map[string]any{ "address": schedulerAddr, "cert": "/etc/fetch_ml/scheduler.crt", "token": token, } } else { cfg["queue"] = map[string]any{ "backend": "redis", "redis_addr": "localhost:6379", "redis_password": "", "redis_db": 0, } } cfg["slots"] = map[string]any{ "service_slots": 1, "ports": map[string]any{ "service_range_start": 8000, "service_range_end": 8099, }, } cfg["gpu"] = map[string]any{ "vendor": "auto", } cfg["prewarm"] = map[string]any{ "enabled": true, } cfg["log"] = map[string]any{ "level": "info", "format": "json", } data, err := yaml.Marshal(cfg) if err != nil { return fmt.Errorf("marshal config: %w", err) } // Add header comment header := fmt.Sprintf(`# Worker Configuration for fetch_ml # Generated by: worker -init # Mode: %s #`, mode) if mode == "distributed" && token == "" { header += ` # ⚠️ SECURITY WARNING: You must add the scheduler token to this config. # Copy the token from the scheduler's -init output and paste it below. # scheduler: # token: "wkr_xxx..." #` } fullContent := header + "\n\n" + string(data) if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil { return fmt.Errorf("write config file: %w", err) } return nil }