Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
217 lines
4.3 KiB
Go
217 lines
4.3 KiB
Go
package scheduler
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// SchedulerConn manages the WebSocket connection to the scheduler
|
|
type SchedulerConn struct {
|
|
addr string
|
|
certFile string
|
|
token string
|
|
conn *websocket.Conn
|
|
workerID string
|
|
capabilities WorkerCapabilities
|
|
send chan Message
|
|
activeTasks sync.Map
|
|
mu sync.RWMutex
|
|
closed bool
|
|
}
|
|
|
|
// NewSchedulerConn creates a new scheduler connection
|
|
func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn {
|
|
return &SchedulerConn{
|
|
addr: addr,
|
|
certFile: certFile,
|
|
token: token,
|
|
workerID: workerID,
|
|
capabilities: caps,
|
|
send: make(chan Message, 100),
|
|
}
|
|
}
|
|
|
|
// Connect establishes the WebSocket connection
|
|
func (sc *SchedulerConn) Connect() error {
|
|
var conn *websocket.Conn
|
|
var err error
|
|
|
|
if sc.certFile == "" {
|
|
// Local mode - no TLS, parse port from address
|
|
port := parsePortFromAddr(sc.addr)
|
|
conn, err = LocalModeDial(port, sc.token)
|
|
} else {
|
|
conn, err = DialWSS(sc.addr, sc.certFile, sc.token)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sc.mu.Lock()
|
|
sc.conn = conn
|
|
sc.closed = false
|
|
sc.mu.Unlock()
|
|
|
|
// Send registration
|
|
sc.Send(Message{
|
|
Type: MsgRegister,
|
|
Payload: mustMarshal(WorkerRegistration{
|
|
ID: sc.workerID,
|
|
Capabilities: sc.capabilities,
|
|
ActiveTasks: sc.collectActiveTasks(),
|
|
}),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// Send sends a message to the scheduler
|
|
func (sc *SchedulerConn) Send(msg Message) {
|
|
select {
|
|
case sc.send <- msg:
|
|
default:
|
|
// Channel full, drop message
|
|
}
|
|
}
|
|
|
|
// Run starts the send/receive loops
|
|
func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) {
|
|
// Send loop
|
|
go func() {
|
|
for msg := range sc.send {
|
|
sc.mu.RLock()
|
|
conn := sc.conn
|
|
closed := sc.closed
|
|
sc.mu.RUnlock()
|
|
|
|
if closed || conn == nil {
|
|
continue
|
|
}
|
|
|
|
if err := conn.WriteJSON(msg); err != nil {
|
|
// Trigger reconnect
|
|
go sc.reconnect()
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Receive loop
|
|
for {
|
|
sc.mu.RLock()
|
|
conn := sc.conn
|
|
sc.mu.RUnlock()
|
|
|
|
if conn == nil {
|
|
time.Sleep(100 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
var msg Message
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
// Connection lost, reconnect
|
|
sc.reconnect()
|
|
continue
|
|
}
|
|
|
|
switch msg.Type {
|
|
case MsgJobAssign:
|
|
var spec JobSpec
|
|
json.Unmarshal(msg.Payload, &spec)
|
|
onJobAssign(&spec)
|
|
case MsgJobCancel:
|
|
var taskID string
|
|
json.Unmarshal(msg.Payload, &taskID)
|
|
onJobCancel(taskID)
|
|
case MsgPrewarmHint:
|
|
var hint PrewarmHintPayload
|
|
json.Unmarshal(msg.Payload, &hint)
|
|
onPrewarmHint(hint)
|
|
case MsgNoWork:
|
|
// No action needed - worker will retry
|
|
}
|
|
}
|
|
}
|
|
|
|
// reconnect attempts to reconnect with exponential backoff
|
|
func (sc *SchedulerConn) reconnect() {
|
|
sc.mu.Lock()
|
|
if sc.closed {
|
|
sc.mu.Unlock()
|
|
return
|
|
}
|
|
sc.conn = nil
|
|
sc.mu.Unlock()
|
|
|
|
backoff := 1 * time.Second
|
|
maxBackoff := 30 * time.Second
|
|
|
|
for {
|
|
time.Sleep(backoff)
|
|
|
|
if err := sc.Connect(); err == nil {
|
|
return // Reconnected successfully
|
|
}
|
|
|
|
backoff *= 2
|
|
if backoff > maxBackoff {
|
|
backoff = maxBackoff
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close closes the connection
|
|
func (sc *SchedulerConn) Close() {
|
|
sc.mu.Lock()
|
|
defer sc.mu.Unlock()
|
|
|
|
sc.closed = true
|
|
if sc.conn != nil {
|
|
sc.conn.Close()
|
|
}
|
|
close(sc.send)
|
|
}
|
|
|
|
// TrackTask tracks an active task
|
|
func (sc *SchedulerConn) TrackTask(taskID string) {
|
|
sc.activeTasks.Store(taskID, time.Now())
|
|
}
|
|
|
|
// UntrackTask removes a task from tracking
|
|
func (sc *SchedulerConn) UntrackTask(taskID string) {
|
|
sc.activeTasks.Delete(taskID)
|
|
}
|
|
|
|
func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport {
|
|
var reports []ActiveTaskReport
|
|
sc.activeTasks.Range(func(key, value any) bool {
|
|
taskID := key.(string)
|
|
startedAt := value.(time.Time)
|
|
reports = append(reports, ActiveTaskReport{
|
|
TaskID: taskID,
|
|
State: "running",
|
|
StartedAt: startedAt,
|
|
})
|
|
return true
|
|
})
|
|
return reports
|
|
}
|
|
|
|
// parsePortFromAddr extracts port from "host:port" address string
|
|
// Returns default port 7777 if parsing fails
|
|
func parsePortFromAddr(addr string) int {
|
|
parts := strings.Split(addr, ":")
|
|
if len(parts) != 2 {
|
|
return 7777
|
|
}
|
|
port, err := strconv.Atoi(parts[1])
|
|
if err != nil {
|
|
return 7777
|
|
}
|
|
return port
|
|
}
|