Update worker system for scheduler integration: - Worker server with scheduler registration - Configuration with scheduler endpoint support - Artifact handling with integrity verification - Container executor with supply chain validation - Local executor enhancements - GPU detection improvements (cross-platform) - Error handling with execution context - Factory pattern for executor instantiation - Hash integrity with native library support
208 lines
5 KiB
Go
208 lines
5 KiB
Go
// Package main implements the ML task worker
|
|
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/invopop/yaml"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/worker"
|
|
)
|
|
|
|
const (
|
|
defaultConfigPath = "config-local.yaml"
|
|
)
|
|
|
|
func resolveWorkerConfigPath(flags *auth.Flags) string {
|
|
if flags != nil {
|
|
p := strings.TrimSpace(flags.ConfigFile)
|
|
if p != "" {
|
|
return p
|
|
}
|
|
}
|
|
if _, err := os.Stat("/app/configs/worker.yaml"); err == nil {
|
|
return "/app/configs/worker.yaml"
|
|
}
|
|
return defaultConfigPath
|
|
}
|
|
|
|
func main() {
|
|
var (
|
|
configPath string
|
|
initConfig bool
|
|
mode string
|
|
schedulerAddr string
|
|
token string
|
|
)
|
|
flag.StringVar(&configPath, "config", "worker.yaml", "Path to worker config file")
|
|
flag.BoolVar(&initConfig, "init", false, "Initialize a new worker config file")
|
|
flag.StringVar(&mode, "mode", "distributed", "Worker mode: standalone or distributed")
|
|
flag.StringVar(&schedulerAddr, "scheduler", "", "Scheduler address (for distributed mode)")
|
|
flag.StringVar(&token, "token", "", "Worker token (copy from scheduler -init output)")
|
|
flag.Parse()
|
|
|
|
// Handle init mode
|
|
if initConfig {
|
|
if err := generateWorkerConfig(configPath, mode, schedulerAddr, token); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
fmt.Printf("Config generated: %s\n", configPath)
|
|
fmt.Println("\nNext steps:")
|
|
if mode == "distributed" {
|
|
fmt.Println("1. Copy the token from your scheduler's -init output")
|
|
fmt.Println("2. Edit the config to set scheduler.address and scheduler.token")
|
|
fmt.Println("3. Copy the scheduler's TLS cert to the worker")
|
|
}
|
|
os.Exit(0)
|
|
}
|
|
|
|
// Normal worker startup...
|
|
|
|
// Parse authentication flags
|
|
authFlags := auth.ParseAuthFlags()
|
|
if err := auth.ValidateFlags(authFlags); err != nil {
|
|
log.Fatalf("Authentication flag error: %v", err)
|
|
}
|
|
|
|
// Get API key from various sources
|
|
apiKey := auth.GetAPIKeyFromSources(authFlags)
|
|
|
|
// Load configuration
|
|
resolvedConfig, err := config.ResolveConfigPath(resolveWorkerConfigPath(authFlags))
|
|
if err != nil {
|
|
log.Fatalf("%v", err)
|
|
}
|
|
|
|
cfg, err := worker.LoadConfig(resolvedConfig)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// Validate authentication configuration
|
|
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
|
log.Fatalf("Invalid authentication configuration: %v", err)
|
|
}
|
|
|
|
// Validate configuration
|
|
if err := cfg.Validate(); err != nil {
|
|
log.Fatalf("Invalid configuration: %v", err)
|
|
}
|
|
|
|
// Test authentication if enabled
|
|
if cfg.Auth.Enabled && apiKey != "" {
|
|
user, err := cfg.Auth.ValidateAPIKey(apiKey)
|
|
if err != nil {
|
|
log.Fatalf("Authentication failed: %v", err)
|
|
}
|
|
log.Printf("Worker authenticated as user: %s (admin: %v)", user.Name, user.Admin)
|
|
} else if cfg.Auth.Enabled {
|
|
log.Fatal("Authentication required but no API key provided")
|
|
}
|
|
|
|
wrk, err := worker.NewWorker(cfg, apiKey)
|
|
if err != nil {
|
|
log.Fatalf("Failed to create worker: %v", err)
|
|
}
|
|
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go wrk.Start()
|
|
|
|
sig := <-sigChan
|
|
log.Printf("Received signal: %v", sig)
|
|
|
|
// Use graceful shutdown
|
|
if err := wrk.Shutdown(); err != nil {
|
|
log.Printf("Graceful shutdown error: %v", err)
|
|
wrk.Stop() // Fallback to force stop
|
|
} else {
|
|
log.Println("Worker shut down gracefully")
|
|
}
|
|
}
|
|
|
|
// generateWorkerConfig creates a new worker config file
|
|
func generateWorkerConfig(path, mode, schedulerAddr, token string) error {
|
|
cfg := map[string]any{
|
|
"node": map[string]any{
|
|
"role": "worker",
|
|
"id": "",
|
|
},
|
|
"worker": map[string]any{
|
|
"mode": mode,
|
|
"max_workers": 3,
|
|
},
|
|
}
|
|
|
|
if mode == "distributed" {
|
|
cfg["scheduler"] = map[string]any{
|
|
"address": schedulerAddr,
|
|
"cert": "/etc/fetch_ml/scheduler.crt",
|
|
"token": token,
|
|
}
|
|
} else {
|
|
cfg["queue"] = map[string]any{
|
|
"backend": "redis",
|
|
"redis_addr": "localhost:6379",
|
|
"redis_password": "",
|
|
"redis_db": 0,
|
|
}
|
|
}
|
|
|
|
cfg["slots"] = map[string]any{
|
|
"service_slots": 1,
|
|
"ports": map[string]any{
|
|
"service_range_start": 8000,
|
|
"service_range_end": 8099,
|
|
},
|
|
}
|
|
|
|
cfg["gpu"] = map[string]any{
|
|
"vendor": "auto",
|
|
}
|
|
|
|
cfg["prewarm"] = map[string]any{
|
|
"enabled": true,
|
|
}
|
|
|
|
cfg["log"] = map[string]any{
|
|
"level": "info",
|
|
"format": "json",
|
|
}
|
|
|
|
data, err := yaml.Marshal(cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal config: %w", err)
|
|
}
|
|
|
|
// Add header comment
|
|
header := fmt.Sprintf(`# Worker Configuration for fetch_ml
|
|
# Generated by: worker -init
|
|
# Mode: %s
|
|
#`, mode)
|
|
|
|
if mode == "distributed" && token == "" {
|
|
header += `
|
|
# ⚠️ SECURITY WARNING: You must add the scheduler token to this config.
|
|
# Copy the token from the scheduler's -init output and paste it below.
|
|
# scheduler:
|
|
# token: "wkr_xxx..."
|
|
#`
|
|
}
|
|
|
|
fullContent := header + "\n\n" + string(data)
|
|
|
|
if err := os.WriteFile(path, []byte(fullContent), 0600); err != nil {
|
|
return fmt.Errorf("write config file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|