From add4a90e626a1c50a0b70b40ed81087c40d57f15 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Mon, 5 Jan 2026 12:31:07 -0500 Subject: [PATCH] feat(api): refactor websocket handlers; add health and prometheus middleware --- cmd/api-server/README.md | 2 +- cmd/api-server/main.go | 2 +- go.mod | 15 +- go.sum | 21 + internal/api/handlers.go | 62 +- internal/api/health.go | 82 ++ internal/api/metrics_middleware.go | 63 ++ internal/api/monitoring_config.go | 20 + internal/api/protocol.go | 127 ++- internal/api/protocol_simplified.go | 155 ---- internal/api/server.go | 150 +++- internal/api/server_config.go | 68 +- internal/api/ws.go | 652 -------------- internal/api/ws_datasets.go | 208 +++++ internal/api/ws_handler.go | 279 ++++++ internal/api/ws_jobs.go | 1268 +++++++++++++++++++++++++++ internal/api/ws_jupyter.go | 478 ++++++++++ internal/api/ws_tls_auth.go | 100 +++ internal/api/ws_validate.go | 642 ++++++++++++++ internal/auth/api_key.go | 60 ++ internal/auth/database.go | 16 +- internal/auth/hybrid.go | 10 +- internal/auth/permissions.go | 36 +- internal/config/constants.go | 1 + internal/config/security.go | 83 ++ internal/metrics/metrics.go | 72 +- internal/middleware/security.go | 173 ++-- internal/prommetrics/prometheus.go | 295 +++++++ 28 files changed, 4179 insertions(+), 961 deletions(-) create mode 100644 internal/api/health.go create mode 100644 internal/api/metrics_middleware.go create mode 100644 internal/api/monitoring_config.go delete mode 100644 internal/api/protocol_simplified.go delete mode 100644 internal/api/ws.go create mode 100644 internal/api/ws_datasets.go create mode 100644 internal/api/ws_handler.go create mode 100644 internal/api/ws_jobs.go create mode 100644 internal/api/ws_jupyter.go create mode 100644 internal/api/ws_tls_auth.go create mode 100644 internal/api/ws_validate.go create mode 100644 internal/config/security.go create mode 100644 internal/prommetrics/prometheus.go diff --git a/cmd/api-server/README.md b/cmd/api-server/README.md index 513b330..bd43686 100644 --- a/cmd/api-server/README.md +++ b/cmd/api-server/README.md @@ -5,7 +5,7 @@ WebSocket API server for the ML CLI tool... ## Usage ```bash -./bin/api-server --config configs/config-dev.yaml --listen :9100 +./bin/api-server --config configs/api/dev.yaml ``` ## Endpoints diff --git a/cmd/api-server/main.go b/cmd/api-server/main.go index 247525c..590ad1a 100644 --- a/cmd/api-server/main.go +++ b/cmd/api-server/main.go @@ -9,7 +9,7 @@ import ( ) func main() { - configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path") + configFile := flag.String("config", "configs/api/dev.yaml", "Configuration file path") apiKey := flag.String("api-key", "", "API key for authentication") flag.Parse() diff --git a/go.mod b/go.mod index ae94fd1..81bd5ef 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/jfraeys/fetch_ml go 1.25.0 // Fetch ML - Secure Machine Learning Platform -// Copyright (c) 2024 Fetch ML -// Licensed under the MIT License +// Copyright (c) 2026 Fetch ML +// Licensed under the FetchML Source-Available Research & Audit License (SARAL). See LICENSE. require ( github.com/BurntSushi/toml v1.5.0 @@ -17,6 +17,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.32 + github.com/minio/minio-go/v7 v7.0.97 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.17.2 github.com/stretchr/testify v1.11.1 @@ -43,24 +44,34 @@ require ( github.com/danieljoos/wincred v1.2.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-ini/ini v1.67.0 // indirect github.com/godbus/dbus/v5 v5.2.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.11 // indirect + github.com/klauspost/crc32 v1.3.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/minio/crc64nvme v1.1.0 // indirect + github.com/minio/md5-simd v1.1.2 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/philhofer/fwd v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rs/xid v1.6.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/tinylib/msgp v1.3.0 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect diff --git a/go.sum b/go.sum index cbf4015..8b9cfdf 100644 --- a/go.sum +++ b/go.sum @@ -48,10 +48,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= @@ -66,6 +70,11 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= +github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= +github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -84,6 +93,12 @@ github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byF github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q= +github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ= +github.com/minio/minio-go/v7 v7.0.97/go.mod h1:re5VXuo0pwEtoNLsNuSr0RrLfT/MBtohwdaSmPPSRSk= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= @@ -98,6 +113,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= @@ -114,6 +131,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -122,6 +141,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= +github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 4d55373..55bef74 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" @@ -34,7 +35,6 @@ func NewHandlers( // RegisterHandlers registers all HTTP handlers with the mux func (h *Handlers) RegisterHandlers(mux *http.ServeMux) { // Health check endpoints - mux.HandleFunc("/health", h.handleHealth) mux.HandleFunc("/db-status", h.handleDBStatus) // Jupyter service endpoints @@ -57,7 +57,7 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) { // This would need the DB instance passed to handlers // For now, return a basic response - response := map[string]interface{}{ + response := map[string]any{ "status": "unknown", "message": "Database status check not implemented", } @@ -72,13 +72,30 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) { // handleJupyterServices handles Jupyter service management requests func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) + return + } switch r.Method { case http.MethodGet: + if !user.HasPermission("jupyter:read") { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } h.listJupyterServices(w, r) case http.MethodPost: + if !user.HasPermission("jupyter:manage") { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } h.startJupyterService(w, r) case http.MethodDelete: + if !user.HasPermission("jupyter:manage") { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } h.stopJupyterService(w, r) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -140,7 +157,10 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]string{"status": "stopped", "id": serviceID}); err != nil { + if err := json.NewEncoder(w).Encode(map[string]string{ + "status": "stopped", + "id": serviceID, + }); err != nil { h.logger.Error("failed to encode response", "error", err) } } @@ -148,6 +168,15 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) { // handleJupyterExperimentLink handles linking Jupyter workspaces with experiments func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) + return + } + if !user.HasPermission("jupyter:manage") { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -176,7 +205,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re } // Link workspace with experiment using service manager - if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(req.Workspace, req.ExperimentID, req.ServiceID); err != nil { + if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment( + req.Workspace, + req.ExperimentID, + req.ServiceID, + ); err != nil { http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError) return } @@ -184,7 +217,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re // Get workspace metadata to return metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace) if err != nil { - http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError) + http.Error( + w, + fmt.Sprintf("Failed to get workspace metadata: %v", err), + http.StatusInternalServerError, + ) return } @@ -205,6 +242,15 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re // handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized) + return + } + if !user.HasPermission("jupyter:manage") { + http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden) + return + } if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -245,7 +291,11 @@ func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Re // Get updated metadata metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace) if err != nil { - http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError) + http.Error( + w, + fmt.Sprintf("Failed to get workspace metadata: %v", err), + http.StatusInternalServerError, + ) return } diff --git a/internal/api/health.go b/internal/api/health.go new file mode 100644 index 0000000..282a339 --- /dev/null +++ b/internal/api/health.go @@ -0,0 +1,82 @@ +package api + +import ( + "encoding/json" + "net/http" + "time" +) + +// HealthStatus represents the health status of the service +type HealthStatus struct { + Status string `json:"status"` + Timestamp time.Time `json:"timestamp"` + Checks map[string]string `json:"checks,omitempty"` +} + +// HealthHandler handles /health requests +type HealthHandler struct { + server *Server +} + +// NewHealthHandler creates a new health check handler +func NewHealthHandler(s *Server) *HealthHandler { + return &HealthHandler{server: s} +} + +// Health performs a basic health check +func (h *HealthHandler) Health(w http.ResponseWriter, r *http.Request) { + status := HealthStatus{ + Status: "healthy", + Timestamp: time.Now().UTC(), + Checks: make(map[string]string), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(status) +} + +// Liveness performs a liveness probe (is the service running?) +func (h *HealthHandler) Liveness(w http.ResponseWriter, r *http.Request) { + // Simple liveness check - if we can respond, we're alive + status := HealthStatus{ + Status: "alive", + Timestamp: time.Now().UTC(), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(status) +} + +// Readiness performs a readiness probe (is the service ready to accept traffic?) +func (h *HealthHandler) Readiness(w http.ResponseWriter, r *http.Request) { + status := HealthStatus{ + Status: "ready", + Timestamp: time.Now().UTC(), + Checks: make(map[string]string), + } + + // Check Redis connection (if queue is configured) + if h.server.taskQueue != nil { + // Simple check - if queue exists, assume it's ready + status.Checks["queue"] = "ok" + } + + // Check experiment manager + if h.server.expManager != nil { + status.Checks["experiments"] = "ok" + } + + // If all checks pass, we're ready + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(status) +} + +// RegisterHealthRoutes registers health check routes +func (h *HealthHandler) RegisterRoutes(mux *http.ServeMux) { + mux.HandleFunc("/health", h.Health) + mux.HandleFunc("/health/live", h.Liveness) + mux.HandleFunc("/health/ready", h.Readiness) +} diff --git a/internal/api/metrics_middleware.go b/internal/api/metrics_middleware.go new file mode 100644 index 0000000..3be08e6 --- /dev/null +++ b/internal/api/metrics_middleware.go @@ -0,0 +1,63 @@ +package api + +import ( + "bufio" + "fmt" + "net" + "net/http" + "time" +) + +// wrapWithMetrics wraps a handler with Prometheus metrics tracking +func (s *Server) wrapWithMetrics(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s.promMetrics == nil { + next.ServeHTTP(w, r) + return + } + + start := time.Now() + + // Track HTTP request + method := r.Method + endpoint := r.URL.Path + + // Wrap response writer to capture status code + ww := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // Serve request + next.ServeHTTP(ww, r) + + // Record metrics + duration := time.Since(start) + statusStr := http.StatusText(ww.statusCode) + + s.promMetrics.IncHTTPRequests(method, endpoint, statusStr) + s.promMetrics.ObserveHTTPDuration(method, endpoint, duration) + }) +} + +// responseWriter wraps http.ResponseWriter to capture status code +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *responseWriter) Flush() { + if f, ok := rw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("websocket: response does not implement http.Hijacker") + } + return h.Hijack() +} diff --git a/internal/api/monitoring_config.go b/internal/api/monitoring_config.go new file mode 100644 index 0000000..b334619 --- /dev/null +++ b/internal/api/monitoring_config.go @@ -0,0 +1,20 @@ +package api + +// MonitoringConfig holds monitoring-related configuration +type MonitoringConfig struct { + Prometheus PrometheusConfig `yaml:"prometheus"` + HealthChecks HealthChecksConfig `yaml:"health_checks"` +} + +// PrometheusConfig holds Prometheus metrics configuration +type PrometheusConfig struct { + Enabled bool `yaml:"enabled"` + Port int `yaml:"port"` + Path string `yaml:"path"` +} + +// HealthChecksConfig holds health check configuration +type HealthChecksConfig struct { + Enabled bool `yaml:"enabled"` + Interval string `yaml:"interval"` +} diff --git a/internal/api/protocol.go b/internal/api/protocol.go index 378808f..464e794 100644 --- a/internal/api/protocol.go +++ b/internal/api/protocol.go @@ -4,9 +4,17 @@ import ( "encoding/binary" "encoding/json" "fmt" + "sync" "time" ) +var bufferPool = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 0, 256) + return &buf + }, +} + // Response packet types const ( PacketTypeSuccess = 0x00 @@ -126,7 +134,12 @@ func NewErrorPacket(errorCode byte, message string, details string) *ResponsePac } // NewProgressPacket creates a progress response packet -func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket { +func NewProgressPacket( + progressType byte, + value uint32, + total uint32, + message string, +) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeProgress, Timestamp: uint64(time.Now().Unix()), @@ -168,48 +181,65 @@ func NewLogPacket(level byte, message string) *ResponsePacket { // Serialize converts the packet to binary format func (p *ResponsePacket) Serialize() ([]byte, error) { - var buf []byte + // For small packets, avoid pool overhead + if p.estimatedSize() <= 1024 { + buf := make([]byte, 0, p.estimatedSize()) + return serializePacketToBuffer(p, buf) + } + // Use pool for larger packets + bufPtr := bufferPool.Get().(*[]byte) + defer func() { + *bufPtr = (*bufPtr)[:0] + bufferPool.Put(bufPtr) + }() + buf := *bufPtr + + // Ensure buffer has enough capacity + if cap(buf) < p.estimatedSize() { + buf = make([]byte, 0, p.estimatedSize()) + } else { + buf = buf[:0] + } + + return serializePacketToBuffer(p, buf) +} + +func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) { // Packet type buf = append(buf, p.PacketType) // Timestamp (8 bytes, big-endian) - timestampBytes := make([]byte, 8) - binary.BigEndian.PutUint64(timestampBytes, p.Timestamp) - buf = append(buf, timestampBytes...) + var timestampBytes [8]byte + binary.BigEndian.PutUint64(timestampBytes[:], p.Timestamp) + buf = append(buf, timestampBytes[:]...) // Packet-specific data switch p.PacketType { case PacketTypeSuccess: - buf = append(buf, serializeString(p.SuccessMessage)...) + buf = appendString(buf, p.SuccessMessage) case PacketTypeError: buf = append(buf, p.ErrorCode) - buf = append(buf, serializeString(p.ErrorMessage)...) - buf = append(buf, serializeString(p.ErrorDetails)...) + buf = appendString(buf, p.ErrorMessage) + buf = appendString(buf, p.ErrorDetails) case PacketTypeProgress: buf = append(buf, p.ProgressType) - valueBytes := make([]byte, 4) - binary.BigEndian.PutUint32(valueBytes, p.ProgressValue) - buf = append(buf, valueBytes...) - - totalBytes := make([]byte, 4) - binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal) - buf = append(buf, totalBytes...) - - buf = append(buf, serializeString(p.ProgressMessage)...) + buf = appendUint32(buf, p.ProgressValue) + buf = appendUint32(buf, p.ProgressTotal) + buf = appendString(buf, p.ProgressMessage) case PacketTypeStatus: - buf = append(buf, serializeString(p.StatusData)...) + buf = appendString(buf, p.StatusData) case PacketTypeData: - buf = append(buf, serializeString(p.DataType)...) - buf = append(buf, serializeBytes(p.DataPayload)...) + buf = appendString(buf, p.DataType) + buf = appendBytes(buf, p.DataPayload) case PacketTypeLog: buf = append(buf, p.LogLevel) - buf = append(buf, serializeString(p.LogMessage)...) + buf = appendString(buf, p.LogMessage) default: return nil, fmt.Errorf("unknown packet type: %d", p.PacketType) @@ -218,22 +248,49 @@ func (p *ResponsePacket) Serialize() ([]byte, error) { return buf, nil } -// serializeString writes a string with 2-byte length prefix -func serializeString(s string) []byte { - length := uint16(len(s)) - buf := make([]byte, 2+len(s)) - binary.BigEndian.PutUint16(buf[:2], length) - copy(buf[2:], s) - return buf +// appendString writes a string with varint length prefix +func appendString(buf []byte, s string) []byte { + length := uint64(len(s)) + var tmp [binary.MaxVarintLen64]byte + n := binary.PutUvarint(tmp[:], length) + buf = append(buf, tmp[:n]...) + return append(buf, s...) } -// serializeBytes writes bytes with 4-byte length prefix -func serializeBytes(b []byte) []byte { - length := uint32(len(b)) - buf := make([]byte, 4+len(b)) - binary.BigEndian.PutUint32(buf[:4], length) - copy(buf[4:], b) - return buf +// appendBytes writes bytes with varint length prefix +func appendBytes(buf []byte, b []byte) []byte { + length := uint64(len(b)) + var tmp [binary.MaxVarintLen64]byte + n := binary.PutUvarint(tmp[:], length) + buf = append(buf, tmp[:n]...) + return append(buf, b...) +} + +func appendUint32(buf []byte, value uint32) []byte { + var tmp [4]byte + binary.BigEndian.PutUint32(tmp[:], value) + return append(buf, tmp[:]...) +} + +func (p *ResponsePacket) estimatedSize() int { + base := 1 + 8 // packet type + timestamp + + switch p.PacketType { + case PacketTypeSuccess: + return base + binary.MaxVarintLen64 + len(p.SuccessMessage) + case PacketTypeError: + return base + 1 + 2*binary.MaxVarintLen64 + len(p.ErrorMessage) + len(p.ErrorDetails) + case PacketTypeProgress: + return base + 1 + 4 + 4 + binary.MaxVarintLen64 + len(p.ProgressMessage) + case PacketTypeStatus: + return base + binary.MaxVarintLen64 + len(p.StatusData) + case PacketTypeData: + return base + binary.MaxVarintLen64 + len(p.DataType) + binary.MaxVarintLen64 + len(p.DataPayload) + case PacketTypeLog: + return base + 1 + binary.MaxVarintLen64 + len(p.LogMessage) + default: + return base + } } // GetErrorMessage returns a human-readable error message for an error code diff --git a/internal/api/protocol_simplified.go b/internal/api/protocol_simplified.go deleted file mode 100644 index 8e809c9..0000000 --- a/internal/api/protocol_simplified.go +++ /dev/null @@ -1,155 +0,0 @@ -package api - -import ( - "encoding/json" - "time" -) - -// Simplified protocol using JSON instead of binary serialization - -// Response represents a simplified API response -type Response struct { - Type string `json:"type"` - Timestamp int64 `json:"timestamp"` - Data interface{} `json:"data,omitempty"` - Error *ErrorInfo `json:"error,omitempty"` -} - -// ErrorInfo represents error information -type ErrorInfo struct { - Code int `json:"code"` - Message string `json:"message"` - Details string `json:"details,omitempty"` -} - -// ProgressInfo represents progress information -type ProgressInfo struct { - Type string `json:"type"` - Value uint32 `json:"value"` - Total uint32 `json:"total"` - Message string `json:"message"` -} - -// LogInfo represents log information -type LogInfo struct { - Level string `json:"level"` - Message string `json:"message"` -} - -// Response types -const ( - TypeSuccess = "success" - TypeError = "error" - TypeProgress = "progress" - TypeStatus = "status" - TypeData = "data" - TypeLog = "log" -) - -// Error codes -const ( - ErrUnknown = 0 - ErrInvalidRequest = 1 - ErrAuthFailed = 2 - ErrPermissionDenied = 3 - ErrNotFound = 4 - ErrExists = 5 - ErrServerOverload = 16 - ErrDatabase = 17 - ErrNetwork = 18 - ErrStorage = 19 - ErrTimeout = 20 -) - -// NewSuccessResponse creates a success response -func NewSuccessResponse(message string) *Response { - return &Response{ - Type: TypeSuccess, - Timestamp: time.Now().Unix(), - Data: message, - } -} - -// NewSuccessResponseWithData creates a success response with data -func NewSuccessResponseWithData(message string, data interface{}) *Response { - return &Response{ - Type: TypeData, - Timestamp: time.Now().Unix(), - Data: map[string]interface{}{ - "message": message, - "payload": data, - }, - } -} - -// NewErrorResponse creates an error response -func NewErrorResponse(code int, message, details string) *Response { - return &Response{ - Type: TypeError, - Timestamp: time.Now().Unix(), - Error: &ErrorInfo{ - Code: code, - Message: message, - Details: details, - }, - } -} - -// NewProgressResponse creates a progress response -func NewProgressResponse(progressType string, value, total uint32, message string) *Response { - return &Response{ - Type: TypeProgress, - Timestamp: time.Now().Unix(), - Data: ProgressInfo{ - Type: progressType, - Value: value, - Total: total, - Message: message, - }, - } -} - -// NewStatusResponse creates a status response -func NewStatusResponse(data string) *Response { - return &Response{ - Type: TypeStatus, - Timestamp: time.Now().Unix(), - Data: data, - } -} - -// NewDataResponse creates a data response -func NewDataResponse(dataType string, payload interface{}) *Response { - return &Response{ - Type: TypeData, - Timestamp: time.Now().Unix(), - Data: map[string]interface{}{ - "type": dataType, - "payload": payload, - }, - } -} - -// NewLogResponse creates a log response -func NewLogResponse(level, message string) *Response { - return &Response{ - Type: TypeLog, - Timestamp: time.Now().Unix(), - Data: LogInfo{ - Level: level, - Message: message, - }, - } -} - -// ToJSON converts the response to JSON bytes -func (r *Response) ToJSON() ([]byte, error) { - return json.Marshal(r) -} - -// FromJSON creates a response from JSON bytes -func FromJSON(data []byte) (*Response, error) { - var response Response - err := json.Unmarshal(data, &response) - return &response, err -} diff --git a/internal/api/server.go b/internal/api/server.go index 78c6aad..8fbdb6b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,13 +5,17 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/config" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/jupyter" "github.com/jfraeys/fetch_ml/internal/logging" "github.com/jfraeys/fetch_ml/internal/middleware" + "github.com/jfraeys/fetch_ml/internal/prommetrics" "github.com/jfraeys/fetch_ml/internal/queue" "github.com/jfraeys/fetch_ml/internal/storage" ) @@ -22,12 +26,14 @@ type Server struct { httpServer *http.Server logger *logging.Logger expManager *experiment.Manager - taskQueue *queue.TaskQueue + taskQueue queue.Backend db *storage.DB handlers *Handlers sec *middleware.SecurityMiddleware cleanupFuncs []func() jupyterServiceMgr *jupyter.ServiceManager + auditLogger *audit.Logger + promMetrics *prommetrics.Metrics // Prometheus metrics } // NewServer creates a new API server @@ -80,6 +86,11 @@ func (s *Server) initializeComponents() error { return err } + // Initialize database schema (if DB enabled) + if err := s.initDatabaseSchema(); err != nil { + return err + } + // Initialize security s.initSecurity() @@ -87,11 +98,29 @@ func (s *Server) initializeComponents() error { s.initJupyterServiceManager() // Initialize handlers - s.handlers = NewHandlers(s.expManager, s.jupyterServiceMgr, s.logger) + s.handlers = NewHandlers(s.expManager, nil, s.logger) return nil } +func (s *Server) initDatabaseSchema() error { + if s.db == nil { + return nil + } + + schema, err := storage.SchemaForDBType(s.config.Database.Type) + if err != nil { + return err + } + + if err := s.db.Initialize(schema); err != nil { + return err + } + + s.logger.Info("database schema initialized", "type", s.config.Database.Type) + return nil +} + // setupLogger creates and configures the logger func (s *Server) setupLogger() *logging.Logger { logger := logging.NewLoggerFromConfig(s.config.Logging) @@ -112,26 +141,38 @@ func (s *Server) initExperimentManager() error { // initTaskQueue initializes the task queue func (s *Server) initTaskQueue() error { - queueCfg := queue.Config{ - RedisAddr: s.config.Redis.Addr, - RedisPassword: s.config.Redis.Password, - RedisDB: s.config.Redis.DB, + backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend)) + if backend == "" { + backend = "redis" + } + redisAddr := strings.TrimSpace(s.config.Redis.Addr) + if redisAddr == "" { + redisAddr = "localhost:6379" + } + if strings.TrimSpace(s.config.Redis.URL) != "" { + redisAddr = strings.TrimSpace(s.config.Redis.URL) } - if queueCfg.RedisAddr == "" { - queueCfg.RedisAddr = "localhost:6379" - } - if s.config.Redis.URL != "" { - queueCfg.RedisAddr = s.config.Redis.URL + backendCfg := queue.BackendConfig{ + Backend: queue.QueueBackend(backend), + RedisAddr: redisAddr, + RedisPassword: s.config.Redis.Password, + RedisDB: s.config.Redis.DB, + SQLitePath: s.config.Queue.SQLitePath, + MetricsFlushInterval: 0, } - taskQueue, err := queue.NewTaskQueue(queueCfg) + taskQueue, err := queue.NewBackend(backendCfg) if err != nil { return err } s.taskQueue = taskQueue - s.logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr) + if backend == "sqlite" { + s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath) + } else { + s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr) + } // Add cleanup function s.cleanupFuncs = append(s.cleanupFuncs, func() { @@ -220,8 +261,67 @@ func (s *Server) initJupyterServiceManager() { func (s *Server) setupHTTPServer() { mux := http.NewServeMux() - // Register WebSocket handler - wsHandler := NewWSHandler(s.config.BuildAuthConfig(), s.logger, s.expManager, s.taskQueue) + // Initialize Prometheus metrics (if enabled) + if s.config.Monitoring.Prometheus.Enabled { + s.promMetrics = prommetrics.New() + s.logger.Info("prometheus metrics initialized") + + // Register metrics endpoint + metricsPath := s.config.Monitoring.Prometheus.Path + if metricsPath == "" { + metricsPath = "/metrics" + } + mux.Handle(metricsPath, s.promMetrics.Handler()) + s.logger.Info("metrics endpoint registered", "path", metricsPath) + } + + // Initialize health check handler + if s.config.Monitoring.HealthChecks.Enabled { + healthHandler := NewHealthHandler(s) + healthHandler.RegisterRoutes(mux) + mux.HandleFunc("/health/ok", s.handlers.handleHealth) + s.logger.Info("health check endpoints registered") + } + + // Initialize audit logger + var auditLogger *audit.Logger + if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" { + al, err := audit.NewLogger( + s.config.Security.AuditLogging.Enabled, + s.config.Security.AuditLogging.LogPath, + s.logger, + ) + if err != nil { + s.logger.Warn("failed to initialize audit logger", "error", err) + } else { + auditLogger = al + s.auditLogger = al + + // Add cleanup function + s.cleanupFuncs = append(s.cleanupFuncs, func() { + s.logger.Info("closing audit logger...") + if err := auditLogger.Close(); err != nil { + s.logger.Error("failed to close audit logger", "error", err) + } + }) + } + } + + // Register WebSocket handler with security config and audit logger + securityCfg := getSecurityConfig(s.config) + wsHandler := NewWSHandler( + s.config.BuildAuthConfig(), + s.logger, + s.expManager, + s.config.DataDir, + s.taskQueue, + s.db, + s.jupyterServiceMgr, + securityCfg, + auditLogger, + ) + + // Wrap WebSocket handler with metrics mux.Handle("/ws", wsHandler) // Register HTTP handlers @@ -229,6 +329,7 @@ func (s *Server) setupHTTPServer() { // Wrap with middleware finalHandler := s.wrapWithMiddleware(mux) + finalHandler = s.wrapWithMetrics(finalHandler) s.httpServer = &http.Server{ Addr: s.config.Server.Address, @@ -239,10 +340,25 @@ func (s *Server) setupHTTPServer() { } } +// getSecurityConfig extracts security config from server config +func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig { + return &config.SecurityConfig{ + AllowedOrigins: cfg.Security.AllowedOrigins, + ProductionMode: cfg.Security.ProductionMode, + APIKeyRotationDays: cfg.Security.APIKeyRotationDays, + AuditLogging: config.AuditLoggingConfig{ + Enabled: cfg.Security.AuditLogging.Enabled, + LogPath: cfg.Security.AuditLogging.LogPath, + }, + IPWhitelist: cfg.Security.IPWhitelist, + } +} + // wrapWithMiddleware wraps the handler with security middleware func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/ws" { + // Skip auth for WebSocket and health endpoints + if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") { mux.ServeHTTP(w, r) return } @@ -250,7 +366,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { handler := s.sec.APIKeyAuth(mux) handler = s.sec.RateLimit(handler) handler = middleware.SecurityHeaders(handler) - handler = middleware.CORS(handler) + handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) handler = middleware.RequestTimeout(30 * time.Second)(handler) handler = middleware.AuditLogger(handler) if len(s.config.Security.IPWhitelist) > 0 { diff --git a/internal/api/server_config.go b/internal/api/server_config.go index 056aa1a..3879b83 100644 --- a/internal/api/server_config.go +++ b/internal/api/server_config.go @@ -1,26 +1,37 @@ package api import ( + "fmt" "log" "os" "path/filepath" + "strings" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/jfraeys/fetch_ml/internal/logging" "gopkg.in/yaml.v3" ) +type QueueConfig struct { + Backend string `yaml:"backend"` + SQLitePath string `yaml:"sqlite_path"` +} + // ServerConfig holds all server configuration type ServerConfig struct { - BasePath string `yaml:"base_path"` - Auth auth.Config `yaml:"auth"` - Server ServerSection `yaml:"server"` - Security SecurityConfig `yaml:"security"` - Redis RedisConfig `yaml:"redis"` - Database DatabaseConfig `yaml:"database"` - Logging logging.Config `yaml:"logging"` - Resources config.ResourceConfig `yaml:"resources"` + BasePath string `yaml:"base_path"` + DataDir string `yaml:"data_dir"` + Auth auth.Config `yaml:"auth"` + Server ServerSection `yaml:"server"` + Security SecurityConfig `yaml:"security"` + Monitoring MonitoringConfig `yaml:"monitoring"` + Queue QueueConfig `yaml:"queue"` + Redis RedisConfig `yaml:"redis"` + Database DatabaseConfig `yaml:"database"` + Logging logging.Config `yaml:"logging"` + Resources config.ResourceConfig `yaml:"resources"` } // ServerSection holds server-specific configuration @@ -38,9 +49,19 @@ type TLSConfig struct { // SecurityConfig holds security-related configuration type SecurityConfig struct { - RateLimit RateLimitConfig `yaml:"rate_limit"` - IPWhitelist []string `yaml:"ip_whitelist"` - FailedLockout LockoutConfig `yaml:"failed_login_lockout"` + ProductionMode bool `yaml:"production_mode"` + AllowedOrigins []string `yaml:"allowed_origins"` + APIKeyRotationDays int `yaml:"api_key_rotation_days"` + AuditLogging AuditLog `yaml:"audit_logging"` + RateLimit RateLimitConfig `yaml:"rate_limit"` + IPWhitelist []string `yaml:"ip_whitelist"` + FailedLockout LockoutConfig `yaml:"failed_login_lockout"` +} + +// AuditLog holds audit logging configuration +type AuditLog struct { + Enabled bool `yaml:"enabled"` + LogPath string `yaml:"log_path"` } // RateLimitConfig holds rate limiting configuration @@ -108,9 +129,7 @@ func loadConfigFromFile(path string) (*ServerConfig, error) { // secureFileRead safely reads a file func secureFileRead(path string) ([]byte, error) { - // This would use the fileutil.SecureFileRead function - // For now, implement basic file reading - return os.ReadFile(path) + return fileutil.SecureFileRead(path) } // EnsureLogDirectory creates the log directory if needed @@ -144,6 +163,27 @@ func (c *ServerConfig) Validate() error { if c.BasePath == "" { c.BasePath = "/tmp/ml-experiments" } + if c.DataDir == "" { + c.DataDir = config.DefaultDataDir + } + + backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend)) + if backend == "" { + backend = "redis" + c.Queue.Backend = backend + } + if backend != "redis" && backend != "sqlite" { + return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'") + } + if backend == "sqlite" { + if strings.TrimSpace(c.Queue.SQLitePath) == "" { + c.Queue.SQLitePath = filepath.Join(c.DataDir, "queue.db") + } + c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath) + if !filepath.IsAbs(c.Queue.SQLitePath) { + c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath) + } + } return nil } diff --git a/internal/api/ws.go b/internal/api/ws.go deleted file mode 100644 index 72edff2..0000000 --- a/internal/api/ws.go +++ /dev/null @@ -1,652 +0,0 @@ -package api - -import ( - "crypto/tls" - "encoding/binary" - "encoding/json" - "fmt" - "math" - "net/http" - "net/url" - "strings" - "time" - - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/jfraeys/fetch_ml/internal/auth" - "github.com/jfraeys/fetch_ml/internal/experiment" - "github.com/jfraeys/fetch_ml/internal/logging" - "github.com/jfraeys/fetch_ml/internal/queue" - "github.com/jfraeys/fetch_ml/internal/telemetry" - "golang.org/x/crypto/acme/autocert" -) - -// Opcodes for binary WebSocket protocol -const ( - OpcodeQueueJob = 0x01 - OpcodeStatusRequest = 0x02 - OpcodeCancelJob = 0x03 - OpcodePrune = 0x04 - OpcodeLogMetric = 0x0A - OpcodeGetExperiment = 0x0B -) - -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - // Allow localhost and homelab origins for development - origin := r.Header.Get("Origin") - if origin == "" { - return true // Allow same-origin requests - } - - // Parse origin URL - parsedOrigin, err := url.Parse(origin) - if err != nil { - return false - } - - // Allow localhost and local network origins - host := parsedOrigin.Host - return strings.HasSuffix(host, ":8080") || - strings.HasPrefix(host, "localhost:") || - strings.HasPrefix(host, "127.0.0.1:") || - strings.HasPrefix(host, "192.168.") || - strings.HasPrefix(host, "10.") || - strings.HasPrefix(host, "172.") - }, - // Performance optimizations - HandshakeTimeout: 10 * time.Second, - ReadBufferSize: 4096, - WriteBufferSize: 4096, - EnableCompression: true, -} - -// WSHandler handles WebSocket connections for the API. -type WSHandler struct { - authConfig *auth.Config - logger *logging.Logger - expManager *experiment.Manager - queue *queue.TaskQueue -} - -// NewWSHandler creates a new WebSocket handler. -func NewWSHandler( - authConfig *auth.Config, - logger *logging.Logger, - expManager *experiment.Manager, - taskQueue *queue.TaskQueue, -) *WSHandler { - return &WSHandler{ - authConfig: authConfig, - logger: logger, - expManager: expManager, - queue: taskQueue, - } -} - -func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Check API key before upgrading WebSocket - apiKey := auth.ExtractAPIKeyFromRequest(r) - - // Validate API key if authentication is enabled - if h.authConfig != nil && h.authConfig.Enabled { - prefixLen := len(apiKey) - if prefixLen > 8 { - prefixLen = 8 - } - h.logger.Info("websocket auth attempt", "api_key_length", len(apiKey), "api_key_prefix", apiKey[:prefixLen]) - if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil { - h.logger.Warn("websocket authentication failed", "error", err) - http.Error(w, "Invalid API key", http.StatusUnauthorized) - return - } - h.logger.Info("websocket authentication succeeded") - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - h.logger.Error("websocket upgrade failed", "error", err) - return - } - defer func() { - _ = conn.Close() - }() - - h.logger.Info("websocket connection established", "remote", r.RemoteAddr) - - for { - messageType, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - h.logger.Error("websocket read error", "error", err) - } - break - } - - if messageType != websocket.BinaryMessage { - h.logger.Warn("received non-binary message") - continue - } - - if err := h.handleMessage(conn, message); err != nil { - h.logger.Error("message handling error", "error", err) - // Send error response - _ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode - } - } -} - -func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { - if len(message) < 1 { - return fmt.Errorf("message too short") - } - - opcode := message[0] - payload := message[1:] - - switch opcode { - case OpcodeQueueJob: - return h.handleQueueJob(conn, payload) - case OpcodeStatusRequest: - return h.handleStatusRequest(conn, payload) - case OpcodeCancelJob: - return h.handleCancelJob(conn, payload) - case OpcodePrune: - return h.handlePrune(conn, payload) - case OpcodeLogMetric: - return h.handleLogMetric(conn, payload) - case OpcodeGetExperiment: - return h.handleGetExperiment(conn, payload) - default: - return fmt.Errorf("unknown opcode: 0x%02x", opcode) - } -} - -func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var] - if len(payload) < 130 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") - } - - apiKeyHash := string(payload[:64]) - commitID := string(payload[64:128]) - priority := int64(payload[128]) - jobNameLen := int(payload[129]) - - if len(payload) < 130+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - - jobName := string(payload[130 : 130+jobNameLen]) - - h.logger.Info("queue job request", - "job", jobName, - "priority", priority, - "commit_id", commitID, - ) - - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKey(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } - } - - // Check user permissions - if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { - h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) - } else { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "") - } - - // Create experiment directory and metadata (optimized) - if _, err := telemetry.ExecWithMetrics(h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.CreateExperiment(commitID) - }); err != nil { - h.logger.Error("failed to create experiment directory", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error()) - } - - // Add user info to experiment metadata (deferred for performance) - go func() { - meta := &experiment.Metadata{ - CommitID: commitID, - JobName: jobName, - User: user.Name, - Timestamp: time.Now().Unix(), - } - if _, err := telemetry.ExecWithMetrics( - h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { - return "", h.expManager.WriteMetadata(meta) - }); err != nil { - h.logger.Error("failed to save experiment metadata", "error", err) - } - }() - - h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name) - - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) - - // Enqueue task if queue is available - if h.queue != nil { - taskID := uuid.New().String() - task := &queue.Task{ - ID: taskID, - JobName: jobName, - Args: "", - Status: "queued", - Priority: priority, - CreatedAt: time.Now(), - UserID: user.Name, - Username: user.Name, - CreatedBy: user.Name, - Metadata: map[string]string{ - "commit_id": commitID, // Reduced redundant metadata - }, - } - - if _, err := telemetry.ExecWithMetrics(h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) { - return "", h.queue.AddTask(task) - }); err != nil { - h.logger.Error("failed to enqueue task", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) - } - h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) - } else { - h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) - } - - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) -} - -func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64] - if len(payload) < 64 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") - } - - apiKeyHash := string(payload[0:64]) - h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...") - - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKey(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } - } - - // Check user permissions for viewing jobs - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "") - } - // Get tasks with user filtering - var tasks []*queue.Task - if h.queue != nil { - allTasks, err := h.queue.GetAllTasks() - if err != nil { - h.logger.Error("failed to get tasks", "error", err) - return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) - } - - // Filter tasks based on user permissions - for _, task := range allTasks { - // If auth is disabled or admin can see all tasks - if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { - tasks = append(tasks, task) - continue - } - - // Users can only see their own tasks - if task.UserID == user.Name || task.CreatedBy == user.Name { - tasks = append(tasks, task) - } - } - } - - // Build status response as raw JSON for CLI compatibility - h.logger.Info("building status response") - status := map[string]interface{}{ - "user": map[string]interface{}{ - "name": user.Name, - "admin": user.Admin, - "roles": user.Roles, - }, - "tasks": map[string]interface{}{ - "total": len(tasks), - "queued": countTasksByStatus(tasks, "queued"), - "running": countTasksByStatus(tasks, "running"), - "failed": countTasksByStatus(tasks, "failed"), - "completed": countTasksByStatus(tasks, "completed"), - }, - "queue": tasks, - } - - h.logger.Info("serializing JSON response") - jsonData, err := json.Marshal(status) - if err != nil { - h.logger.Error("failed to marshal JSON", "error", err) - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") - } - - h.logger.Info("sending websocket JSON response", "len", len(jsonData)) - return conn.WriteMessage(websocket.BinaryMessage, jsonData) -} - -// countTasksByStatus counts tasks by their status -func countTasksByStatus(tasks []*queue.Task, status string) int { - count := 0 - for _, task := range tasks { - if task.Status == status { - count++ - } - } - return count -} - -func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64][job_name_len:1][job_name:var] - if len(payload) < 65 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") - } - - // Parse 64-byte hex API key hash - apiKeyHash := string(payload[0:64]) - jobNameLen := int(payload[64]) - - if len(payload) < 65+jobNameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") - } - - jobName := string(payload[65 : 65+jobNameLen]) - - h.logger.Info("cancel job request", "job", jobName) - - // Validate API key and get user information - var user *auth.User - var err error - if h.authConfig != nil { - user, err = h.authConfig.ValidateAPIKey(apiKeyHash) - if err != nil { - h.logger.Error("invalid api key", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) - } - } else { - // Auth disabled - use default admin user - user = &auth.User{ - Name: "default", - Admin: true, - Roles: []string{"admin"}, - Permissions: map[string]bool{ - "*": true, - }, - } - } - - // Check user permissions for canceling jobs - if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { - h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "") - } - // Find the task and verify ownership - if h.queue != nil { - task, err := h.queue.GetTaskByName(jobName) - if err != nil { - h.logger.Error("task not found", "job", jobName, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) - } - - // Check if user can cancel this task (admin or owner) - if h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name { - h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID) - return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "") - } - // Cancel the task - if err := h.queue.CancelTask(task.ID); err != nil { - h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) - return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) - } - - h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) - } else { - h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) - } - - packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) - packetData, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize packet", "error", err) - return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response") - } - return conn.WriteMessage(websocket.BinaryMessage, packetData) -} - -func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64][prune_type:1][value:4] - if len(payload) < 69 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") - } - - // Parse 64-byte hex API key hash - apiKeyHash := string(payload[0:64]) - pruneType := payload[64] - value := binary.BigEndian.Uint32(payload[65:69]) - - h.logger.Info("prune request", "type", pruneType, "value", value) - - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - h.logger.Error("api key verification failed", "error", err) - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) - } - } - - // Convert prune parameters - var keepCount int - var olderThanDays int - - switch pruneType { - case 0: - // keep N - keepCount = int(value) - olderThanDays = 0 - case 1: - // older than days - keepCount = 0 - olderThanDays = int(value) - default: - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "") - } - - // Perform pruning - pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) - if err != nil { - h.logger.Error("prune failed", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) - } - - h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) - - // Send structured success response - packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) - return h.sendResponsePacket(conn, packet) -} - -func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var] - if len(payload) < 141 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") - } - - apiKeyHash := string(payload[:64]) - commitID := string(payload[64:128]) - step := int(binary.BigEndian.Uint32(payload[128:132])) - valueBits := binary.BigEndian.Uint64(payload[132:140]) - value := math.Float64frombits(valueBits) - nameLen := int(payload[140]) - - if len(payload) < 141+nameLen { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") - } - - name := string(payload[141 : 141+nameLen]) - - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) - } - } - - if err := h.expManager.LogMetric(commitID, name, value, step); err != nil { - h.logger.Error("failed to log metric", "error", err) - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) - } - - return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) -} - -func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { - // Protocol: [api_key_hash:64][commit_id:64] - if len(payload) < 128 { - return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") - } - - apiKeyHash := string(payload[:64]) - commitID := string(payload[64:128]) - - // Verify API key - if h.authConfig != nil && h.authConfig.Enabled { - if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { - return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error()) - } - } - - meta, err := h.expManager.ReadMetadata(commitID) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) - } - - metrics, err := h.expManager.GetMetrics(commitID) - if err != nil { - return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) - } - - response := map[string]interface{}{ - "metadata": meta, - "metrics": metrics, - } - - return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response)) -} - -// SetupTLSConfig creates TLS configuration for WebSocket server -func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) { - var server *http.Server - - if certFile != "" && keyFile != "" { - // Use provided certificates - server = &http.Server{ - ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - }, - }, - } - } else if host != "" { - // Use Let's Encrypt with autocert - certManager := &autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(host), - Cache: autocert.DirCache("/var/www/.cache"), - } - - server = &http.Server{ - ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks - TLSConfig: certManager.TLSConfig(), - } - } - - return server, nil -} - -// verifyAPIKeyHash verifies the provided hex hash against stored API keys -func (h *WSHandler) verifyAPIKeyHash(hexHash string) error { - if h.authConfig == nil || !h.authConfig.Enabled { - return nil // No auth required - } - - // For now, just check if it's a valid 64-char hex string - if len(hexHash) != 64 { - return fmt.Errorf("invalid api key hash length") - } - - // Check against stored API keys - for username, entry := range h.authConfig.APIKeys { - if string(entry.Hash) == hexHash { - _ = username // Username found but not needed for verification - return nil // Valid API key found - } - } - - return fmt.Errorf("invalid api key") -} - -// sendErrorPacket sends an error response packet -func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error { - packet := NewErrorPacket(errorCode, message, details) - return h.sendResponsePacket(conn, packet) -} - -// sendResponsePacket sends a structured response packet -func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error { - data, err := packet.Serialize() - if err != nil { - h.logger.Error("failed to serialize response packet", "error", err) - // Fallback to simple error response - return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) - } - - return conn.WriteMessage(websocket.BinaryMessage, data) -} - -// sendErrorResponse removed (unused) diff --git a/internal/api/ws_datasets.go b/internal/api/ws_datasets.go new file mode 100644 index 0000000..db2e50e --- /dev/null +++ b/internal/api/ws_datasets.go @@ -0,0 +1,208 @@ +package api + +import ( + "context" + "database/sql" + "encoding/binary" + "encoding/json" + "net/url" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16] + if len(payload) < 16 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "") + } + apiKeyHash := payload[:16] + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + if h.db == nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + datasets, err := h.db.ListDatasets(ctx, 0) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error()) + } + + data, err := json.Marshal(datasets) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Failed to serialize response", + err.Error(), + ) + } + return h.sendResponsePacket(conn, NewDataPacket("datasets", data)) +} + +func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var] + if len(payload) < 16+1+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "") + } + apiKeyHash := payload[:16] + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + if h.db == nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + } + + offset := 16 + nameLen := int(payload[offset]) + offset++ + if nameLen <= 0 || len(payload) < offset+nameLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "") + } + name := string(payload[offset : offset+nameLen]) + offset += nameLen + + urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if urlLen <= 0 || len(payload) < offset+urlLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "") + } + urlStr := string(payload[offset : offset+urlLen]) + + // Minimal validation (server-side authoritative): name non-empty and url parseable. + if strings.TrimSpace(name) == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "") + } + if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error()) + } + return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered")) +} + +func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][name_len:1][name:var] + if len(payload) < 16+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "") + } + apiKeyHash := payload[:16] + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + if h.db == nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + } + + offset := 16 + nameLen := int(payload[offset]) + offset++ + if nameLen <= 0 || len(payload) < offset+nameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "") + } + name := string(payload[offset : offset+nameLen]) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + ds, err := h.db.GetDataset(ctx, name) + if err != nil { + if err == sql.ErrNoRows { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "") + } + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error()) + } + + data, err := json.Marshal(ds) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Failed to serialize response", + err.Error(), + ) + } + return h.sendResponsePacket(conn, NewDataPacket("dataset", data)) +} + +func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][term_len:1][term:var] + if len(payload) < 16+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "") + } + apiKeyHash := payload[:16] + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + if h.db == nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "") + } + + offset := 16 + termLen := int(payload[offset]) + offset++ + if termLen < 0 || len(payload) < offset+termLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "") + } + term := string(payload[offset : offset+termLen]) + term = strings.TrimSpace(term) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + datasets, err := h.db.SearchDatasets(ctx, term, 0) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error()) + } + + data, err := json.Marshal(datasets) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Failed to serialize response", + err.Error(), + ) + } + return h.sendResponsePacket(conn, NewDataPacket("datasets", data)) +} diff --git a/internal/api/ws_handler.go b/internal/api/ws_handler.go new file mode 100644 index 0000000..0df5d39 --- /dev/null +++ b/internal/api/ws_handler.go @@ -0,0 +1,279 @@ +package api + +import ( + "compress/flate" + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/audit" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/jupyter" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// Opcodes for binary WebSocket protocol +const ( + OpcodeQueueJob = 0x01 + OpcodeStatusRequest = 0x02 + OpcodeCancelJob = 0x03 + OpcodePrune = 0x04 + OpcodeDatasetList = 0x06 + OpcodeDatasetRegister = 0x07 + OpcodeDatasetInfo = 0x08 + OpcodeDatasetSearch = 0x09 + OpcodeLogMetric = 0x0A + OpcodeGetExperiment = 0x0B + OpcodeQueueJobWithTracking = 0x0C + OpcodeQueueJobWithSnapshot = 0x17 + OpcodeStartJupyter = 0x0D + OpcodeStopJupyter = 0x0E + OpcodeRemoveJupyter = 0x18 + OpcodeRestoreJupyter = 0x19 + OpcodeListJupyter = 0x0F + OpcodeValidateRequest = 0x16 +) + +// createUpgrader creates a WebSocket upgrader with the given security configuration +func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader { + return websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true // Allow same-origin requests + } + + // Production mode: strict checking against allowed origins + if securityCfg != nil && securityCfg.ProductionMode { + for _, allowed := range securityCfg.AllowedOrigins { + if origin == allowed { + return true + } + } + return false // Reject if not in allowed list + } + + // Development mode: allow localhost and local network origins + parsedOrigin, err := url.Parse(origin) + if err != nil { + return false + } + + host := parsedOrigin.Host + return strings.HasSuffix(host, ":8080") || + strings.HasPrefix(host, "localhost:") || + strings.HasPrefix(host, "127.0.0.1:") || + strings.HasPrefix(host, "192.168.") || + strings.HasPrefix(host, "10.") || + strings.HasPrefix(host, "172.") + }, + // Performance optimizations + HandshakeTimeout: 10 * time.Second, + ReadBufferSize: 16 * 1024, + WriteBufferSize: 16 * 1024, + EnableCompression: true, + } +} + +// WSHandler handles WebSocket connections for the API. +type WSHandler struct { + authConfig *auth.Config + logger *logging.Logger + expManager *experiment.Manager + dataDir string + queue queue.Backend + db *storage.DB + jupyterServiceMgr *jupyter.ServiceManager + securityConfig *config.SecurityConfig + auditLogger *audit.Logger + upgrader websocket.Upgrader +} + +// NewWSHandler creates a new WebSocket handler. +func NewWSHandler( + authConfig *auth.Config, + logger *logging.Logger, + expManager *experiment.Manager, + dataDir string, + taskQueue queue.Backend, + db *storage.DB, + jupyterServiceMgr *jupyter.ServiceManager, + securityConfig *config.SecurityConfig, + auditLogger *audit.Logger, +) *WSHandler { + return &WSHandler{ + authConfig: authConfig, + logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"), + expManager: expManager, + dataDir: dataDir, + queue: taskQueue, + db: db, + jupyterServiceMgr: jupyterServiceMgr, + securityConfig: securityConfig, + auditLogger: auditLogger, + upgrader: createUpgrader(securityConfig), + } +} + +// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets. +func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) { + if conn == nil { + return + } + + if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok { + if err := tcpConn.SetNoDelay(true); err != nil { + logger.Warn("failed to enable tcp no delay", "error", err) + } + } +} + +func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Add security headers + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-XSS-Protection", "1; mode=block") + if r.TLS != nil { + // Only set HSTS if using HTTPS + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + } + + // Check API key before upgrading WebSocket + apiKey := auth.ExtractAPIKeyFromRequest(r) + clientIP := r.RemoteAddr + + // Validate API key if authentication is enabled + if h.authConfig != nil && h.authConfig.Enabled { + prefixLen := len(apiKey) + if prefixLen > 8 { + prefixLen = 8 + } + h.logger.Info( + "websocket auth attempt", + "api_key_length", + len(apiKey), + "api_key_prefix", + apiKey[:prefixLen], + ) + + userID, err := h.authConfig.ValidateAPIKey(apiKey) + if err != nil { + h.logger.Warn("websocket authentication failed", "error", err) + + // Audit log failed authentication + if h.auditLogger != nil { + h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error()) + } + + http.Error(w, "Invalid API key", http.StatusUnauthorized) + return + } + h.logger.Info("websocket authentication succeeded") + + // Audit log successful authentication + if h.auditLogger != nil && userID != nil { + h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "") + } + } + + conn, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Error("websocket upgrade failed", "error", err) + return + } + conn.EnableWriteCompression(true) + if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil { + h.logger.Warn("failed to set websocket compression level", "error", err) + } + enableLowLatencyTCP(conn, h.logger) + defer func() { + _ = conn.Close() + }() + + h.logger.Info("websocket connection established", "remote", r.RemoteAddr) + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + h.logger.Error("websocket read error", "error", err) + } + break + } + + if messageType != websocket.BinaryMessage { + h.logger.Warn("received non-binary message") + continue + } + + if err := h.handleMessage(conn, message); err != nil { + h.logger.Error("message handling error", "error", err) + // Send structured error response so CLI clients can parse it. + // (Raw fallback bytes cause client-side InvalidPacket errors.) + _ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error()) + } + } +} + +func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error { + if len(message) < 1 { + return fmt.Errorf("message too short") + } + + opcode := message[0] + payload := message[1:] + + switch opcode { + case OpcodeQueueJob: + return h.handleQueueJob(conn, payload) + case OpcodeQueueJobWithTracking: + return h.handleQueueJobWithTracking(conn, payload) + case OpcodeQueueJobWithSnapshot: + return h.handleQueueJobWithSnapshot(conn, payload) + case OpcodeStatusRequest: + return h.handleStatusRequest(conn, payload) + case OpcodeCancelJob: + return h.handleCancelJob(conn, payload) + case OpcodePrune: + return h.handlePrune(conn, payload) + case OpcodeDatasetList: + return h.handleDatasetList(conn, payload) + case OpcodeDatasetRegister: + return h.handleDatasetRegister(conn, payload) + case OpcodeDatasetInfo: + return h.handleDatasetInfo(conn, payload) + case OpcodeDatasetSearch: + return h.handleDatasetSearch(conn, payload) + case OpcodeLogMetric: + return h.handleLogMetric(conn, payload) + case OpcodeGetExperiment: + return h.handleGetExperiment(conn, payload) + case OpcodeStartJupyter: + return h.handleStartJupyter(conn, payload) + case OpcodeStopJupyter: + return h.handleStopJupyter(conn, payload) + case OpcodeRemoveJupyter: + return h.handleRemoveJupyter(conn, payload) + case OpcodeRestoreJupyter: + return h.handleRestoreJupyter(conn, payload) + case OpcodeListJupyter: + return h.handleListJupyter(conn, payload) + case OpcodeValidateRequest: + return h.handleValidateRequest(conn, payload) + default: + return fmt.Errorf("unknown opcode: 0x%02x", opcode) + } +} diff --git a/internal/api/ws_jobs.go b/internal/api/ws_jobs.go new file mode 100644 index 0000000..defe033 --- /dev/null +++ b/internal/api/ws_jobs.go @@ -0,0 +1,1268 @@ +package api + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "math" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/experiment" + "github.com/jfraeys/fetch_ml/internal/fileutil" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/storage" + "github.com/jfraeys/fetch_ml/internal/telemetry" + "github.com/jfraeys/fetch_ml/internal/worker" +) + +func fileSHA256Hex(path string) (string, error) { + f, err := os.Open(filepath.Clean(path)) + if err != nil { + return "", err + } + defer func() { _ = f.Close() }() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +func expectedProvenanceForCommit( + expMgr *experiment.Manager, + commitID string, +) (map[string]string, error) { + out := map[string]string{} + manifest, err := expMgr.ReadManifest(commitID) + if err != nil { + return nil, err + } + if manifest == nil || manifest.OverallSHA == "" { + return nil, fmt.Errorf("missing manifest overall_sha") + } + out["experiment_manifest_overall_sha"] = manifest.OverallSHA + + filesPath := expMgr.GetFilesPath(commitID) + depName, err := worker.SelectDependencyManifest(filesPath) + if err == nil && strings.TrimSpace(depName) != "" { + depPath := filepath.Join(filesPath, depName) + sha, err := fileSHA256Hex(depPath) + if err == nil && strings.TrimSpace(sha) != "" { + out["deps_manifest_name"] = depName + out["deps_manifest_sha256"] = sha + } + } + return out, nil +} + +func ensureMinimalExperimentFiles(expMgr *experiment.Manager, commitID string) error { + if expMgr == nil { + return fmt.Errorf("missing experiment manager") + } + commitID = strings.TrimSpace(commitID) + if commitID == "" { + return fmt.Errorf("missing commit id") + } + filesPath := expMgr.GetFilesPath(commitID) + if err := os.MkdirAll(filesPath, 0750); err != nil { + return err + } + + trainPath := filepath.Join(filesPath, "train.py") + if _, err := os.Stat(trainPath); os.IsNotExist(err) { + if err := fileutil.SecureFileWrite(trainPath, []byte("print('ok')\n"), 0640); err != nil { + return err + } + } + + reqPath := filepath.Join(filesPath, "requirements.txt") + if _, err := os.Stat(reqPath); os.IsNotExist(err) { + if err := fileutil.SecureFileWrite(reqPath, []byte("numpy==1.0.0\n"), 0640); err != nil { + return err + } + } + + return nil +} + +func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] + if len(payload) < 38 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "") + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + priority := int64(payload[36]) + jobNameLen := int(payload[37]) + + if len(payload) < 38+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + + jobName := string(payload[38 : 38+jobNameLen]) + + resources, resErr := parseOptionalResourceRequest(payload[38+jobNameLen:]) + if resErr != nil { + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "invalid resource request", + resErr.Error(), + ) + } + + h.logger.Info("queue job request", + "job", jobName, + "priority", priority, + "commit_id", fmt.Sprintf("%x", commitID), + ) + + // Validate API key and get user information + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + // Auth disabled - use default admin user + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + // Check user permissions + if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info( + "job queued", + "job", jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), + "user", user.Name, + ) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to create jobs", + "", + ) + } + + // Create experiment directory and metadata (optimized) + if _, err := telemetry.ExecWithMetrics( + h.logger, + "experiment.create", + 50*time.Millisecond, + func() (string, error) { + return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) + }, + ); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to create experiment directory", + err.Error(), + ) + } + + meta := &experiment.Metadata{ + CommitID: fmt.Sprintf("%x", commitID), + JobName: jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { + return "", h.expManager.WriteMetadata(meta) + }); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to save experiment metadata", + err.Error(), + ) + } + + // Generate and write content integrity manifest + commitIDStr := fmt.Sprintf("%x", commitID) + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { + return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) + }); err != nil { + h.logger.Error("failed to ensure minimal experiment files", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to initialize experiment files", + err.Error(), + ) + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { + manifest, err := h.expManager.GenerateManifest(commitIDStr) + if err != nil { + return "", fmt.Errorf("failed to generate manifest: %w", err) + } + if err := h.expManager.WriteManifest(manifest); err != nil { + return "", fmt.Errorf("failed to write manifest: %w", err) + } + return "", nil + }); err != nil { + h.logger.Error("failed to generate/write manifest", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to generate content integrity manifest", + err.Error(), + ) + } + + // Add user info to experiment metadata (deferred for performance) + go func() { + if h.db != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + exp := &storage.Experiment{ + ID: fmt.Sprintf("%x", commitID), + Name: jobName, + Status: "pending", + UserID: user.Name, + } + if _, err := telemetry.ExecWithMetrics( + h.logger, + "db.experiments.upsert", + 50*time.Millisecond, + func() (string, error) { + return "", h.db.UpsertExperiment(ctx, exp) + }, + ); err != nil { + h.logger.Error("failed to upsert experiment row", "error", err) + } + } + + }() + + h.logger.Info( + "job queued", + "job", jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), + "user", user.Name, + ) + + return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, nil, resources) +} + +func (h *WSHandler) handleQueueJobWithSnapshot(conn *websocket.Conn, payload []byte) error { + if len(payload) < 40 { + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "queue job with snapshot payload too short", + "", + ) + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + priority := int64(payload[36]) + jobNameLen := int(payload[37]) + + if len(payload) < 38+jobNameLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + jobName := string(payload[38 : 38+jobNameLen]) + offset := 38 + jobNameLen + + snapIDLen := int(payload[offset]) + offset++ + if snapIDLen < 1 || len(payload) < offset+snapIDLen+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot id length", "") + } + snapshotID := string(payload[offset : offset+snapIDLen]) + offset += snapIDLen + + snapSHALen := int(payload[offset]) + offset++ + if snapSHALen < 1 || len(payload) < offset+snapSHALen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid snapshot sha length", "") + } + snapshotSHA := string(payload[offset : offset+snapSHALen]) + offset += snapSHALen + + resources, resErr := parseOptionalResourceRequest(payload[offset:]) + if resErr != nil { + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "invalid resource request", + resErr.Error(), + ) + } + + h.logger.Info("queue job with snapshot request", + "job", jobName, + "priority", priority, + "commit_id", fmt.Sprintf("%x", commitID), + "snapshot_id", snapshotID, + ) + + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info( + "job queued", + "job", jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), + "user", user.Name, + ) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to create jobs", + "", + ) + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, + "experiment.create", + 50*time.Millisecond, + func() (string, error) { + return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) + }, + ); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to create experiment directory", + err.Error(), + ) + } + + meta := &experiment.Metadata{ + CommitID: fmt.Sprintf("%x", commitID), + JobName: jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { + return "", h.expManager.WriteMetadata(meta) + }); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to save experiment metadata", + err.Error(), + ) + } + + commitIDStr := fmt.Sprintf("%x", commitID) + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.ensure_minimal_files", 50*time.Millisecond, func() (string, error) { + return "", ensureMinimalExperimentFiles(h.expManager, commitIDStr) + }); err != nil { + h.logger.Error("failed to ensure minimal experiment files", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to initialize experiment files", + err.Error(), + ) + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { + manifest, err := h.expManager.GenerateManifest(commitIDStr) + if err != nil { + return "", fmt.Errorf("failed to generate manifest: %w", err) + } + if err := h.expManager.WriteManifest(manifest); err != nil { + return "", fmt.Errorf("failed to write manifest: %w", err) + } + return "", nil + }); err != nil { + h.logger.Error("failed to generate/write manifest", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to generate content integrity manifest", + err.Error(), + ) + } + + go func() { + if h.db != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + exp := &storage.Experiment{ + ID: fmt.Sprintf("%x", commitID), + Name: jobName, + Status: "pending", + UserID: user.Name, + } + if _, err := telemetry.ExecWithMetrics( + h.logger, + "db.experiments.upsert", + 50*time.Millisecond, + func() (string, error) { + return "", h.db.UpsertExperiment(ctx, exp) + }, + ); err != nil { + h.logger.Error("failed to upsert experiment row", "error", err) + } + } + }() + + h.logger.Info( + "job queued", + "job", jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), + "user", user.Name, + ) + + return h.enqueueTaskAndRespondWithSnapshot( + conn, + user, + jobName, + priority, + commitID, + nil, + resources, + snapshotID, + snapshotSHA, + ) +} + +// handleQueueJobWithTracking queues a job with optional tracking configuration. +// Protocol: [api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var] +// [tracking_json_len:2][tracking_json:var] +func (h *WSHandler) handleQueueJobWithTracking(conn *websocket.Conn, payload []byte) error { + if len(payload) < 38+2 { // minimum with zero-length tracking JSON + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "queue job with tracking payload too short", + "", + ) + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + priority := int64(payload[36]) + jobNameLen := int(payload[37]) + + // Ensure we have job name and two bytes for tracking length + if len(payload) < 38+jobNameLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + + jobName := string(payload[38 : 38+jobNameLen]) + offset := 38 + jobNameLen + trackingLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + + if trackingLen < 0 || len(payload) < offset+trackingLen { + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "invalid tracking json length", + "", + ) + } + + var trackingCfg *queue.TrackingConfig + if trackingLen > 0 { + var cfg queue.TrackingConfig + if err := json.Unmarshal(payload[offset:offset+trackingLen], &cfg); err != nil { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid tracking json", err.Error()) + } + trackingCfg = &cfg + } + + offset += trackingLen + resources, resErr := parseOptionalResourceRequest(payload[offset:]) + if resErr != nil { + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + "invalid resource request", + resErr.Error(), + ) + } + + h.logger.Info("queue job with tracking request", + "job", jobName, + "priority", priority, + "commit_id", fmt.Sprintf("%x", commitID), + ) + + // Validate API key and get user information + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + // Auth disabled - use default admin user + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + // Check user permissions + if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") { + h.logger.Info( + "job queued (with tracking)", + "job", jobName, + "path", h.expManager.GetExperimentPath(fmt.Sprintf("%x", commitID)), + "user", user.Name, + ) + } else { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to create jobs", + "", + ) + } + + // Create experiment directory and metadata (optimized) + if _, err := telemetry.ExecWithMetrics( + h.logger, + "experiment.create", + 50*time.Millisecond, + func() (string, error) { + return "", h.expManager.CreateExperiment(fmt.Sprintf("%x", commitID)) + }, + ); err != nil { + h.logger.Error("failed to create experiment directory", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to create experiment directory", + err.Error(), + ) + } + + meta := &experiment.Metadata{ + CommitID: fmt.Sprintf("%x", commitID), + JobName: jobName, + User: user.Name, + Timestamp: time.Now().Unix(), + } + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) { + return "", h.expManager.WriteMetadata(meta) + }); err != nil { + h.logger.Error("failed to save experiment metadata", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to save experiment metadata", + err.Error(), + ) + } + + // Generate and write content integrity manifest + commitIDStr := fmt.Sprintf("%x", commitID) + if _, err := telemetry.ExecWithMetrics( + h.logger, "experiment.generate_manifest", 100*time.Millisecond, func() (string, error) { + manifest, err := h.expManager.GenerateManifest(commitIDStr) + if err != nil { + return "", fmt.Errorf("failed to generate manifest: %w", err) + } + if err := h.expManager.WriteManifest(manifest); err != nil { + return "", fmt.Errorf("failed to write manifest: %w", err) + } + return "", nil + }); err != nil { + h.logger.Error("failed to generate/write manifest", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to generate content integrity manifest", + err.Error(), + ) + } + + // Add user info to experiment metadata (deferred for performance) + go func() { + if h.db != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + exp := &storage.Experiment{ + ID: fmt.Sprintf("%x", commitID), + Name: jobName, + Status: "pending", + UserID: user.Name, + } + if _, err := telemetry.ExecWithMetrics( + h.logger, + "db.experiments.upsert", + 50*time.Millisecond, + func() (string, error) { + return "", h.db.UpsertExperiment(ctx, exp) + }, + ); err != nil { + h.logger.Error("failed to upsert experiment row", "error", err) + } + } + + }() + + return h.enqueueTaskAndRespond(conn, user, jobName, priority, commitID, trackingCfg, resources) +} + +// enqueueTaskAndRespond enqueues a task and sends a success response. +func (h *WSHandler) enqueueTaskAndRespond( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + tracking *queue.TrackingConfig, + resources *resourceRequest, +) error { + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) + + commitIDStr := fmt.Sprintf("%x", commitID) + prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + if provErr != nil { + h.logger.Error("failed to compute expected provenance; refusing to enqueue", + "commit_id", commitIDStr, + "error", provErr) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to compute expected provenance", + provErr.Error(), + ) + } + + // Enqueue task if queue is available + if h.queue != nil { + taskID := uuid.New().String() + task := &queue.Task{ + ID: taskID, + JobName: jobName, + Args: "", + Status: "queued", + Priority: priority, + CreatedAt: time.Now(), + UserID: user.Name, + Username: user.Name, + CreatedBy: user.Name, + Metadata: map[string]string{ + "commit_id": commitIDStr, + }, + Tracking: tracking, + } + for k, v := range prov { + if v != "" { + task.Metadata[k] = v + } + } + if resources != nil { + task.CPU = resources.CPU + task.MemoryGB = resources.MemoryGB + task.GPU = resources.GPU + task.GPUMemory = resources.GPUMemory + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, + "queue.add_task", + 20*time.Millisecond, + func() (string, error) { + return "", h.queue.AddTask(task) + }, + ); err != nil { + h.logger.Error("failed to enqueue task", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeDatabaseError, + "Failed to enqueue task", + err.Error(), + ) + } + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) + } + + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Internal error", + "Failed to serialize response", + ) + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +func (h *WSHandler) enqueueTaskAndRespondWithSnapshot( + conn *websocket.Conn, + user *auth.User, + jobName string, + priority int64, + commitID []byte, + tracking *queue.TrackingConfig, + resources *resourceRequest, + snapshotID string, + snapshotSHA string, +) error { + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName)) + + commitIDStr := fmt.Sprintf("%x", commitID) + prov, provErr := expectedProvenanceForCommit(h.expManager, commitIDStr) + if provErr != nil { + h.logger.Error("failed to compute expected provenance; refusing to enqueue", + "commit_id", commitIDStr, + "error", provErr) + return h.sendErrorPacket( + conn, + ErrorCodeStorageError, + "Failed to compute expected provenance", + provErr.Error(), + ) + } + + if h.queue != nil { + taskID := uuid.New().String() + task := &queue.Task{ + ID: taskID, + JobName: jobName, + Args: "", + Status: "queued", + Priority: priority, + CreatedAt: time.Now(), + UserID: user.Name, + Username: user.Name, + CreatedBy: user.Name, + SnapshotID: strings.TrimSpace(snapshotID), + Metadata: map[string]string{ + "commit_id": commitIDStr, + "snapshot_sha256": strings.TrimSpace(snapshotSHA), + }, + Tracking: tracking, + } + for k, v := range prov { + if v != "" { + task.Metadata[k] = v + } + } + if resources != nil { + task.CPU = resources.CPU + task.MemoryGB = resources.MemoryGB + task.GPU = resources.GPU + task.GPUMemory = resources.GPUMemory + } + + if _, err := telemetry.ExecWithMetrics( + h.logger, + "queue.add_task", + 20*time.Millisecond, + func() (string, error) { + return "", h.queue.AddTask(task) + }, + ); err != nil { + h.logger.Error("failed to enqueue task", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error()) + } + h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName) + } + + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Internal error", + "Failed to serialize response", + ) + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +type resourceRequest struct { + CPU int + MemoryGB int + GPU int + GPUMemory string +} + +// parseOptionalResourceRequest parses an optional tail encoding: +// [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var] +// If payload is empty, returns nil. +func parseOptionalResourceRequest(payload []byte) (*resourceRequest, error) { + if len(payload) == 0 { + return nil, nil + } + if len(payload) < 4 { + return nil, fmt.Errorf("resource payload too short") + } + cpu := int(payload[0]) + mem := int(payload[1]) + gpu := int(payload[2]) + gpuMemLen := int(payload[3]) + if gpuMemLen < 0 || len(payload) < 4+gpuMemLen { + return nil, fmt.Errorf("invalid gpu memory length") + } + gpuMem := "" + if gpuMemLen > 0 { + gpuMem = string(payload[4 : 4+gpuMemLen]) + } + return &resourceRequest{CPU: cpu, MemoryGB: mem, GPU: gpu, GPUMemory: gpuMem}, nil +} + +func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16] + if len(payload) < 16 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "") + } + + apiKeyHash := payload[:16] + h.logger.Info("status request received", "api_key_hash", fmt.Sprintf("%x", apiKeyHash)) + + // Validate API key and get user information + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + // Auth disabled - use default admin user + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + // Check user permissions for viewing jobs + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to view jobs", + "", + ) + } + // Get tasks with user filtering + var tasks []*queue.Task + if h.queue != nil { + allTasks, err := h.queue.GetAllTasks() + if err != nil { + h.logger.Error("failed to get tasks", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error()) + } + + // Filter tasks based on user permissions + for _, task := range allTasks { + // If auth is disabled or admin can see all tasks + if h.authConfig == nil || !h.authConfig.Enabled || user.Admin { + tasks = append(tasks, task) + continue + } + + // Users can only see their own tasks + if task.UserID == user.Name || task.CreatedBy == user.Name { + tasks = append(tasks, task) + } + } + } + + // Build status response as raw JSON for CLI compatibility + h.logger.Info("building status response") + status := map[string]any{ + "user": map[string]any{ + "name": user.Name, + "admin": user.Admin, + "roles": user.Roles, + }, + "tasks": map[string]any{ + "total": len(tasks), + "queued": countTasksByStatus(tasks, "queued"), + "running": countTasksByStatus(tasks, "running"), + "failed": countTasksByStatus(tasks, "failed"), + "completed": countTasksByStatus(tasks, "completed"), + }, + "queue": tasks, + } + if h.queue != nil { + if states, err := h.queue.GetAllWorkerPrewarmStates(); err == nil { + sort.Slice(states, func(i, j int) bool { + if states[i].WorkerID != states[j].WorkerID { + return states[i].WorkerID < states[j].WorkerID + } + return states[i].TaskID < states[j].TaskID + }) + status["prewarm"] = states + } + } + + h.logger.Info("serializing JSON response") + jsonData, err := json.Marshal(status) + if err != nil { + h.logger.Error("failed to marshal JSON", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Internal error", + "Failed to serialize response", + ) + } + + h.logger.Info("sending websocket JSON response", "len", len(jsonData)) + + // Send as binary protocol packet + packet := NewDataPacket("status", jsonData) + return h.sendResponsePacket(conn, packet) +} + +// countTasksByStatus counts tasks by their status +func countTasksByStatus(tasks []*queue.Task, status string) int { + count := 0 + for _, task := range tasks { + if task.Status == status { + count++ + } + } + return count +} + +func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][job_name_len:1][job_name:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "") + } + + // Parse 16-byte binary API key hash + apiKeyHash := payload[:16] + jobNameLen := int(payload[16]) + + if len(payload) < 17+jobNameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "") + } + + jobName := string(payload[17 : 17+jobNameLen]) + + h.logger.Info("cancel job request", "job", jobName) + + // Validate API key and get user information + var user *auth.User + var err error + if h.authConfig != nil { + user, err = h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + h.logger.Error("invalid api key", "error", err) + return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error()) + } + } else { + // Auth disabled - use default admin user + user = &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + } + } + + // Check user permissions for canceling jobs + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") { + h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update") + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to cancel jobs", + "", + ) + } + // Find the task and verify ownership + if h.queue != nil { + task, err := h.queue.GetTaskByName(jobName) + if err != nil { + h.logger.Error("task not found", "job", jobName, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error()) + } + + // Check if user can cancel this task (admin or owner) + if h.authConfig != nil && + h.authConfig.Enabled && + !user.Admin && + task.UserID != user.Name && + task.CreatedBy != user.Name { + h.logger.Error( + "unauthorized job cancellation attempt", + "user", user.Name, + "job", jobName, + "task_owner", task.UserID, + ) + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "You can only cancel your own jobs", + "", + ) + } + // Cancel the task + if err := h.queue.CancelTask(task.ID); err != nil { + h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err) + return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error()) + } + + h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name) + } else { + h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName) + } + + packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName)) + packetData, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize packet", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Internal error", + "Failed to serialize response", + ) + } + return conn.WriteMessage(websocket.BinaryMessage, packetData) +} + +func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][prune_type:1][value:4] + if len(payload) < 21 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "") + } + + // Parse 16-byte binary API key hash + apiKeyHash := payload[:16] + pruneType := payload[16] + value := binary.BigEndian.Uint32(payload[17:21]) + + h.logger.Info("prune request", "type", pruneType, "value", value) + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + h.logger.Error("api key verification failed", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + + // Convert prune parameters + var keepCount int + var olderThanDays int + + switch pruneType { + case 0: + // keep N + keepCount = int(value) + olderThanDays = 0 + case 1: + // older than days + keepCount = 0 + olderThanDays = int(value) + default: + return h.sendErrorPacket( + conn, + ErrorCodeInvalidRequest, + fmt.Sprintf("invalid prune type: %d", pruneType), + "", + ) + } + + // Perform pruning + pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays) + if err != nil { + h.logger.Error("prune failed", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error()) + } + if h.queue != nil { + _ = h.queue.SignalPrewarmGC() + } + + h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned) + + // Send structured success response + packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned))) + return h.sendResponsePacket(conn, packet) +} + +func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var] + if len(payload) < 51 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "") + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + step := int(binary.BigEndian.Uint32(payload[36:40])) + valueBits := binary.BigEndian.Uint64(payload[40:48]) + value := math.Float64frombits(valueBits) + nameLen := int(payload[48]) + + if len(payload) < 49+nameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "") + } + + name := string(payload[49 : 49+nameLen]) + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + h.logger.Error("api key verification failed", "error", err) + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + + if err := h.expManager.LogMetric(fmt.Sprintf("%x", commitID), name, value, step); err != nil { + h.logger.Error("failed to log metric", "error", err) + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error()) + } + + return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged")) +} + +func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][commit_id:20] + if len(payload) < 36 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "") + } + + apiKeyHash := payload[:16] + commitID := payload[16:36] + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + + meta, err := h.expManager.ReadMetadata(fmt.Sprintf("%x", commitID)) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error()) + } + + metrics, err := h.expManager.GetMetrics(fmt.Sprintf("%x", commitID)) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error()) + } + + var dbMeta *storage.ExperimentWithMetadata + if h.db != nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + m, err := h.db.GetExperimentWithMetadata(ctx, fmt.Sprintf("%x", commitID)) + if err == nil { + dbMeta = m + } + } + + response := map[string]interface{}{ + "metadata": meta, + "metrics": metrics, + } + if dbMeta != nil { + response["reproducibility"] = dbMeta + } + + responseData, err := json.Marshal(response) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeServerOverloaded, + "Failed to serialize response", + err.Error(), + ) + } + + return h.sendResponsePacket(conn, NewDataPacket("experiment", responseData)) +} diff --git a/internal/api/ws_jupyter.go b/internal/api/ws_jupyter.go new file mode 100644 index 0000000..72795cd --- /dev/null +++ b/internal/api/ws_jupyter.go @@ -0,0 +1,478 @@ +package api + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/queue" +) + +type jupyterTaskOutput struct { + Type string `json:"type"` + Service json.RawMessage `json:"service,omitempty"` + Services json.RawMessage `json:"services,omitempty"` + RestorePath string `json:"restore_path,omitempty"` +} + +func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][name_len:1][name:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "") + } + + apiKeyHash := payload[:16] + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:manage") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + offset := 16 + nameLen := int(payload[offset]) + offset++ + if len(payload) < offset+nameLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") + } + name := string(payload[offset : offset+nameLen]) + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionRestore, + jupyterNameKey: strings.TrimSpace(name), + } + jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name)) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter restore", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "") + } + + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter restore", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error)) + } + + msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name)) + out := strings.TrimSpace(result.Output) + if out != "" { + var payloadOut jupyterTaskOutput + if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { + if strings.TrimSpace(payloadOut.RestorePath) != "" { + msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath)) + } + } + } + + return h.sendResponsePacket(conn, NewSuccessPacket(msg)) +} + +type jupyterServiceView struct { + Name string `json:"name"` + URL string `json:"url"` +} + +const ( + jupyterTaskTypeKey = "task_type" + jupyterTaskTypeValue = "jupyter" + + jupyterTaskActionKey = "jupyter_action" + jupyterActionStart = "start" + jupyterActionStop = "stop" + jupyterActionRemove = "remove" + jupyterActionRestore = "restore" + jupyterActionList = "list" + + jupyterNameKey = "jupyter_name" + jupyterWorkspaceKey = "jupyter_workspace" + jupyterServiceIDKey = "jupyter_service_id" +) + +func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) { + if h.queue == nil { + return "", fmt.Errorf("task queue not configured") + } + if err := container.ValidateJobName(jobName); err != nil { + return "", err + } + if strings.TrimSpace(userName) == "" { + return "", fmt.Errorf("missing user") + } + if meta == nil { + meta = make(map[string]string) + } + meta[jupyterTaskTypeKey] = jupyterTaskTypeValue + + taskID := uuid.New().String() + task := &queue.Task{ + ID: taskID, + JobName: jobName, + Args: "", + Status: "queued", + Priority: 100, // high priority; interactive request + CreatedAt: time.Now(), + UserID: userName, + Username: userName, + CreatedBy: userName, + Metadata: meta, + } + if err := h.queue.AddTask(task); err != nil { + return "", err + } + return taskID, nil +} + +func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) { + if h.queue == nil { + return nil, fmt.Errorf("task queue not configured") + } + deadline := time.Now().Add(timeout) + for { + if time.Now().After(deadline) { + return nil, fmt.Errorf("timed out waiting for worker") + } + t, err := h.queue.GetTask(taskID) + if err != nil { + return nil, err + } + if t == nil { + time.Sleep(200 * time.Millisecond) + continue + } + if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" { + return t, nil + } + time.Sleep(200 * time.Millisecond) + } +} + +func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error { + // Protocol: + // [api_key_hash:16][name][workspace][password] + if len(payload) < 21 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "") + } + + apiKeyHash := payload[:16] + + // Verify API key + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:manage") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + offset := 16 + nameLen := int(payload[offset]) + offset++ + + if len(payload) < offset+nameLen+2 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "") + } + + name := string(payload[offset : offset+nameLen]) + offset += nameLen + + workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + + if len(payload) < offset+workspaceLen+1 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "") + } + + workspace := string(payload[offset : offset+workspaceLen]) + offset += workspaceLen + + passwordLen := int(payload[offset]) + offset++ + + if len(payload) < offset+passwordLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "") + } + + // Password is parsed but not used in StartRequest + // offset += passwordLen (already advanced during parsing) + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionStart, + jupyterNameKey: strings.TrimSpace(name), + jupyterWorkspaceKey: strings.TrimSpace(workspace), + } + jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name)) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter task", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "") + } + + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter start", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + h.logger.Error("jupyter task failed", "error", result.Error) + details := strings.TrimSpace(result.Error) + lower := strings.ToLower(details) + if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") { + return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details) + } + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details) + } + + msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name)) + out := strings.TrimSpace(result.Output) + if out != "" { + var payloadOut jupyterTaskOutput + if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 { + var svc jupyterServiceView + if err := json.Unmarshal(payloadOut.Service, &svc); err == nil { + if strings.TrimSpace(svc.URL) != "" { + msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL)) + } + } + } + } + return h.sendResponsePacket(conn, NewSuccessPacket(msg)) +} + +func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][service_id_len:1][service_id:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "") + } + + apiKeyHash := payload[:16] + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:manage") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + offset := 16 + idLen := int(payload[offset]) + offset++ + + if len(payload) < offset+idLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") + } + + serviceID := string(payload[offset : offset+idLen]) + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionStop, + jupyterServiceIDKey: strings.TrimSpace(serviceID), + } + jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID)) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter stop", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "") + } + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter stop", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error)) + } + return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID))) +} + +func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][service_id_len:1][service_id:var] + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "") + } + + apiKeyHash := payload[:16] + + offset := 16 + idLen := int(payload[offset]) + offset++ + if len(payload) < offset+idLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "") + } + serviceID := string(payload[offset : offset+idLen]) + offset += idLen + + // Optional: purge flag (1 byte). Default false for trash-first behavior. + purge := false + if len(payload) > offset { + purge = payload[offset] == 0x01 + } + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:manage") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionRemove, + jupyterServiceIDKey: strings.TrimSpace(serviceID), + "jupyter_purge": fmt.Sprintf("%t", purge), + } + jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID)) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter remove", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "") + } + + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter remove", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error)) + } + return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID))) +} + +func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16] + if len(payload) < 16 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "") + } + + apiKeyHash := payload[:16] + + if h.authConfig != nil && h.authConfig.Enabled { + if err := h.verifyAPIKeyHash(apiKeyHash); err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + } + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Authentication failed", + err.Error(), + ) + } + if user != nil && !user.HasPermission("jupyter:read") { + return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "") + } + + meta := map[string]string{ + jupyterTaskActionKey: jupyterActionList, + } + jobName := fmt.Sprintf("jupyter-list-%s", user.Name) + taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta) + if err != nil { + h.logger.Error("failed to enqueue jupyter list", "error", err) + return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "") + } + result, err := h.waitForTask(taskID, 2*time.Minute) + if err != nil { + h.logger.Error("failed waiting for jupyter list", "error", err) + return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "") + } + if result.Status != "completed" { + return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error)) + } + + out := strings.TrimSpace(result.Output) + if out == "" { + empty, _ := json.Marshal([]any{}) + return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty)) + } + var payloadOut jupyterTaskOutput + if err := json.Unmarshal([]byte(out), &payloadOut); err == nil { + // Always return an array payload (even if empty) so clients can render a stable table. + payload := payloadOut.Services + if len(payload) == 0 { + payload = []byte("[]") + } + return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload)) + } + // Fallback: return empty array on unexpected output. + return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]"))) +} diff --git a/internal/api/ws_tls_auth.go b/internal/api/ws_tls_auth.go new file mode 100644 index 0000000..41730be --- /dev/null +++ b/internal/api/ws_tls_auth.go @@ -0,0 +1,100 @@ +package api + +import ( + "crypto/tls" + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/auth" + "golang.org/x/crypto/acme/autocert" +) + +// SetupTLSConfig creates TLS configuration for WebSocket server +func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) { + var server *http.Server + + if certFile != "" && keyFile != "" { + // Use provided certificates + server = &http.Server{ + ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + }, + } + } else if host != "" { + // Use Let's Encrypt with autocert + certManager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(host), + Cache: autocert.DirCache("/var/www/.cache"), + } + + server = &http.Server{ + ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks + TLSConfig: certManager.TLSConfig(), + } + } + + return server, nil +} + +// verifyAPIKeyHash verifies the provided binary hash against stored API keys +func (h *WSHandler) verifyAPIKeyHash(hash []byte) error { + if h.authConfig == nil || !h.authConfig.Enabled { + return nil // No auth required + } + _, err := h.authConfig.ValidateAPIKeyHash(hash) + if err != nil { + return fmt.Errorf("invalid api key") + } + return nil +} + +// sendErrorPacket sends an error response packet +func (h *WSHandler) sendErrorPacket( + conn *websocket.Conn, + errorCode byte, + message string, + details string, +) error { + packet := NewErrorPacket(errorCode, message, details) + return h.sendResponsePacket(conn, packet) +} + +// sendResponsePacket sends a structured response packet +func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error { + data, err := packet.Serialize() + if err != nil { + h.logger.Error("failed to serialize response packet", "error", err) + // Fallback to simple error response + return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) + } + + return conn.WriteMessage(websocket.BinaryMessage, data) +} + +func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) { + if h.authConfig != nil { + user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash) + if err != nil { + return nil, err + } + return user, nil + } + + return &auth.User{ + Name: "default", + Admin: true, + Roles: []string{"admin"}, + Permissions: map[string]bool{ + "*": true, + }, + }, nil +} diff --git a/internal/api/ws_validate.go b/internal/api/ws_validate.go new file mode 100644 index 0000000..63af9e9 --- /dev/null +++ b/internal/api/ws_validate.go @@ -0,0 +1,642 @@ +package api + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/config" + "github.com/jfraeys/fetch_ml/internal/container" + "github.com/jfraeys/fetch_ml/internal/manifest" + "github.com/jfraeys/fetch_ml/internal/queue" + "github.com/jfraeys/fetch_ml/internal/worker" +) + +type validateCheck struct { + OK bool `json:"ok"` + Expected string `json:"expected,omitempty"` + Actual string `json:"actual,omitempty"` + Details string `json:"details,omitempty"` +} + +type validateReport struct { + OK bool `json:"ok"` + CommitID string `json:"commit_id,omitempty"` + TaskID string `json:"task_id,omitempty"` + Checks map[string]validateCheck `json:"checks"` + Errors []string `json:"errors,omitempty"` + Warnings []string `json:"warnings,omitempty"` + TS string `json:"ts"` +} + +func shouldRequireRunManifest(task *queue.Task) bool { + if task == nil { + return false + } + s := strings.ToLower(strings.TrimSpace(task.Status)) + switch s { + case "running", "completed", "failed": + return true + default: + return false + } +} + +func expectedRunManifestBucketForStatus(status string) (string, bool) { + s := strings.ToLower(strings.TrimSpace(status)) + switch s { + case "queued", "pending": + return "pending", true + case "running": + return "running", true + case "completed", "finished": + return "finished", true + case "failed": + return "failed", true + default: + return "", false + } +} + +func findRunManifestDir(basePath string, jobName string) (string, string, bool) { + if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" { + return "", "", false + } + jobPaths := config.NewJobPaths(basePath) + typedRoots := []struct { + bucket string + root string + }{ + {bucket: "running", root: jobPaths.RunningPath()}, + {bucket: "pending", root: jobPaths.PendingPath()}, + {bucket: "finished", root: jobPaths.FinishedPath()}, + {bucket: "failed", root: jobPaths.FailedPath()}, + } + for _, item := range typedRoots { + root := item.root + dir := filepath.Join(root, jobName) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil { + return dir, item.bucket, true + } + } + } + return "", "", false +} + +func validateResourcesForTask(task *queue.Task) (validateCheck, []string) { + if task == nil { + return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"} + } + + if task.CPU < 0 { + chk := validateCheck{OK: false, Details: "cpu must be >= 0"} + return chk, []string{"invalid cpu request"} + } + if task.MemoryGB < 0 { + chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"} + return chk, []string{"invalid memory request"} + } + if task.GPU < 0 { + chk := validateCheck{OK: false, Details: "gpu must be >= 0"} + return chk, []string{"invalid gpu request"} + } + + if strings.TrimSpace(task.GPUMemory) != "" { + s := strings.TrimSpace(task.GPUMemory) + if strings.HasSuffix(s, "%") { + v := strings.TrimSuffix(s, "%") + f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + if err != nil || f <= 0 || f > 100 { + details := "gpu_memory must be a percentage in (0,100]" + chk := validateCheck{OK: false, Details: details} + return chk, []string{"invalid gpu_memory"} + } + } else { + f, err := strconv.ParseFloat(s, 64) + if err != nil || f <= 0 || f > 1 { + chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"} + return chk, []string{"invalid gpu_memory"} + } + } + } + + if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" { + chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"} + return chk, []string{"invalid gpu_memory"} + } + + return validateCheck{OK: true}, nil +} + +func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error { + // Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var] + // target_type: 0=commit_id (20 bytes), 1=task_id (string) + // TODO(context): Add a versioned validate protocol once we need more target types/fields. + if len(payload) < 18 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "") + } + apiKeyHash := payload[:16] + targetType := payload[16] + idLen := int(payload[17]) + if idLen < 1 || len(payload) < 18+idLen { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "") + } + idBytes := payload[18 : 18+idLen] + + // Validate API key and user + user, err := h.validateWSUser(apiKeyHash) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeAuthenticationFailed, + "Invalid API key", + err.Error(), + ) + } + if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") { + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "Insufficient permissions to validate jobs", + "", + ) + } + if h.expManager == nil { + return h.sendErrorPacket( + conn, + ErrorCodeServiceUnavailable, + "Experiment manager not available", + "", + ) + } + + var task *queue.Task + commitID := "" + switch targetType { + case 0: + if len(idBytes) != 20 { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "") + } + commitID = fmt.Sprintf("%x", idBytes) + case 1: + taskID := string(idBytes) + if h.queue == nil { + return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "") + } + t, err := h.queue.GetTask(taskID) + if err != nil { + return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error()) + } + task = t + if h.authConfig != nil && + h.authConfig.Enabled && + !user.Admin && + task.UserID != user.Name && + task.CreatedBy != user.Name { + return h.sendErrorPacket( + conn, + ErrorCodePermissionDenied, + "You can only validate your own jobs", + "", + ) + } + if task.Metadata == nil || task.Metadata["commit_id"] == "" { + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "") + } + commitID = task.Metadata["commit_id"] + default: + return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "") + } + + r := validateReport{ + OK: true, + TS: time.Now().UTC().Format(time.RFC3339Nano), + Checks: map[string]validateCheck{}, + } + if task != nil { + r.TaskID = task.ID + } + if commitID != "" { + r.CommitID = commitID + } + + // Validate commit id format + if len(commitID) != 40 { + r.OK = false + r.Errors = append(r.Errors, "invalid commit_id length") + } else if _, err := hex.DecodeString(commitID); err != nil { + r.OK = false + r.Errors = append(r.Errors, "invalid commit_id hex") + } + + // Experiment manifest integrity + // TODO(context): Extend report to include per-file diff list on mismatch (bounded output). + if r.OK { + if err := h.expManager.ValidateManifest(commitID); err != nil { + r.OK = false + r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()} + r.Errors = append(r.Errors, "experiment manifest validation failed") + } else { + r.Checks["experiment_manifest"] = validateCheck{OK: true} + } + } + + // Deps manifest presence + hash + // TODO(context): Allow client to declare which dependency manifest is authoritative. + filesPath := h.expManager.GetFilesPath(commitID) + depName, depErr := worker.SelectDependencyManifest(filesPath) + if depErr != nil { + r.OK = false + r.Checks["deps_manifest"] = validateCheck{ + OK: false, + Details: depErr.Error(), + } + r.Errors = append(r.Errors, "deps manifest missing") + } else { + sha, err := fileSHA256Hex(filepath.Join(filesPath, depName)) + if err != nil { + r.OK = false + r.Checks["deps_manifest"] = validateCheck{ + OK: false, + Details: err.Error(), + } + r.Errors = append(r.Errors, "deps manifest hash failed") + } else { + r.Checks["deps_manifest"] = validateCheck{ + OK: true, + Actual: depName + ":" + sha, + } + } + } + + // Compare against expected task metadata if available. + if task != nil { + resCheck, resErrs := validateResourcesForTask(task) + r.Checks["resources"] = resCheck + if !resCheck.OK { + r.OK = false + r.Errors = append(r.Errors, resErrs...) + } + + // Run manifest checks: best-effort for queued tasks, required for running/completed/failed. + if err := container.ValidateJobName(task.JobName); err != nil { + r.OK = false + r.Errors = append(r.Errors, "invalid job name") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"} + } else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" { + if shouldRequireRunManifest(task) { + r.OK = false + r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} + } else { + r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"} + } + } else { + manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName) + if !found { + if shouldRequireRunManifest(task) { + r.OK = false + r.Errors = append(r.Errors, "run manifest not found") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} + } else { + r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"} + } + } else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil { + r.OK = false + r.Errors = append(r.Errors, "unable to read run manifest") + r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"} + } else { + r.Checks["run_manifest"] = validateCheck{OK: true} + + expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status) + if ok { + if expectedBucket != manifestBucket { + msg := "run manifest location mismatch" + chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket} + if shouldRequireRunManifest(task) { + r.OK = false + r.Errors = append(r.Errors, msg) + r.Checks["run_manifest_location"] = chk + } else { + r.Warnings = append(r.Warnings, msg) + r.Checks["run_manifest_location"] = chk + } + } else { + r.Checks["run_manifest_location"] = validateCheck{ + OK: true, + Expected: expectedBucket, + Actual: manifestBucket, + } + } + } + + if strings.TrimSpace(rm.TaskID) == "" { + r.OK = false + r.Errors = append(r.Errors, "run manifest missing task_id") + r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID} + } else if rm.TaskID != task.ID { + r.OK = false + r.Errors = append(r.Errors, "run manifest task_id mismatch") + r.Checks["run_manifest_task_id"] = validateCheck{ + OK: false, + Expected: task.ID, + Actual: rm.TaskID, + } + } else { + r.Checks["run_manifest_task_id"] = validateCheck{ + OK: true, + Expected: task.ID, + Actual: rm.TaskID, + } + } + + commitWant := strings.TrimSpace(task.Metadata["commit_id"]) + commitGot := strings.TrimSpace(rm.CommitID) + if commitWant != "" && commitGot != "" && commitWant != commitGot { + r.OK = false + r.Errors = append(r.Errors, "run manifest commit_id mismatch") + r.Checks["run_manifest_commit_id"] = validateCheck{ + OK: false, + Expected: commitWant, + Actual: commitGot, + } + } else if commitWant != "" { + r.Checks["run_manifest_commit_id"] = validateCheck{ + OK: true, + Expected: commitWant, + Actual: commitGot, + } + } + + depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"]) + depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) + depGotName := strings.TrimSpace(rm.DepsManifestName) + depGotSHA := strings.TrimSpace(rm.DepsManifestSHA) + if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" { + expectedDep := depWantName + ":" + depWantSHA + actualDep := depGotName + ":" + depGotSHA + if depWantName != depGotName || depWantSHA != depGotSHA { + r.OK = false + r.Errors = append(r.Errors, "run manifest deps provenance mismatch") + r.Checks["run_manifest_deps"] = validateCheck{ + OK: false, + Expected: expectedDep, + Actual: actualDep, + } + } else { + r.Checks["run_manifest_deps"] = validateCheck{ + OK: true, + Expected: expectedDep, + Actual: actualDep, + } + } + } + + if strings.TrimSpace(task.SnapshotID) != "" { + snapWantID := strings.TrimSpace(task.SnapshotID) + snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"]) + snapGotID := strings.TrimSpace(rm.SnapshotID) + snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256) + if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID { + r.OK = false + r.Errors = append(r.Errors, "run manifest snapshot_id mismatch") + r.Checks["run_manifest_snapshot_id"] = validateCheck{ + OK: false, + Expected: snapWantID, + Actual: snapGotID, + } + } else { + r.Checks["run_manifest_snapshot_id"] = validateCheck{ + OK: true, + Expected: snapWantID, + Actual: snapGotID, + } + } + if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA { + r.OK = false + r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch") + r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ + OK: false, + Expected: snapWantSHA, + Actual: snapGotSHA, + } + } else if snapWantSHA != "" { + r.Checks["run_manifest_snapshot_sha256"] = validateCheck{ + OK: true, + Expected: snapWantSHA, + Actual: snapGotSHA, + } + } + } + + statusLower := strings.ToLower(strings.TrimSpace(task.Status)) + lifecycleOK := true + details := "" + + switch statusLower { + case "running": + if rm.StartedAt.IsZero() { + lifecycleOK = false + details = "missing started_at for running task" + } + if !rm.EndedAt.IsZero() { + lifecycleOK = false + if details == "" { + details = "ended_at must be empty for running task" + } + } + if rm.ExitCode != nil { + lifecycleOK = false + if details == "" { + details = "exit_code must be empty for running task" + } + } + case "completed", "failed": + if rm.StartedAt.IsZero() { + lifecycleOK = false + details = "missing started_at for completed/failed task" + } + if rm.EndedAt.IsZero() { + lifecycleOK = false + if details == "" { + details = "missing ended_at for completed/failed task" + } + } + if rm.ExitCode == nil { + lifecycleOK = false + if details == "" { + details = "missing exit_code for completed/failed task" + } + } + if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) { + lifecycleOK = false + if details == "" { + details = "ended_at is before started_at" + } + } + case "queued", "pending": + // queued/pending tasks may not have started yet. + if !rm.EndedAt.IsZero() || rm.ExitCode != nil { + lifecycleOK = false + details = "queued/pending task should not have ended_at/exit_code" + } + } + + if lifecycleOK { + r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true} + } else { + chk := validateCheck{OK: false, Details: details} + if shouldRequireRunManifest(task) { + r.OK = false + r.Errors = append(r.Errors, "run manifest lifecycle invalid") + r.Checks["run_manifest_lifecycle"] = chk + } else { + r.Warnings = append(r.Warnings, "run manifest lifecycle invalid") + r.Checks["run_manifest_lifecycle"] = chk + } + } + } + } + + want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"]) + cur := "" + if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil { + cur = man.OverallSHA + } + if want == "" { + r.OK = false + r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha") + r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur} + } else if cur == "" { + r.OK = false + r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha") + r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want} + } else if want != cur { + r.OK = false + r.Errors = append(r.Errors, "experiment manifest overall sha mismatch") + r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur} + } else { + r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur} + } + + wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"]) + wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"]) + if wantDep == "" || wantDepSha == "" { + r.OK = false + r.Errors = append(r.Errors, "missing expected deps manifest provenance") + r.Checks["expected_deps_manifest"] = validateCheck{OK: false} + } else if depName != "" { + sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName)) + ok := (wantDep == depName && wantDepSha == sha) + if !ok { + r.OK = false + r.Errors = append(r.Errors, "deps manifest provenance mismatch") + r.Checks["expected_deps_manifest"] = validateCheck{ + OK: false, + Expected: wantDep + ":" + wantDepSha, + Actual: depName + ":" + sha, + } + } else { + r.Checks["expected_deps_manifest"] = validateCheck{ + OK: true, + Expected: wantDep + ":" + wantDepSha, + Actual: depName + ":" + sha, + } + } + } + + // Snapshot/dataset checks require dataDir. + // TODO(context): Support snapshot stores beyond local filesystem (e.g. S3). + // TODO(context): Validate snapshots by digest. + if task.SnapshotID != "" { + if h.dataDir == "" { + r.OK = false + r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot") + r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"} + } else { + wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"]) + if nerr != nil || wantSnap == "" { + r.OK = false + r.Errors = append(r.Errors, "missing/invalid snapshot_sha256") + r.Checks["snapshot"] = validateCheck{OK: false} + } else { + curSnap, err := worker.DirOverallSHA256Hex( + filepath.Join(h.dataDir, "snapshots", task.SnapshotID), + ) + if err != nil { + r.OK = false + r.Errors = append(r.Errors, "snapshot hash computation failed") + r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()} + } else if curSnap != wantSnap { + r.OK = false + r.Errors = append(r.Errors, "snapshot checksum mismatch") + r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap} + } else { + r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap} + } + } + } + } + + if len(task.DatasetSpecs) > 0 { + // TODO(context): Add dataset URI fetch/verification. + // TODO(context): Currently only validates local materialized datasets. + for _, ds := range task.DatasetSpecs { + if ds.Checksum == "" { + continue + } + key := "dataset:" + ds.Name + if h.dataDir == "" { + r.OK = false + r.Errors = append( + r.Errors, + "api server data_dir not configured; cannot validate dataset checksums", + ) + r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"} + continue + } + wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum) + if nerr != nil || wantDS == "" { + r.OK = false + r.Errors = append(r.Errors, "invalid dataset checksum format") + r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"} + continue + } + curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name)) + if err != nil { + r.OK = false + r.Errors = append(r.Errors, "dataset checksum computation failed") + r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()} + continue + } + if curDS != wantDS { + r.OK = false + r.Errors = append(r.Errors, "dataset checksum mismatch") + r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS} + continue + } + r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS} + } + } + } + + payloadBytes, err := json.Marshal(r) + if err != nil { + return h.sendErrorPacket( + conn, + ErrorCodeUnknownError, + "failed to serialize validate report", + err.Error(), + ) + } + return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes)) +} diff --git a/internal/auth/api_key.go b/internal/auth/api_key.go index 8baf0da..1cec87d 100644 --- a/internal/auth/api_key.go +++ b/internal/auth/api_key.go @@ -133,6 +133,62 @@ func (c *Config) ValidateAPIKey(key string) (*User, error) { return nil, fmt.Errorf("invalid API key") } +// ValidateAPIKeyHash validates a pre-hashed API key and returns user information +func (c *Config) ValidateAPIKeyHash(hash []byte) (*User, error) { + if !c.Enabled { + // Auth disabled - return default admin user for development + return &User{Name: "default", Admin: true}, nil + } + if len(hash) != 16 { + return nil, fmt.Errorf("invalid api key hash length: %d", len(hash)) + } + + for username, entry := range c.APIKeys { + stored := strings.TrimSpace(string(entry.Hash)) + if stored == "" { + continue + } + storedBytes, err := hex.DecodeString(stored) + if err != nil { + continue + } + if len(storedBytes) != sha256.Size { + continue + } + if string(storedBytes[:16]) == string(hash) { + // Build user with role and permission inheritance + user := &User{ + Name: string(username), + Admin: entry.Admin, + Roles: entry.Roles, + Permissions: make(map[string]bool), + } + + // Copy explicit permissions + for perm, value := range entry.Permissions { + user.Permissions[perm] = value + } + + // Add role-based permissions + rolePerms := getRolePermissions(entry.Roles) + for perm, value := range rolePerms { + if _, exists := user.Permissions[perm]; !exists { + user.Permissions[perm] = value + } + } + + // Admin gets all permissions + if entry.Admin { + user.Permissions["*"] = true + } + + return user, nil + } + } + + return nil, fmt.Errorf("invalid api key") +} + // AuthMiddleware creates HTTP middleware for API key authentication func (c *Config) AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -177,6 +233,10 @@ func GetUserFromContext(ctx context.Context) *User { return nil } +func WithUserContext(ctx context.Context, user *User) context.Context { + return context.WithValue(ctx, userContextKey, user) +} + // RequireAdmin creates middleware that requires admin privileges func RequireAdmin(next http.Handler) http.Handler { return RequirePermission("system:admin")(next) diff --git a/internal/auth/database.go b/internal/auth/database.go index 44c9c6a..cef7014 100644 --- a/internal/auth/database.go +++ b/internal/auth/database.go @@ -64,7 +64,10 @@ func (s *DatabaseAuthStore) init() error { CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash); CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id); - CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31')); + CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys( + revoked_at, + COALESCE(expires_at, '9999-12-31') + ); ` _, err := s.db.ExecContext(ctx, query) @@ -158,7 +161,16 @@ func (s *DatabaseAuthStore) CreateAPIKey( revoked_at = NULL ` - _, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt) + _, err = s.db.ExecContext( + ctx, + query, + userID, + keyHash, + admin, + rolesJSON, + permissionsJSON, + expiresAt, + ) return err } diff --git a/internal/auth/hybrid.go b/internal/auth/hybrid.go index 8d1ef0a..3933176 100644 --- a/internal/auth/hybrid.go +++ b/internal/auth/hybrid.go @@ -213,7 +213,15 @@ func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error { for username, entry := range h.fileStore.APIKeys { userID := string(username) - err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil) + err := h.dbStore.CreateAPIKey( + ctx, + userID, + string(entry.Hash), + entry.Admin, + entry.Roles, + entry.Permissions, + nil, + ) if err != nil { log.Printf("Failed to migrate key for user %s: %v", userID, err) continue diff --git a/internal/auth/permissions.go b/internal/auth/permissions.go index 89c11da..a2d917f 100644 --- a/internal/auth/permissions.go +++ b/internal/auth/permissions.go @@ -59,23 +59,43 @@ var PermissionGroups = map[string]PermissionGroup{ Description: "Complete system access", }, "job_management": { - Name: "Job Management", - Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete}, + Name: "Job Management", + Permissions: []string{ + PermissionJobsCreate, + PermissionJobsRead, + PermissionJobsUpdate, + PermissionJobsDelete, + }, Description: "Create, read, update, and delete ML jobs", }, "data_access": { - Name: "Data Access", - Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete}, + Name: "Data Access", + Permissions: []string{ + PermissionDataRead, + PermissionDataCreate, + PermissionDataUpdate, + PermissionDataDelete, + }, Description: "Access and manage datasets", }, "readonly": { - Name: "Read Only", - Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics}, + Name: "Read Only", + Permissions: []string{ + PermissionJobsRead, + PermissionDataRead, + PermissionModelsRead, + PermissionSystemMetrics, + }, Description: "View-only access to system resources", }, "system_admin": { - Name: "System Administration", - Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics}, + Name: "System Administration", + Permissions: []string{ + PermissionSystemConfig, + PermissionSystemLogs, + PermissionSystemUsers, + PermissionSystemMetrics, + }, Description: "System configuration and user management", }, } diff --git a/internal/config/constants.go b/internal/config/constants.go index a224d14..e6663f2 100644 --- a/internal/config/constants.go +++ b/internal/config/constants.go @@ -25,6 +25,7 @@ const ( RedisTaskStatusPrefix = "ml:status:" RedisDatasetPrefix = "ml:dataset:" RedisWorkerHeartbeat = "ml:workers:heartbeat" + RedisWorkerPrewarmKey = "ml:workers:prewarm:" ) // Task status constants diff --git a/internal/config/security.go b/internal/config/security.go new file mode 100644 index 0000000..1adc50b --- /dev/null +++ b/internal/config/security.go @@ -0,0 +1,83 @@ +package config + +import ( + "fmt" + "time" +) + +// SecurityConfig holds security-related configuration +type SecurityConfig struct { + // AllowedOrigins lists the allowed origins for WebSocket connections + // Empty list defaults to localhost-only in production mode + AllowedOrigins []string `yaml:"allowed_origins"` + + // ProductionMode enables strict security checks + ProductionMode bool `yaml:"production_mode"` + + // APIKeyRotationDays is the number of days before API keys should be rotated + APIKeyRotationDays int `yaml:"api_key_rotation_days"` + + // AuditLogging configuration + AuditLogging AuditLoggingConfig `yaml:"audit_logging"` + + // IPWhitelist for additional connection filtering + IPWhitelist []string `yaml:"ip_whitelist"` +} + +// AuditLoggingConfig holds audit logging configuration +type AuditLoggingConfig struct { + Enabled bool `yaml:"enabled"` + LogPath string `yaml:"log_path"` +} + +// MonitoringConfig holds monitoring-related configuration +type MonitoringConfig struct { + Prometheus PrometheusConfig `yaml:"prometheus"` + HealthChecks HealthChecksConfig `yaml:"health_checks"` +} + +// PrometheusConfig holds Prometheus metrics configuration +type PrometheusConfig struct { + Enabled bool `yaml:"enabled"` + Port int `yaml:"port"` + Path string `yaml:"path"` +} + +// HealthChecksConfig holds health check configuration +type HealthChecksConfig struct { + Enabled bool `yaml:"enabled"` + Interval time.Duration `yaml:"interval"` +} + +// Validate validates the security configuration +func (s *SecurityConfig) Validate() error { + if s.ProductionMode { + if len(s.AllowedOrigins) == 0 { + return fmt.Errorf("production_mode requires at least one allowed_origin") + } + } + + if s.APIKeyRotationDays < 0 { + return fmt.Errorf("api_key_rotation_days must be positive") + } + + if s.AuditLogging.Enabled && s.AuditLogging.LogPath == "" { + return fmt.Errorf("audit_logging enabled but log_path not set") + } + + return nil +} + +// Validate validates the monitoring configuration +func (m *MonitoringConfig) Validate() error { + if m.Prometheus.Enabled { + if m.Prometheus.Port <= 0 || m.Prometheus.Port > 65535 { + return fmt.Errorf("prometheus port must be between 1 and 65535") + } + if m.Prometheus.Path == "" { + m.Prometheus.Path = "/metrics" // Default + } + } + + return nil +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 9374718..4be4d20 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -15,13 +15,21 @@ func int64Max(a, b int64) int64 { // Metrics tracks various performance counters and statistics. type Metrics struct { - TasksProcessed atomic.Int64 - TasksFailed atomic.Int64 - DataFetchTime atomic.Int64 // Total nanoseconds - ExecutionTime atomic.Int64 - DataTransferred atomic.Int64 // Total bytes - ActiveTasks atomic.Int64 - QueuedTasks atomic.Int64 + TasksProcessed atomic.Int64 + TasksFailed atomic.Int64 + DataFetchTime atomic.Int64 // Total nanoseconds + ExecutionTime atomic.Int64 + DataTransferred atomic.Int64 // Total bytes + ActiveTasks atomic.Int64 + QueuedTasks atomic.Int64 + PrewarmEnvHit atomic.Int64 + PrewarmEnvMiss atomic.Int64 + PrewarmEnvBuilt atomic.Int64 + PrewarmEnvTime atomic.Int64 // Total nanoseconds + PrewarmSnapshotHit atomic.Int64 + PrewarmSnapshotMiss atomic.Int64 + PrewarmSnapshotBuilt atomic.Int64 + PrewarmSnapshotTime atomic.Int64 // Total nanoseconds } // RecordTaskSuccess records successful task completion with duration. @@ -53,6 +61,32 @@ func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) { m.DataFetchTime.Add(duration.Nanoseconds()) } +func (m *Metrics) RecordPrewarmEnvHit() { + m.PrewarmEnvHit.Add(1) +} + +func (m *Metrics) RecordPrewarmEnvMiss() { + m.PrewarmEnvMiss.Add(1) +} + +func (m *Metrics) RecordPrewarmEnvBuilt(duration time.Duration) { + m.PrewarmEnvBuilt.Add(1) + m.PrewarmEnvTime.Add(duration.Nanoseconds()) +} + +func (m *Metrics) RecordPrewarmSnapshotHit() { + m.PrewarmSnapshotHit.Add(1) +} + +func (m *Metrics) RecordPrewarmSnapshotMiss() { + m.PrewarmSnapshotMiss.Add(1) +} + +func (m *Metrics) RecordPrewarmSnapshotBuilt(duration time.Duration) { + m.PrewarmSnapshotBuilt.Add(1) + m.PrewarmSnapshotTime.Add(duration.Nanoseconds()) +} + // SetQueuedTasks sets the number of queued tasks. func (m *Metrics) SetQueuedTasks(count int64) { m.QueuedTasks.Store(count) @@ -66,13 +100,21 @@ func (m *Metrics) GetStats() map[string]any { dataFetchTime := m.DataFetchTime.Load() return map[string]any{ - "tasks_processed": processed, - "tasks_failed": failed, - "active_tasks": m.ActiveTasks.Load(), - "queued_tasks": m.QueuedTasks.Load(), - "success_rate": float64(processed-failed) / float64(int64Max(processed, 1)), - "avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)), - "data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024), - "avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)), + "tasks_processed": processed, + "tasks_failed": failed, + "active_tasks": m.ActiveTasks.Load(), + "queued_tasks": m.QueuedTasks.Load(), + "success_rate": float64(processed-failed) / float64(int64Max(processed, 1)), + "avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)), + "data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024), + "avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)), + "prewarm_env_hit": m.PrewarmEnvHit.Load(), + "prewarm_env_miss": m.PrewarmEnvMiss.Load(), + "prewarm_env_built": m.PrewarmEnvBuilt.Load(), + "prewarm_env_time": time.Duration(m.PrewarmEnvTime.Load()), + "prewarm_snapshot_hit": m.PrewarmSnapshotHit.Load(), + "prewarm_snapshot_miss": m.PrewarmSnapshotMiss.Load(), + "prewarm_snapshot_built": m.PrewarmSnapshotBuilt.Load(), + "prewarm_snapshot_time": time.Duration(m.PrewarmSnapshotTime.Load()), } } diff --git a/internal/middleware/security.go b/internal/middleware/security.go index c899b9e..22b8ca2 100644 --- a/internal/middleware/security.go +++ b/internal/middleware/security.go @@ -3,8 +3,11 @@ package middleware import ( "context" + "fmt" "log" + "net" "net/http" + "net/netip" "strings" "time" @@ -26,7 +29,11 @@ type RateLimitOptions struct { } // NewSecurityMiddleware creates a new security middleware instance. -func NewSecurityMiddleware(authConfig *auth.Config, jwtSecret string, rlOpts *RateLimitOptions) *SecurityMiddleware { +func NewSecurityMiddleware( + authConfig *auth.Config, + jwtSecret string, + rlOpts *RateLimitOptions, +) *SecurityMiddleware { sm := &SecurityMiddleware{ authConfig: authConfig, jwtSecret: []byte(jwtSecret), @@ -62,21 +69,23 @@ func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler { // APIKeyAuth provides API key authentication middleware. func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiKey := auth.ExtractAPIKeyFromRequest(r) - - // Validate API key using auth config - if sm.authConfig == nil { - http.Error(w, "Authentication not configured", http.StatusInternalServerError) + // If authentication is not configured or disabled, allow all requests. + // This keeps local/dev environments functional without requiring API keys. + if sm.authConfig == nil || !sm.authConfig.Enabled { + next.ServeHTTP(w, r) return } - _, err := sm.authConfig.ValidateAPIKey(apiKey) + apiKey := auth.ExtractAPIKeyFromRequest(r) + + // Validate API key using auth config + user, err := sm.authConfig.ValidateAPIKey(apiKey) if err != nil { http.Error(w, "Invalid API key", http.StatusUnauthorized) return } - - next.ServeHTTP(w, r) + ctx := auth.WithUserContext(r.Context(), user) + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -103,21 +112,54 @@ func SecurityHeaders(next http.Handler) http.Handler { // IPWhitelist provides IP whitelist middleware. func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler { + parsedAddrs := make([]netip.Addr, 0, len(allowedIPs)) + parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs)) + for _, raw := range allowedIPs { + val := strings.TrimSpace(raw) + if val == "" { + continue + } + if strings.Contains(val, "/") { + p, err := netip.ParsePrefix(val) + if err != nil { + log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err) + continue + } + parsedPrefixes = append(parsedPrefixes, p) + continue + } + a, err := netip.ParseAddr(val) + if err != nil { + log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err) + continue + } + parsedAddrs = append(parsedAddrs, a) + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - clientIP := getClientIP(r) + if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 { + http.Error(w, "IP not whitelisted", http.StatusForbidden) + return + } + + clientIPStr := getClientIP(r) + addr, err := parseClientIP(clientIPStr) + if err != nil { + http.Error(w, "IP not whitelisted", http.StatusForbidden) + return + } - // Check if client IP is in whitelist allowed := false - for _, ip := range allowedIPs { - if strings.Contains(ip, "/") { - // CIDR notation - would need proper IP net parsing - if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) { - allowed = true - break - } - } else { - if clientIP == ip { + for _, a := range parsedAddrs { + if a == addr { + allowed = true + break + } + } + if !allowed { + for _, p := range parsedPrefixes { + if p.Contains(addr) { allowed = true break } @@ -134,41 +176,46 @@ func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler } } -// CORS middleware with restrictive defaults -func CORS(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - - // Only allow specific origins in production - allowedOrigins := []string{ - "https://ml-experiments.example.com", - "https://app.example.com", +// CORS middleware with configured allowed origins +func CORS(allowedOrigins []string) func(http.Handler) http.Handler { + allowed := make([]string, 0, len(allowedOrigins)) + for _, o := range allowedOrigins { + v := strings.TrimSpace(o) + if v != "" { + allowed = append(allowed, v) } + } - isAllowed := false - for _, allowed := range allowedOrigins { - if origin == allowed { - isAllowed = true - break + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" { + isAllowed := false + for _, a := range allowed { + if a == "*" || origin == a { + isAllowed = true + break + } + } + if isAllowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + } } - } - if isAllowed { - w.Header().Set("Access-Control-Allow-Origin", origin) - } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Max-Age", "86400") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Max-Age", "86400") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusNoContent) - return - } - - next.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) + }) + } } // RequestTimeout provides request timeout middleware. @@ -253,12 +300,28 @@ func getClientIP(r *http.Request) string { } // Fall back to RemoteAddr - if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { - return r.RemoteAddr[:idx] + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + return host } return r.RemoteAddr } +func parseClientIP(raw string) (netip.Addr, error) { + s := strings.TrimSpace(raw) + if s == "" { + return netip.Addr{}, fmt.Errorf("empty ip") + } + // Try host:port parsing first (covers IPv4 and bracketed IPv6) + if host, _, err := net.SplitHostPort(s); err == nil { + s = host + } + // Trim brackets for IPv6 literals + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + return netip.ParseAddr(s) +} + // Response writer wrapper to capture status code type responseWriter struct { http.ResponseWriter @@ -273,5 +336,11 @@ func (rw *responseWriter) WriteHeader(code int) { func logSecurityEvent(event map[string]interface{}) { // Implementation would send to security monitoring system // For now, just log (in production, use proper logging) - log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"]) + log.Printf( + "SECURITY AUDIT: %s %s %s %v", + event["client_ip"], + event["method"], + event["path"], + event["status"], + ) } diff --git a/internal/prommetrics/prometheus.go b/internal/prommetrics/prometheus.go new file mode 100644 index 0000000..f2e9548 --- /dev/null +++ b/internal/prommetrics/prometheus.go @@ -0,0 +1,295 @@ +package prommetrics + +import ( + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// Metrics holds all Prometheus metrics for the application +type Metrics struct { + // WebSocket metrics + wsConnections *prometheus.GaugeVec + wsMessages *prometheus.CounterVec + wsDuration *prometheus.HistogramVec + wsErrors *prometheus.CounterVec + + // Job queue metrics + jobsQueued prometheus.Counter + jobsCompleted *prometheus.CounterVec + jobsActive prometheus.Gauge + jobDuration *prometheus.HistogramVec + queueLength prometheus.Gauge + + // Jupyter metrics + jupyterServices *prometheus.GaugeVec + jupyterOps *prometheus.CounterVec + + // HTTP metrics + httpRequests *prometheus.CounterVec + httpDuration *prometheus.HistogramVec + + // Prewarm metrics + prewarmSnapshotHit prometheus.Counter + prewarmSnapshotMiss prometheus.Counter + prewarmSnapshotBuilt prometheus.Counter + prewarmSnapshotTime prometheus.Histogram + + registry *prometheus.Registry +} + +// New creates a new Prometheus Metrics instance +func New() *Metrics { + m := &Metrics{ + registry: prometheus.NewRegistry(), + } + + m.initMetrics() + return m +} + +// initMetrics initializes all Prometheus metrics +func (m *Metrics) initMetrics() { + // WebSocket metrics + m.wsConnections = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "fetchml_websocket_connections", + Help: "Number of active WebSocket connections", + }, + []string{"status"}, + ) + + m.wsMessages = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "fetchml_websocket_messages_total", + Help: "Total number of WebSocket messages", + }, + []string{"opcode", "status"}, + ) + + m.wsDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "fetchml_websocket_duration_seconds", + Help: "WebSocket message processing duration", + Buckets: prometheus.DefBuckets, + }, + []string{"opcode"}, + ) + + m.wsErrors = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "fetchml_websocket_errors_total", + Help: "Total number of WebSocket errors", + }, + []string{"type"}, + ) + + // Job queue metrics + m.jobsQueued = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "fetchml_jobs_queued_total", + Help: "Total number of jobs queued", + }, + ) + + m.jobsCompleted = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "fetchml_jobs_completed_total", + Help: "Total number of completed jobs", + }, + []string{"status"}, + ) + + m.jobsActive = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "fetchml_jobs_active", + Help: "Number of currently active jobs", + }, + ) + + m.jobDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "fetchml_job_duration_seconds", + Help: "Job execution duration", + Buckets: []float64{1, 5, 10, 30, 60, 300, 600, 1800, 3600}, + }, + []string{"status"}, + ) + + m.queueLength = prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "fetchml_queue_length", + Help: "Current job queue length", + }, + ) + + // Jupyter metrics + m.jupyterServices = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "fetchml_jupyter_services", + Help: "Number of Jupyter services", + }, + []string{"status"}, + ) + + m.jupyterOps = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "fetchml_jupyter_operations_total", + Help: "Total number of Jupyter operations", + }, + []string{"operation", "status"}, + ) + + // HTTP metrics + m.httpRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "fetchml_http_requests_total", + Help: "Total number of HTTP requests", + }, + []string{"method", "endpoint", "status"}, + ) + + m.httpDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "fetchml_http_duration_seconds", + Help: "HTTP request duration", + Buckets: prometheus.DefBuckets, + }, + []string{"method", "endpoint"}, + ) + + // Prewarm metrics + m.prewarmSnapshotHit = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "fetchml_prewarm_snapshot_hit_total", + Help: "Total number of prewarmed snapshot hits (snapshots found in .prewarm/)", + }, + ) + + m.prewarmSnapshotMiss = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "fetchml_prewarm_snapshot_miss_total", + Help: "Total number of prewarmed snapshot misses (snapshots not found in .prewarm/)", + }, + ) + + m.prewarmSnapshotBuilt = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "fetchml_prewarm_snapshot_built_total", + Help: "Total number of snapshots prewarmed into .prewarm/", + }, + ) + + m.prewarmSnapshotTime = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Name: "fetchml_prewarm_snapshot_duration_seconds", + Help: "Time spent prewarming snapshots", + Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 30, 60, 120}, + }, + ) + + // Register all metrics + m.registry.MustRegister( + m.wsConnections, + m.wsMessages, + m.wsDuration, + m.wsErrors, + m.jobsQueued, + m.jobsCompleted, + m.jobsActive, + m.jobDuration, + m.queueLength, + m.jupyterServices, + m.jupyterOps, + m.httpRequests, + m.httpDuration, + m.prewarmSnapshotHit, + m.prewarmSnapshotMiss, + m.prewarmSnapshotBuilt, + m.prewarmSnapshotTime, + ) +} + +// Handler returns the Prometheus HTTP handler +func (m *Metrics) Handler() http.Handler { + return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{}) +} + +// WebSocket metrics methods +func (m *Metrics) IncWSConnections(status string) { + m.wsConnections.WithLabelValues(status).Inc() +} + +func (m *Metrics) DecWSConnections(status string) { + m.wsConnections.WithLabelValues(status).Dec() +} + +func (m *Metrics) IncWSMessages(opcode, status string) { + m.wsMessages.WithLabelValues(opcode, status).Inc() +} + +func (m *Metrics) ObserveWSDuration(opcode string, duration time.Duration) { + m.wsDuration.WithLabelValues(opcode).Observe(duration.Seconds()) +} + +func (m *Metrics) IncWSErrors(errType string) { + m.wsErrors.WithLabelValues(errType).Inc() +} + +// Job queue metrics methods +func (m *Metrics) IncJobsQueued() { + m.jobsQueued.Inc() +} + +func (m *Metrics) IncJobsCompleted(status string) { + m.jobsCompleted.WithLabelValues(status).Inc() +} + +func (m *Metrics) SetJobsActive(count float64) { + m.jobsActive.Set(count) +} + +func (m *Metrics) ObserveJobDuration(status string, duration time.Duration) { + m.jobDuration.WithLabelValues(status).Observe(duration.Seconds()) +} + +func (m *Metrics) SetQueueLength(length float64) { + m.queueLength.Set(length) +} + +// Jupyter metrics methods +func (m *Metrics) SetJupyterServices(status string, count float64) { + m.jupyterServices.WithLabelValues(status).Set(count) +} + +func (m *Metrics) IncJupyterOps(operation, status string) { + m.jupyterOps.WithLabelValues(operation, status).Inc() +} + +// HTTP metrics methods +func (m *Metrics) IncHTTPRequests(method, endpoint, status string) { + m.httpRequests.WithLabelValues(method, endpoint, status).Inc() +} + +func (m *Metrics) ObserveHTTPDuration(method, endpoint string, duration time.Duration) { + m.httpDuration.WithLabelValues(method, endpoint).Observe(duration.Seconds()) +} + +// Prewarm metrics methods +func (m *Metrics) IncPrewarmSnapshotHit() { + m.prewarmSnapshotHit.Inc() +} + +func (m *Metrics) IncPrewarmSnapshotMiss() { + m.prewarmSnapshotMiss.Inc() +} + +func (m *Metrics) IncPrewarmSnapshotBuilt() { + m.prewarmSnapshotBuilt.Inc() +} + +func (m *Metrics) ObservePrewarmSnapshotDuration(duration time.Duration) { + m.prewarmSnapshotTime.Observe(duration.Seconds()) +}