package scheduler import ( "encoding/json" "strconv" "strings" "sync" "time" "github.com/gorilla/websocket" ) // SchedulerConn manages the WebSocket connection to the scheduler type SchedulerConn struct { addr string certFile string token string conn *websocket.Conn workerID string capabilities WorkerCapabilities send chan Message activeTasks sync.Map mu sync.RWMutex closed bool } // NewSchedulerConn creates a new scheduler connection func NewSchedulerConn(addr, certFile, token, workerID string, caps WorkerCapabilities) *SchedulerConn { return &SchedulerConn{ addr: addr, certFile: certFile, token: token, workerID: workerID, capabilities: caps, send: make(chan Message, 100), } } // Connect establishes the WebSocket connection func (sc *SchedulerConn) Connect() error { var conn *websocket.Conn var err error if sc.certFile == "" { // Local mode - no TLS, parse port from address port := parsePortFromAddr(sc.addr) conn, err = LocalModeDial(port, sc.token) } else { conn, err = DialWSS(sc.addr, sc.certFile, sc.token) } if err != nil { return err } sc.mu.Lock() sc.conn = conn sc.closed = false sc.mu.Unlock() // Send registration sc.Send(Message{ Type: MsgRegister, Payload: mustMarshal(WorkerRegistration{ ID: sc.workerID, Capabilities: sc.capabilities, ActiveTasks: sc.collectActiveTasks(), }), }) return nil } // Send sends a message to the scheduler func (sc *SchedulerConn) Send(msg Message) { select { case sc.send <- msg: default: // Channel full, drop message } } // Run starts the send/receive loops func (sc *SchedulerConn) Run(onJobAssign func(*JobSpec), onJobCancel func(string), onPrewarmHint func(PrewarmHintPayload)) { // Send loop go func() { for msg := range sc.send { sc.mu.RLock() conn := sc.conn closed := sc.closed sc.mu.RUnlock() if closed || conn == nil { continue } if err := conn.WriteJSON(msg); err != nil { // Trigger reconnect go sc.reconnect() } } }() // Receive loop for { sc.mu.RLock() conn := sc.conn sc.mu.RUnlock() if conn == nil { time.Sleep(100 * time.Millisecond) continue } var msg Message if err := conn.ReadJSON(&msg); err != nil { // Connection lost, reconnect sc.reconnect() continue } switch msg.Type { case MsgJobAssign: var spec JobSpec json.Unmarshal(msg.Payload, &spec) onJobAssign(&spec) case MsgJobCancel: var taskID string json.Unmarshal(msg.Payload, &taskID) onJobCancel(taskID) case MsgPrewarmHint: var hint PrewarmHintPayload json.Unmarshal(msg.Payload, &hint) onPrewarmHint(hint) case MsgNoWork: // No action needed - worker will retry } } } // reconnect attempts to reconnect with exponential backoff func (sc *SchedulerConn) reconnect() { sc.mu.Lock() if sc.closed { sc.mu.Unlock() return } sc.conn = nil sc.mu.Unlock() backoff := 1 * time.Second maxBackoff := 30 * time.Second for { time.Sleep(backoff) if err := sc.Connect(); err == nil { return // Reconnected successfully } backoff *= 2 if backoff > maxBackoff { backoff = maxBackoff } } } // Close closes the connection func (sc *SchedulerConn) Close() { sc.mu.Lock() defer sc.mu.Unlock() sc.closed = true if sc.conn != nil { sc.conn.Close() } close(sc.send) } // TrackTask tracks an active task func (sc *SchedulerConn) TrackTask(taskID string) { sc.activeTasks.Store(taskID, time.Now()) } // UntrackTask removes a task from tracking func (sc *SchedulerConn) UntrackTask(taskID string) { sc.activeTasks.Delete(taskID) } func (sc *SchedulerConn) collectActiveTasks() []ActiveTaskReport { var reports []ActiveTaskReport sc.activeTasks.Range(func(key, value any) bool { taskID := key.(string) startedAt := value.(time.Time) reports = append(reports, ActiveTaskReport{ TaskID: taskID, State: "running", StartedAt: startedAt, }) return true }) return reports } // parsePortFromAddr extracts port from "host:port" address string // Returns default port 7777 if parsing fails func parsePortFromAddr(addr string) int { parts := strings.Split(addr, ":") if len(parts) != 2 { return 7777 } port, err := strconv.Atoi(parts[1]) if err != nil { return 7777 } return port }