fetch_ml/internal/api/ws_handler.go

279 lines
8 KiB
Go

package api
import (
"compress/flate"
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeDatasetList = 0x06
OpcodeDatasetRegister = 0x07
OpcodeDatasetInfo = 0x08
OpcodeDatasetSearch = 0x09
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
OpcodeQueueJobWithTracking = 0x0C
OpcodeQueueJobWithSnapshot = 0x17
OpcodeStartJupyter = 0x0D
OpcodeStopJupyter = 0x0E
OpcodeRemoveJupyter = 0x18
OpcodeRestoreJupyter = 0x19
OpcodeListJupyter = 0x0F
OpcodeValidateRequest = 0x16
)
// createUpgrader creates a WebSocket upgrader with the given security configuration
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
return websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Production mode: strict checking against allowed origins
if securityCfg != nil && securityCfg.ProductionMode {
for _, allowed := range securityCfg.AllowedOrigins {
if origin == allowed {
return true
}
}
return false // Reject if not in allowed list
}
// Development mode: allow localhost and local network origins
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasPrefix(host, "localhost:") ||
strings.HasPrefix(host, "127.0.0.1:") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
// Performance optimizations
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
EnableCompression: true,
}
}
// WSHandler handles WebSocket connections for the API.
type WSHandler struct {
authConfig *auth.Config
logger *logging.Logger
expManager *experiment.Manager
dataDir string
queue queue.Backend
db *storage.DB
jupyterServiceMgr *jupyter.ServiceManager
securityConfig *config.SecurityConfig
auditLogger *audit.Logger
upgrader websocket.Upgrader
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
dataDir string,
taskQueue queue.Backend,
db *storage.DB,
jupyterServiceMgr *jupyter.ServiceManager,
securityConfig *config.SecurityConfig,
auditLogger *audit.Logger,
) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
expManager: expManager,
dataDir: dataDir,
queue: taskQueue,
db: db,
jupyterServiceMgr: jupyterServiceMgr,
securityConfig: securityConfig,
auditLogger: auditLogger,
upgrader: createUpgrader(securityConfig),
}
}
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
if conn == nil {
return
}
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
if err := tcpConn.SetNoDelay(true); err != nil {
logger.Warn("failed to enable tcp no delay", "error", err)
}
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
if r.TLS != nil {
// Only set HSTS if using HTTPS
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
// Check API key before upgrading WebSocket
apiKey := auth.ExtractAPIKeyFromRequest(r)
clientIP := r.RemoteAddr
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
prefixLen := len(apiKey)
if prefixLen > 8 {
prefixLen = 8
}
h.logger.Info(
"websocket auth attempt",
"api_key_length",
len(apiKey),
"api_key_prefix",
apiKey[:prefixLen],
)
userID, err := h.authConfig.ValidateAPIKey(apiKey)
if err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
// Audit log failed authentication
if h.auditLogger != nil {
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
}
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
h.logger.Info("websocket authentication succeeded")
// Audit log successful authentication
if h.auditLogger != nil && userID != nil {
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
}
}
conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
conn.EnableWriteCompression(true)
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
h.logger.Warn("failed to set websocket compression level", "error", err)
}
enableLowLatencyTCP(conn, h.logger)
defer func() {
_ = conn.Close()
}()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(
err,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send structured error response so CLI clients can parse it.
// (Raw fallback bytes cause client-side InvalidPacket errors.)
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeQueueJobWithTracking:
return h.handleQueueJobWithTracking(conn, payload)
case OpcodeQueueJobWithSnapshot:
return h.handleQueueJobWithSnapshot(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeDatasetList:
return h.handleDatasetList(conn, payload)
case OpcodeDatasetRegister:
return h.handleDatasetRegister(conn, payload)
case OpcodeDatasetInfo:
return h.handleDatasetInfo(conn, payload)
case OpcodeDatasetSearch:
return h.handleDatasetSearch(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
case OpcodeStartJupyter:
return h.handleStartJupyter(conn, payload)
case OpcodeStopJupyter:
return h.handleStopJupyter(conn, payload)
case OpcodeRemoveJupyter:
return h.handleRemoveJupyter(conn, payload)
case OpcodeRestoreJupyter:
return h.handleRestoreJupyter(conn, payload)
case OpcodeListJupyter:
return h.handleListJupyter(conn, payload)
case OpcodeValidateRequest:
return h.handleValidateRequest(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}