fetch_ml/internal/scheduler/scheduler_conn.go
Jeremie Fraeys 0b5e99f720
refactor(scheduler,worker): improve service management and GPU detection
Scheduler enhancements:
- auth.go: Group membership validation in authentication
- hub.go: Task distribution with group affinity
- port_allocator.go: Dynamic port allocation with conflict resolution
- scheduler_conn.go: Connection pooling and retry logic
- service_manager.go: Lifecycle management for scheduler services
- service_templates.go: Template-based service configuration
- state.go: Persistent state management with recovery

Worker improvements:
- config.go: Extended configuration for task visibility rules
- execution/setup.go: Sandboxed execution environment setup
- executor/container.go: Container runtime integration
- executor/runner.go: Task runner with visibility enforcement
- gpu_detector.go: Robust GPU detection (NVIDIA, AMD, Apple Silicon, CPU fallback)
- integrity/validate.go: Data integrity validation
- lifecycle/runloop.go: Improved runloop with graceful shutdown
- lifecycle/service_manager.go: Service lifecycle coordination
- process/isolation.go + isolation_unix.go: Process isolation with namespaces/cgroups
- tenant/manager.go: Multi-tenant resource isolation
- tenant/middleware.go: Tenant context propagation
- worker.go: Core worker with group-scoped task execution
2026-03-08 13:03:15 -04:00

229 lines
4.7 KiB
Go

package scheduler
import (
"encoding/json"
"log/slog"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// SchedulerConn manages the WebSocket connection to the scheduler
type SchedulerConn struct {
addr string
certFile string
token string
conn *websocket.Conn
workerID string
capabilities WorkerCapabilities
send chan Message
activeTasks sync.Map
mu sync.RWMutex
closed bool
}
// NewSchedulerConn creates a new scheduler connection
func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn {
return &SchedulerConn{
addr: addr,
certFile: certFile,
token: token,
workerID: workerID,
capabilities: caps,
send: make(chan Message, 100),
}
}
// Connect establishes the WebSocket connection
func (sc *SchedulerConn) Connect() error {
var conn *websocket.Conn
var err error
if sc.certFile == "" {
// Local mode - no TLS, parse port from address
port := parsePortFromAddr(sc.addr)
conn, err = LocalModeDial(port, sc.token)
} else {
conn, err = DialWSS(sc.addr, sc.certFile, sc.token)
}
if err != nil {
return err
}
sc.mu.Lock()
sc.conn = conn
sc.closed = false
sc.mu.Unlock()
// Send registration
sc.Send(Message{
Type: MsgRegister,
Payload: mustMarshal(WorkerRegistration{
ID: sc.workerID,
Capabilities: sc.capabilities,
ActiveTasks: sc.collectActiveTasks(),
}),
})
return nil
}
// Send sends a message to the scheduler
func (sc *SchedulerConn) Send(msg Message) {
select {
case sc.send <- msg:
default:
// Channel full, drop message
}
}
// Run starts the send/receive loops
func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) {
// Send loop
go func() {
for msg := range sc.send {
sc.mu.RLock()
conn := sc.conn
closed := sc.closed
sc.mu.RUnlock()
if closed || conn == nil {
continue
}
if err := conn.WriteJSON(msg); err != nil {
// Trigger reconnect
go sc.reconnect()
}
}
}()
// Receive loop
for {
sc.mu.RLock()
conn := sc.conn
sc.mu.RUnlock()
if conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
var msg Message
if err := conn.ReadJSON(&msg); err != nil {
// Connection lost, reconnect
sc.reconnect()
continue
}
switch msg.Type {
case MsgJobAssign:
var spec JobSpec
if err := json.Unmarshal(msg.Payload, &spec); err != nil {
slog.Error("failed to unmarshal job assign", "error", err)
continue
}
onJobAssign(&spec)
case MsgJobCancel:
var taskID string
if err := json.Unmarshal(msg.Payload, &taskID); err != nil {
slog.Error("failed to unmarshal job cancel", "error", err)
continue
}
onJobCancel(taskID)
case MsgPrewarmHint:
var hint PrewarmHintPayload
if err := json.Unmarshal(msg.Payload, &hint); err != nil {
slog.Error("failed to unmarshal prewarm hint", "error", err)
continue
}
onPrewarmHint(hint)
case MsgNoWork:
// No action needed - worker will retry
}
}
}
// reconnect attempts to reconnect with exponential backoff
func (sc *SchedulerConn) reconnect() {
sc.mu.Lock()
if sc.closed {
sc.mu.Unlock()
return
}
sc.conn = nil
sc.mu.Unlock()
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
for {
time.Sleep(backoff)
if err := sc.Connect(); err == nil {
return // Reconnected successfully
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
}
// Close closes the connection
func (sc *SchedulerConn) Close() {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.closed = true
if sc.conn != nil {
if err := sc.conn.Close(); err != nil {
slog.Error("failed to close connection", "error", err)
}
}
close(sc.send)
}
// TrackTask tracks an active task
func (sc *SchedulerConn) TrackTask(taskID string) {
sc.activeTasks.Store(taskID, time.Now())
}
// UntrackTask removes a task from tracking
func (sc *SchedulerConn) UntrackTask(taskID string) {
sc.activeTasks.Delete(taskID)
}
func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport {
var reports []ActiveTaskReport
sc.activeTasks.Range(func(key, value any) bool {
taskID := key.(string)
startedAt := value.(time.Time)
reports = append(reports, ActiveTaskReport{
TaskID: taskID,
State: "running",
StartedAt: startedAt,
})
return true
})
return reports
}
// parsePortFromAddr extracts port from "host:port" address string
// Returns default port 7777 if parsing fails
func parsePortFromAddr(addr string) int {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
return 7777
}
port, err := strconv.Atoi(parts[1])
if err != nil {
return 7777
}
return port
}