// Package services provides TUI service clients package services import ( "context" "encoding/json" "fmt" "net/http" "net/url" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" "github.com/jfraeys/fetch_ml/internal/logging" ) // WebSocketClient manages real-time updates from the server type WebSocketClient struct { conn *websocket.Conn serverURL string apiKey string logger *logging.Logger // Channels for different update types jobUpdates chan model.JobUpdateMsg gpuUpdates chan model.GPUUpdateMsg statusUpdates chan model.StatusMsg // Control ctx context.Context cancel context.CancelFunc connected bool } // JobUpdateMsg represents a real-time job status update type JobUpdateMsg struct { JobName string `json:"job_name"` Status string `json:"status"` TaskID string `json:"task_id"` Progress int `json:"progress"` } // GPUUpdateMsg represents a real-time GPU status update type GPUUpdateMsg struct { DeviceID int `json:"device_id"` Utilization int `json:"utilization"` MemoryUsed int64 `json:"memory_used"` MemoryTotal int64 `json:"memory_total"` Temperature int `json:"temperature"` } // NewWebSocketClient creates a new WebSocket client func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient { ctx, cancel := context.WithCancel(context.Background()) return &WebSocketClient{ serverURL: serverURL, apiKey: apiKey, logger: logger, jobUpdates: make(chan model.JobUpdateMsg, 100), gpuUpdates: make(chan model.GPUUpdateMsg, 100), statusUpdates: make(chan model.StatusMsg, 100), ctx: ctx, cancel: cancel, } } // Connect establishes the WebSocket connection func (c *WebSocketClient) Connect() error { // Parse server URL and construct WebSocket URL u, err := url.Parse(c.serverURL) if err != nil { return fmt.Errorf("invalid server URL: %w", err) } // Convert http/https to ws/wss wsScheme := "ws" if u.Scheme == "https" { wsScheme = "wss" } wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host) // Create dialer with timeout dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, Subprotocols: []string{"fetchml-v1"}, } // Add API key to headers headers := http.Header{} if c.apiKey != "" { headers.Set("X-API-Key", c.apiKey) } conn, resp, err := dialer.Dial(wsURL, headers) if err != nil { if resp != nil { return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err) } return fmt.Errorf("websocket dial failed: %w", err) } c.conn = conn c.connected = true c.logger.Info("websocket connected", "url", wsURL) // Start message handler go c.messageHandler() // Start heartbeat go c.heartbeat() return nil } // Disconnect closes the WebSocket connection func (c *WebSocketClient) Disconnect() { c.cancel() if c.conn != nil { c.conn.Close() } c.connected = false } // IsConnected returns true if connected func (c *WebSocketClient) IsConnected() bool { return c.connected } // messageHandler reads messages from the WebSocket func (c *WebSocketClient) messageHandler() { for { select { case <-c.ctx.Done(): return default: } if c.conn == nil { time.Sleep(100 * time.Millisecond) continue } // Set read deadline c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Read message messageType, data, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { c.logger.Error("websocket read error", "error", err) } c.connected = false // Attempt reconnect time.Sleep(5 * time.Second) if err := c.Connect(); err != nil { c.logger.Error("websocket reconnect failed", "error", err) } continue } // Handle binary vs text messages if messageType == websocket.BinaryMessage { c.handleBinaryMessage(data) } else { c.handleTextMessage(data) } } } // handleBinaryMessage handles binary WebSocket messages func (c *WebSocketClient) handleBinaryMessage(data []byte) { if len(data) < 2 { return } // Binary protocol: [opcode:1][data...] opcode := data[0] payload := data[1:] switch opcode { case 0x01: // Job update var update JobUpdateMsg if err := json.Unmarshal(payload, &update); err != nil { c.logger.Error("failed to unmarshal job update", "error", err) return } c.jobUpdates <- model.JobUpdateMsg(update) case 0x02: // GPU update var update GPUUpdateMsg if err := json.Unmarshal(payload, &update); err != nil { c.logger.Error("failed to unmarshal GPU update", "error", err) return } c.gpuUpdates <- model.GPUUpdateMsg(update) case 0x03: // Status message var status model.StatusMsg if err := json.Unmarshal(payload, &status); err != nil { c.logger.Error("failed to unmarshal status", "error", err) return } c.statusUpdates <- status } } // handleTextMessage handles text WebSocket messages (JSON) func (c *WebSocketClient) handleTextMessage(data []byte) { var msg map[string]interface{} if err := json.Unmarshal(data, &msg); err != nil { c.logger.Error("failed to unmarshal text message", "error", err) return } msgType, _ := msg["type"].(string) switch msgType { case "job_update": // Handle JSON job updates case "gpu_update": // Handle JSON GPU updates case "status": // Handle status messages } } // heartbeat sends periodic ping messages func (c *WebSocketClient) heartbeat() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-c.ctx.Done(): return case <-ticker.C: if c.conn != nil { if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.logger.Error("websocket ping failed", "error", err) c.connected = false } } } } } // Subscribe subscribes to specific update channels func (c *WebSocketClient) Subscribe(channels ...string) error { if !c.connected { return fmt.Errorf("not connected") } subMsg := map[string]interface{}{ "action": "subscribe", "channels": channels, } data, _ := json.Marshal(subMsg) return c.conn.WriteMessage(websocket.TextMessage, data) } // GetJobUpdates returns the job updates channel func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg { return c.jobUpdates } // GetGPUUpdates returns the GPU updates channel func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg { return c.gpuUpdates } // GetStatusUpdates returns the status updates channel func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg { return c.statusUpdates }