fetch_ml/internal/scheduler/auth.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

157 lines
4 KiB
Go

package scheduler
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gorilla/websocket"
)
// GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler
func GenerateSelfSignedCert(certFile, keyFile string) error {
if err := os.MkdirAll(filepath.Dir(certFile), 0755); err != nil {
return fmt.Errorf("create cert directory: %w", err)
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return fmt.Errorf("generate key: %w", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"fetch_ml"},
CommonName: "fetch_ml_scheduler",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
// Add IP SANs for local development
template.IPAddresses = []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return fmt.Errorf("create certificate: %w", err)
}
certOut, err := os.Create(certFile)
if err != nil {
return fmt.Errorf("create cert file: %w", err)
}
defer certOut.Close()
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("create key file: %w", err)
}
defer keyOut.Close()
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return fmt.Errorf("marshal private key: %w", err)
}
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return nil
}
// DialWSS connects to the scheduler via WSS with cert pinning
func DialWSS(addr, certFile, token string) (*websocket.Conn, error) {
certPEM, err := os.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf("read cert file: %w", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certPEM) {
return nil, fmt.Errorf("parse cert file")
}
dialer := websocket.Dialer{
TLSClientConfig: &tls.Config{
RootCAs: pool,
MinVersion: tls.VersionTLS12,
},
HandshakeTimeout: 10 * time.Second,
}
header := http.Header{}
if token != "" {
header.Set("Authorization", "Bearer "+token)
}
url := "wss://" + addr + "/ws/worker"
conn, _, err := dialer.Dial(url, header)
if err != nil {
return nil, fmt.Errorf("dial scheduler: %w", err)
}
return conn, nil
}
// LocalModeDial connects without TLS for single-node mode (loopback only)
func LocalModeDial(port int, token string) (*websocket.Conn, error) {
dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
header := http.Header{}
if token != "" {
header.Set("Authorization", "Bearer "+token)
}
url := fmt.Sprintf("ws://127.0.0.1:%d/ws/worker", port)
conn, _, err := dialer.Dial(url, header)
if err != nil {
return nil, fmt.Errorf("dial local scheduler: %w", err)
}
return conn, nil
}
// TokenValidator validates worker authentication tokens
type TokenValidator struct {
tokens map[string]string // token -> workerID
}
func NewTokenValidator(tokens map[string]string) *TokenValidator {
return &TokenValidator{tokens: tokens}
}
func (tv *TokenValidator) Validate(token string) (workerID string, ok bool) {
workerID, ok = tv.tokens[token]
return
}
// ExtractBearerToken extracts the token from an Authorization header
func ExtractBearerToken(header string) string {
return strings.TrimPrefix(header, "Bearer ")
}
// GenerateWorkerToken creates a cryptographically secure random token for a worker
func GenerateWorkerToken() string {
b := make([]byte, 32)
rand.Read(b)
// Use URL-safe base64 encoding for compact, URL-friendly tokens
return "wkr_" + base64.URLEncoding.EncodeToString(b)
}