- Refactor internal/worker and internal/queue packages - Update cmd/tui for monitoring interface - Update test configurations
275 lines
6.5 KiB
Go
275 lines
6.5 KiB
Go
// Package services provides TUI service clients
|
|
package services
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
// WebSocketClient manages real-time updates from the server
|
|
type WebSocketClient struct {
|
|
conn *websocket.Conn
|
|
serverURL string
|
|
apiKey string
|
|
logger *logging.Logger
|
|
|
|
// Channels for different update types
|
|
jobUpdates chan model.JobUpdateMsg
|
|
gpuUpdates chan model.GPUUpdateMsg
|
|
statusUpdates chan model.StatusMsg
|
|
|
|
// Control
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
connected bool
|
|
}
|
|
|
|
// JobUpdateMsg represents a real-time job status update
|
|
type JobUpdateMsg struct {
|
|
JobName string `json:"job_name"`
|
|
Status string `json:"status"`
|
|
TaskID string `json:"task_id"`
|
|
Progress int `json:"progress"`
|
|
}
|
|
|
|
// GPUUpdateMsg represents a real-time GPU status update
|
|
type GPUUpdateMsg struct {
|
|
DeviceID int `json:"device_id"`
|
|
Utilization int `json:"utilization"`
|
|
MemoryUsed int64 `json:"memory_used"`
|
|
MemoryTotal int64 `json:"memory_total"`
|
|
Temperature int `json:"temperature"`
|
|
}
|
|
|
|
// NewWebSocketClient creates a new WebSocket client
|
|
func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &WebSocketClient{
|
|
serverURL: serverURL,
|
|
apiKey: apiKey,
|
|
logger: logger,
|
|
jobUpdates: make(chan model.JobUpdateMsg, 100),
|
|
gpuUpdates: make(chan model.GPUUpdateMsg, 100),
|
|
statusUpdates: make(chan model.StatusMsg, 100),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// Connect establishes the WebSocket connection
|
|
func (c *WebSocketClient) Connect() error {
|
|
// Parse server URL and construct WebSocket URL
|
|
u, err := url.Parse(c.serverURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid server URL: %w", err)
|
|
}
|
|
|
|
// Convert http/https to ws/wss
|
|
wsScheme := "ws"
|
|
if u.Scheme == "https" {
|
|
wsScheme = "wss"
|
|
}
|
|
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
|
|
|
|
// Create dialer with timeout
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
Subprotocols: []string{"fetchml-v1"},
|
|
}
|
|
|
|
// Add API key to headers
|
|
headers := http.Header{}
|
|
if c.apiKey != "" {
|
|
headers.Set("X-API-Key", c.apiKey)
|
|
}
|
|
|
|
conn, resp, err := dialer.Dial(wsURL, headers)
|
|
if err != nil {
|
|
if resp != nil {
|
|
return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err)
|
|
}
|
|
return fmt.Errorf("websocket dial failed: %w", err)
|
|
}
|
|
|
|
c.conn = conn
|
|
c.connected = true
|
|
c.logger.Info("websocket connected", "url", wsURL)
|
|
|
|
// Start message handler
|
|
go c.messageHandler()
|
|
|
|
// Start heartbeat
|
|
go c.heartbeat()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnect closes the WebSocket connection
|
|
func (c *WebSocketClient) Disconnect() {
|
|
c.cancel()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
}
|
|
c.connected = false
|
|
}
|
|
|
|
// IsConnected returns true if connected
|
|
func (c *WebSocketClient) IsConnected() bool {
|
|
return c.connected
|
|
}
|
|
|
|
// messageHandler reads messages from the WebSocket
|
|
func (c *WebSocketClient) messageHandler() {
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if c.conn == nil {
|
|
time.Sleep(100 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
// Set read deadline
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
|
|
// Read message
|
|
messageType, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
c.logger.Error("websocket read error", "error", err)
|
|
}
|
|
c.connected = false
|
|
|
|
// Attempt reconnect
|
|
time.Sleep(5 * time.Second)
|
|
if err := c.Connect(); err != nil {
|
|
c.logger.Error("websocket reconnect failed", "error", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Handle binary vs text messages
|
|
if messageType == websocket.BinaryMessage {
|
|
c.handleBinaryMessage(data)
|
|
} else {
|
|
c.handleTextMessage(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleBinaryMessage handles binary WebSocket messages
|
|
func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
|
if len(data) < 2 {
|
|
return
|
|
}
|
|
|
|
// Binary protocol: [opcode:1][data...]
|
|
opcode := data[0]
|
|
payload := data[1:]
|
|
|
|
switch opcode {
|
|
case 0x01: // Job update
|
|
var update JobUpdateMsg
|
|
if err := json.Unmarshal(payload, &update); err != nil {
|
|
c.logger.Error("failed to unmarshal job update", "error", err)
|
|
return
|
|
}
|
|
c.jobUpdates <- model.JobUpdateMsg(update)
|
|
|
|
case 0x02: // GPU update
|
|
var update GPUUpdateMsg
|
|
if err := json.Unmarshal(payload, &update); err != nil {
|
|
c.logger.Error("failed to unmarshal GPU update", "error", err)
|
|
return
|
|
}
|
|
c.gpuUpdates <- model.GPUUpdateMsg(update)
|
|
|
|
case 0x03: // Status message
|
|
var status model.StatusMsg
|
|
if err := json.Unmarshal(payload, &status); err != nil {
|
|
c.logger.Error("failed to unmarshal status", "error", err)
|
|
return
|
|
}
|
|
c.statusUpdates <- status
|
|
}
|
|
}
|
|
|
|
// handleTextMessage handles text WebSocket messages (JSON)
|
|
func (c *WebSocketClient) handleTextMessage(data []byte) {
|
|
var msg map[string]interface{}
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
c.logger.Error("failed to unmarshal text message", "error", err)
|
|
return
|
|
}
|
|
|
|
msgType, _ := msg["type"].(string)
|
|
switch msgType {
|
|
case "job_update":
|
|
// Handle JSON job updates
|
|
case "gpu_update":
|
|
// Handle JSON GPU updates
|
|
case "status":
|
|
// Handle status messages
|
|
}
|
|
}
|
|
|
|
// heartbeat sends periodic ping messages
|
|
func (c *WebSocketClient) heartbeat() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if c.conn != nil {
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
c.logger.Error("websocket ping failed", "error", err)
|
|
c.connected = false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Subscribe subscribes to specific update channels
|
|
func (c *WebSocketClient) Subscribe(channels ...string) error {
|
|
if !c.connected {
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
|
|
subMsg := map[string]interface{}{
|
|
"action": "subscribe",
|
|
"channels": channels,
|
|
}
|
|
|
|
data, _ := json.Marshal(subMsg)
|
|
return c.conn.WriteMessage(websocket.TextMessage, data)
|
|
}
|
|
|
|
// GetJobUpdates returns the job updates channel
|
|
func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg {
|
|
return c.jobUpdates
|
|
}
|
|
|
|
// GetGPUUpdates returns the GPU updates channel
|
|
func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg {
|
|
return c.gpuUpdates
|
|
}
|
|
|
|
// GetStatusUpdates returns the status updates channel
|
|
func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg {
|
|
return c.statusUpdates
|
|
}
|