fetch_ml/internal/scheduler/scheduler_conn.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

217 lines
4.3 KiB
Go

package scheduler
import (
"encoding/json"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// SchedulerConn manages the WebSocket connection to the scheduler
type SchedulerConn struct {
addr string
certFile string
token string
conn *websocket.Conn
workerID string
capabilities WorkerCapabilities
send chan Message
activeTasks sync.Map
mu sync.RWMutex
closed bool
}
// NewSchedulerConn creates a new scheduler connection
func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn {
return &SchedulerConn{
addr: addr,
certFile: certFile,
token: token,
workerID: workerID,
capabilities: caps,
send: make(chan Message, 100),
}
}
// Connect establishes the WebSocket connection
func (sc *SchedulerConn) Connect() error {
var conn *websocket.Conn
var err error
if sc.certFile == "" {
// Local mode - no TLS, parse port from address
port := parsePortFromAddr(sc.addr)
conn, err = LocalModeDial(port, sc.token)
} else {
conn, err = DialWSS(sc.addr, sc.certFile, sc.token)
}
if err != nil {
return err
}
sc.mu.Lock()
sc.conn = conn
sc.closed = false
sc.mu.Unlock()
// Send registration
sc.Send(Message{
Type: MsgRegister,
Payload: mustMarshal(WorkerRegistration{
ID: sc.workerID,
Capabilities: sc.capabilities,
ActiveTasks: sc.collectActiveTasks(),
}),
})
return nil
}
// Send sends a message to the scheduler
func (sc *SchedulerConn) Send(msg Message) {
select {
case sc.send <- msg:
default:
// Channel full, drop message
}
}
// Run starts the send/receive loops
func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) {
// Send loop
go func() {
for msg := range sc.send {
sc.mu.RLock()
conn := sc.conn
closed := sc.closed
sc.mu.RUnlock()
if closed || conn == nil {
continue
}
if err := conn.WriteJSON(msg); err != nil {
// Trigger reconnect
go sc.reconnect()
}
}
}()
// Receive loop
for {
sc.mu.RLock()
conn := sc.conn
sc.mu.RUnlock()
if conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
var msg Message
if err := conn.ReadJSON(&msg); err != nil {
// Connection lost, reconnect
sc.reconnect()
continue
}
switch msg.Type {
case MsgJobAssign:
var spec JobSpec
json.Unmarshal(msg.Payload, &spec)
onJobAssign(&spec)
case MsgJobCancel:
var taskID string
json.Unmarshal(msg.Payload, &taskID)
onJobCancel(taskID)
case MsgPrewarmHint:
var hint PrewarmHintPayload
json.Unmarshal(msg.Payload, &hint)
onPrewarmHint(hint)
case MsgNoWork:
// No action needed - worker will retry
}
}
}
// reconnect attempts to reconnect with exponential backoff
func (sc *SchedulerConn) reconnect() {
sc.mu.Lock()
if sc.closed {
sc.mu.Unlock()
return
}
sc.conn = nil
sc.mu.Unlock()
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
for {
time.Sleep(backoff)
if err := sc.Connect(); err == nil {
return // Reconnected successfully
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
}
// Close closes the connection
func (sc *SchedulerConn) Close() {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.closed = true
if sc.conn != nil {
sc.conn.Close()
}
close(sc.send)
}
// TrackTask tracks an active task
func (sc *SchedulerConn) TrackTask(taskID string) {
sc.activeTasks.Store(taskID, time.Now())
}
// UntrackTask removes a task from tracking
func (sc *SchedulerConn) UntrackTask(taskID string) {
sc.activeTasks.Delete(taskID)
}
func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport {
var reports []ActiveTaskReport
sc.activeTasks.Range(func(key, value any) bool {
taskID := key.(string)
startedAt := value.(time.Time)
reports = append(reports, ActiveTaskReport{
TaskID: taskID,
State: "running",
StartedAt: startedAt,
})
return true
})
return reports
}
// parsePortFromAddr extracts port from "host:port" address string
// Returns default port 7777 if parsing fails
func parsePortFromAddr(addr string) int {
parts := strings.Split(addr, ":")
if len(parts) != 2 {
return 7777
}
port, err := strconv.Atoi(parts[1])
if err != nil {
return 7777
}
return port
}