Scheduler enhancements: - auth.go: Group membership validation in authentication - hub.go: Task distribution with group affinity - port_allocator.go: Dynamic port allocation with conflict resolution - scheduler_conn.go: Connection pooling and retry logic - service_manager.go: Lifecycle management for scheduler services - service_templates.go: Template-based service configuration - state.go: Persistent state management with recovery Worker improvements: - config.go: Extended configuration for task visibility rules - execution/setup.go: Sandboxed execution environment setup - executor/container.go: Container runtime integration - executor/runner.go: Task runner with visibility enforcement - gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback) - integrity/validate.go: Data integrity validation - lifecycle/runloop.go: Improved runloop with graceful shutdown - lifecycle/service_manager.go: Service lifecycle coordination - process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups - tenant/manager.go: Multi-tenant resource isolation - tenant/middleware.go: Tenant context propagation - worker.go: Core worker with group-scoped task execution
164 lines
4.3 KiB
Go
164 lines
4.3 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// GenerateSelfSignedCert creates a self-signed TLS certificate for the scheduler
|
|
func GenerateSelfSignedCert(certFile, keyFile string) error {
|
|
if err := os.MkdirAll(filepath.Dir(certFile), 0750); err != nil {
|
|
return fmt.Errorf("create cert directory: %w", err)
|
|
}
|
|
|
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return fmt.Errorf("generate key: %w", err)
|
|
}
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"fetch_ml"},
|
|
CommonName: "fetch_ml_scheduler",
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
// Add IP SANs for local development
|
|
template.IPAddresses = []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
|
if err != nil {
|
|
return fmt.Errorf("create certificate: %w", err)
|
|
}
|
|
|
|
// #nosec G304 -- certFile is internally controlled TLS cert path
|
|
certOut, err := os.Create(certFile)
|
|
if err != nil {
|
|
return fmt.Errorf("create cert file: %w", err)
|
|
}
|
|
defer certOut.Close()
|
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
|
return fmt.Errorf("encode cert: %w", err)
|
|
}
|
|
|
|
// #nosec G304 -- keyFile is internally controlled TLS key path
|
|
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("create key file: %w", err)
|
|
}
|
|
defer keyOut.Close()
|
|
|
|
keyDER, err := x509.MarshalECPrivateKey(priv)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal private key: %w", err)
|
|
}
|
|
if err := pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
|
|
return fmt.Errorf("encode key: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DialWSS connects to the scheduler via WSS with cert pinning
|
|
func DialWSS(addr, certFile, token string) (*websocket.Conn, error) {
|
|
// #nosec G304 -- certFile is internally controlled TLS cert path
|
|
certPEM, err := os.ReadFile(certFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read cert file: %w", err)
|
|
}
|
|
|
|
pool := x509.NewCertPool()
|
|
if !pool.AppendCertsFromPEM(certPEM) {
|
|
return nil, fmt.Errorf("parse cert file")
|
|
}
|
|
|
|
dialer := websocket.Dialer{
|
|
TLSClientConfig: &tls.Config{
|
|
RootCAs: pool,
|
|
MinVersion: tls.VersionTLS12,
|
|
},
|
|
HandshakeTimeout: 10 * time.Second,
|
|
}
|
|
|
|
header := http.Header{}
|
|
if token != "" {
|
|
header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
url := "wss://" + addr + "/ws/worker"
|
|
conn, _, err := dialer.Dial(url, header)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial scheduler: %w", err)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// LocalModeDial connects without TLS for single-node mode (loopback only)
|
|
func LocalModeDial(port int, token string) (*websocket.Conn, error) {
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 5 * time.Second,
|
|
}
|
|
|
|
header := http.Header{}
|
|
if token != "" {
|
|
header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
url := fmt.Sprintf("ws://127.0.0.1:%d/ws/worker", port)
|
|
conn, _, err := dialer.Dial(url, header)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial local scheduler: %w", err)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// TokenValidator validates worker authentication tokens
|
|
type TokenValidator struct {
|
|
tokens map[string]string // token -> workerID
|
|
}
|
|
|
|
func NewTokenValidator(tokens map[string]string) *TokenValidator {
|
|
return &TokenValidator{tokens: tokens}
|
|
}
|
|
|
|
func (tv *TokenValidator) Validate(token string) (workerID string, ok bool) {
|
|
workerID, ok = tv.tokens[token]
|
|
return
|
|
}
|
|
|
|
// ExtractBearerToken extracts the token from an Authorization header
|
|
func ExtractBearerToken(header string) string {
|
|
return strings.TrimPrefix(header, "Bearer ")
|
|
}
|
|
|
|
// GenerateWorkerToken creates a cryptographically secure random token for a worker
|
|
func GenerateWorkerToken() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
// Use URL-safe base64 encoding for compact, URL-friendly tokens
|
|
return "wkr_" + base64.URLEncoding.EncodeToString(b)
|
|
}
|