package scheduler import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/pem" "fmt" "math/big" "net" "net/http" "os" "path/filepath" "strings" "time" "github.com/gorilla/websocket" ) // GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler func GenerateSelfSignedCert(certFile, keyFile string) error { if err := os.MkdirAll(filepath.Dir(certFile), 0750); err != nil { return fmt.Errorf("create cert directory: %w", err) } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return fmt.Errorf("generate key: %w", err) } template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"fetch_ml"}, CommonName: "fetch_ml_scheduler", }, NotBefore: time.Now(), NotAfter: time.Now().Add(365 * 24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } // Add IP SANs for local development template.IPAddresses = []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback} certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return fmt.Errorf("create certificate: %w", err) } // #nosec G304 -- certFile is internally controlled TLS cert path certOut, err := os.Create(certFile) if err != nil { return fmt.Errorf("create cert file: %w", err) } defer certOut.Close() if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { return fmt.Errorf("encode cert: %w", err) } // #nosec G304 -- keyFile is internally controlled TLS key path keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("create key file: %w", err) } defer keyOut.Close() keyDER, err := x509.MarshalECPrivateKey(priv) if err != nil { return fmt.Errorf("marshal private key: %w", err) } if err := pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { return fmt.Errorf("encode key: %w", err) } return nil } // DialWSS connects to the scheduler via WSS with cert pinning func DialWSS(addr, certFile, token string) (*websocket.Conn, error) { // #nosec G304 -- certFile is internally controlled TLS cert path certPEM, err := os.ReadFile(certFile) if err != nil { return nil, fmt.Errorf("read cert file: %w", err) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(certPEM) { return nil, fmt.Errorf("parse cert file") } dialer := websocket.Dialer{ TLSClientConfig: &tls.Config{ RootCAs: pool, MinVersion: tls.VersionTLS12, }, HandshakeTimeout: 10 * time.Second, } header := http.Header{} if token != "" { header.Set("Authorization", "Bearer "+token) } url := "wss://" + addr + "/ws/worker" conn, _, err := dialer.Dial(url, header) if err != nil { return nil, fmt.Errorf("dial scheduler: %w", err) } return conn, nil } // LocalModeDial connects without TLS for single-node mode (loopback only) func LocalModeDial(port int, token string) (*websocket.Conn, error) { dialer := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } header := http.Header{} if token != "" { header.Set("Authorization", "Bearer "+token) } url := fmt.Sprintf("ws://127.0.0.1:%d/ws/worker", port) conn, _, err := dialer.Dial(url, header) if err != nil { return nil, fmt.Errorf("dial local scheduler: %w", err) } return conn, nil } // TokenValidator validates worker authentication tokens type TokenValidator struct { tokens map[string]string // token -> workerID } func NewTokenValidator(tokens map[string]string) *TokenValidator { return &TokenValidator{tokens: tokens} } func (tv *TokenValidator) Validate(token string) (workerID string, ok bool) { workerID, ok = tv.tokens[token] return } // ExtractBearerToken extracts the token from an Authorization header func ExtractBearerToken(header string) string { return strings.TrimPrefix(header, "Bearer ") } // GenerateWorkerToken creates a cryptographically secure random token for a worker func GenerateWorkerToken() string { b := make([]byte, 32) rand.Read(b) // Use URL-safe base64 encoding for compact, URL-friendly tokens return "wkr_" + base64.URLEncoding.EncodeToString(b) }