fetch_ml/cmd/tui/internal/services/websocket.go
Jeremie Fraeys 23e5f3d1dc
refactor(api): internal refactoring for TUI and worker modules
- Refactor internal/worker and internal/queue packages
- Update cmd/tui for monitoring interface
- Update test configurations
2026-02-20 15:51:23 -05:00

275 lines
6.5 KiB
Go

// Package services provides TUI service clients
package services
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// WebSocketClient manages real-time updates from the server
type WebSocketClient struct {
conn *websocket.Conn
serverURL string
apiKey string
logger *logging.Logger
// Channels for different update types
jobUpdates chan model.JobUpdateMsg
gpuUpdates chan model.GPUUpdateMsg
statusUpdates chan model.StatusMsg
// Control
ctx context.Context
cancel context.CancelFunc
connected bool
}
// JobUpdateMsg represents a real-time job status update
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}
// NewWebSocketClient creates a new WebSocket client
func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketClient{
serverURL: serverURL,
apiKey: apiKey,
logger: logger,
jobUpdates: make(chan model.JobUpdateMsg, 100),
gpuUpdates: make(chan model.GPUUpdateMsg, 100),
statusUpdates: make(chan model.StatusMsg, 100),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the WebSocket connection
func (c *WebSocketClient) Connect() error {
// Parse server URL and construct WebSocket URL
u, err := url.Parse(c.serverURL)
if err != nil {
return fmt.Errorf("invalid server URL: %w", err)
}
// Convert http/https to ws/wss
wsScheme := "ws"
if u.Scheme == "https" {
wsScheme = "wss"
}
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
// Create dialer with timeout
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
Subprotocols: []string{"fetchml-v1"},
}
// Add API key to headers
headers := http.Header{}
if c.apiKey != "" {
headers.Set("X-API-Key", c.apiKey)
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
if resp != nil {
return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err)
}
return fmt.Errorf("websocket dial failed: %w", err)
}
c.conn = conn
c.connected = true
c.logger.Info("websocket connected", "url", wsURL)
// Start message handler
go c.messageHandler()
// Start heartbeat
go c.heartbeat()
return nil
}
// Disconnect closes the WebSocket connection
func (c *WebSocketClient) Disconnect() {
c.cancel()
if c.conn != nil {
c.conn.Close()
}
c.connected = false
}
// IsConnected returns true if connected
func (c *WebSocketClient) IsConnected() bool {
return c.connected
}
// messageHandler reads messages from the WebSocket
func (c *WebSocketClient) messageHandler() {
for {
select {
case <-c.ctx.Done():
return
default:
}
if c.conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Read message
messageType, data, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.logger.Error("websocket read error", "error", err)
}
c.connected = false
// Attempt reconnect
time.Sleep(5 * time.Second)
if err := c.Connect(); err != nil {
c.logger.Error("websocket reconnect failed", "error", err)
}
continue
}
// Handle binary vs text messages
if messageType == websocket.BinaryMessage {
c.handleBinaryMessage(data)
} else {
c.handleTextMessage(data)
}
}
}
// handleBinaryMessage handles binary WebSocket messages
func (c *WebSocketClient) handleBinaryMessage(data []byte) {
if len(data) < 2 {
return
}
// Binary protocol: [opcode:1][data...]
opcode := data[0]
payload := data[1:]
switch opcode {
case 0x01: // Job update
var update JobUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal job update", "error", err)
return
}
c.jobUpdates <- model.JobUpdateMsg(update)
case 0x02: // GPU update
var update GPUUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal GPU update", "error", err)
return
}
c.gpuUpdates <- model.GPUUpdateMsg(update)
case 0x03: // Status message
var status model.StatusMsg
if err := json.Unmarshal(payload, &status); err != nil {
c.logger.Error("failed to unmarshal status", "error", err)
return
}
c.statusUpdates <- status
}
}
// handleTextMessage handles text WebSocket messages (JSON)
func (c *WebSocketClient) handleTextMessage(data []byte) {
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err != nil {
c.logger.Error("failed to unmarshal text message", "error", err)
return
}
msgType, _ := msg["type"].(string)
switch msgType {
case "job_update":
// Handle JSON job updates
case "gpu_update":
// Handle JSON GPU updates
case "status":
// Handle status messages
}
}
// heartbeat sends periodic ping messages
func (c *WebSocketClient) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if c.conn != nil {
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.logger.Error("websocket ping failed", "error", err)
c.connected = false
}
}
}
}
}
// Subscribe subscribes to specific update channels
func (c *WebSocketClient) Subscribe(channels ...string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
subMsg := map[string]interface{}{
"action": "subscribe",
"channels": channels,
}
data, _ := json.Marshal(subMsg)
return c.conn.WriteMessage(websocket.TextMessage, data)
}
// GetJobUpdates returns the job updates channel
func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg {
return c.jobUpdates
}
// GetGPUUpdates returns the GPU updates channel
func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg {
return c.gpuUpdates
}
// GetStatusUpdates returns the status updates channel
func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg {
return c.statusUpdates
}