279 lines
8 KiB
Go
279 lines
8 KiB
Go
package api
|
|
|
|
import (
|
|
"compress/flate"
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/audit"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/config"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
"github.com/jfraeys/fetch_ml/internal/storage"
|
|
)
|
|
|
|
// Opcodes for binary WebSocket protocol
|
|
const (
|
|
OpcodeQueueJob = 0x01
|
|
OpcodeStatusRequest = 0x02
|
|
OpcodeCancelJob = 0x03
|
|
OpcodePrune = 0x04
|
|
OpcodeDatasetList = 0x06
|
|
OpcodeDatasetRegister = 0x07
|
|
OpcodeDatasetInfo = 0x08
|
|
OpcodeDatasetSearch = 0x09
|
|
OpcodeLogMetric = 0x0A
|
|
OpcodeGetExperiment = 0x0B
|
|
OpcodeQueueJobWithTracking = 0x0C
|
|
OpcodeQueueJobWithSnapshot = 0x17
|
|
OpcodeStartJupyter = 0x0D
|
|
OpcodeStopJupyter = 0x0E
|
|
OpcodeRemoveJupyter = 0x18
|
|
OpcodeRestoreJupyter = 0x19
|
|
OpcodeListJupyter = 0x0F
|
|
OpcodeValidateRequest = 0x16
|
|
)
|
|
|
|
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
|
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
|
return websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
return true // Allow same-origin requests
|
|
}
|
|
|
|
// Production mode: strict checking against allowed origins
|
|
if securityCfg != nil && securityCfg.ProductionMode {
|
|
for _, allowed := range securityCfg.AllowedOrigins {
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false // Reject if not in allowed list
|
|
}
|
|
|
|
// Development mode: allow localhost and local network origins
|
|
parsedOrigin, err := url.Parse(origin)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
host := parsedOrigin.Host
|
|
return strings.HasSuffix(host, ":8080") ||
|
|
strings.HasPrefix(host, "localhost:") ||
|
|
strings.HasPrefix(host, "127.0.0.1:") ||
|
|
strings.HasPrefix(host, "192.168.") ||
|
|
strings.HasPrefix(host, "10.") ||
|
|
strings.HasPrefix(host, "172.")
|
|
},
|
|
// Performance optimizations
|
|
HandshakeTimeout: 10 * time.Second,
|
|
ReadBufferSize: 16 * 1024,
|
|
WriteBufferSize: 16 * 1024,
|
|
EnableCompression: true,
|
|
}
|
|
}
|
|
|
|
// WSHandler handles WebSocket connections for the API.
|
|
type WSHandler struct {
|
|
authConfig *auth.Config
|
|
logger *logging.Logger
|
|
expManager *experiment.Manager
|
|
dataDir string
|
|
queue queue.Backend
|
|
db *storage.DB
|
|
jupyterServiceMgr *jupyter.ServiceManager
|
|
securityConfig *config.SecurityConfig
|
|
auditLogger *audit.Logger
|
|
upgrader websocket.Upgrader
|
|
}
|
|
|
|
// NewWSHandler creates a new WebSocket handler.
|
|
func NewWSHandler(
|
|
authConfig *auth.Config,
|
|
logger *logging.Logger,
|
|
expManager *experiment.Manager,
|
|
dataDir string,
|
|
taskQueue queue.Backend,
|
|
db *storage.DB,
|
|
jupyterServiceMgr *jupyter.ServiceManager,
|
|
securityConfig *config.SecurityConfig,
|
|
auditLogger *audit.Logger,
|
|
) *WSHandler {
|
|
return &WSHandler{
|
|
authConfig: authConfig,
|
|
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
|
|
expManager: expManager,
|
|
dataDir: dataDir,
|
|
queue: taskQueue,
|
|
db: db,
|
|
jupyterServiceMgr: jupyterServiceMgr,
|
|
securityConfig: securityConfig,
|
|
auditLogger: auditLogger,
|
|
upgrader: createUpgrader(securityConfig),
|
|
}
|
|
}
|
|
|
|
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
|
|
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
|
|
if conn == nil {
|
|
return
|
|
}
|
|
|
|
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
|
|
if err := tcpConn.SetNoDelay(true); err != nil {
|
|
logger.Warn("failed to enable tcp no delay", "error", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// Add security headers
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
if r.TLS != nil {
|
|
// Only set HSTS if using HTTPS
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
|
}
|
|
|
|
// Check API key before upgrading WebSocket
|
|
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
|
clientIP := r.RemoteAddr
|
|
|
|
// Validate API key if authentication is enabled
|
|
if h.authConfig != nil && h.authConfig.Enabled {
|
|
prefixLen := len(apiKey)
|
|
if prefixLen > 8 {
|
|
prefixLen = 8
|
|
}
|
|
h.logger.Info(
|
|
"websocket auth attempt",
|
|
"api_key_length",
|
|
len(apiKey),
|
|
"api_key_prefix",
|
|
apiKey[:prefixLen],
|
|
)
|
|
|
|
userID, err := h.authConfig.ValidateAPIKey(apiKey)
|
|
if err != nil {
|
|
h.logger.Warn("websocket authentication failed", "error", err)
|
|
|
|
// Audit log failed authentication
|
|
if h.auditLogger != nil {
|
|
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
|
|
}
|
|
|
|
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
h.logger.Info("websocket authentication succeeded")
|
|
|
|
// Audit log successful authentication
|
|
if h.auditLogger != nil && userID != nil {
|
|
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
|
|
}
|
|
}
|
|
|
|
conn, err := h.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
h.logger.Error("websocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
conn.EnableWriteCompression(true)
|
|
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
|
|
h.logger.Warn("failed to set websocket compression level", "error", err)
|
|
}
|
|
enableLowLatencyTCP(conn, h.logger)
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
|
|
|
for {
|
|
messageType, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(
|
|
err,
|
|
websocket.CloseGoingAway,
|
|
websocket.CloseAbnormalClosure,
|
|
) {
|
|
h.logger.Error("websocket read error", "error", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if messageType != websocket.BinaryMessage {
|
|
h.logger.Warn("received non-binary message")
|
|
continue
|
|
}
|
|
|
|
if err := h.handleMessage(conn, message); err != nil {
|
|
h.logger.Error("message handling error", "error", err)
|
|
// Send structured error response so CLI clients can parse it.
|
|
// (Raw fallback bytes cause client-side InvalidPacket errors.)
|
|
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
|
if len(message) < 1 {
|
|
return fmt.Errorf("message too short")
|
|
}
|
|
|
|
opcode := message[0]
|
|
payload := message[1:]
|
|
|
|
switch opcode {
|
|
case OpcodeQueueJob:
|
|
return h.handleQueueJob(conn, payload)
|
|
case OpcodeQueueJobWithTracking:
|
|
return h.handleQueueJobWithTracking(conn, payload)
|
|
case OpcodeQueueJobWithSnapshot:
|
|
return h.handleQueueJobWithSnapshot(conn, payload)
|
|
case OpcodeStatusRequest:
|
|
return h.handleStatusRequest(conn, payload)
|
|
case OpcodeCancelJob:
|
|
return h.handleCancelJob(conn, payload)
|
|
case OpcodePrune:
|
|
return h.handlePrune(conn, payload)
|
|
case OpcodeDatasetList:
|
|
return h.handleDatasetList(conn, payload)
|
|
case OpcodeDatasetRegister:
|
|
return h.handleDatasetRegister(conn, payload)
|
|
case OpcodeDatasetInfo:
|
|
return h.handleDatasetInfo(conn, payload)
|
|
case OpcodeDatasetSearch:
|
|
return h.handleDatasetSearch(conn, payload)
|
|
case OpcodeLogMetric:
|
|
return h.handleLogMetric(conn, payload)
|
|
case OpcodeGetExperiment:
|
|
return h.handleGetExperiment(conn, payload)
|
|
case OpcodeStartJupyter:
|
|
return h.handleStartJupyter(conn, payload)
|
|
case OpcodeStopJupyter:
|
|
return h.handleStopJupyter(conn, payload)
|
|
case OpcodeRemoveJupyter:
|
|
return h.handleRemoveJupyter(conn, payload)
|
|
case OpcodeRestoreJupyter:
|
|
return h.handleRestoreJupyter(conn, payload)
|
|
case OpcodeListJupyter:
|
|
return h.handleListJupyter(conn, payload)
|
|
case OpcodeValidateRequest:
|
|
return h.handleValidateRequest(conn, payload)
|
|
default:
|
|
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
|
}
|
|
}
|