feat(api): refactor websocket handlers; add health and prometheus middleware

This commit is contained in:
Jeremie Fraeys 2026-01-05 12:31:07 -05:00
parent 6ff5324e74
commit add4a90e62
28 changed files with 4179 additions and 961 deletions

View file

@ -5,7 +5,7 @@ WebSocket API server for the ML CLI tool...
## Usage
```bash
./bin/api-server --config configs/config-dev.yaml --listen :9100
./bin/api-server --config configs/api/dev.yaml
```
## Endpoints

View file

@ -9,7 +9,7 @@ import (
)
func main() {
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
configFile := flag.String("config", "configs/api/dev.yaml", "Configuration file path")
apiKey := flag.String("api-key", "", "API key for authentication")
flag.Parse()

15
go.mod
View file

@ -3,8 +3,8 @@ module github.com/jfraeys/fetch_ml
go 1.25.0
// Fetch ML - Secure Machine Learning Platform
// Copyright (c) 2024 Fetch ML
// Licensed under the MIT License
// Copyright (c) 2026 Fetch ML
// Licensed under the FetchML Source-Available Research & Audit License (SARAL). See LICENSE.
require (
github.com/BurntSushi/toml v1.5.0
@ -17,6 +17,7 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.32
github.com/minio/minio-go/v7 v7.0.97
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.11.1
@ -43,24 +44,34 @@ require (
github.com/danieljoos/wincred v1.2.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/godbus/dbus/v5 v5.2.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.11 // indirect
github.com/klauspost/crc32 v1.3.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/crc64nvme v1.1.0 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.4 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/tinylib/msgp v1.3.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect

21
go.sum
View file

@ -48,10 +48,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
@ -66,6 +70,11 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU=
github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM=
github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@ -84,6 +93,12 @@ github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byF
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q=
github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ=
github.com/minio/minio-go/v7 v7.0.97/go.mod h1:re5VXuo0pwEtoNLsNuSr0RrLfT/MBtohwdaSmPPSRSk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@ -98,6 +113,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
@ -114,6 +131,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -122,6 +141,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww=
github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
@ -34,7 +35,6 @@ func NewHandlers(
// RegisterHandlers registers all HTTP handlers with the mux
func (h *Handlers) RegisterHandlers(mux *http.ServeMux) {
// Health check endpoints
mux.HandleFunc("/health", h.handleHealth)
mux.HandleFunc("/db-status", h.handleDBStatus)
// Jupyter service endpoints
@ -57,7 +57,7 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
// This would need the DB instance passed to handlers
// For now, return a basic response
response := map[string]interface{}{
response := map[string]any{
"status": "unknown",
"message": "Database status check not implemented",
}
@ -72,13 +72,30 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
// handleJupyterServices handles Jupyter service management requests
func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
switch r.Method {
case http.MethodGet:
if !user.HasPermission("jupyter:read") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.listJupyterServices(w, r)
case http.MethodPost:
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.startJupyterService(w, r)
case http.MethodDelete:
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
h.stopJupyterService(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -140,7 +157,10 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(map[string]string{"status": "stopped", "id": serviceID}); err != nil {
if err := json.NewEncoder(w).Encode(map[string]string{
"status": "stopped",
"id": serviceID,
}); err != nil {
h.logger.Error("failed to encode response", "error", err)
}
}
@ -148,6 +168,15 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
// handleJupyterExperimentLink handles linking Jupyter workspaces with experiments
func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -176,7 +205,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
}
// Link workspace with experiment using service manager
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(req.Workspace, req.ExperimentID, req.ServiceID); err != nil {
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(
req.Workspace,
req.ExperimentID,
req.ServiceID,
); err != nil {
http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError)
return
}
@ -184,7 +217,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
// Get workspace metadata to return
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError)
http.Error(
w,
fmt.Sprintf("Failed to get workspace metadata: %v", err),
http.StatusInternalServerError,
)
return
}
@ -205,6 +242,15 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
// handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments
func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
user := auth.GetUserFromContext(r.Context())
if user == nil {
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
return
}
if !user.HasPermission("jupyter:manage") {
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -245,7 +291,11 @@ func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Re
// Get updated metadata
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError)
http.Error(
w,
fmt.Sprintf("Failed to get workspace metadata: %v", err),
http.StatusInternalServerError,
)
return
}

82
internal/api/health.go Normal file
View file

@ -0,0 +1,82 @@
package api
import (
"encoding/json"
"net/http"
"time"
)
// HealthStatus represents the health status of the service
type HealthStatus struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Checks map[string]string `json:"checks,omitempty"`
}
// HealthHandler handles /health requests
type HealthHandler struct {
server *Server
}
// NewHealthHandler creates a new health check handler
func NewHealthHandler(s *Server) *HealthHandler {
return &HealthHandler{server: s}
}
// Health performs a basic health check
func (h *HealthHandler) Health(w http.ResponseWriter, r *http.Request) {
status := HealthStatus{
Status: "healthy",
Timestamp: time.Now().UTC(),
Checks: make(map[string]string),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(status)
}
// Liveness performs a liveness probe (is the service running?)
func (h *HealthHandler) Liveness(w http.ResponseWriter, r *http.Request) {
// Simple liveness check - if we can respond, we're alive
status := HealthStatus{
Status: "alive",
Timestamp: time.Now().UTC(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(status)
}
// Readiness performs a readiness probe (is the service ready to accept traffic?)
func (h *HealthHandler) Readiness(w http.ResponseWriter, r *http.Request) {
status := HealthStatus{
Status: "ready",
Timestamp: time.Now().UTC(),
Checks: make(map[string]string),
}
// Check Redis connection (if queue is configured)
if h.server.taskQueue != nil {
// Simple check - if queue exists, assume it's ready
status.Checks["queue"] = "ok"
}
// Check experiment manager
if h.server.expManager != nil {
status.Checks["experiments"] = "ok"
}
// If all checks pass, we're ready
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(status)
}
// RegisterHealthRoutes registers health check routes
func (h *HealthHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/health", h.Health)
mux.HandleFunc("/health/live", h.Liveness)
mux.HandleFunc("/health/ready", h.Readiness)
}

View file

@ -0,0 +1,63 @@
package api
import (
"bufio"
"fmt"
"net"
"net/http"
"time"
)
// wrapWithMetrics wraps a handler with Prometheus metrics tracking
func (s *Server) wrapWithMetrics(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.promMetrics == nil {
next.ServeHTTP(w, r)
return
}
start := time.Now()
// Track HTTP request
method := r.Method
endpoint := r.URL.Path
// Wrap response writer to capture status code
ww := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Serve request
next.ServeHTTP(ww, r)
// Record metrics
duration := time.Since(start)
statusStr := http.StatusText(ww.statusCode)
s.promMetrics.IncHTTPRequests(method, endpoint, statusStr)
s.promMetrics.ObserveHTTPDuration(method, endpoint, duration)
})
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Flush() {
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("websocket: response does not implement http.Hijacker")
}
return h.Hijack()
}

View file

@ -0,0 +1,20 @@
package api
// MonitoringConfig holds monitoring-related configuration
type MonitoringConfig struct {
Prometheus PrometheusConfig `yaml:"prometheus"`
HealthChecks HealthChecksConfig `yaml:"health_checks"`
}
// PrometheusConfig holds Prometheus metrics configuration
type PrometheusConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
}
// HealthChecksConfig holds health check configuration
type HealthChecksConfig struct {
Enabled bool `yaml:"enabled"`
Interval string `yaml:"interval"`
}

View file

@ -4,9 +4,17 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"sync"
"time"
)
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 0, 256)
return &buf
},
}
// Response packet types
const (
PacketTypeSuccess = 0x00
@ -126,7 +134,12 @@ func NewErrorPacket(errorCode byte, message string, details string) *ResponsePac
}
// NewProgressPacket creates a progress response packet
func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket {
func NewProgressPacket(
progressType byte,
value uint32,
total uint32,
message string,
) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeProgress,
Timestamp: uint64(time.Now().Unix()),
@ -168,48 +181,65 @@ func NewLogPacket(level byte, message string) *ResponsePacket {
// Serialize converts the packet to binary format
func (p *ResponsePacket) Serialize() ([]byte, error) {
var buf []byte
// For small packets, avoid pool overhead
if p.estimatedSize() <= 1024 {
buf := make([]byte, 0, p.estimatedSize())
return serializePacketToBuffer(p, buf)
}
// Use pool for larger packets
bufPtr := bufferPool.Get().(*[]byte)
defer func() {
*bufPtr = (*bufPtr)[:0]
bufferPool.Put(bufPtr)
}()
buf := *bufPtr
// Ensure buffer has enough capacity
if cap(buf) < p.estimatedSize() {
buf = make([]byte, 0, p.estimatedSize())
} else {
buf = buf[:0]
}
return serializePacketToBuffer(p, buf)
}
func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) {
// Packet type
buf = append(buf, p.PacketType)
// Timestamp (8 bytes, big-endian)
timestampBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timestampBytes, p.Timestamp)
buf = append(buf, timestampBytes...)
var timestampBytes [8]byte
binary.BigEndian.PutUint64(timestampBytes[:], p.Timestamp)
buf = append(buf, timestampBytes[:]...)
// Packet-specific data
switch p.PacketType {
case PacketTypeSuccess:
buf = append(buf, serializeString(p.SuccessMessage)...)
buf = appendString(buf, p.SuccessMessage)
case PacketTypeError:
buf = append(buf, p.ErrorCode)
buf = append(buf, serializeString(p.ErrorMessage)...)
buf = append(buf, serializeString(p.ErrorDetails)...)
buf = appendString(buf, p.ErrorMessage)
buf = appendString(buf, p.ErrorDetails)
case PacketTypeProgress:
buf = append(buf, p.ProgressType)
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, p.ProgressValue)
buf = append(buf, valueBytes...)
totalBytes := make([]byte, 4)
binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal)
buf = append(buf, totalBytes...)
buf = append(buf, serializeString(p.ProgressMessage)...)
buf = appendUint32(buf, p.ProgressValue)
buf = appendUint32(buf, p.ProgressTotal)
buf = appendString(buf, p.ProgressMessage)
case PacketTypeStatus:
buf = append(buf, serializeString(p.StatusData)...)
buf = appendString(buf, p.StatusData)
case PacketTypeData:
buf = append(buf, serializeString(p.DataType)...)
buf = append(buf, serializeBytes(p.DataPayload)...)
buf = appendString(buf, p.DataType)
buf = appendBytes(buf, p.DataPayload)
case PacketTypeLog:
buf = append(buf, p.LogLevel)
buf = append(buf, serializeString(p.LogMessage)...)
buf = appendString(buf, p.LogMessage)
default:
return nil, fmt.Errorf("unknown packet type: %d", p.PacketType)
@ -218,22 +248,49 @@ func (p *ResponsePacket) Serialize() ([]byte, error) {
return buf, nil
}
// serializeString writes a string with 2-byte length prefix
func serializeString(s string) []byte {
length := uint16(len(s))
buf := make([]byte, 2+len(s))
binary.BigEndian.PutUint16(buf[:2], length)
copy(buf[2:], s)
return buf
// appendString writes a string with varint length prefix
func appendString(buf []byte, s string) []byte {
length := uint64(len(s))
var tmp [binary.MaxVarintLen64]byte
n := binary.PutUvarint(tmp[:], length)
buf = append(buf, tmp[:n]...)
return append(buf, s...)
}
// serializeBytes writes bytes with 4-byte length prefix
func serializeBytes(b []byte) []byte {
length := uint32(len(b))
buf := make([]byte, 4+len(b))
binary.BigEndian.PutUint32(buf[:4], length)
copy(buf[4:], b)
return buf
// appendBytes writes bytes with varint length prefix
func appendBytes(buf []byte, b []byte) []byte {
length := uint64(len(b))
var tmp [binary.MaxVarintLen64]byte
n := binary.PutUvarint(tmp[:], length)
buf = append(buf, tmp[:n]...)
return append(buf, b...)
}
func appendUint32(buf []byte, value uint32) []byte {
var tmp [4]byte
binary.BigEndian.PutUint32(tmp[:], value)
return append(buf, tmp[:]...)
}
func (p *ResponsePacket) estimatedSize() int {
base := 1 + 8 // packet type + timestamp
switch p.PacketType {
case PacketTypeSuccess:
return base + binary.MaxVarintLen64 + len(p.SuccessMessage)
case PacketTypeError:
return base + 1 + 2*binary.MaxVarintLen64 + len(p.ErrorMessage) + len(p.ErrorDetails)
case PacketTypeProgress:
return base + 1 + 4 + 4 + binary.MaxVarintLen64 + len(p.ProgressMessage)
case PacketTypeStatus:
return base + binary.MaxVarintLen64 + len(p.StatusData)
case PacketTypeData:
return base + binary.MaxVarintLen64 + len(p.DataType) + binary.MaxVarintLen64 + len(p.DataPayload)
case PacketTypeLog:
return base + 1 + binary.MaxVarintLen64 + len(p.LogMessage)
default:
return base
}
}
// GetErrorMessage returns a human-readable error message for an error code

View file

@ -1,155 +0,0 @@
package api
import (
"encoding/json"
"time"
)
// Simplified protocol using JSON instead of binary serialization
// Response represents a simplified API response
type Response struct {
Type string `json:"type"`
Timestamp int64 `json:"timestamp"`
Data interface{} `json:"data,omitempty"`
Error *ErrorInfo `json:"error,omitempty"`
}
// ErrorInfo represents error information
type ErrorInfo struct {
Code int `json:"code"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
}
// ProgressInfo represents progress information
type ProgressInfo struct {
Type string `json:"type"`
Value uint32 `json:"value"`
Total uint32 `json:"total"`
Message string `json:"message"`
}
// LogInfo represents log information
type LogInfo struct {
Level string `json:"level"`
Message string `json:"message"`
}
// Response types
const (
TypeSuccess = "success"
TypeError = "error"
TypeProgress = "progress"
TypeStatus = "status"
TypeData = "data"
TypeLog = "log"
)
// Error codes
const (
ErrUnknown = 0
ErrInvalidRequest = 1
ErrAuthFailed = 2
ErrPermissionDenied = 3
ErrNotFound = 4
ErrExists = 5
ErrServerOverload = 16
ErrDatabase = 17
ErrNetwork = 18
ErrStorage = 19
ErrTimeout = 20
)
// NewSuccessResponse creates a success response
func NewSuccessResponse(message string) *Response {
return &Response{
Type: TypeSuccess,
Timestamp: time.Now().Unix(),
Data: message,
}
}
// NewSuccessResponseWithData creates a success response with data
func NewSuccessResponseWithData(message string, data interface{}) *Response {
return &Response{
Type: TypeData,
Timestamp: time.Now().Unix(),
Data: map[string]interface{}{
"message": message,
"payload": data,
},
}
}
// NewErrorResponse creates an error response
func NewErrorResponse(code int, message, details string) *Response {
return &Response{
Type: TypeError,
Timestamp: time.Now().Unix(),
Error: &ErrorInfo{
Code: code,
Message: message,
Details: details,
},
}
}
// NewProgressResponse creates a progress response
func NewProgressResponse(progressType string, value, total uint32, message string) *Response {
return &Response{
Type: TypeProgress,
Timestamp: time.Now().Unix(),
Data: ProgressInfo{
Type: progressType,
Value: value,
Total: total,
Message: message,
},
}
}
// NewStatusResponse creates a status response
func NewStatusResponse(data string) *Response {
return &Response{
Type: TypeStatus,
Timestamp: time.Now().Unix(),
Data: data,
}
}
// NewDataResponse creates a data response
func NewDataResponse(dataType string, payload interface{}) *Response {
return &Response{
Type: TypeData,
Timestamp: time.Now().Unix(),
Data: map[string]interface{}{
"type": dataType,
"payload": payload,
},
}
}
// NewLogResponse creates a log response
func NewLogResponse(level, message string) *Response {
return &Response{
Type: TypeLog,
Timestamp: time.Now().Unix(),
Data: LogInfo{
Level: level,
Message: message,
},
}
}
// ToJSON converts the response to JSON bytes
func (r *Response) ToJSON() ([]byte, error) {
return json.Marshal(r)
}
// FromJSON creates a response from JSON bytes
func FromJSON(data []byte) (*Response, error) {
var response Response
err := json.Unmarshal(data, &response)
return &response, err
}

View file

@ -5,13 +5,17 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/middleware"
"github.com/jfraeys/fetch_ml/internal/prommetrics"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
@ -22,12 +26,14 @@ type Server struct {
httpServer *http.Server
logger *logging.Logger
expManager *experiment.Manager
taskQueue *queue.TaskQueue
taskQueue queue.Backend
db *storage.DB
handlers *Handlers
sec *middleware.SecurityMiddleware
cleanupFuncs []func()
jupyterServiceMgr *jupyter.ServiceManager
auditLogger *audit.Logger
promMetrics *prommetrics.Metrics // Prometheus metrics
}
// NewServer creates a new API server
@ -80,6 +86,11 @@ func (s *Server) initializeComponents() error {
return err
}
// Initialize database schema (if DB enabled)
if err := s.initDatabaseSchema(); err != nil {
return err
}
// Initialize security
s.initSecurity()
@ -87,11 +98,29 @@ func (s *Server) initializeComponents() error {
s.initJupyterServiceManager()
// Initialize handlers
s.handlers = NewHandlers(s.expManager, s.jupyterServiceMgr, s.logger)
s.handlers = NewHandlers(s.expManager, nil, s.logger)
return nil
}
func (s *Server) initDatabaseSchema() error {
if s.db == nil {
return nil
}
schema, err := storage.SchemaForDBType(s.config.Database.Type)
if err != nil {
return err
}
if err := s.db.Initialize(schema); err != nil {
return err
}
s.logger.Info("database schema initialized", "type", s.config.Database.Type)
return nil
}
// setupLogger creates and configures the logger
func (s *Server) setupLogger() *logging.Logger {
logger := logging.NewLoggerFromConfig(s.config.Logging)
@ -112,26 +141,38 @@ func (s *Server) initExperimentManager() error {
// initTaskQueue initializes the task queue
func (s *Server) initTaskQueue() error {
queueCfg := queue.Config{
RedisAddr: s.config.Redis.Addr,
RedisPassword: s.config.Redis.Password,
RedisDB: s.config.Redis.DB,
backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend))
if backend == "" {
backend = "redis"
}
redisAddr := strings.TrimSpace(s.config.Redis.Addr)
if redisAddr == "" {
redisAddr = "localhost:6379"
}
if strings.TrimSpace(s.config.Redis.URL) != "" {
redisAddr = strings.TrimSpace(s.config.Redis.URL)
}
if queueCfg.RedisAddr == "" {
queueCfg.RedisAddr = "localhost:6379"
}
if s.config.Redis.URL != "" {
queueCfg.RedisAddr = s.config.Redis.URL
backendCfg := queue.BackendConfig{
Backend: queue.QueueBackend(backend),
RedisAddr: redisAddr,
RedisPassword: s.config.Redis.Password,
RedisDB: s.config.Redis.DB,
SQLitePath: s.config.Queue.SQLitePath,
MetricsFlushInterval: 0,
}
taskQueue, err := queue.NewTaskQueue(queueCfg)
taskQueue, err := queue.NewBackend(backendCfg)
if err != nil {
return err
}
s.taskQueue = taskQueue
s.logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
if backend == "sqlite" {
s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath)
} else {
s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr)
}
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
@ -220,8 +261,67 @@ func (s *Server) initJupyterServiceManager() {
func (s *Server) setupHTTPServer() {
mux := http.NewServeMux()
// Register WebSocket handler
wsHandler := NewWSHandler(s.config.BuildAuthConfig(), s.logger, s.expManager, s.taskQueue)
// Initialize Prometheus metrics (if enabled)
if s.config.Monitoring.Prometheus.Enabled {
s.promMetrics = prommetrics.New()
s.logger.Info("prometheus metrics initialized")
// Register metrics endpoint
metricsPath := s.config.Monitoring.Prometheus.Path
if metricsPath == "" {
metricsPath = "/metrics"
}
mux.Handle(metricsPath, s.promMetrics.Handler())
s.logger.Info("metrics endpoint registered", "path", metricsPath)
}
// Initialize health check handler
if s.config.Monitoring.HealthChecks.Enabled {
healthHandler := NewHealthHandler(s)
healthHandler.RegisterRoutes(mux)
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
s.logger.Info("health check endpoints registered")
}
// Initialize audit logger
var auditLogger *audit.Logger
if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" {
al, err := audit.NewLogger(
s.config.Security.AuditLogging.Enabled,
s.config.Security.AuditLogging.LogPath,
s.logger,
)
if err != nil {
s.logger.Warn("failed to initialize audit logger", "error", err)
} else {
auditLogger = al
s.auditLogger = al
// Add cleanup function
s.cleanupFuncs = append(s.cleanupFuncs, func() {
s.logger.Info("closing audit logger...")
if err := auditLogger.Close(); err != nil {
s.logger.Error("failed to close audit logger", "error", err)
}
})
}
}
// Register WebSocket handler with security config and audit logger
securityCfg := getSecurityConfig(s.config)
wsHandler := NewWSHandler(
s.config.BuildAuthConfig(),
s.logger,
s.expManager,
s.config.DataDir,
s.taskQueue,
s.db,
s.jupyterServiceMgr,
securityCfg,
auditLogger,
)
// Wrap WebSocket handler with metrics
mux.Handle("/ws", wsHandler)
// Register HTTP handlers
@ -229,6 +329,7 @@ func (s *Server) setupHTTPServer() {
// Wrap with middleware
finalHandler := s.wrapWithMiddleware(mux)
finalHandler = s.wrapWithMetrics(finalHandler)
s.httpServer = &http.Server{
Addr: s.config.Server.Address,
@ -239,10 +340,25 @@ func (s *Server) setupHTTPServer() {
}
}
// getSecurityConfig extracts security config from server config
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
return &config.SecurityConfig{
AllowedOrigins: cfg.Security.AllowedOrigins,
ProductionMode: cfg.Security.ProductionMode,
APIKeyRotationDays: cfg.Security.APIKeyRotationDays,
AuditLogging: config.AuditLoggingConfig{
Enabled: cfg.Security.AuditLogging.Enabled,
LogPath: cfg.Security.AuditLogging.LogPath,
},
IPWhitelist: cfg.Security.IPWhitelist,
}
}
// wrapWithMiddleware wraps the handler with security middleware
func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ws" {
// Skip auth for WebSocket and health endpoints
if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") {
mux.ServeHTTP(w, r)
return
}
@ -250,7 +366,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
handler := s.sec.APIKeyAuth(mux)
handler = s.sec.RateLimit(handler)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(handler)
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
handler = middleware.RequestTimeout(30 * time.Second)(handler)
handler = middleware.AuditLogger(handler)
if len(s.config.Security.IPWhitelist) > 0 {

View file

@ -1,26 +1,37 @@
package api
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/jfraeys/fetch_ml/internal/logging"
"gopkg.in/yaml.v3"
)
type QueueConfig struct {
Backend string `yaml:"backend"`
SQLitePath string `yaml:"sqlite_path"`
}
// ServerConfig holds all server configuration
type ServerConfig struct {
BasePath string `yaml:"base_path"`
Auth auth.Config `yaml:"auth"`
Server ServerSection `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
Resources config.ResourceConfig `yaml:"resources"`
BasePath string `yaml:"base_path"`
DataDir string `yaml:"data_dir"`
Auth auth.Config `yaml:"auth"`
Server ServerSection `yaml:"server"`
Security SecurityConfig `yaml:"security"`
Monitoring MonitoringConfig `yaml:"monitoring"`
Queue QueueConfig `yaml:"queue"`
Redis RedisConfig `yaml:"redis"`
Database DatabaseConfig `yaml:"database"`
Logging logging.Config `yaml:"logging"`
Resources config.ResourceConfig `yaml:"resources"`
}
// ServerSection holds server-specific configuration
@ -38,9 +49,19 @@ type TLSConfig struct {
// SecurityConfig holds security-related configuration
type SecurityConfig struct {
RateLimit RateLimitConfig `yaml:"rate_limit"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
ProductionMode bool `yaml:"production_mode"`
AllowedOrigins []string `yaml:"allowed_origins"`
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
AuditLogging AuditLog `yaml:"audit_logging"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
IPWhitelist []string `yaml:"ip_whitelist"`
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
}
// AuditLog holds audit logging configuration
type AuditLog struct {
Enabled bool `yaml:"enabled"`
LogPath string `yaml:"log_path"`
}
// RateLimitConfig holds rate limiting configuration
@ -108,9 +129,7 @@ func loadConfigFromFile(path string) (*ServerConfig, error) {
// secureFileRead safely reads a file
func secureFileRead(path string) ([]byte, error) {
// This would use the fileutil.SecureFileRead function
// For now, implement basic file reading
return os.ReadFile(path)
return fileutil.SecureFileRead(path)
}
// EnsureLogDirectory creates the log directory if needed
@ -144,6 +163,27 @@ func (c *ServerConfig) Validate() error {
if c.BasePath == "" {
c.BasePath = "/tmp/ml-experiments"
}
if c.DataDir == "" {
c.DataDir = config.DefaultDataDir
}
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
if backend == "" {
backend = "redis"
c.Queue.Backend = backend
}
if backend != "redis" && backend != "sqlite" {
return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'")
}
if backend == "sqlite" {
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
c.Queue.SQLitePath = filepath.Join(c.DataDir, "queue.db")
}
c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath)
if !filepath.IsAbs(c.Queue.SQLitePath) {
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
}
}
return nil
}

View file

@ -1,652 +0,0 @@
package api
import (
"crypto/tls"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/telemetry"
"golang.org/x/crypto/acme/autocert"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// Allow localhost and homelab origins for development
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Parse origin URL
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
// Allow localhost and local network origins
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasPrefix(host, "localhost:") ||
strings.HasPrefix(host, "127.0.0.1:") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
// Performance optimizations
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
EnableCompression: true,
}
// WSHandler handles WebSocket connections for the API.
type WSHandler struct {
authConfig *auth.Config
logger *logging.Logger
expManager *experiment.Manager
queue *queue.TaskQueue
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
taskQueue *queue.TaskQueue,
) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger,
expManager: expManager,
queue: taskQueue,
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check API key before upgrading WebSocket
apiKey := auth.ExtractAPIKeyFromRequest(r)
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
prefixLen := len(apiKey)
if prefixLen > 8 {
prefixLen = 8
}
h.logger.Info("websocket auth attempt", "api_key_length", len(apiKey), "api_key_prefix", apiKey[:prefixLen])
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
h.logger.Info("websocket authentication succeeded")
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
defer func() {
_ = conn.Close()
}()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send error response
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
if len(payload) < 130 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
priority := int64(payload[128])
jobNameLen := int(payload[129])
if len(payload) < 130+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[130 : 130+jobNameLen])
h.logger.Info("queue job request",
"job", jobName,
"priority", priority,
"commit_id", commitID,
)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
} else {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
}
// Create experiment directory and metadata (optimized)
if _, err := telemetry.ExecWithMetrics(h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.CreateExperiment(commitID)
}); err != nil {
h.logger.Error("failed to create experiment directory", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
}
// Add user info to experiment metadata (deferred for performance)
go func() {
meta := &experiment.Metadata{
CommitID: commitID,
JobName: jobName,
User: user.Name,
Timestamp: time.Now().Unix(),
}
if _, err := telemetry.ExecWithMetrics(
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
return "", h.expManager.WriteMetadata(meta)
}); err != nil {
h.logger.Error("failed to save experiment metadata", "error", err)
}
}()
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
// Enqueue task if queue is available
if h.queue != nil {
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: priority,
CreatedAt: time.Now(),
UserID: user.Name,
Username: user.Name,
CreatedBy: user.Name,
Metadata: map[string]string{
"commit_id": commitID, // Reduced redundant metadata
},
}
if _, err := telemetry.ExecWithMetrics(h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) {
return "", h.queue.AddTask(task)
}); err != nil {
h.logger.Error("failed to enqueue task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
}
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
}
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64]
if len(payload) < 64 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
}
apiKeyHash := string(payload[0:64])
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for viewing jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
}
// Get tasks with user filtering
var tasks []*queue.Task
if h.queue != nil {
allTasks, err := h.queue.GetAllTasks()
if err != nil {
h.logger.Error("failed to get tasks", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
}
// Filter tasks based on user permissions
for _, task := range allTasks {
// If auth is disabled or admin can see all tasks
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
tasks = append(tasks, task)
continue
}
// Users can only see their own tasks
if task.UserID == user.Name || task.CreatedBy == user.Name {
tasks = append(tasks, task)
}
}
}
// Build status response as raw JSON for CLI compatibility
h.logger.Info("building status response")
status := map[string]interface{}{
"user": map[string]interface{}{
"name": user.Name,
"admin": user.Admin,
"roles": user.Roles,
},
"tasks": map[string]interface{}{
"total": len(tasks),
"queued": countTasksByStatus(tasks, "queued"),
"running": countTasksByStatus(tasks, "running"),
"failed": countTasksByStatus(tasks, "failed"),
"completed": countTasksByStatus(tasks, "completed"),
},
"queue": tasks,
}
h.logger.Info("serializing JSON response")
jsonData, err := json.Marshal(status)
if err != nil {
h.logger.Error("failed to marshal JSON", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
return conn.WriteMessage(websocket.BinaryMessage, jsonData)
}
// countTasksByStatus counts tasks by their status
func countTasksByStatus(tasks []*queue.Task, status string) int {
count := 0
for _, task := range tasks {
if task.Status == status {
count++
}
}
return count
}
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
if len(payload) < 65 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
jobNameLen := int(payload[64])
if len(payload) < 65+jobNameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
}
jobName := string(payload[65 : 65+jobNameLen])
h.logger.Info("cancel job request", "job", jobName)
// Validate API key and get user information
var user *auth.User
var err error
if h.authConfig != nil {
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
if err != nil {
h.logger.Error("invalid api key", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
}
} else {
// Auth disabled - use default admin user
user = &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}
}
// Check user permissions for canceling jobs
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
}
// Find the task and verify ownership
if h.queue != nil {
task, err := h.queue.GetTaskByName(jobName)
if err != nil {
h.logger.Error("task not found", "job", jobName, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
}
// Check if user can cancel this task (admin or owner)
if h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name {
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
}
// Cancel the task
if err := h.queue.CancelTask(task.ID); err != nil {
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
}
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
} else {
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
}
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
packetData, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize packet", "error", err)
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
}
return conn.WriteMessage(websocket.BinaryMessage, packetData)
}
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][prune_type:1][value:4]
if len(payload) < 69 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
}
// Parse 64-byte hex API key hash
apiKeyHash := string(payload[0:64])
pruneType := payload[64]
value := binary.BigEndian.Uint32(payload[65:69])
h.logger.Info("prune request", "type", pruneType, "value", value)
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
h.logger.Error("api key verification failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
// Convert prune parameters
var keepCount int
var olderThanDays int
switch pruneType {
case 0:
// keep N
keepCount = int(value)
olderThanDays = 0
case 1:
// older than days
keepCount = 0
olderThanDays = int(value)
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
}
// Perform pruning
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
if err != nil {
h.logger.Error("prune failed", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
}
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
// Send structured success response
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
return h.sendResponsePacket(conn, packet)
}
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
if len(payload) < 141 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
step := int(binary.BigEndian.Uint32(payload[128:132]))
valueBits := binary.BigEndian.Uint64(payload[132:140])
value := math.Float64frombits(valueBits)
nameLen := int(payload[140])
if len(payload) < 141+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
}
name := string(payload[141 : 141+nameLen])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
h.logger.Error("failed to log metric", "error", err)
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
}
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:64][commit_id:64]
if len(payload) < 128 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
}
apiKeyHash := string(payload[:64])
commitID := string(payload[64:128])
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
}
}
meta, err := h.expManager.ReadMetadata(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
}
metrics, err := h.expManager.GetMetrics(commitID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
}
response := map[string]interface{}{
"metadata": meta,
"metrics": metrics,
}
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
}
// SetupTLSConfig creates TLS configuration for WebSocket server
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
var server *http.Server
if certFile != "" && keyFile != "" {
// Use provided certificates
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
} else if host != "" {
// Use Let's Encrypt with autocert
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(host),
Cache: autocert.DirCache("/var/www/.cache"),
}
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: certManager.TLSConfig(),
}
}
return server, nil
}
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
if h.authConfig == nil || !h.authConfig.Enabled {
return nil // No auth required
}
// For now, just check if it's a valid 64-char hex string
if len(hexHash) != 64 {
return fmt.Errorf("invalid api key hash length")
}
// Check against stored API keys
for username, entry := range h.authConfig.APIKeys {
if string(entry.Hash) == hexHash {
_ = username // Username found but not needed for verification
return nil // Valid API key found
}
}
return fmt.Errorf("invalid api key")
}
// sendErrorPacket sends an error response packet
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
packet := NewErrorPacket(errorCode, message, details)
return h.sendResponsePacket(conn, packet)
}
// sendResponsePacket sends a structured response packet
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
data, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize response packet", "error", err)
// Fallback to simple error response
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
}
return conn.WriteMessage(websocket.BinaryMessage, data)
}
// sendErrorResponse removed (unused)

208
internal/api/ws_datasets.go Normal file
View file

@ -0,0 +1,208 @@
package api
import (
"context"
"database/sql"
"encoding/binary"
"encoding/json"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/storage"
)
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
datasets, err := h.db.ListDatasets(ctx, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
if len(payload) < 16+1+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if urlLen <= 0 || len(payload) < offset+urlLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
}
urlStr := string(payload[offset : offset+urlLen])
// Minimal validation (server-side authoritative): name non-empty and url parseable.
if strings.TrimSpace(name) == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
}
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
}
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
}
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if nameLen <= 0 || len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
}
name := string(payload[offset : offset+nameLen])
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
ds, err := h.db.GetDataset(ctx, name)
if err != nil {
if err == sql.ErrNoRows {
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
}
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
}
data, err := json.Marshal(ds)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
}
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][term_len:1][term:var]
if len(payload) < 16+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
if h.db == nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
}
offset := 16
termLen := int(payload[offset])
offset++
if termLen < 0 || len(payload) < offset+termLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
}
term := string(payload[offset : offset+termLen])
term = strings.TrimSpace(term)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
datasets, err := h.db.SearchDatasets(ctx, term, 0)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
}
data, err := json.Marshal(datasets)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeServerOverloaded,
"Failed to serialize response",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
}

279
internal/api/ws_handler.go Normal file
View file

@ -0,0 +1,279 @@
package api
import (
"compress/flate"
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/jupyter"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
)
// Opcodes for binary WebSocket protocol
const (
OpcodeQueueJob = 0x01
OpcodeStatusRequest = 0x02
OpcodeCancelJob = 0x03
OpcodePrune = 0x04
OpcodeDatasetList = 0x06
OpcodeDatasetRegister = 0x07
OpcodeDatasetInfo = 0x08
OpcodeDatasetSearch = 0x09
OpcodeLogMetric = 0x0A
OpcodeGetExperiment = 0x0B
OpcodeQueueJobWithTracking = 0x0C
OpcodeQueueJobWithSnapshot = 0x17
OpcodeStartJupyter = 0x0D
OpcodeStopJupyter = 0x0E
OpcodeRemoveJupyter = 0x18
OpcodeRestoreJupyter = 0x19
OpcodeListJupyter = 0x0F
OpcodeValidateRequest = 0x16
)
// createUpgrader creates a WebSocket upgrader with the given security configuration
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
return websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true // Allow same-origin requests
}
// Production mode: strict checking against allowed origins
if securityCfg != nil && securityCfg.ProductionMode {
for _, allowed := range securityCfg.AllowedOrigins {
if origin == allowed {
return true
}
}
return false // Reject if not in allowed list
}
// Development mode: allow localhost and local network origins
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
host := parsedOrigin.Host
return strings.HasSuffix(host, ":8080") ||
strings.HasPrefix(host, "localhost:") ||
strings.HasPrefix(host, "127.0.0.1:") ||
strings.HasPrefix(host, "192.168.") ||
strings.HasPrefix(host, "10.") ||
strings.HasPrefix(host, "172.")
},
// Performance optimizations
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
EnableCompression: true,
}
}
// WSHandler handles WebSocket connections for the API.
type WSHandler struct {
authConfig *auth.Config
logger *logging.Logger
expManager *experiment.Manager
dataDir string
queue queue.Backend
db *storage.DB
jupyterServiceMgr *jupyter.ServiceManager
securityConfig *config.SecurityConfig
auditLogger *audit.Logger
upgrader websocket.Upgrader
}
// NewWSHandler creates a new WebSocket handler.
func NewWSHandler(
authConfig *auth.Config,
logger *logging.Logger,
expManager *experiment.Manager,
dataDir string,
taskQueue queue.Backend,
db *storage.DB,
jupyterServiceMgr *jupyter.ServiceManager,
securityConfig *config.SecurityConfig,
auditLogger *audit.Logger,
) *WSHandler {
return &WSHandler{
authConfig: authConfig,
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
expManager: expManager,
dataDir: dataDir,
queue: taskQueue,
db: db,
jupyterServiceMgr: jupyterServiceMgr,
securityConfig: securityConfig,
auditLogger: auditLogger,
upgrader: createUpgrader(securityConfig),
}
}
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
if conn == nil {
return
}
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
if err := tcpConn.SetNoDelay(true); err != nil {
logger.Warn("failed to enable tcp no delay", "error", err)
}
}
}
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
if r.TLS != nil {
// Only set HSTS if using HTTPS
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
// Check API key before upgrading WebSocket
apiKey := auth.ExtractAPIKeyFromRequest(r)
clientIP := r.RemoteAddr
// Validate API key if authentication is enabled
if h.authConfig != nil && h.authConfig.Enabled {
prefixLen := len(apiKey)
if prefixLen > 8 {
prefixLen = 8
}
h.logger.Info(
"websocket auth attempt",
"api_key_length",
len(apiKey),
"api_key_prefix",
apiKey[:prefixLen],
)
userID, err := h.authConfig.ValidateAPIKey(apiKey)
if err != nil {
h.logger.Warn("websocket authentication failed", "error", err)
// Audit log failed authentication
if h.auditLogger != nil {
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
}
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
h.logger.Info("websocket authentication succeeded")
// Audit log successful authentication
if h.auditLogger != nil && userID != nil {
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
}
}
conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("websocket upgrade failed", "error", err)
return
}
conn.EnableWriteCompression(true)
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
h.logger.Warn("failed to set websocket compression level", "error", err)
}
enableLowLatencyTCP(conn, h.logger)
defer func() {
_ = conn.Close()
}()
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(
err,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
h.logger.Error("websocket read error", "error", err)
}
break
}
if messageType != websocket.BinaryMessage {
h.logger.Warn("received non-binary message")
continue
}
if err := h.handleMessage(conn, message); err != nil {
h.logger.Error("message handling error", "error", err)
// Send structured error response so CLI clients can parse it.
// (Raw fallback bytes cause client-side InvalidPacket errors.)
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
}
}
}
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
if len(message) < 1 {
return fmt.Errorf("message too short")
}
opcode := message[0]
payload := message[1:]
switch opcode {
case OpcodeQueueJob:
return h.handleQueueJob(conn, payload)
case OpcodeQueueJobWithTracking:
return h.handleQueueJobWithTracking(conn, payload)
case OpcodeQueueJobWithSnapshot:
return h.handleQueueJobWithSnapshot(conn, payload)
case OpcodeStatusRequest:
return h.handleStatusRequest(conn, payload)
case OpcodeCancelJob:
return h.handleCancelJob(conn, payload)
case OpcodePrune:
return h.handlePrune(conn, payload)
case OpcodeDatasetList:
return h.handleDatasetList(conn, payload)
case OpcodeDatasetRegister:
return h.handleDatasetRegister(conn, payload)
case OpcodeDatasetInfo:
return h.handleDatasetInfo(conn, payload)
case OpcodeDatasetSearch:
return h.handleDatasetSearch(conn, payload)
case OpcodeLogMetric:
return h.handleLogMetric(conn, payload)
case OpcodeGetExperiment:
return h.handleGetExperiment(conn, payload)
case OpcodeStartJupyter:
return h.handleStartJupyter(conn, payload)
case OpcodeStopJupyter:
return h.handleStopJupyter(conn, payload)
case OpcodeRemoveJupyter:
return h.handleRemoveJupyter(conn, payload)
case OpcodeRestoreJupyter:
return h.handleRestoreJupyter(conn, payload)
case OpcodeListJupyter:
return h.handleListJupyter(conn, payload)
case OpcodeValidateRequest:
return h.handleValidateRequest(conn, payload)
default:
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
}
}

1268
internal/api/ws_jobs.go Normal file

File diff suppressed because it is too large Load diff

478
internal/api/ws_jupyter.go Normal file
View file

@ -0,0 +1,478 @@
package api
import (
"encoding/binary"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/queue"
)
type jupyterTaskOutput struct {
Type string `json:"type"`
Service json.RawMessage `json:"service,omitempty"`
Services json.RawMessage `json:"services,omitempty"`
RestorePath string `json:"restore_path,omitempty"`
}
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][name_len:1][name:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRestore,
jupyterNameKey: strings.TrimSpace(name),
}
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter restore", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
}
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
if strings.TrimSpace(payloadOut.RestorePath) != "" {
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
type jupyterServiceView struct {
Name string `json:"name"`
URL string `json:"url"`
}
const (
jupyterTaskTypeKey = "task_type"
jupyterTaskTypeValue = "jupyter"
jupyterTaskActionKey = "jupyter_action"
jupyterActionStart = "start"
jupyterActionStop = "stop"
jupyterActionRemove = "remove"
jupyterActionRestore = "restore"
jupyterActionList = "list"
jupyterNameKey = "jupyter_name"
jupyterWorkspaceKey = "jupyter_workspace"
jupyterServiceIDKey = "jupyter_service_id"
)
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
if h.queue == nil {
return "", fmt.Errorf("task queue not configured")
}
if err := container.ValidateJobName(jobName); err != nil {
return "", err
}
if strings.TrimSpace(userName) == "" {
return "", fmt.Errorf("missing user")
}
if meta == nil {
meta = make(map[string]string)
}
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
taskID := uuid.New().String()
task := &queue.Task{
ID: taskID,
JobName: jobName,
Args: "",
Status: "queued",
Priority: 100, // high priority; interactive request
CreatedAt: time.Now(),
UserID: userName,
Username: userName,
CreatedBy: userName,
Metadata: meta,
}
if err := h.queue.AddTask(task); err != nil {
return "", err
}
return taskID, nil
}
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
if h.queue == nil {
return nil, fmt.Errorf("task queue not configured")
}
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timed out waiting for worker")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return nil, err
}
if t == nil {
time.Sleep(200 * time.Millisecond)
continue
}
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
return t, nil
}
time.Sleep(200 * time.Millisecond)
}
}
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol:
// [api_key_hash:16][name][workspace][password]
if len(payload) < 21 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
}
apiKeyHash := payload[:16]
// Verify API key
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
nameLen := int(payload[offset])
offset++
if len(payload) < offset+nameLen+2 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
}
name := string(payload[offset : offset+nameLen])
offset += nameLen
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
if len(payload) < offset+workspaceLen+1 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
}
workspace := string(payload[offset : offset+workspaceLen])
offset += workspaceLen
passwordLen := int(payload[offset])
offset++
if len(payload) < offset+passwordLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
}
// Password is parsed but not used in StartRequest
// offset += passwordLen (already advanced during parsing)
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStart,
jupyterNameKey: strings.TrimSpace(name),
jupyterWorkspaceKey: strings.TrimSpace(workspace),
}
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter task", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter start", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
h.logger.Error("jupyter task failed", "error", result.Error)
details := strings.TrimSpace(result.Error)
lower := strings.ToLower(details)
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
}
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
}
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
out := strings.TrimSpace(result.Output)
if out != "" {
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
var svc jupyterServiceView
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
if strings.TrimSpace(svc.URL) != "" {
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
}
}
}
}
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
}
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
meta := map[string]string{
jupyterTaskActionKey: jupyterActionStop,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
}
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter stop", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
}
apiKeyHash := payload[:16]
offset := 16
idLen := int(payload[offset])
offset++
if len(payload) < offset+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
}
serviceID := string(payload[offset : offset+idLen])
offset += idLen
// Optional: purge flag (1 byte). Default false for trash-first behavior.
purge := false
if len(payload) > offset {
purge = payload[offset] == 0x01
}
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:manage") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionRemove,
jupyterServiceIDKey: strings.TrimSpace(serviceID),
"jupyter_purge": fmt.Sprintf("%t", purge),
}
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter remove", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
}
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
}
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16]
if len(payload) < 16 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
}
apiKeyHash := payload[:16]
if h.authConfig != nil && h.authConfig.Enabled {
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
}
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Authentication failed",
err.Error(),
)
}
if user != nil && !user.HasPermission("jupyter:read") {
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
}
meta := map[string]string{
jupyterTaskActionKey: jupyterActionList,
}
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
if err != nil {
h.logger.Error("failed to enqueue jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
}
result, err := h.waitForTask(taskID, 2*time.Minute)
if err != nil {
h.logger.Error("failed waiting for jupyter list", "error", err)
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
}
if result.Status != "completed" {
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
}
out := strings.TrimSpace(result.Output)
if out == "" {
empty, _ := json.Marshal([]any{})
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
}
var payloadOut jupyterTaskOutput
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
// Always return an array payload (even if empty) so clients can render a stable table.
payload := payloadOut.Services
if len(payload) == 0 {
payload = []byte("[]")
}
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
}
// Fallback: return empty array on unexpected output.
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
}

100
internal/api/ws_tls_auth.go Normal file
View file

@ -0,0 +1,100 @@
package api
import (
"crypto/tls"
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"golang.org/x/crypto/acme/autocert"
)
// SetupTLSConfig creates TLS configuration for WebSocket server
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
var server *http.Server
if certFile != "" && keyFile != "" {
// Use provided certificates
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
} else if host != "" {
// Use Let's Encrypt with autocert
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(host),
Cache: autocert.DirCache("/var/www/.cache"),
}
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
TLSConfig: certManager.TLSConfig(),
}
}
return server, nil
}
// verifyAPIKeyHash verifies the provided binary hash against stored API keys
func (h *WSHandler) verifyAPIKeyHash(hash []byte) error {
if h.authConfig == nil || !h.authConfig.Enabled {
return nil // No auth required
}
_, err := h.authConfig.ValidateAPIKeyHash(hash)
if err != nil {
return fmt.Errorf("invalid api key")
}
return nil
}
// sendErrorPacket sends an error response packet
func (h *WSHandler) sendErrorPacket(
conn *websocket.Conn,
errorCode byte,
message string,
details string,
) error {
packet := NewErrorPacket(errorCode, message, details)
return h.sendResponsePacket(conn, packet)
}
// sendResponsePacket sends a structured response packet
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
data, err := packet.Serialize()
if err != nil {
h.logger.Error("failed to serialize response packet", "error", err)
// Fallback to simple error response
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
}
return conn.WriteMessage(websocket.BinaryMessage, data)
}
func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) {
if h.authConfig != nil {
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
if err != nil {
return nil, err
}
return user, nil
}
return &auth.User{
Name: "default",
Admin: true,
Roles: []string{"admin"},
Permissions: map[string]bool{
"*": true,
},
}, nil
}

642
internal/api/ws_validate.go Normal file
View file

@ -0,0 +1,642 @@
package api
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/config"
"github.com/jfraeys/fetch_ml/internal/container"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/worker"
)
type validateCheck struct {
OK bool `json:"ok"`
Expected string `json:"expected,omitempty"`
Actual string `json:"actual,omitempty"`
Details string `json:"details,omitempty"`
}
type validateReport struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Checks map[string]validateCheck `json:"checks"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
TS string `json:"ts"`
}
func shouldRequireRunManifest(task *queue.Task) bool {
if task == nil {
return false
}
s := strings.ToLower(strings.TrimSpace(task.Status))
switch s {
case "running", "completed", "failed":
return true
default:
return false
}
}
func expectedRunManifestBucketForStatus(status string) (string, bool) {
s := strings.ToLower(strings.TrimSpace(status))
switch s {
case "queued", "pending":
return "pending", true
case "running":
return "running", true
case "completed", "finished":
return "finished", true
case "failed":
return "failed", true
default:
return "", false
}
}
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
return "", "", false
}
jobPaths := config.NewJobPaths(basePath)
typedRoots := []struct {
bucket string
root string
}{
{bucket: "running", root: jobPaths.RunningPath()},
{bucket: "pending", root: jobPaths.PendingPath()},
{bucket: "finished", root: jobPaths.FinishedPath()},
{bucket: "failed", root: jobPaths.FailedPath()},
}
for _, item := range typedRoots {
root := item.root
dir := filepath.Join(root, jobName)
if info, err := os.Stat(dir); err == nil && info.IsDir() {
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
return dir, item.bucket, true
}
}
}
return "", "", false
}
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
if task == nil {
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
}
if task.CPU < 0 {
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
return chk, []string{"invalid cpu request"}
}
if task.MemoryGB < 0 {
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
return chk, []string{"invalid memory request"}
}
if task.GPU < 0 {
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
return chk, []string{"invalid gpu request"}
}
if strings.TrimSpace(task.GPUMemory) != "" {
s := strings.TrimSpace(task.GPUMemory)
if strings.HasSuffix(s, "%") {
v := strings.TrimSuffix(s, "%")
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil || f <= 0 || f > 100 {
details := "gpu_memory must be a percentage in (0,100]"
chk := validateCheck{OK: false, Details: details}
return chk, []string{"invalid gpu_memory"}
}
} else {
f, err := strconv.ParseFloat(s, 64)
if err != nil || f <= 0 || f > 1 {
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
return chk, []string{"invalid gpu_memory"}
}
}
}
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
return chk, []string{"invalid gpu_memory"}
}
return validateCheck{OK: true}, nil
}
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
if len(payload) < 18 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
}
apiKeyHash := payload[:16]
targetType := payload[16]
idLen := int(payload[17])
if idLen < 1 || len(payload) < 18+idLen {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
}
idBytes := payload[18 : 18+idLen]
// Validate API key and user
user, err := h.validateWSUser(apiKeyHash)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeAuthenticationFailed,
"Invalid API key",
err.Error(),
)
}
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"Insufficient permissions to validate jobs",
"",
)
}
if h.expManager == nil {
return h.sendErrorPacket(
conn,
ErrorCodeServiceUnavailable,
"Experiment manager not available",
"",
)
}
var task *queue.Task
commitID := ""
switch targetType {
case 0:
if len(idBytes) != 20 {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
}
commitID = fmt.Sprintf("%x", idBytes)
case 1:
taskID := string(idBytes)
if h.queue == nil {
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
}
t, err := h.queue.GetTask(taskID)
if err != nil {
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
}
task = t
if h.authConfig != nil &&
h.authConfig.Enabled &&
!user.Admin &&
task.UserID != user.Name &&
task.CreatedBy != user.Name {
return h.sendErrorPacket(
conn,
ErrorCodePermissionDenied,
"You can only validate your own jobs",
"",
)
}
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
}
commitID = task.Metadata["commit_id"]
default:
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
}
r := validateReport{
OK: true,
TS: time.Now().UTC().Format(time.RFC3339Nano),
Checks: map[string]validateCheck{},
}
if task != nil {
r.TaskID = task.ID
}
if commitID != "" {
r.CommitID = commitID
}
// Validate commit id format
if len(commitID) != 40 {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id length")
} else if _, err := hex.DecodeString(commitID); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid commit_id hex")
}
// Experiment manifest integrity
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
if r.OK {
if err := h.expManager.ValidateManifest(commitID); err != nil {
r.OK = false
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()}
r.Errors = append(r.Errors, "experiment manifest validation failed")
} else {
r.Checks["experiment_manifest"] = validateCheck{OK: true}
}
}
// Deps manifest presence + hash
// TODO(context): Allow client to declare which dependency manifest is authoritative.
filesPath := h.expManager.GetFilesPath(commitID)
depName, depErr := worker.SelectDependencyManifest(filesPath)
if depErr != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: depErr.Error(),
}
r.Errors = append(r.Errors, "deps manifest missing")
} else {
sha, err := fileSHA256Hex(filepath.Join(filesPath, depName))
if err != nil {
r.OK = false
r.Checks["deps_manifest"] = validateCheck{
OK: false,
Details: err.Error(),
}
r.Errors = append(r.Errors, "deps manifest hash failed")
} else {
r.Checks["deps_manifest"] = validateCheck{
OK: true,
Actual: depName + ":" + sha,
}
}
}
// Compare against expected task metadata if available.
if task != nil {
resCheck, resErrs := validateResourcesForTask(task)
r.Checks["resources"] = resCheck
if !resCheck.OK {
r.OK = false
r.Errors = append(r.Errors, resErrs...)
}
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
if err := container.ValidateJobName(task.JobName); err != nil {
r.OK = false
r.Errors = append(r.Errors, "invalid job name")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
} else {
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
}
} else {
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
if !found {
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest not found")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
} else {
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
}
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
r.OK = false
r.Errors = append(r.Errors, "unable to read run manifest")
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
} else {
r.Checks["run_manifest"] = validateCheck{OK: true}
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
if ok {
if expectedBucket != manifestBucket {
msg := "run manifest location mismatch"
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, msg)
r.Checks["run_manifest_location"] = chk
} else {
r.Warnings = append(r.Warnings, msg)
r.Checks["run_manifest_location"] = chk
}
} else {
r.Checks["run_manifest_location"] = validateCheck{
OK: true,
Expected: expectedBucket,
Actual: manifestBucket,
}
}
}
if strings.TrimSpace(rm.TaskID) == "" {
r.OK = false
r.Errors = append(r.Errors, "run manifest missing task_id")
r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID}
} else if rm.TaskID != task.ID {
r.OK = false
r.Errors = append(r.Errors, "run manifest task_id mismatch")
r.Checks["run_manifest_task_id"] = validateCheck{
OK: false,
Expected: task.ID,
Actual: rm.TaskID,
}
} else {
r.Checks["run_manifest_task_id"] = validateCheck{
OK: true,
Expected: task.ID,
Actual: rm.TaskID,
}
}
commitWant := strings.TrimSpace(task.Metadata["commit_id"])
commitGot := strings.TrimSpace(rm.CommitID)
if commitWant != "" && commitGot != "" && commitWant != commitGot {
r.OK = false
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: false,
Expected: commitWant,
Actual: commitGot,
}
} else if commitWant != "" {
r.Checks["run_manifest_commit_id"] = validateCheck{
OK: true,
Expected: commitWant,
Actual: commitGot,
}
}
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
depGotName := strings.TrimSpace(rm.DepsManifestName)
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" {
expectedDep := depWantName + ":" + depWantSHA
actualDep := depGotName + ":" + depGotSHA
if depWantName != depGotName || depWantSHA != depGotSHA {
r.OK = false
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
r.Checks["run_manifest_deps"] = validateCheck{
OK: false,
Expected: expectedDep,
Actual: actualDep,
}
} else {
r.Checks["run_manifest_deps"] = validateCheck{
OK: true,
Expected: expectedDep,
Actual: actualDep,
}
}
}
if strings.TrimSpace(task.SnapshotID) != "" {
snapWantID := strings.TrimSpace(task.SnapshotID)
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
snapGotID := strings.TrimSpace(rm.SnapshotID)
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: false,
Expected: snapWantID,
Actual: snapGotID,
}
} else {
r.Checks["run_manifest_snapshot_id"] = validateCheck{
OK: true,
Expected: snapWantID,
Actual: snapGotID,
}
}
if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA {
r.OK = false
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: false,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
} else if snapWantSHA != "" {
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
OK: true,
Expected: snapWantSHA,
Actual: snapGotSHA,
}
}
}
statusLower := strings.ToLower(strings.TrimSpace(task.Status))
lifecycleOK := true
details := ""
switch statusLower {
case "running":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for running task"
}
if !rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "ended_at must be empty for running task"
}
}
if rm.ExitCode != nil {
lifecycleOK = false
if details == "" {
details = "exit_code must be empty for running task"
}
}
case "completed", "failed":
if rm.StartedAt.IsZero() {
lifecycleOK = false
details = "missing started_at for completed/failed task"
}
if rm.EndedAt.IsZero() {
lifecycleOK = false
if details == "" {
details = "missing ended_at for completed/failed task"
}
}
if rm.ExitCode == nil {
lifecycleOK = false
if details == "" {
details = "missing exit_code for completed/failed task"
}
}
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
lifecycleOK = false
if details == "" {
details = "ended_at is before started_at"
}
}
case "queued", "pending":
// queued/pending tasks may not have started yet.
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
lifecycleOK = false
details = "queued/pending task should not have ended_at/exit_code"
}
}
if lifecycleOK {
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
} else {
chk := validateCheck{OK: false, Details: details}
if shouldRequireRunManifest(task) {
r.OK = false
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
} else {
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
r.Checks["run_manifest_lifecycle"] = chk
}
}
}
}
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
cur := ""
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
cur = man.OverallSHA
}
if want == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
} else if cur == "" {
r.OK = false
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
} else if want != cur {
r.OK = false
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
} else {
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
}
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
if wantDep == "" || wantDepSha == "" {
r.OK = false
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
} else if depName != "" {
sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName))
ok := (wantDep == depName && wantDepSha == sha)
if !ok {
r.OK = false
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
r.Checks["expected_deps_manifest"] = validateCheck{
OK: false,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
} else {
r.Checks["expected_deps_manifest"] = validateCheck{
OK: true,
Expected: wantDep + ":" + wantDepSha,
Actual: depName + ":" + sha,
}
}
}
// Snapshot/dataset checks require dataDir.
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
// TODO(context): Validate snapshots by digest.
if task.SnapshotID != "" {
if h.dataDir == "" {
r.OK = false
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
} else {
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
if nerr != nil || wantSnap == "" {
r.OK = false
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
r.Checks["snapshot"] = validateCheck{OK: false}
} else {
curSnap, err := worker.DirOverallSHA256Hex(
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
)
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "snapshot hash computation failed")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
} else if curSnap != wantSnap {
r.OK = false
r.Errors = append(r.Errors, "snapshot checksum mismatch")
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
} else {
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
}
}
}
}
if len(task.DatasetSpecs) > 0 {
// TODO(context): Add dataset URI fetch/verification.
// TODO(context): Currently only validates local materialized datasets.
for _, ds := range task.DatasetSpecs {
if ds.Checksum == "" {
continue
}
key := "dataset:" + ds.Name
if h.dataDir == "" {
r.OK = false
r.Errors = append(
r.Errors,
"api server data_dir not configured; cannot validate dataset checksums",
)
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
continue
}
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
if nerr != nil || wantDS == "" {
r.OK = false
r.Errors = append(r.Errors, "invalid dataset checksum format")
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
continue
}
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
if err != nil {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum computation failed")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
continue
}
if curDS != wantDS {
r.OK = false
r.Errors = append(r.Errors, "dataset checksum mismatch")
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
continue
}
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
}
}
}
payloadBytes, err := json.Marshal(r)
if err != nil {
return h.sendErrorPacket(
conn,
ErrorCodeUnknownError,
"failed to serialize validate report",
err.Error(),
)
}
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
}

View file

@ -133,6 +133,62 @@ func (c *Config) ValidateAPIKey(key string) (*User, error) {
return nil, fmt.Errorf("invalid API key")
}
// ValidateAPIKeyHash validates a pre-hashed API key and returns user information
func (c *Config) ValidateAPIKeyHash(hash []byte) (*User, error) {
if !c.Enabled {
// Auth disabled - return default admin user for development
return &User{Name: "default", Admin: true}, nil
}
if len(hash) != 16 {
return nil, fmt.Errorf("invalid api key hash length: %d", len(hash))
}
for username, entry := range c.APIKeys {
stored := strings.TrimSpace(string(entry.Hash))
if stored == "" {
continue
}
storedBytes, err := hex.DecodeString(stored)
if err != nil {
continue
}
if len(storedBytes) != sha256.Size {
continue
}
if string(storedBytes[:16]) == string(hash) {
// Build user with role and permission inheritance
user := &User{
Name: string(username),
Admin: entry.Admin,
Roles: entry.Roles,
Permissions: make(map[string]bool),
}
// Copy explicit permissions
for perm, value := range entry.Permissions {
user.Permissions[perm] = value
}
// Add role-based permissions
rolePerms := getRolePermissions(entry.Roles)
for perm, value := range rolePerms {
if _, exists := user.Permissions[perm]; !exists {
user.Permissions[perm] = value
}
}
// Admin gets all permissions
if entry.Admin {
user.Permissions["*"] = true
}
return user, nil
}
}
return nil, fmt.Errorf("invalid api key")
}
// AuthMiddleware creates HTTP middleware for API key authentication
func (c *Config) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -177,6 +233,10 @@ func GetUserFromContext(ctx context.Context) *User {
return nil
}
func WithUserContext(ctx context.Context, user *User) context.Context {
return context.WithValue(ctx, userContextKey, user)
}
// RequireAdmin creates middleware that requires admin privileges
func RequireAdmin(next http.Handler) http.Handler {
return RequirePermission("system:admin")(next)

View file

@ -64,7 +64,10 @@ func (s *DatabaseAuthStore) init() error {
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31'));
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(
revoked_at,
COALESCE(expires_at, '9999-12-31')
);
`
_, err := s.db.ExecContext(ctx, query)
@ -158,7 +161,16 @@ func (s *DatabaseAuthStore) CreateAPIKey(
revoked_at = NULL
`
_, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt)
_, err = s.db.ExecContext(
ctx,
query,
userID,
keyHash,
admin,
rolesJSON,
permissionsJSON,
expiresAt,
)
return err
}

View file

@ -213,7 +213,15 @@ func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error {
for username, entry := range h.fileStore.APIKeys {
userID := string(username)
err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil)
err := h.dbStore.CreateAPIKey(
ctx,
userID,
string(entry.Hash),
entry.Admin,
entry.Roles,
entry.Permissions,
nil,
)
if err != nil {
log.Printf("Failed to migrate key for user %s: %v", userID, err)
continue

View file

@ -59,23 +59,43 @@ var PermissionGroups = map[string]PermissionGroup{
Description: "Complete system access",
},
"job_management": {
Name: "Job Management",
Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete},
Name: "Job Management",
Permissions: []string{
PermissionJobsCreate,
PermissionJobsRead,
PermissionJobsUpdate,
PermissionJobsDelete,
},
Description: "Create, read, update, and delete ML jobs",
},
"data_access": {
Name: "Data Access",
Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete},
Name: "Data Access",
Permissions: []string{
PermissionDataRead,
PermissionDataCreate,
PermissionDataUpdate,
PermissionDataDelete,
},
Description: "Access and manage datasets",
},
"readonly": {
Name: "Read Only",
Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics},
Name: "Read Only",
Permissions: []string{
PermissionJobsRead,
PermissionDataRead,
PermissionModelsRead,
PermissionSystemMetrics,
},
Description: "View-only access to system resources",
},
"system_admin": {
Name: "System Administration",
Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics},
Name: "System Administration",
Permissions: []string{
PermissionSystemConfig,
PermissionSystemLogs,
PermissionSystemUsers,
PermissionSystemMetrics,
},
Description: "System configuration and user management",
},
}

View file

@ -25,6 +25,7 @@ const (
RedisTaskStatusPrefix = "ml:status:"
RedisDatasetPrefix = "ml:dataset:"
RedisWorkerHeartbeat = "ml:workers:heartbeat"
RedisWorkerPrewarmKey = "ml:workers:prewarm:"
)
// Task status constants

View file

@ -0,0 +1,83 @@
package config
import (
"fmt"
"time"
)
// SecurityConfig holds security-related configuration
type SecurityConfig struct {
// AllowedOrigins lists the allowed origins for WebSocket connections
// Empty list defaults to localhost-only in production mode
AllowedOrigins []string `yaml:"allowed_origins"`
// ProductionMode enables strict security checks
ProductionMode bool `yaml:"production_mode"`
// APIKeyRotationDays is the number of days before API keys should be rotated
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
// AuditLogging configuration
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
// IPWhitelist for additional connection filtering
IPWhitelist []string `yaml:"ip_whitelist"`
}
// AuditLoggingConfig holds audit logging configuration
type AuditLoggingConfig struct {
Enabled bool `yaml:"enabled"`
LogPath string `yaml:"log_path"`
}
// MonitoringConfig holds monitoring-related configuration
type MonitoringConfig struct {
Prometheus PrometheusConfig `yaml:"prometheus"`
HealthChecks HealthChecksConfig `yaml:"health_checks"`
}
// PrometheusConfig holds Prometheus metrics configuration
type PrometheusConfig struct {
Enabled bool `yaml:"enabled"`
Port int `yaml:"port"`
Path string `yaml:"path"`
}
// HealthChecksConfig holds health check configuration
type HealthChecksConfig struct {
Enabled bool `yaml:"enabled"`
Interval time.Duration `yaml:"interval"`
}
// Validate validates the security configuration
func (s *SecurityConfig) Validate() error {
if s.ProductionMode {
if len(s.AllowedOrigins) == 0 {
return fmt.Errorf("production_mode requires at least one allowed_origin")
}
}
if s.APIKeyRotationDays < 0 {
return fmt.Errorf("api_key_rotation_days must be positive")
}
if s.AuditLogging.Enabled && s.AuditLogging.LogPath == "" {
return fmt.Errorf("audit_logging enabled but log_path not set")
}
return nil
}
// Validate validates the monitoring configuration
func (m *MonitoringConfig) Validate() error {
if m.Prometheus.Enabled {
if m.Prometheus.Port <= 0 || m.Prometheus.Port > 65535 {
return fmt.Errorf("prometheus port must be between 1 and 65535")
}
if m.Prometheus.Path == "" {
m.Prometheus.Path = "/metrics" // Default
}
}
return nil
}

View file

@ -15,13 +15,21 @@ func int64Max(a, b int64) int64 {
// Metrics tracks various performance counters and statistics.
type Metrics struct {
TasksProcessed atomic.Int64
TasksFailed atomic.Int64
DataFetchTime atomic.Int64 // Total nanoseconds
ExecutionTime atomic.Int64
DataTransferred atomic.Int64 // Total bytes
ActiveTasks atomic.Int64
QueuedTasks atomic.Int64
TasksProcessed atomic.Int64
TasksFailed atomic.Int64
DataFetchTime atomic.Int64 // Total nanoseconds
ExecutionTime atomic.Int64
DataTransferred atomic.Int64 // Total bytes
ActiveTasks atomic.Int64
QueuedTasks atomic.Int64
PrewarmEnvHit atomic.Int64
PrewarmEnvMiss atomic.Int64
PrewarmEnvBuilt atomic.Int64
PrewarmEnvTime atomic.Int64 // Total nanoseconds
PrewarmSnapshotHit atomic.Int64
PrewarmSnapshotMiss atomic.Int64
PrewarmSnapshotBuilt atomic.Int64
PrewarmSnapshotTime atomic.Int64 // Total nanoseconds
}
// RecordTaskSuccess records successful task completion with duration.
@ -53,6 +61,32 @@ func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) {
m.DataFetchTime.Add(duration.Nanoseconds())
}
func (m *Metrics) RecordPrewarmEnvHit() {
m.PrewarmEnvHit.Add(1)
}
func (m *Metrics) RecordPrewarmEnvMiss() {
m.PrewarmEnvMiss.Add(1)
}
func (m *Metrics) RecordPrewarmEnvBuilt(duration time.Duration) {
m.PrewarmEnvBuilt.Add(1)
m.PrewarmEnvTime.Add(duration.Nanoseconds())
}
func (m *Metrics) RecordPrewarmSnapshotHit() {
m.PrewarmSnapshotHit.Add(1)
}
func (m *Metrics) RecordPrewarmSnapshotMiss() {
m.PrewarmSnapshotMiss.Add(1)
}
func (m *Metrics) RecordPrewarmSnapshotBuilt(duration time.Duration) {
m.PrewarmSnapshotBuilt.Add(1)
m.PrewarmSnapshotTime.Add(duration.Nanoseconds())
}
// SetQueuedTasks sets the number of queued tasks.
func (m *Metrics) SetQueuedTasks(count int64) {
m.QueuedTasks.Store(count)
@ -66,13 +100,21 @@ func (m *Metrics) GetStats() map[string]any {
dataFetchTime := m.DataFetchTime.Load()
return map[string]any{
"tasks_processed": processed,
"tasks_failed": failed,
"active_tasks": m.ActiveTasks.Load(),
"queued_tasks": m.QueuedTasks.Load(),
"success_rate": float64(processed-failed) / float64(int64Max(processed, 1)),
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)),
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
"avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)),
"tasks_processed": processed,
"tasks_failed": failed,
"active_tasks": m.ActiveTasks.Load(),
"queued_tasks": m.QueuedTasks.Load(),
"success_rate": float64(processed-failed) / float64(int64Max(processed, 1)),
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)),
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
"avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)),
"prewarm_env_hit": m.PrewarmEnvHit.Load(),
"prewarm_env_miss": m.PrewarmEnvMiss.Load(),
"prewarm_env_built": m.PrewarmEnvBuilt.Load(),
"prewarm_env_time": time.Duration(m.PrewarmEnvTime.Load()),
"prewarm_snapshot_hit": m.PrewarmSnapshotHit.Load(),
"prewarm_snapshot_miss": m.PrewarmSnapshotMiss.Load(),
"prewarm_snapshot_built": m.PrewarmSnapshotBuilt.Load(),
"prewarm_snapshot_time": time.Duration(m.PrewarmSnapshotTime.Load()),
}
}

View file

@ -3,8 +3,11 @@ package middleware
import (
"context"
"fmt"
"log"
"net"
"net/http"
"net/netip"
"strings"
"time"
@ -26,7 +29,11 @@ type RateLimitOptions struct {
}
// NewSecurityMiddleware creates a new security middleware instance.
func NewSecurityMiddleware(authConfig *auth.Config, jwtSecret string, rlOpts *RateLimitOptions) *SecurityMiddleware {
func NewSecurityMiddleware(
authConfig *auth.Config,
jwtSecret string,
rlOpts *RateLimitOptions,
) *SecurityMiddleware {
sm := &SecurityMiddleware{
authConfig: authConfig,
jwtSecret: []byte(jwtSecret),
@ -62,21 +69,23 @@ func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
// APIKeyAuth provides API key authentication middleware.
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := auth.ExtractAPIKeyFromRequest(r)
// Validate API key using auth config
if sm.authConfig == nil {
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
// If authentication is not configured or disabled, allow all requests.
// This keeps local/dev environments functional without requiring API keys.
if sm.authConfig == nil || !sm.authConfig.Enabled {
next.ServeHTTP(w, r)
return
}
_, err := sm.authConfig.ValidateAPIKey(apiKey)
apiKey := auth.ExtractAPIKeyFromRequest(r)
// Validate API key using auth config
user, err := sm.authConfig.ValidateAPIKey(apiKey)
if err != nil {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
ctx := auth.WithUserContext(r.Context(), user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
@ -103,21 +112,54 @@ func SecurityHeaders(next http.Handler) http.Handler {
// IPWhitelist provides IP whitelist middleware.
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
parsedAddrs := make([]netip.Addr, 0, len(allowedIPs))
parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs))
for _, raw := range allowedIPs {
val := strings.TrimSpace(raw)
if val == "" {
continue
}
if strings.Contains(val, "/") {
p, err := netip.ParsePrefix(val)
if err != nil {
log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err)
continue
}
parsedPrefixes = append(parsedPrefixes, p)
continue
}
a, err := netip.ParseAddr(val)
if err != nil {
log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err)
continue
}
parsedAddrs = append(parsedAddrs, a)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
clientIPStr := getClientIP(r)
addr, err := parseClientIP(clientIPStr)
if err != nil {
http.Error(w, "IP not whitelisted", http.StatusForbidden)
return
}
// Check if client IP is in whitelist
allowed := false
for _, ip := range allowedIPs {
if strings.Contains(ip, "/") {
// CIDR notation - would need proper IP net parsing
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
allowed = true
break
}
} else {
if clientIP == ip {
for _, a := range parsedAddrs {
if a == addr {
allowed = true
break
}
}
if !allowed {
for _, p := range parsedPrefixes {
if p.Contains(addr) {
allowed = true
break
}
@ -134,41 +176,46 @@ func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler
}
}
// CORS middleware with restrictive defaults
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Only allow specific origins in production
allowedOrigins := []string{
"https://ml-experiments.example.com",
"https://app.example.com",
// CORS middleware with configured allowed origins
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
allowed := make([]string, 0, len(allowedOrigins))
for _, o := range allowedOrigins {
v := strings.TrimSpace(o)
if v != "" {
allowed = append(allowed, v)
}
}
isAllowed := false
for _, allowed := range allowedOrigins {
if origin == allowed {
isAllowed = true
break
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" {
isAllowed := false
for _, a := range allowed {
if a == "*" || origin == a {
isAllowed = true
break
}
}
if isAllowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
}
}
if isAllowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
next.ServeHTTP(w, r)
})
}
}
// RequestTimeout provides request timeout middleware.
@ -253,12 +300,28 @@ func getClientIP(r *http.Request) string {
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
return host
}
return r.RemoteAddr
}
func parseClientIP(raw string) (netip.Addr, error) {
s := strings.TrimSpace(raw)
if s == "" {
return netip.Addr{}, fmt.Errorf("empty ip")
}
// Try host:port parsing first (covers IPv4 and bracketed IPv6)
if host, _, err := net.SplitHostPort(s); err == nil {
s = host
}
// Trim brackets for IPv6 literals
s = strings.TrimPrefix(s, "[")
s = strings.TrimSuffix(s, "]")
return netip.ParseAddr(s)
}
// Response writer wrapper to capture status code
type responseWriter struct {
http.ResponseWriter
@ -273,5 +336,11 @@ func (rw *responseWriter) WriteHeader(code int) {
func logSecurityEvent(event map[string]interface{}) {
// Implementation would send to security monitoring system
// For now, just log (in production, use proper logging)
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
log.Printf(
"SECURITY AUDIT: %s %s %s %v",
event["client_ip"],
event["method"],
event["path"],
event["status"],
)
}

View file

@ -0,0 +1,295 @@
package prommetrics
import (
"net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds all Prometheus metrics for the application
type Metrics struct {
// WebSocket metrics
wsConnections *prometheus.GaugeVec
wsMessages *prometheus.CounterVec
wsDuration *prometheus.HistogramVec
wsErrors *prometheus.CounterVec
// Job queue metrics
jobsQueued prometheus.Counter
jobsCompleted *prometheus.CounterVec
jobsActive prometheus.Gauge
jobDuration *prometheus.HistogramVec
queueLength prometheus.Gauge
// Jupyter metrics
jupyterServices *prometheus.GaugeVec
jupyterOps *prometheus.CounterVec
// HTTP metrics
httpRequests *prometheus.CounterVec
httpDuration *prometheus.HistogramVec
// Prewarm metrics
prewarmSnapshotHit prometheus.Counter
prewarmSnapshotMiss prometheus.Counter
prewarmSnapshotBuilt prometheus.Counter
prewarmSnapshotTime prometheus.Histogram
registry *prometheus.Registry
}
// New creates a new Prometheus Metrics instance
func New() *Metrics {
m := &Metrics{
registry: prometheus.NewRegistry(),
}
m.initMetrics()
return m
}
// initMetrics initializes all Prometheus metrics
func (m *Metrics) initMetrics() {
// WebSocket metrics
m.wsConnections = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "fetchml_websocket_connections",
Help: "Number of active WebSocket connections",
},
[]string{"status"},
)
m.wsMessages = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "fetchml_websocket_messages_total",
Help: "Total number of WebSocket messages",
},
[]string{"opcode", "status"},
)
m.wsDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "fetchml_websocket_duration_seconds",
Help: "WebSocket message processing duration",
Buckets: prometheus.DefBuckets,
},
[]string{"opcode"},
)
m.wsErrors = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "fetchml_websocket_errors_total",
Help: "Total number of WebSocket errors",
},
[]string{"type"},
)
// Job queue metrics
m.jobsQueued = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "fetchml_jobs_queued_total",
Help: "Total number of jobs queued",
},
)
m.jobsCompleted = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "fetchml_jobs_completed_total",
Help: "Total number of completed jobs",
},
[]string{"status"},
)
m.jobsActive = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "fetchml_jobs_active",
Help: "Number of currently active jobs",
},
)
m.jobDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "fetchml_job_duration_seconds",
Help: "Job execution duration",
Buckets: []float64{1, 5, 10, 30, 60, 300, 600, 1800, 3600},
},
[]string{"status"},
)
m.queueLength = prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "fetchml_queue_length",
Help: "Current job queue length",
},
)
// Jupyter metrics
m.jupyterServices = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "fetchml_jupyter_services",
Help: "Number of Jupyter services",
},
[]string{"status"},
)
m.jupyterOps = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "fetchml_jupyter_operations_total",
Help: "Total number of Jupyter operations",
},
[]string{"operation", "status"},
)
// HTTP metrics
m.httpRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "fetchml_http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "endpoint", "status"},
)
m.httpDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "fetchml_http_duration_seconds",
Help: "HTTP request duration",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "endpoint"},
)
// Prewarm metrics
m.prewarmSnapshotHit = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "fetchml_prewarm_snapshot_hit_total",
Help: "Total number of prewarmed snapshot hits (snapshots found in .prewarm/)",
},
)
m.prewarmSnapshotMiss = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "fetchml_prewarm_snapshot_miss_total",
Help: "Total number of prewarmed snapshot misses (snapshots not found in .prewarm/)",
},
)
m.prewarmSnapshotBuilt = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "fetchml_prewarm_snapshot_built_total",
Help: "Total number of snapshots prewarmed into .prewarm/",
},
)
m.prewarmSnapshotTime = prometheus.NewHistogram(
prometheus.HistogramOpts{
Name: "fetchml_prewarm_snapshot_duration_seconds",
Help: "Time spent prewarming snapshots",
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 30, 60, 120},
},
)
// Register all metrics
m.registry.MustRegister(
m.wsConnections,
m.wsMessages,
m.wsDuration,
m.wsErrors,
m.jobsQueued,
m.jobsCompleted,
m.jobsActive,
m.jobDuration,
m.queueLength,
m.jupyterServices,
m.jupyterOps,
m.httpRequests,
m.httpDuration,
m.prewarmSnapshotHit,
m.prewarmSnapshotMiss,
m.prewarmSnapshotBuilt,
m.prewarmSnapshotTime,
)
}
// Handler returns the Prometheus HTTP handler
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// WebSocket metrics methods
func (m *Metrics) IncWSConnections(status string) {
m.wsConnections.WithLabelValues(status).Inc()
}
func (m *Metrics) DecWSConnections(status string) {
m.wsConnections.WithLabelValues(status).Dec()
}
func (m *Metrics) IncWSMessages(opcode, status string) {
m.wsMessages.WithLabelValues(opcode, status).Inc()
}
func (m *Metrics) ObserveWSDuration(opcode string, duration time.Duration) {
m.wsDuration.WithLabelValues(opcode).Observe(duration.Seconds())
}
func (m *Metrics) IncWSErrors(errType string) {
m.wsErrors.WithLabelValues(errType).Inc()
}
// Job queue metrics methods
func (m *Metrics) IncJobsQueued() {
m.jobsQueued.Inc()
}
func (m *Metrics) IncJobsCompleted(status string) {
m.jobsCompleted.WithLabelValues(status).Inc()
}
func (m *Metrics) SetJobsActive(count float64) {
m.jobsActive.Set(count)
}
func (m *Metrics) ObserveJobDuration(status string, duration time.Duration) {
m.jobDuration.WithLabelValues(status).Observe(duration.Seconds())
}
func (m *Metrics) SetQueueLength(length float64) {
m.queueLength.Set(length)
}
// Jupyter metrics methods
func (m *Metrics) SetJupyterServices(status string, count float64) {
m.jupyterServices.WithLabelValues(status).Set(count)
}
func (m *Metrics) IncJupyterOps(operation, status string) {
m.jupyterOps.WithLabelValues(operation, status).Inc()
}
// HTTP metrics methods
func (m *Metrics) IncHTTPRequests(method, endpoint, status string) {
m.httpRequests.WithLabelValues(method, endpoint, status).Inc()
}
func (m *Metrics) ObserveHTTPDuration(method, endpoint string, duration time.Duration) {
m.httpDuration.WithLabelValues(method, endpoint).Observe(duration.Seconds())
}
// Prewarm metrics methods
func (m *Metrics) IncPrewarmSnapshotHit() {
m.prewarmSnapshotHit.Inc()
}
func (m *Metrics) IncPrewarmSnapshotMiss() {
m.prewarmSnapshotMiss.Inc()
}
func (m *Metrics) IncPrewarmSnapshotBuilt() {
m.prewarmSnapshotBuilt.Inc()
}
func (m *Metrics) ObservePrewarmSnapshotDuration(duration time.Duration) {
m.prewarmSnapshotTime.Observe(duration.Seconds())
}