feat(api): refactor websocket handlers; add health and prometheus middleware
This commit is contained in:
parent
6ff5324e74
commit
add4a90e62
28 changed files with 4179 additions and 961 deletions
|
|
@ -5,7 +5,7 @@ WebSocket API server for the ML CLI tool...
|
|||
## Usage
|
||||
|
||||
```bash
|
||||
./bin/api-server --config configs/config-dev.yaml --listen :9100
|
||||
./bin/api-server --config configs/api/dev.yaml
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
configFile := flag.String("config", "configs/config-local.yaml", "Configuration file path")
|
||||
configFile := flag.String("config", "configs/api/dev.yaml", "Configuration file path")
|
||||
apiKey := flag.String("api-key", "", "API key for authentication")
|
||||
flag.Parse()
|
||||
|
||||
|
|
|
|||
15
go.mod
15
go.mod
|
|
@ -3,8 +3,8 @@ module github.com/jfraeys/fetch_ml
|
|||
go 1.25.0
|
||||
|
||||
// Fetch ML - Secure Machine Learning Platform
|
||||
// Copyright (c) 2024 Fetch ML
|
||||
// Licensed under the MIT License
|
||||
// Copyright (c) 2026 Fetch ML
|
||||
// Licensed under the FetchML Source-Available Research & Audit License (SARAL). See LICENSE.
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
|
|
@ -17,6 +17,7 @@ require (
|
|||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/minio/minio-go/v7 v7.0.97
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
|
|
@ -43,24 +44,34 @@ require (
|
|||
github.com/danieljoos/wincred v1.2.3 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/godbus/dbus/v5 v5.2.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.11 // indirect
|
||||
github.com/klauspost/crc32 v1.3.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/crc64nvme v1.1.0 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/philhofer/fwd v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.4 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/tinylib/msgp v1.3.0 // indirect
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
|
|
|
|||
21
go.sum
21
go.sum
|
|
@ -48,10 +48,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
|
||||
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
|
|
@ -66,6 +70,11 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
|||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU=
|
||||
github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM=
|
||||
github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
|
|
@ -84,6 +93,12 @@ github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byF
|
|||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q=
|
||||
github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ=
|
||||
github.com/minio/minio-go/v7 v7.0.97/go.mod h1:re5VXuo0pwEtoNLsNuSr0RrLfT/MBtohwdaSmPPSRSk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
|
|
@ -98,6 +113,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
|
||||
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
|
|
@ -114,6 +131,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
|
@ -122,6 +141,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/
|
|||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww=
|
||||
github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
|
|
@ -34,7 +35,6 @@ func NewHandlers(
|
|||
// RegisterHandlers registers all HTTP handlers with the mux
|
||||
func (h *Handlers) RegisterHandlers(mux *http.ServeMux) {
|
||||
// Health check endpoints
|
||||
mux.HandleFunc("/health", h.handleHealth)
|
||||
mux.HandleFunc("/db-status", h.handleDBStatus)
|
||||
|
||||
// Jupyter service endpoints
|
||||
|
|
@ -57,7 +57,7 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
|
|||
|
||||
// This would need the DB instance passed to handlers
|
||||
// For now, return a basic response
|
||||
response := map[string]interface{}{
|
||||
response := map[string]any{
|
||||
"status": "unknown",
|
||||
"message": "Database status check not implemented",
|
||||
}
|
||||
|
|
@ -72,13 +72,30 @@ func (h *Handlers) handleDBStatus(w http.ResponseWriter, _ *http.Request) {
|
|||
// handleJupyterServices handles Jupyter service management requests
|
||||
func (h *Handlers) handleJupyterServices(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if !user.HasPermission("jupyter:read") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.listJupyterServices(w, r)
|
||||
case http.MethodPost:
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.startJupyterService(w, r)
|
||||
case http.MethodDelete:
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.stopJupyterService(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
@ -140,7 +157,10 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{"status": "stopped", "id": serviceID}); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "stopped",
|
||||
"id": serviceID,
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -148,6 +168,15 @@ func (h *Handlers) stopJupyterService(w http.ResponseWriter, r *http.Request) {
|
|||
// handleJupyterExperimentLink handles linking Jupyter workspaces with experiments
|
||||
func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
@ -176,7 +205,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
|
|||
}
|
||||
|
||||
// Link workspace with experiment using service manager
|
||||
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(req.Workspace, req.ExperimentID, req.ServiceID); err != nil {
|
||||
if err := h.jupyterServiceMgr.LinkWorkspaceWithExperiment(
|
||||
req.Workspace,
|
||||
req.ExperimentID,
|
||||
req.ServiceID,
|
||||
); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to link workspace: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -184,7 +217,11 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
|
|||
// Get workspace metadata to return
|
||||
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError)
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to get workspace metadata: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -205,6 +242,15 @@ func (h *Handlers) handleJupyterExperimentLink(w http.ResponseWriter, r *http.Re
|
|||
// handleJupyterExperimentSync handles synchronization between Jupyter workspaces and experiments
|
||||
func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.HasPermission("jupyter:manage") {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
|
@ -245,7 +291,11 @@ func (h *Handlers) handleJupyterExperimentSync(w http.ResponseWriter, r *http.Re
|
|||
// Get updated metadata
|
||||
metadata, err := h.jupyterServiceMgr.GetWorkspaceMetadata(req.Workspace)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to get workspace metadata: %v", err), http.StatusInternalServerError)
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to get workspace metadata: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
82
internal/api/health.go
Normal file
82
internal/api/health.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of the service
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Checks map[string]string `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
// HealthHandler handles /health requests
|
||||
type HealthHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new health check handler
|
||||
func NewHealthHandler(s *Server) *HealthHandler {
|
||||
return &HealthHandler{server: s}
|
||||
}
|
||||
|
||||
// Health performs a basic health check
|
||||
func (h *HealthHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
status := HealthStatus{
|
||||
Status: "healthy",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Checks: make(map[string]string),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// Liveness performs a liveness probe (is the service running?)
|
||||
func (h *HealthHandler) Liveness(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
status := HealthStatus{
|
||||
Status: "alive",
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// Readiness performs a readiness probe (is the service ready to accept traffic?)
|
||||
func (h *HealthHandler) Readiness(w http.ResponseWriter, r *http.Request) {
|
||||
status := HealthStatus{
|
||||
Status: "ready",
|
||||
Timestamp: time.Now().UTC(),
|
||||
Checks: make(map[string]string),
|
||||
}
|
||||
|
||||
// Check Redis connection (if queue is configured)
|
||||
if h.server.taskQueue != nil {
|
||||
// Simple check - if queue exists, assume it's ready
|
||||
status.Checks["queue"] = "ok"
|
||||
}
|
||||
|
||||
// Check experiment manager
|
||||
if h.server.expManager != nil {
|
||||
status.Checks["experiments"] = "ok"
|
||||
}
|
||||
|
||||
// If all checks pass, we're ready
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// RegisterHealthRoutes registers health check routes
|
||||
func (h *HealthHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/health", h.Health)
|
||||
mux.HandleFunc("/health/live", h.Liveness)
|
||||
mux.HandleFunc("/health/ready", h.Readiness)
|
||||
}
|
||||
63
internal/api/metrics_middleware.go
Normal file
63
internal/api/metrics_middleware.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// wrapWithMetrics wraps a handler with Prometheus metrics tracking
|
||||
func (s *Server) wrapWithMetrics(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if s.promMetrics == nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Track HTTP request
|
||||
method := r.Method
|
||||
endpoint := r.URL.Path
|
||||
|
||||
// Wrap response writer to capture status code
|
||||
ww := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
// Serve request
|
||||
next.ServeHTTP(ww, r)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start)
|
||||
statusStr := http.StatusText(ww.statusCode)
|
||||
|
||||
s.promMetrics.IncHTTPRequests(method, endpoint, statusStr)
|
||||
s.promMetrics.ObserveHTTPDuration(method, endpoint, duration)
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Flush() {
|
||||
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
h, ok := rw.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("websocket: response does not implement http.Hijacker")
|
||||
}
|
||||
return h.Hijack()
|
||||
}
|
||||
20
internal/api/monitoring_config.go
Normal file
20
internal/api/monitoring_config.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package api
|
||||
|
||||
// MonitoringConfig holds monitoring-related configuration
|
||||
type MonitoringConfig struct {
|
||||
Prometheus PrometheusConfig `yaml:"prometheus"`
|
||||
HealthChecks HealthChecksConfig `yaml:"health_checks"`
|
||||
}
|
||||
|
||||
// PrometheusConfig holds Prometheus metrics configuration
|
||||
type PrometheusConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
// HealthChecksConfig holds health check configuration
|
||||
type HealthChecksConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Interval string `yaml:"interval"`
|
||||
}
|
||||
|
|
@ -4,9 +4,17 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 0, 256)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Response packet types
|
||||
const (
|
||||
PacketTypeSuccess = 0x00
|
||||
|
|
@ -126,7 +134,12 @@ func NewErrorPacket(errorCode byte, message string, details string) *ResponsePac
|
|||
}
|
||||
|
||||
// NewProgressPacket creates a progress response packet
|
||||
func NewProgressPacket(progressType byte, value uint32, total uint32, message string) *ResponsePacket {
|
||||
func NewProgressPacket(
|
||||
progressType byte,
|
||||
value uint32,
|
||||
total uint32,
|
||||
message string,
|
||||
) *ResponsePacket {
|
||||
return &ResponsePacket{
|
||||
PacketType: PacketTypeProgress,
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
|
|
@ -168,48 +181,65 @@ func NewLogPacket(level byte, message string) *ResponsePacket {
|
|||
|
||||
// Serialize converts the packet to binary format
|
||||
func (p *ResponsePacket) Serialize() ([]byte, error) {
|
||||
var buf []byte
|
||||
// For small packets, avoid pool overhead
|
||||
if p.estimatedSize() <= 1024 {
|
||||
buf := make([]byte, 0, p.estimatedSize())
|
||||
return serializePacketToBuffer(p, buf)
|
||||
}
|
||||
|
||||
// Use pool for larger packets
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
defer func() {
|
||||
*bufPtr = (*bufPtr)[:0]
|
||||
bufferPool.Put(bufPtr)
|
||||
}()
|
||||
buf := *bufPtr
|
||||
|
||||
// Ensure buffer has enough capacity
|
||||
if cap(buf) < p.estimatedSize() {
|
||||
buf = make([]byte, 0, p.estimatedSize())
|
||||
} else {
|
||||
buf = buf[:0]
|
||||
}
|
||||
|
||||
return serializePacketToBuffer(p, buf)
|
||||
}
|
||||
|
||||
func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) {
|
||||
// Packet type
|
||||
buf = append(buf, p.PacketType)
|
||||
|
||||
// Timestamp (8 bytes, big-endian)
|
||||
timestampBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timestampBytes, p.Timestamp)
|
||||
buf = append(buf, timestampBytes...)
|
||||
var timestampBytes [8]byte
|
||||
binary.BigEndian.PutUint64(timestampBytes[:], p.Timestamp)
|
||||
buf = append(buf, timestampBytes[:]...)
|
||||
|
||||
// Packet-specific data
|
||||
switch p.PacketType {
|
||||
case PacketTypeSuccess:
|
||||
buf = append(buf, serializeString(p.SuccessMessage)...)
|
||||
buf = appendString(buf, p.SuccessMessage)
|
||||
|
||||
case PacketTypeError:
|
||||
buf = append(buf, p.ErrorCode)
|
||||
buf = append(buf, serializeString(p.ErrorMessage)...)
|
||||
buf = append(buf, serializeString(p.ErrorDetails)...)
|
||||
buf = appendString(buf, p.ErrorMessage)
|
||||
buf = appendString(buf, p.ErrorDetails)
|
||||
|
||||
case PacketTypeProgress:
|
||||
buf = append(buf, p.ProgressType)
|
||||
valueBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(valueBytes, p.ProgressValue)
|
||||
buf = append(buf, valueBytes...)
|
||||
|
||||
totalBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(totalBytes, p.ProgressTotal)
|
||||
buf = append(buf, totalBytes...)
|
||||
|
||||
buf = append(buf, serializeString(p.ProgressMessage)...)
|
||||
buf = appendUint32(buf, p.ProgressValue)
|
||||
buf = appendUint32(buf, p.ProgressTotal)
|
||||
buf = appendString(buf, p.ProgressMessage)
|
||||
|
||||
case PacketTypeStatus:
|
||||
buf = append(buf, serializeString(p.StatusData)...)
|
||||
buf = appendString(buf, p.StatusData)
|
||||
|
||||
case PacketTypeData:
|
||||
buf = append(buf, serializeString(p.DataType)...)
|
||||
buf = append(buf, serializeBytes(p.DataPayload)...)
|
||||
buf = appendString(buf, p.DataType)
|
||||
buf = appendBytes(buf, p.DataPayload)
|
||||
|
||||
case PacketTypeLog:
|
||||
buf = append(buf, p.LogLevel)
|
||||
buf = append(buf, serializeString(p.LogMessage)...)
|
||||
buf = appendString(buf, p.LogMessage)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown packet type: %d", p.PacketType)
|
||||
|
|
@ -218,22 +248,49 @@ func (p *ResponsePacket) Serialize() ([]byte, error) {
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
// serializeString writes a string with 2-byte length prefix
|
||||
func serializeString(s string) []byte {
|
||||
length := uint16(len(s))
|
||||
buf := make([]byte, 2+len(s))
|
||||
binary.BigEndian.PutUint16(buf[:2], length)
|
||||
copy(buf[2:], s)
|
||||
return buf
|
||||
// appendString writes a string with varint length prefix
|
||||
func appendString(buf []byte, s string) []byte {
|
||||
length := uint64(len(s))
|
||||
var tmp [binary.MaxVarintLen64]byte
|
||||
n := binary.PutUvarint(tmp[:], length)
|
||||
buf = append(buf, tmp[:n]...)
|
||||
return append(buf, s...)
|
||||
}
|
||||
|
||||
// serializeBytes writes bytes with 4-byte length prefix
|
||||
func serializeBytes(b []byte) []byte {
|
||||
length := uint32(len(b))
|
||||
buf := make([]byte, 4+len(b))
|
||||
binary.BigEndian.PutUint32(buf[:4], length)
|
||||
copy(buf[4:], b)
|
||||
return buf
|
||||
// appendBytes writes bytes with varint length prefix
|
||||
func appendBytes(buf []byte, b []byte) []byte {
|
||||
length := uint64(len(b))
|
||||
var tmp [binary.MaxVarintLen64]byte
|
||||
n := binary.PutUvarint(tmp[:], length)
|
||||
buf = append(buf, tmp[:n]...)
|
||||
return append(buf, b...)
|
||||
}
|
||||
|
||||
func appendUint32(buf []byte, value uint32) []byte {
|
||||
var tmp [4]byte
|
||||
binary.BigEndian.PutUint32(tmp[:], value)
|
||||
return append(buf, tmp[:]...)
|
||||
}
|
||||
|
||||
func (p *ResponsePacket) estimatedSize() int {
|
||||
base := 1 + 8 // packet type + timestamp
|
||||
|
||||
switch p.PacketType {
|
||||
case PacketTypeSuccess:
|
||||
return base + binary.MaxVarintLen64 + len(p.SuccessMessage)
|
||||
case PacketTypeError:
|
||||
return base + 1 + 2*binary.MaxVarintLen64 + len(p.ErrorMessage) + len(p.ErrorDetails)
|
||||
case PacketTypeProgress:
|
||||
return base + 1 + 4 + 4 + binary.MaxVarintLen64 + len(p.ProgressMessage)
|
||||
case PacketTypeStatus:
|
||||
return base + binary.MaxVarintLen64 + len(p.StatusData)
|
||||
case PacketTypeData:
|
||||
return base + binary.MaxVarintLen64 + len(p.DataType) + binary.MaxVarintLen64 + len(p.DataPayload)
|
||||
case PacketTypeLog:
|
||||
return base + 1 + binary.MaxVarintLen64 + len(p.LogMessage)
|
||||
default:
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
// GetErrorMessage returns a human-readable error message for an error code
|
||||
|
|
|
|||
|
|
@ -1,155 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Simplified protocol using JSON instead of binary serialization
|
||||
|
||||
// Response represents a simplified API response
|
||||
type Response struct {
|
||||
Type string `json:"type"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error *ErrorInfo `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorInfo represents error information
|
||||
type ErrorInfo struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProgressInfo represents progress information
|
||||
type ProgressInfo struct {
|
||||
Type string `json:"type"`
|
||||
Value uint32 `json:"value"`
|
||||
Total uint32 `json:"total"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// LogInfo represents log information
|
||||
type LogInfo struct {
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Response types
|
||||
const (
|
||||
TypeSuccess = "success"
|
||||
TypeError = "error"
|
||||
TypeProgress = "progress"
|
||||
TypeStatus = "status"
|
||||
TypeData = "data"
|
||||
TypeLog = "log"
|
||||
)
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrUnknown = 0
|
||||
ErrInvalidRequest = 1
|
||||
ErrAuthFailed = 2
|
||||
ErrPermissionDenied = 3
|
||||
ErrNotFound = 4
|
||||
ErrExists = 5
|
||||
ErrServerOverload = 16
|
||||
ErrDatabase = 17
|
||||
ErrNetwork = 18
|
||||
ErrStorage = 19
|
||||
ErrTimeout = 20
|
||||
)
|
||||
|
||||
// NewSuccessResponse creates a success response
|
||||
func NewSuccessResponse(message string) *Response {
|
||||
return &Response{
|
||||
Type: TypeSuccess,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSuccessResponseWithData creates a success response with data
|
||||
func NewSuccessResponseWithData(message string, data interface{}) *Response {
|
||||
return &Response{
|
||||
Type: TypeData,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: map[string]interface{}{
|
||||
"message": message,
|
||||
"payload": data,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse creates an error response
|
||||
func NewErrorResponse(code int, message, details string) *Response {
|
||||
return &Response{
|
||||
Type: TypeError,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Error: &ErrorInfo{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewProgressResponse creates a progress response
|
||||
func NewProgressResponse(progressType string, value, total uint32, message string) *Response {
|
||||
return &Response{
|
||||
Type: TypeProgress,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: ProgressInfo{
|
||||
Type: progressType,
|
||||
Value: value,
|
||||
Total: total,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewStatusResponse creates a status response
|
||||
func NewStatusResponse(data string) *Response {
|
||||
return &Response{
|
||||
Type: TypeStatus,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDataResponse creates a data response
|
||||
func NewDataResponse(dataType string, payload interface{}) *Response {
|
||||
return &Response{
|
||||
Type: TypeData,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: map[string]interface{}{
|
||||
"type": dataType,
|
||||
"payload": payload,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewLogResponse creates a log response
|
||||
func NewLogResponse(level, message string) *Response {
|
||||
return &Response{
|
||||
Type: TypeLog,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: LogInfo{
|
||||
Level: level,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToJSON converts the response to JSON bytes
|
||||
func (r *Response) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
// FromJSON creates a response from JSON bytes
|
||||
func FromJSON(data []byte) (*Response, error) {
|
||||
var response Response
|
||||
err := json.Unmarshal(data, &response)
|
||||
return &response, err
|
||||
}
|
||||
|
|
@ -5,13 +5,17 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/middleware"
|
||||
"github.com/jfraeys/fetch_ml/internal/prommetrics"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
|
@ -22,12 +26,14 @@ type Server struct {
|
|||
httpServer *http.Server
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
taskQueue *queue.TaskQueue
|
||||
taskQueue queue.Backend
|
||||
db *storage.DB
|
||||
handlers *Handlers
|
||||
sec *middleware.SecurityMiddleware
|
||||
cleanupFuncs []func()
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
auditLogger *audit.Logger
|
||||
promMetrics *prommetrics.Metrics // Prometheus metrics
|
||||
}
|
||||
|
||||
// NewServer creates a new API server
|
||||
|
|
@ -80,6 +86,11 @@ func (s *Server) initializeComponents() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Initialize database schema (if DB enabled)
|
||||
if err := s.initDatabaseSchema(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize security
|
||||
s.initSecurity()
|
||||
|
||||
|
|
@ -87,11 +98,29 @@ func (s *Server) initializeComponents() error {
|
|||
s.initJupyterServiceManager()
|
||||
|
||||
// Initialize handlers
|
||||
s.handlers = NewHandlers(s.expManager, s.jupyterServiceMgr, s.logger)
|
||||
s.handlers = NewHandlers(s.expManager, nil, s.logger)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initDatabaseSchema() error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
schema, err := storage.SchemaForDBType(s.config.Database.Type)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Initialize(schema); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("database schema initialized", "type", s.config.Database.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupLogger creates and configures the logger
|
||||
func (s *Server) setupLogger() *logging.Logger {
|
||||
logger := logging.NewLoggerFromConfig(s.config.Logging)
|
||||
|
|
@ -112,26 +141,38 @@ func (s *Server) initExperimentManager() error {
|
|||
|
||||
// initTaskQueue initializes the task queue
|
||||
func (s *Server) initTaskQueue() error {
|
||||
queueCfg := queue.Config{
|
||||
RedisAddr: s.config.Redis.Addr,
|
||||
RedisPassword: s.config.Redis.Password,
|
||||
RedisDB: s.config.Redis.DB,
|
||||
backend := strings.ToLower(strings.TrimSpace(s.config.Queue.Backend))
|
||||
if backend == "" {
|
||||
backend = "redis"
|
||||
}
|
||||
redisAddr := strings.TrimSpace(s.config.Redis.Addr)
|
||||
if redisAddr == "" {
|
||||
redisAddr = "localhost:6379"
|
||||
}
|
||||
if strings.TrimSpace(s.config.Redis.URL) != "" {
|
||||
redisAddr = strings.TrimSpace(s.config.Redis.URL)
|
||||
}
|
||||
|
||||
if queueCfg.RedisAddr == "" {
|
||||
queueCfg.RedisAddr = "localhost:6379"
|
||||
}
|
||||
if s.config.Redis.URL != "" {
|
||||
queueCfg.RedisAddr = s.config.Redis.URL
|
||||
backendCfg := queue.BackendConfig{
|
||||
Backend: queue.QueueBackend(backend),
|
||||
RedisAddr: redisAddr,
|
||||
RedisPassword: s.config.Redis.Password,
|
||||
RedisDB: s.config.Redis.DB,
|
||||
SQLitePath: s.config.Queue.SQLitePath,
|
||||
MetricsFlushInterval: 0,
|
||||
}
|
||||
|
||||
taskQueue, err := queue.NewTaskQueue(queueCfg)
|
||||
taskQueue, err := queue.NewBackend(backendCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.taskQueue = taskQueue
|
||||
s.logger.Info("task queue initialized", "redis_addr", queueCfg.RedisAddr)
|
||||
if backend == "sqlite" {
|
||||
s.logger.Info("task queue initialized", "backend", backend, "sqlite_path", s.config.Queue.SQLitePath)
|
||||
} else {
|
||||
s.logger.Info("task queue initialized", "backend", backend, "redis_addr", redisAddr)
|
||||
}
|
||||
|
||||
// Add cleanup function
|
||||
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
||||
|
|
@ -220,8 +261,67 @@ func (s *Server) initJupyterServiceManager() {
|
|||
func (s *Server) setupHTTPServer() {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register WebSocket handler
|
||||
wsHandler := NewWSHandler(s.config.BuildAuthConfig(), s.logger, s.expManager, s.taskQueue)
|
||||
// Initialize Prometheus metrics (if enabled)
|
||||
if s.config.Monitoring.Prometheus.Enabled {
|
||||
s.promMetrics = prommetrics.New()
|
||||
s.logger.Info("prometheus metrics initialized")
|
||||
|
||||
// Register metrics endpoint
|
||||
metricsPath := s.config.Monitoring.Prometheus.Path
|
||||
if metricsPath == "" {
|
||||
metricsPath = "/metrics"
|
||||
}
|
||||
mux.Handle(metricsPath, s.promMetrics.Handler())
|
||||
s.logger.Info("metrics endpoint registered", "path", metricsPath)
|
||||
}
|
||||
|
||||
// Initialize health check handler
|
||||
if s.config.Monitoring.HealthChecks.Enabled {
|
||||
healthHandler := NewHealthHandler(s)
|
||||
healthHandler.RegisterRoutes(mux)
|
||||
mux.HandleFunc("/health/ok", s.handlers.handleHealth)
|
||||
s.logger.Info("health check endpoints registered")
|
||||
}
|
||||
|
||||
// Initialize audit logger
|
||||
var auditLogger *audit.Logger
|
||||
if s.config.Security.AuditLogging.Enabled && s.config.Security.AuditLogging.LogPath != "" {
|
||||
al, err := audit.NewLogger(
|
||||
s.config.Security.AuditLogging.Enabled,
|
||||
s.config.Security.AuditLogging.LogPath,
|
||||
s.logger,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to initialize audit logger", "error", err)
|
||||
} else {
|
||||
auditLogger = al
|
||||
s.auditLogger = al
|
||||
|
||||
// Add cleanup function
|
||||
s.cleanupFuncs = append(s.cleanupFuncs, func() {
|
||||
s.logger.Info("closing audit logger...")
|
||||
if err := auditLogger.Close(); err != nil {
|
||||
s.logger.Error("failed to close audit logger", "error", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Register WebSocket handler with security config and audit logger
|
||||
securityCfg := getSecurityConfig(s.config)
|
||||
wsHandler := NewWSHandler(
|
||||
s.config.BuildAuthConfig(),
|
||||
s.logger,
|
||||
s.expManager,
|
||||
s.config.DataDir,
|
||||
s.taskQueue,
|
||||
s.db,
|
||||
s.jupyterServiceMgr,
|
||||
securityCfg,
|
||||
auditLogger,
|
||||
)
|
||||
|
||||
// Wrap WebSocket handler with metrics
|
||||
mux.Handle("/ws", wsHandler)
|
||||
|
||||
// Register HTTP handlers
|
||||
|
|
@ -229,6 +329,7 @@ func (s *Server) setupHTTPServer() {
|
|||
|
||||
// Wrap with middleware
|
||||
finalHandler := s.wrapWithMiddleware(mux)
|
||||
finalHandler = s.wrapWithMetrics(finalHandler)
|
||||
|
||||
s.httpServer = &http.Server{
|
||||
Addr: s.config.Server.Address,
|
||||
|
|
@ -239,10 +340,25 @@ func (s *Server) setupHTTPServer() {
|
|||
}
|
||||
}
|
||||
|
||||
// getSecurityConfig extracts security config from server config
|
||||
func getSecurityConfig(cfg *ServerConfig) *config.SecurityConfig {
|
||||
return &config.SecurityConfig{
|
||||
AllowedOrigins: cfg.Security.AllowedOrigins,
|
||||
ProductionMode: cfg.Security.ProductionMode,
|
||||
APIKeyRotationDays: cfg.Security.APIKeyRotationDays,
|
||||
AuditLogging: config.AuditLoggingConfig{
|
||||
Enabled: cfg.Security.AuditLogging.Enabled,
|
||||
LogPath: cfg.Security.AuditLogging.LogPath,
|
||||
},
|
||||
IPWhitelist: cfg.Security.IPWhitelist,
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWithMiddleware wraps the handler with security middleware
|
||||
func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ws" {
|
||||
// Skip auth for WebSocket and health endpoints
|
||||
if r.URL.Path == "/ws" || strings.HasPrefix(r.URL.Path, "/health") {
|
||||
mux.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
|
@ -250,7 +366,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
|
|||
handler := s.sec.APIKeyAuth(mux)
|
||||
handler = s.sec.RateLimit(handler)
|
||||
handler = middleware.SecurityHeaders(handler)
|
||||
handler = middleware.CORS(handler)
|
||||
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
|
||||
handler = middleware.RequestTimeout(30 * time.Second)(handler)
|
||||
handler = middleware.AuditLogger(handler)
|
||||
if len(s.config.Security.IPWhitelist) > 0 {
|
||||
|
|
|
|||
|
|
@ -1,26 +1,37 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type QueueConfig struct {
|
||||
Backend string `yaml:"backend"`
|
||||
SQLitePath string `yaml:"sqlite_path"`
|
||||
}
|
||||
|
||||
// ServerConfig holds all server configuration
|
||||
type ServerConfig struct {
|
||||
BasePath string `yaml:"base_path"`
|
||||
Auth auth.Config `yaml:"auth"`
|
||||
Server ServerSection `yaml:"server"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Logging logging.Config `yaml:"logging"`
|
||||
Resources config.ResourceConfig `yaml:"resources"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
Auth auth.Config `yaml:"auth"`
|
||||
Server ServerSection `yaml:"server"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Monitoring MonitoringConfig `yaml:"monitoring"`
|
||||
Queue QueueConfig `yaml:"queue"`
|
||||
Redis RedisConfig `yaml:"redis"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Logging logging.Config `yaml:"logging"`
|
||||
Resources config.ResourceConfig `yaml:"resources"`
|
||||
}
|
||||
|
||||
// ServerSection holds server-specific configuration
|
||||
|
|
@ -38,9 +49,19 @@ type TLSConfig struct {
|
|||
|
||||
// SecurityConfig holds security-related configuration
|
||||
type SecurityConfig struct {
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
AuditLogging AuditLog `yaml:"audit_logging"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
FailedLockout LockoutConfig `yaml:"failed_login_lockout"`
|
||||
}
|
||||
|
||||
// AuditLog holds audit logging configuration
|
||||
type AuditLog struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogPath string `yaml:"log_path"`
|
||||
}
|
||||
|
||||
// RateLimitConfig holds rate limiting configuration
|
||||
|
|
@ -108,9 +129,7 @@ func loadConfigFromFile(path string) (*ServerConfig, error) {
|
|||
|
||||
// secureFileRead safely reads a file
|
||||
func secureFileRead(path string) ([]byte, error) {
|
||||
// This would use the fileutil.SecureFileRead function
|
||||
// For now, implement basic file reading
|
||||
return os.ReadFile(path)
|
||||
return fileutil.SecureFileRead(path)
|
||||
}
|
||||
|
||||
// EnsureLogDirectory creates the log directory if needed
|
||||
|
|
@ -144,6 +163,27 @@ func (c *ServerConfig) Validate() error {
|
|||
if c.BasePath == "" {
|
||||
c.BasePath = "/tmp/ml-experiments"
|
||||
}
|
||||
if c.DataDir == "" {
|
||||
c.DataDir = config.DefaultDataDir
|
||||
}
|
||||
|
||||
backend := strings.ToLower(strings.TrimSpace(c.Queue.Backend))
|
||||
if backend == "" {
|
||||
backend = "redis"
|
||||
c.Queue.Backend = backend
|
||||
}
|
||||
if backend != "redis" && backend != "sqlite" {
|
||||
return fmt.Errorf("queue.backend must be one of 'redis' or 'sqlite'")
|
||||
}
|
||||
if backend == "sqlite" {
|
||||
if strings.TrimSpace(c.Queue.SQLitePath) == "" {
|
||||
c.Queue.SQLitePath = filepath.Join(c.DataDir, "queue.db")
|
||||
}
|
||||
c.Queue.SQLitePath = config.ExpandPath(c.Queue.SQLitePath)
|
||||
if !filepath.IsAbs(c.Queue.SQLitePath) {
|
||||
c.Queue.SQLitePath = filepath.Join(config.DefaultLocalDataDir, c.Queue.SQLitePath)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,652 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/telemetry"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// Opcodes for binary WebSocket protocol
|
||||
const (
|
||||
OpcodeQueueJob = 0x01
|
||||
OpcodeStatusRequest = 0x02
|
||||
OpcodeCancelJob = 0x03
|
||||
OpcodePrune = 0x04
|
||||
OpcodeLogMetric = 0x0A
|
||||
OpcodeGetExperiment = 0x0B
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Allow localhost and homelab origins for development
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // Allow same-origin requests
|
||||
}
|
||||
|
||||
// Parse origin URL
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Allow localhost and local network origins
|
||||
host := parsedOrigin.Host
|
||||
return strings.HasSuffix(host, ":8080") ||
|
||||
strings.HasPrefix(host, "localhost:") ||
|
||||
strings.HasPrefix(host, "127.0.0.1:") ||
|
||||
strings.HasPrefix(host, "192.168.") ||
|
||||
strings.HasPrefix(host, "10.") ||
|
||||
strings.HasPrefix(host, "172.")
|
||||
},
|
||||
// Performance optimizations
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
EnableCompression: true,
|
||||
}
|
||||
|
||||
// WSHandler handles WebSocket connections for the API.
|
||||
type WSHandler struct {
|
||||
authConfig *auth.Config
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
queue *queue.TaskQueue
|
||||
}
|
||||
|
||||
// NewWSHandler creates a new WebSocket handler.
|
||||
func NewWSHandler(
|
||||
authConfig *auth.Config,
|
||||
logger *logging.Logger,
|
||||
expManager *experiment.Manager,
|
||||
taskQueue *queue.TaskQueue,
|
||||
) *WSHandler {
|
||||
return &WSHandler{
|
||||
authConfig: authConfig,
|
||||
logger: logger,
|
||||
expManager: expManager,
|
||||
queue: taskQueue,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check API key before upgrading WebSocket
|
||||
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
||||
|
||||
// Validate API key if authentication is enabled
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
prefixLen := len(apiKey)
|
||||
if prefixLen > 8 {
|
||||
prefixLen = 8
|
||||
}
|
||||
h.logger.Info("websocket auth attempt", "api_key_length", len(apiKey), "api_key_prefix", apiKey[:prefixLen])
|
||||
if _, err := h.authConfig.ValidateAPIKey(apiKey); err != nil {
|
||||
h.logger.Warn("websocket authentication failed", "error", err)
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h.logger.Info("websocket authentication succeeded")
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
||||
|
||||
for {
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
h.logger.Warn("received non-binary message")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.handleMessage(conn, message); err != nil {
|
||||
h.logger.Error("message handling error", "error", err)
|
||||
// Send error response
|
||||
_ = conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00}) // Error opcode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
||||
if len(message) < 1 {
|
||||
return fmt.Errorf("message too short")
|
||||
}
|
||||
|
||||
opcode := message[0]
|
||||
payload := message[1:]
|
||||
|
||||
switch opcode {
|
||||
case OpcodeQueueJob:
|
||||
return h.handleQueueJob(conn, payload)
|
||||
case OpcodeStatusRequest:
|
||||
return h.handleStatusRequest(conn, payload)
|
||||
case OpcodeCancelJob:
|
||||
return h.handleCancelJob(conn, payload)
|
||||
case OpcodePrune:
|
||||
return h.handlePrune(conn, payload)
|
||||
case OpcodeLogMetric:
|
||||
return h.handleLogMetric(conn, payload)
|
||||
case OpcodeGetExperiment:
|
||||
return h.handleGetExperiment(conn, payload)
|
||||
default:
|
||||
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleQueueJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
|
||||
if len(payload) < 130 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "queue job payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
priority := int64(payload[128])
|
||||
jobNameLen := int(payload[129])
|
||||
|
||||
if len(payload) < 130+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
|
||||
jobName := string(payload[130 : 130+jobNameLen])
|
||||
|
||||
h.logger.Info("queue job request",
|
||||
"job", jobName,
|
||||
"priority", priority,
|
||||
"commit_id", commitID,
|
||||
)
|
||||
|
||||
// Validate API key and get user information
|
||||
var user *auth.User
|
||||
var err error
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
// Auth disabled - use default admin user
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Check user permissions
|
||||
if h.authConfig == nil || !h.authConfig.Enabled || user.HasPermission("jobs:create") {
|
||||
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
||||
} else {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:create")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to create jobs", "")
|
||||
}
|
||||
|
||||
// Create experiment directory and metadata (optimized)
|
||||
if _, err := telemetry.ExecWithMetrics(h.logger, "experiment.create", 50*time.Millisecond, func() (string, error) {
|
||||
return "", h.expManager.CreateExperiment(commitID)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to create experiment directory", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to create experiment directory", err.Error())
|
||||
}
|
||||
|
||||
// Add user info to experiment metadata (deferred for performance)
|
||||
go func() {
|
||||
meta := &experiment.Metadata{
|
||||
CommitID: commitID,
|
||||
JobName: jobName,
|
||||
User: user.Name,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
if _, err := telemetry.ExecWithMetrics(
|
||||
h.logger, "experiment.write_metadata", 50*time.Millisecond, func() (string, error) {
|
||||
return "", h.expManager.WriteMetadata(meta)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to save experiment metadata", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h.logger.Info("job queued", "job", jobName, "path", h.expManager.GetExperimentPath(commitID), "user", user.Name)
|
||||
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' queued successfully", jobName))
|
||||
|
||||
// Enqueue task if queue is available
|
||||
if h.queue != nil {
|
||||
taskID := uuid.New().String()
|
||||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: "",
|
||||
Status: "queued",
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.Name,
|
||||
Username: user.Name,
|
||||
CreatedBy: user.Name,
|
||||
Metadata: map[string]string{
|
||||
"commit_id": commitID, // Reduced redundant metadata
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := telemetry.ExecWithMetrics(h.logger, "queue.add_task", 20*time.Millisecond, func() (string, error) {
|
||||
return "", h.queue.AddTask(task)
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to enqueue task", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue task", err.Error())
|
||||
}
|
||||
h.logger.Info("task enqueued", "task_id", taskID, "job", jobName, "user", user.Name)
|
||||
} else {
|
||||
h.logger.Warn("task queue not initialized, job not enqueued", "job", jobName)
|
||||
}
|
||||
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStatusRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64]
|
||||
if len(payload) < 64 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "status request payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[0:64])
|
||||
h.logger.Info("status request received", "api_key_hash", apiKeyHash[:16]+"...")
|
||||
|
||||
// Validate API key and get user information
|
||||
var user *auth.User
|
||||
var err error
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
// Auth disabled - use default admin user
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Check user permissions for viewing jobs
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:read")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to view jobs", "")
|
||||
}
|
||||
// Get tasks with user filtering
|
||||
var tasks []*queue.Task
|
||||
if h.queue != nil {
|
||||
allTasks, err := h.queue.GetAllTasks()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get tasks", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to retrieve tasks", err.Error())
|
||||
}
|
||||
|
||||
// Filter tasks based on user permissions
|
||||
for _, task := range allTasks {
|
||||
// If auth is disabled or admin can see all tasks
|
||||
if h.authConfig == nil || !h.authConfig.Enabled || user.Admin {
|
||||
tasks = append(tasks, task)
|
||||
continue
|
||||
}
|
||||
|
||||
// Users can only see their own tasks
|
||||
if task.UserID == user.Name || task.CreatedBy == user.Name {
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build status response as raw JSON for CLI compatibility
|
||||
h.logger.Info("building status response")
|
||||
status := map[string]interface{}{
|
||||
"user": map[string]interface{}{
|
||||
"name": user.Name,
|
||||
"admin": user.Admin,
|
||||
"roles": user.Roles,
|
||||
},
|
||||
"tasks": map[string]interface{}{
|
||||
"total": len(tasks),
|
||||
"queued": countTasksByStatus(tasks, "queued"),
|
||||
"running": countTasksByStatus(tasks, "running"),
|
||||
"failed": countTasksByStatus(tasks, "failed"),
|
||||
"completed": countTasksByStatus(tasks, "completed"),
|
||||
},
|
||||
"queue": tasks,
|
||||
}
|
||||
|
||||
h.logger.Info("serializing JSON response")
|
||||
jsonData, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to marshal JSON", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
}
|
||||
|
||||
h.logger.Info("sending websocket JSON response", "len", len(jsonData))
|
||||
return conn.WriteMessage(websocket.BinaryMessage, jsonData)
|
||||
}
|
||||
|
||||
// countTasksByStatus counts tasks by their status
|
||||
func countTasksByStatus(tasks []*queue.Task, status string) int {
|
||||
count := 0
|
||||
for _, task := range tasks {
|
||||
if task.Status == status {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleCancelJob(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][job_name_len:1][job_name:var]
|
||||
if len(payload) < 65 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "cancel job payload too short", "")
|
||||
}
|
||||
|
||||
// Parse 64-byte hex API key hash
|
||||
apiKeyHash := string(payload[0:64])
|
||||
jobNameLen := int(payload[64])
|
||||
|
||||
if len(payload) < 65+jobNameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid job name length", "")
|
||||
}
|
||||
|
||||
jobName := string(payload[65 : 65+jobNameLen])
|
||||
|
||||
h.logger.Info("cancel job request", "job", jobName)
|
||||
|
||||
// Validate API key and get user information
|
||||
var user *auth.User
|
||||
var err error
|
||||
if h.authConfig != nil {
|
||||
user, err = h.authConfig.ValidateAPIKey(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
} else {
|
||||
// Auth disabled - use default admin user
|
||||
user = &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Check user permissions for canceling jobs
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:update") {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", "jobs:update")
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions to cancel jobs", "")
|
||||
}
|
||||
// Find the task and verify ownership
|
||||
if h.queue != nil {
|
||||
task, err := h.queue.GetTaskByName(jobName)
|
||||
if err != nil {
|
||||
h.logger.Error("task not found", "job", jobName, "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", err.Error())
|
||||
}
|
||||
|
||||
// Check if user can cancel this task (admin or owner)
|
||||
if h.authConfig.Enabled && !user.Admin && task.UserID != user.Name && task.CreatedBy != user.Name {
|
||||
h.logger.Error("unauthorized job cancellation attempt", "user", user.Name, "job", jobName, "task_owner", task.UserID)
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "You can only cancel your own jobs", "")
|
||||
}
|
||||
// Cancel the task
|
||||
if err := h.queue.CancelTask(task.ID); err != nil {
|
||||
h.logger.Error("failed to cancel task", "job", jobName, "task_id", task.ID, "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobExecutionFailed, "Failed to cancel job", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("job cancelled", "job", jobName, "task_id", task.ID, "user", user.Name)
|
||||
} else {
|
||||
h.logger.Warn("task queue not initialized, cannot cancel job", "job", jobName)
|
||||
}
|
||||
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Job '%s' cancelled successfully", jobName))
|
||||
packetData, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize packet", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Internal error", "Failed to serialize response")
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, packetData)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handlePrune(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][prune_type:1][value:4]
|
||||
if len(payload) < 69 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "prune payload too short", "")
|
||||
}
|
||||
|
||||
// Parse 64-byte hex API key hash
|
||||
apiKeyHash := string(payload[0:64])
|
||||
pruneType := payload[64]
|
||||
value := binary.BigEndian.Uint32(payload[65:69])
|
||||
|
||||
h.logger.Info("prune request", "type", pruneType, "value", value)
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
h.logger.Error("api key verification failed", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Convert prune parameters
|
||||
var keepCount int
|
||||
var olderThanDays int
|
||||
|
||||
switch pruneType {
|
||||
case 0:
|
||||
// keep N
|
||||
keepCount = int(value)
|
||||
olderThanDays = 0
|
||||
case 1:
|
||||
// older than days
|
||||
keepCount = 0
|
||||
olderThanDays = int(value)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, fmt.Sprintf("invalid prune type: %d", pruneType), "")
|
||||
}
|
||||
|
||||
// Perform pruning
|
||||
pruned, err := h.expManager.PruneExperiments(keepCount, olderThanDays)
|
||||
if err != nil {
|
||||
h.logger.Error("prune failed", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Prune operation failed", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("prune completed", "count", len(pruned), "experiments", pruned)
|
||||
|
||||
// Send structured success response
|
||||
packet := NewSuccessPacket(fmt.Sprintf("Pruned %d experiments", len(pruned)))
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
|
||||
if len(payload) < 141 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "log metric payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
step := int(binary.BigEndian.Uint32(payload[128:132]))
|
||||
valueBits := binary.BigEndian.Uint64(payload[132:140])
|
||||
value := math.Float64frombits(valueBits)
|
||||
nameLen := int(payload[140])
|
||||
|
||||
if len(payload) < 141+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid metric name length", "")
|
||||
}
|
||||
|
||||
name := string(payload[141 : 141+nameLen])
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.expManager.LogMetric(commitID, name, value, step); err != nil {
|
||||
h.logger.Error("failed to log metric", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to log metric", err.Error())
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Metric logged"))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleGetExperiment(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:64][commit_id:64]
|
||||
if len(payload) < 128 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "get experiment payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := string(payload[:64])
|
||||
commitID := string(payload[64:128])
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Authentication failed", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
meta, err := h.expManager.ReadMetadata(commitID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Experiment not found", err.Error())
|
||||
}
|
||||
|
||||
metrics, err := h.expManager.GetMetrics(commitID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeStorageError, "Failed to read metrics", err.Error())
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"metadata": meta,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacketWithPayload("Experiment details", response))
|
||||
}
|
||||
|
||||
// SetupTLSConfig creates TLS configuration for WebSocket server
|
||||
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
|
||||
var server *http.Server
|
||||
|
||||
if certFile != "" && keyFile != "" {
|
||||
// Use provided certificates
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if host != "" {
|
||||
// Use Let's Encrypt with autocert
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(host),
|
||||
Cache: autocert.DirCache("/var/www/.cache"),
|
||||
}
|
||||
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: certManager.TLSConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// verifyAPIKeyHash verifies the provided hex hash against stored API keys
|
||||
func (h *WSHandler) verifyAPIKeyHash(hexHash string) error {
|
||||
if h.authConfig == nil || !h.authConfig.Enabled {
|
||||
return nil // No auth required
|
||||
}
|
||||
|
||||
// For now, just check if it's a valid 64-char hex string
|
||||
if len(hexHash) != 64 {
|
||||
return fmt.Errorf("invalid api key hash length")
|
||||
}
|
||||
|
||||
// Check against stored API keys
|
||||
for username, entry := range h.authConfig.APIKeys {
|
||||
if string(entry.Hash) == hexHash {
|
||||
_ = username // Username found but not needed for verification
|
||||
return nil // Valid API key found
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid api key")
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet
|
||||
func (h *WSHandler) sendErrorPacket(conn *websocket.Conn, errorCode byte, message string, details string) error {
|
||||
packet := NewErrorPacket(errorCode, message, details)
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
// sendResponsePacket sends a structured response packet
|
||||
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize response packet", "error", err)
|
||||
// Fallback to simple error response
|
||||
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
|
||||
}
|
||||
|
||||
return conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// sendErrorResponse removed (unused)
|
||||
208
internal/api/ws_datasets.go
Normal file
208
internal/api/ws_datasets.go
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16]
|
||||
if len(payload) < 16 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset list payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.ListDatasets(ctx, 0)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(datasets)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var][url_len:2][url:var]
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset register payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if urlLen <= 0 || len(payload) < offset+urlLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
|
||||
}
|
||||
urlStr := string(payload[offset : offset+urlLen])
|
||||
|
||||
// Minimal validation (server-side authoritative): name non-empty and url parseable.
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
|
||||
}
|
||||
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset info payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ds, err := h.db.GetDataset(ctx, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
|
||||
}
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ds)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][term_len:1][term:var]
|
||||
if len(payload) < 16+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset search payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
termLen := int(payload[offset])
|
||||
offset++
|
||||
if termLen < 0 || len(payload) < offset+termLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
|
||||
}
|
||||
term := string(payload[offset : offset+termLen])
|
||||
term = strings.TrimSpace(term)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.SearchDatasets(ctx, term, 0)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(datasets)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
||||
}
|
||||
279
internal/api/ws_handler.go
Normal file
279
internal/api/ws_handler.go
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Opcodes for binary WebSocket protocol
|
||||
const (
|
||||
OpcodeQueueJob = 0x01
|
||||
OpcodeStatusRequest = 0x02
|
||||
OpcodeCancelJob = 0x03
|
||||
OpcodePrune = 0x04
|
||||
OpcodeDatasetList = 0x06
|
||||
OpcodeDatasetRegister = 0x07
|
||||
OpcodeDatasetInfo = 0x08
|
||||
OpcodeDatasetSearch = 0x09
|
||||
OpcodeLogMetric = 0x0A
|
||||
OpcodeGetExperiment = 0x0B
|
||||
OpcodeQueueJobWithTracking = 0x0C
|
||||
OpcodeQueueJobWithSnapshot = 0x17
|
||||
OpcodeStartJupyter = 0x0D
|
||||
OpcodeStopJupyter = 0x0E
|
||||
OpcodeRemoveJupyter = 0x18
|
||||
OpcodeRestoreJupyter = 0x19
|
||||
OpcodeListJupyter = 0x0F
|
||||
OpcodeValidateRequest = 0x16
|
||||
)
|
||||
|
||||
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
||||
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // Allow same-origin requests
|
||||
}
|
||||
|
||||
// Production mode: strict checking against allowed origins
|
||||
if securityCfg != nil && securityCfg.ProductionMode {
|
||||
for _, allowed := range securityCfg.AllowedOrigins {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false // Reject if not in allowed list
|
||||
}
|
||||
|
||||
// Development mode: allow localhost and local network origins
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
host := parsedOrigin.Host
|
||||
return strings.HasSuffix(host, ":8080") ||
|
||||
strings.HasPrefix(host, "localhost:") ||
|
||||
strings.HasPrefix(host, "127.0.0.1:") ||
|
||||
strings.HasPrefix(host, "192.168.") ||
|
||||
strings.HasPrefix(host, "10.") ||
|
||||
strings.HasPrefix(host, "172.")
|
||||
},
|
||||
// Performance optimizations
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 16 * 1024,
|
||||
WriteBufferSize: 16 * 1024,
|
||||
EnableCompression: true,
|
||||
}
|
||||
}
|
||||
|
||||
// WSHandler handles WebSocket connections for the API.
|
||||
type WSHandler struct {
|
||||
authConfig *auth.Config
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
dataDir string
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
securityConfig *config.SecurityConfig
|
||||
auditLogger *audit.Logger
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewWSHandler creates a new WebSocket handler.
|
||||
func NewWSHandler(
|
||||
authConfig *auth.Config,
|
||||
logger *logging.Logger,
|
||||
expManager *experiment.Manager,
|
||||
dataDir string,
|
||||
taskQueue queue.Backend,
|
||||
db *storage.DB,
|
||||
jupyterServiceMgr *jupyter.ServiceManager,
|
||||
securityConfig *config.SecurityConfig,
|
||||
auditLogger *audit.Logger,
|
||||
) *WSHandler {
|
||||
return &WSHandler{
|
||||
authConfig: authConfig,
|
||||
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
|
||||
expManager: expManager,
|
||||
dataDir: dataDir,
|
||||
queue: taskQueue,
|
||||
db: db,
|
||||
jupyterServiceMgr: jupyterServiceMgr,
|
||||
securityConfig: securityConfig,
|
||||
auditLogger: auditLogger,
|
||||
upgrader: createUpgrader(securityConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
|
||||
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
|
||||
if err := tcpConn.SetNoDelay(true); err != nil {
|
||||
logger.Warn("failed to enable tcp no delay", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
if r.TLS != nil {
|
||||
// Only set HSTS if using HTTPS
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
// Check API key before upgrading WebSocket
|
||||
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
||||
clientIP := r.RemoteAddr
|
||||
|
||||
// Validate API key if authentication is enabled
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
prefixLen := len(apiKey)
|
||||
if prefixLen > 8 {
|
||||
prefixLen = 8
|
||||
}
|
||||
h.logger.Info(
|
||||
"websocket auth attempt",
|
||||
"api_key_length",
|
||||
len(apiKey),
|
||||
"api_key_prefix",
|
||||
apiKey[:prefixLen],
|
||||
)
|
||||
|
||||
userID, err := h.authConfig.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
h.logger.Warn("websocket authentication failed", "error", err)
|
||||
|
||||
// Audit log failed authentication
|
||||
if h.auditLogger != nil {
|
||||
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
|
||||
}
|
||||
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h.logger.Info("websocket authentication succeeded")
|
||||
|
||||
// Audit log successful authentication
|
||||
if h.auditLogger != nil && userID != nil {
|
||||
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := h.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
conn.EnableWriteCompression(true)
|
||||
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
|
||||
h.logger.Warn("failed to set websocket compression level", "error", err)
|
||||
}
|
||||
enableLowLatencyTCP(conn, h.logger)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
||||
|
||||
for {
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(
|
||||
err,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
h.logger.Warn("received non-binary message")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.handleMessage(conn, message); err != nil {
|
||||
h.logger.Error("message handling error", "error", err)
|
||||
// Send structured error response so CLI clients can parse it.
|
||||
// (Raw fallback bytes cause client-side InvalidPacket errors.)
|
||||
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
||||
if len(message) < 1 {
|
||||
return fmt.Errorf("message too short")
|
||||
}
|
||||
|
||||
opcode := message[0]
|
||||
payload := message[1:]
|
||||
|
||||
switch opcode {
|
||||
case OpcodeQueueJob:
|
||||
return h.handleQueueJob(conn, payload)
|
||||
case OpcodeQueueJobWithTracking:
|
||||
return h.handleQueueJobWithTracking(conn, payload)
|
||||
case OpcodeQueueJobWithSnapshot:
|
||||
return h.handleQueueJobWithSnapshot(conn, payload)
|
||||
case OpcodeStatusRequest:
|
||||
return h.handleStatusRequest(conn, payload)
|
||||
case OpcodeCancelJob:
|
||||
return h.handleCancelJob(conn, payload)
|
||||
case OpcodePrune:
|
||||
return h.handlePrune(conn, payload)
|
||||
case OpcodeDatasetList:
|
||||
return h.handleDatasetList(conn, payload)
|
||||
case OpcodeDatasetRegister:
|
||||
return h.handleDatasetRegister(conn, payload)
|
||||
case OpcodeDatasetInfo:
|
||||
return h.handleDatasetInfo(conn, payload)
|
||||
case OpcodeDatasetSearch:
|
||||
return h.handleDatasetSearch(conn, payload)
|
||||
case OpcodeLogMetric:
|
||||
return h.handleLogMetric(conn, payload)
|
||||
case OpcodeGetExperiment:
|
||||
return h.handleGetExperiment(conn, payload)
|
||||
case OpcodeStartJupyter:
|
||||
return h.handleStartJupyter(conn, payload)
|
||||
case OpcodeStopJupyter:
|
||||
return h.handleStopJupyter(conn, payload)
|
||||
case OpcodeRemoveJupyter:
|
||||
return h.handleRemoveJupyter(conn, payload)
|
||||
case OpcodeRestoreJupyter:
|
||||
return h.handleRestoreJupyter(conn, payload)
|
||||
case OpcodeListJupyter:
|
||||
return h.handleListJupyter(conn, payload)
|
||||
case OpcodeValidateRequest:
|
||||
return h.handleValidateRequest(conn, payload)
|
||||
default:
|
||||
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
||||
}
|
||||
}
|
||||
1268
internal/api/ws_jobs.go
Normal file
1268
internal/api/ws_jobs.go
Normal file
File diff suppressed because it is too large
Load diff
478
internal/api/ws_jupyter.go
Normal file
478
internal/api/ws_jupyter.go
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service json.RawMessage `json:"service,omitempty"`
|
||||
Services json.RawMessage `json:"services,omitempty"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "restore jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionRestore,
|
||||
jupyterNameKey: strings.TrimSpace(name),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter restore", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter restore", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out != "" {
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
if strings.TrimSpace(payloadOut.RestorePath) != "" {
|
||||
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
||||
}
|
||||
|
||||
type jupyterServiceView struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
const (
|
||||
jupyterTaskTypeKey = "task_type"
|
||||
jupyterTaskTypeValue = "jupyter"
|
||||
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
)
|
||||
|
||||
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
|
||||
if h.queue == nil {
|
||||
return "", fmt.Errorf("task queue not configured")
|
||||
}
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(userName) == "" {
|
||||
return "", fmt.Errorf("missing user")
|
||||
}
|
||||
if meta == nil {
|
||||
meta = make(map[string]string)
|
||||
}
|
||||
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
|
||||
|
||||
taskID := uuid.New().String()
|
||||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: "",
|
||||
Status: "queued",
|
||||
Priority: 100, // high priority; interactive request
|
||||
CreatedAt: time.Now(),
|
||||
UserID: userName,
|
||||
Username: userName,
|
||||
CreatedBy: userName,
|
||||
Metadata: meta,
|
||||
}
|
||||
if err := h.queue.AddTask(task); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
|
||||
if h.queue == nil {
|
||||
return nil, fmt.Errorf("task queue not configured")
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("timed out waiting for worker")
|
||||
}
|
||||
t, err := h.queue.GetTask(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
|
||||
return t, nil
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol:
|
||||
// [api_key_hash:16][name][workspace][password]
|
||||
if len(payload) < 21 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
|
||||
if len(payload) < offset+workspaceLen+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
|
||||
}
|
||||
|
||||
workspace := string(payload[offset : offset+workspaceLen])
|
||||
offset += workspaceLen
|
||||
|
||||
passwordLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+passwordLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
|
||||
}
|
||||
|
||||
// Password is parsed but not used in StartRequest
|
||||
// offset += passwordLen (already advanced during parsing)
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionStart,
|
||||
jupyterNameKey: strings.TrimSpace(name),
|
||||
jupyterWorkspaceKey: strings.TrimSpace(workspace),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter task", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter start", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
h.logger.Error("jupyter task failed", "error", result.Error)
|
||||
details := strings.TrimSpace(result.Error)
|
||||
lower := strings.ToLower(details)
|
||||
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
|
||||
}
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to start Jupyter service", details)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out != "" {
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
|
||||
var svc jupyterServiceView
|
||||
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
|
||||
if strings.TrimSpace(svc.URL) != "" {
|
||||
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
idLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+idLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
|
||||
serviceID := string(payload[offset : offset+idLen])
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionStop,
|
||||
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter stop", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
|
||||
}
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter stop", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
offset := 16
|
||||
idLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+idLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
serviceID := string(payload[offset : offset+idLen])
|
||||
offset += idLen
|
||||
|
||||
// Optional: purge flag (1 byte). Default false for trash-first behavior.
|
||||
purge := false
|
||||
if len(payload) > offset {
|
||||
purge = payload[offset] == 0x01
|
||||
}
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionRemove,
|
||||
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
||||
"jupyter_purge": fmt.Sprintf("%t", purge),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter remove", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter remove", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16]
|
||||
if len(payload) < 16 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:read") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionList,
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
|
||||
}
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServerOverloaded, "Failed to list Jupyter services", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
empty, _ := json.Marshal([]any{})
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", empty))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
// Always return an array payload (even if empty) so clients can render a stable table.
|
||||
payload := payloadOut.Services
|
||||
if len(payload) == 0 {
|
||||
payload = []byte("[]")
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
|
||||
}
|
||||
// Fallback: return empty array on unexpected output.
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", []byte("[]")))
|
||||
}
|
||||
100
internal/api/ws_tls_auth.go
Normal file
100
internal/api/ws_tls_auth.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// SetupTLSConfig creates TLS configuration for WebSocket server
|
||||
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
|
||||
var server *http.Server
|
||||
|
||||
if certFile != "" && keyFile != "" {
|
||||
// Use provided certificates
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if host != "" {
|
||||
// Use Let's Encrypt with autocert
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(host),
|
||||
Cache: autocert.DirCache("/var/www/.cache"),
|
||||
}
|
||||
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: certManager.TLSConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// verifyAPIKeyHash verifies the provided binary hash against stored API keys
|
||||
func (h *WSHandler) verifyAPIKeyHash(hash []byte) error {
|
||||
if h.authConfig == nil || !h.authConfig.Enabled {
|
||||
return nil // No auth required
|
||||
}
|
||||
_, err := h.authConfig.ValidateAPIKeyHash(hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid api key")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet
|
||||
func (h *WSHandler) sendErrorPacket(
|
||||
conn *websocket.Conn,
|
||||
errorCode byte,
|
||||
message string,
|
||||
details string,
|
||||
) error {
|
||||
packet := NewErrorPacket(errorCode, message, details)
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
// sendResponsePacket sends a structured response packet
|
||||
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize response packet", "error", err)
|
||||
// Fallback to simple error response
|
||||
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
|
||||
}
|
||||
|
||||
return conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) {
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
642
internal/api/ws_validate.go
Normal file
642
internal/api/ws_validate.go
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
type validateCheck struct {
|
||||
OK bool `json:"ok"`
|
||||
Expected string `json:"expected,omitempty"`
|
||||
Actual string `json:"actual,omitempty"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type validateReport struct {
|
||||
OK bool `json:"ok"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Checks map[string]validateCheck `json:"checks"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
TS string `json:"ts"`
|
||||
}
|
||||
|
||||
func shouldRequireRunManifest(task *queue.Task) bool {
|
||||
if task == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(strings.TrimSpace(task.Status))
|
||||
switch s {
|
||||
case "running", "completed", "failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func expectedRunManifestBucketForStatus(status string) (string, bool) {
|
||||
s := strings.ToLower(strings.TrimSpace(status))
|
||||
switch s {
|
||||
case "queued", "pending":
|
||||
return "pending", true
|
||||
case "running":
|
||||
return "running", true
|
||||
case "completed", "finished":
|
||||
return "finished", true
|
||||
case "failed":
|
||||
return "failed", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
|
||||
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
|
||||
return "", "", false
|
||||
}
|
||||
jobPaths := config.NewJobPaths(basePath)
|
||||
typedRoots := []struct {
|
||||
bucket string
|
||||
root string
|
||||
}{
|
||||
{bucket: "running", root: jobPaths.RunningPath()},
|
||||
{bucket: "pending", root: jobPaths.PendingPath()},
|
||||
{bucket: "finished", root: jobPaths.FinishedPath()},
|
||||
{bucket: "failed", root: jobPaths.FailedPath()},
|
||||
}
|
||||
for _, item := range typedRoots {
|
||||
root := item.root
|
||||
dir := filepath.Join(root, jobName)
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
||||
return dir, item.bucket, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
|
||||
if task == nil {
|
||||
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
|
||||
}
|
||||
|
||||
if task.CPU < 0 {
|
||||
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
|
||||
return chk, []string{"invalid cpu request"}
|
||||
}
|
||||
if task.MemoryGB < 0 {
|
||||
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
|
||||
return chk, []string{"invalid memory request"}
|
||||
}
|
||||
if task.GPU < 0 {
|
||||
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
|
||||
return chk, []string{"invalid gpu request"}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(task.GPUMemory) != "" {
|
||||
s := strings.TrimSpace(task.GPUMemory)
|
||||
if strings.HasSuffix(s, "%") {
|
||||
v := strings.TrimSuffix(s, "%")
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
if err != nil || f <= 0 || f > 100 {
|
||||
details := "gpu_memory must be a percentage in (0,100]"
|
||||
chk := validateCheck{OK: false, Details: details}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
} else {
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil || f <= 0 || f > 1 {
|
||||
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
|
||||
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
|
||||
return validateCheck{OK: true}, nil
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
|
||||
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
|
||||
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
targetType := payload[16]
|
||||
idLen := int(payload[17])
|
||||
if idLen < 1 || len(payload) < 18+idLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
|
||||
}
|
||||
idBytes := payload[18 : 18+idLen]
|
||||
|
||||
// Validate API key and user
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Invalid API key",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"Insufficient permissions to validate jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
if h.expManager == nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServiceUnavailable,
|
||||
"Experiment manager not available",
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
var task *queue.Task
|
||||
commitID := ""
|
||||
switch targetType {
|
||||
case 0:
|
||||
if len(idBytes) != 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
|
||||
}
|
||||
commitID = fmt.Sprintf("%x", idBytes)
|
||||
case 1:
|
||||
taskID := string(idBytes)
|
||||
if h.queue == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
|
||||
}
|
||||
t, err := h.queue.GetTask(taskID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
|
||||
}
|
||||
task = t
|
||||
if h.authConfig != nil &&
|
||||
h.authConfig.Enabled &&
|
||||
!user.Admin &&
|
||||
task.UserID != user.Name &&
|
||||
task.CreatedBy != user.Name {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"You can only validate your own jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
|
||||
}
|
||||
commitID = task.Metadata["commit_id"]
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
|
||||
}
|
||||
|
||||
r := validateReport{
|
||||
OK: true,
|
||||
TS: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Checks: map[string]validateCheck{},
|
||||
}
|
||||
if task != nil {
|
||||
r.TaskID = task.ID
|
||||
}
|
||||
if commitID != "" {
|
||||
r.CommitID = commitID
|
||||
}
|
||||
|
||||
// Validate commit id format
|
||||
if len(commitID) != 40 {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid commit_id length")
|
||||
} else if _, err := hex.DecodeString(commitID); err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid commit_id hex")
|
||||
}
|
||||
|
||||
// Experiment manifest integrity
|
||||
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
|
||||
if r.OK {
|
||||
if err := h.expManager.ValidateManifest(commitID); err != nil {
|
||||
r.OK = false
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: err.Error()}
|
||||
r.Errors = append(r.Errors, "experiment manifest validation failed")
|
||||
} else {
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: true}
|
||||
}
|
||||
}
|
||||
|
||||
// Deps manifest presence + hash
|
||||
// TODO(context): Allow client to declare which dependency manifest is authoritative.
|
||||
filesPath := h.expManager.GetFilesPath(commitID)
|
||||
depName, depErr := worker.SelectDependencyManifest(filesPath)
|
||||
if depErr != nil {
|
||||
r.OK = false
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Details: depErr.Error(),
|
||||
}
|
||||
r.Errors = append(r.Errors, "deps manifest missing")
|
||||
} else {
|
||||
sha, err := fileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Details: err.Error(),
|
||||
}
|
||||
r.Errors = append(r.Errors, "deps manifest hash failed")
|
||||
} else {
|
||||
r.Checks["deps_manifest"] = validateCheck{
|
||||
OK: true,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compare against expected task metadata if available.
|
||||
if task != nil {
|
||||
resCheck, resErrs := validateResourcesForTask(task)
|
||||
r.Checks["resources"] = resCheck
|
||||
if !resCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, resErrs...)
|
||||
}
|
||||
|
||||
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
|
||||
if err := container.ValidateJobName(task.JobName); err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid job name")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
|
||||
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
||||
}
|
||||
} else {
|
||||
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
|
||||
if !found {
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest not found")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
||||
}
|
||||
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "unable to read run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
|
||||
} else {
|
||||
r.Checks["run_manifest"] = validateCheck{OK: true}
|
||||
|
||||
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
|
||||
if ok {
|
||||
if expectedBucket != manifestBucket {
|
||||
msg := "run manifest location mismatch"
|
||||
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, msg)
|
||||
r.Checks["run_manifest_location"] = chk
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, msg)
|
||||
r.Checks["run_manifest_location"] = chk
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_location"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: expectedBucket,
|
||||
Actual: manifestBucket,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(rm.TaskID) == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest missing task_id")
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{OK: false, Expected: task.ID}
|
||||
} else if rm.TaskID != task.ID {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest task_id mismatch")
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: task.ID,
|
||||
Actual: rm.TaskID,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_task_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: task.ID,
|
||||
Actual: rm.TaskID,
|
||||
}
|
||||
}
|
||||
|
||||
commitWant := strings.TrimSpace(task.Metadata["commit_id"])
|
||||
commitGot := strings.TrimSpace(rm.CommitID)
|
||||
if commitWant != "" && commitGot != "" && commitWant != commitGot {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: commitWant,
|
||||
Actual: commitGot,
|
||||
}
|
||||
} else if commitWant != "" {
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: commitWant,
|
||||
Actual: commitGot,
|
||||
}
|
||||
}
|
||||
|
||||
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
||||
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
||||
depGotName := strings.TrimSpace(rm.DepsManifestName)
|
||||
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
|
||||
if depWantName != "" && depWantSHA != "" && depGotName != "" && depGotSHA != "" {
|
||||
expectedDep := depWantName + ":" + depWantSHA
|
||||
actualDep := depGotName + ":" + depGotSHA
|
||||
if depWantName != depGotName || depWantSHA != depGotSHA {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
|
||||
r.Checks["run_manifest_deps"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: expectedDep,
|
||||
Actual: actualDep,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_deps"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: expectedDep,
|
||||
Actual: actualDep,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(task.SnapshotID) != "" {
|
||||
snapWantID := strings.TrimSpace(task.SnapshotID)
|
||||
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
|
||||
snapGotID := strings.TrimSpace(rm.SnapshotID)
|
||||
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
|
||||
if snapWantID != "" && snapGotID != "" && snapWantID != snapGotID {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: snapWantID,
|
||||
Actual: snapGotID,
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: snapWantID,
|
||||
Actual: snapGotID,
|
||||
}
|
||||
}
|
||||
if snapWantSHA != "" && snapGotSHA != "" && snapWantSHA != snapGotSHA {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: snapWantSHA,
|
||||
Actual: snapGotSHA,
|
||||
}
|
||||
} else if snapWantSHA != "" {
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: snapWantSHA,
|
||||
Actual: snapGotSHA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
statusLower := strings.ToLower(strings.TrimSpace(task.Status))
|
||||
lifecycleOK := true
|
||||
details := ""
|
||||
|
||||
switch statusLower {
|
||||
case "running":
|
||||
if rm.StartedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
details = "missing started_at for running task"
|
||||
}
|
||||
if !rm.EndedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "ended_at must be empty for running task"
|
||||
}
|
||||
}
|
||||
if rm.ExitCode != nil {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "exit_code must be empty for running task"
|
||||
}
|
||||
}
|
||||
case "completed", "failed":
|
||||
if rm.StartedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
details = "missing started_at for completed/failed task"
|
||||
}
|
||||
if rm.EndedAt.IsZero() {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "missing ended_at for completed/failed task"
|
||||
}
|
||||
}
|
||||
if rm.ExitCode == nil {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "missing exit_code for completed/failed task"
|
||||
}
|
||||
}
|
||||
if !rm.StartedAt.IsZero() && !rm.EndedAt.IsZero() && rm.EndedAt.Before(rm.StartedAt) {
|
||||
lifecycleOK = false
|
||||
if details == "" {
|
||||
details = "ended_at is before started_at"
|
||||
}
|
||||
}
|
||||
case "queued", "pending":
|
||||
// queued/pending tasks may not have started yet.
|
||||
if !rm.EndedAt.IsZero() || rm.ExitCode != nil {
|
||||
lifecycleOK = false
|
||||
details = "queued/pending task should not have ended_at/exit_code"
|
||||
}
|
||||
}
|
||||
|
||||
if lifecycleOK {
|
||||
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
|
||||
} else {
|
||||
chk := validateCheck{OK: false, Details: details}
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
|
||||
r.Checks["run_manifest_lifecycle"] = chk
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
|
||||
r.Checks["run_manifest_lifecycle"] = chk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
|
||||
cur := ""
|
||||
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
|
||||
cur = man.OverallSHA
|
||||
}
|
||||
if want == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
|
||||
} else if cur == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
|
||||
} else if want != cur {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
|
||||
} else {
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
|
||||
}
|
||||
|
||||
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
||||
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
||||
if wantDep == "" || wantDepSha == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
|
||||
} else if depName != "" {
|
||||
sha, _ := fileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
ok := (wantDep == depName && wantDepSha == sha)
|
||||
if !ok {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: wantDep + ":" + wantDepSha,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
} else {
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: wantDep + ":" + wantDepSha,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot/dataset checks require dataDir.
|
||||
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
|
||||
// TODO(context): Validate snapshots by digest.
|
||||
if task.SnapshotID != "" {
|
||||
if h.dataDir == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
|
||||
} else {
|
||||
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
||||
if nerr != nil || wantSnap == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false}
|
||||
} else {
|
||||
curSnap, err := worker.DirOverallSHA256Hex(
|
||||
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
|
||||
)
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "snapshot hash computation failed")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
|
||||
} else if curSnap != wantSnap {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "snapshot checksum mismatch")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
|
||||
} else {
|
||||
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
// TODO(context): Add dataset URI fetch/verification.
|
||||
// TODO(context): Currently only validates local materialized datasets.
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
if ds.Checksum == "" {
|
||||
continue
|
||||
}
|
||||
key := "dataset:" + ds.Name
|
||||
if h.dataDir == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(
|
||||
r.Errors,
|
||||
"api server data_dir not configured; cannot validate dataset checksums",
|
||||
)
|
||||
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
|
||||
continue
|
||||
}
|
||||
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
|
||||
if nerr != nil || wantDS == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid dataset checksum format")
|
||||
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
|
||||
continue
|
||||
}
|
||||
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "dataset checksum computation failed")
|
||||
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
|
||||
continue
|
||||
}
|
||||
if curDS != wantDS {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "dataset checksum mismatch")
|
||||
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
|
||||
continue
|
||||
}
|
||||
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeUnknownError,
|
||||
"failed to serialize validate report",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
|
||||
}
|
||||
|
|
@ -133,6 +133,62 @@ func (c *Config) ValidateAPIKey(key string) (*User, error) {
|
|||
return nil, fmt.Errorf("invalid API key")
|
||||
}
|
||||
|
||||
// ValidateAPIKeyHash validates a pre-hashed API key and returns user information
|
||||
func (c *Config) ValidateAPIKeyHash(hash []byte) (*User, error) {
|
||||
if !c.Enabled {
|
||||
// Auth disabled - return default admin user for development
|
||||
return &User{Name: "default", Admin: true}, nil
|
||||
}
|
||||
if len(hash) != 16 {
|
||||
return nil, fmt.Errorf("invalid api key hash length: %d", len(hash))
|
||||
}
|
||||
|
||||
for username, entry := range c.APIKeys {
|
||||
stored := strings.TrimSpace(string(entry.Hash))
|
||||
if stored == "" {
|
||||
continue
|
||||
}
|
||||
storedBytes, err := hex.DecodeString(stored)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(storedBytes) != sha256.Size {
|
||||
continue
|
||||
}
|
||||
if string(storedBytes[:16]) == string(hash) {
|
||||
// Build user with role and permission inheritance
|
||||
user := &User{
|
||||
Name: string(username),
|
||||
Admin: entry.Admin,
|
||||
Roles: entry.Roles,
|
||||
Permissions: make(map[string]bool),
|
||||
}
|
||||
|
||||
// Copy explicit permissions
|
||||
for perm, value := range entry.Permissions {
|
||||
user.Permissions[perm] = value
|
||||
}
|
||||
|
||||
// Add role-based permissions
|
||||
rolePerms := getRolePermissions(entry.Roles)
|
||||
for perm, value := range rolePerms {
|
||||
if _, exists := user.Permissions[perm]; !exists {
|
||||
user.Permissions[perm] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Admin gets all permissions
|
||||
if entry.Admin {
|
||||
user.Permissions["*"] = true
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid api key")
|
||||
}
|
||||
|
||||
// AuthMiddleware creates HTTP middleware for API key authentication
|
||||
func (c *Config) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -177,6 +233,10 @@ func GetUserFromContext(ctx context.Context) *User {
|
|||
return nil
|
||||
}
|
||||
|
||||
func WithUserContext(ctx context.Context, user *User) context.Context {
|
||||
return context.WithValue(ctx, userContextKey, user)
|
||||
}
|
||||
|
||||
// RequireAdmin creates middleware that requires admin privileges
|
||||
func RequireAdmin(next http.Handler) http.Handler {
|
||||
return RequirePermission("system:admin")(next)
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ func (s *DatabaseAuthStore) init() error {
|
|||
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(revoked_at, COALESCE(expires_at, '9999-12-31'));
|
||||
CREATE INDEX IF NOT EXISTS idx_api_keys_active ON api_keys(
|
||||
revoked_at,
|
||||
COALESCE(expires_at, '9999-12-31')
|
||||
);
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query)
|
||||
|
|
@ -158,7 +161,16 @@ func (s *DatabaseAuthStore) CreateAPIKey(
|
|||
revoked_at = NULL
|
||||
`
|
||||
|
||||
_, err = s.db.ExecContext(ctx, query, userID, keyHash, admin, rolesJSON, permissionsJSON, expiresAt)
|
||||
_, err = s.db.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
userID,
|
||||
keyHash,
|
||||
admin,
|
||||
rolesJSON,
|
||||
permissionsJSON,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -213,7 +213,15 @@ func (h *HybridAuthStore) migrateFileToDatabase(ctx context.Context) error {
|
|||
|
||||
for username, entry := range h.fileStore.APIKeys {
|
||||
userID := string(username)
|
||||
err := h.dbStore.CreateAPIKey(ctx, userID, string(entry.Hash), entry.Admin, entry.Roles, entry.Permissions, nil)
|
||||
err := h.dbStore.CreateAPIKey(
|
||||
ctx,
|
||||
userID,
|
||||
string(entry.Hash),
|
||||
entry.Admin,
|
||||
entry.Roles,
|
||||
entry.Permissions,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Failed to migrate key for user %s: %v", userID, err)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -59,23 +59,43 @@ var PermissionGroups = map[string]PermissionGroup{
|
|||
Description: "Complete system access",
|
||||
},
|
||||
"job_management": {
|
||||
Name: "Job Management",
|
||||
Permissions: []string{PermissionJobsCreate, PermissionJobsRead, PermissionJobsUpdate, PermissionJobsDelete},
|
||||
Name: "Job Management",
|
||||
Permissions: []string{
|
||||
PermissionJobsCreate,
|
||||
PermissionJobsRead,
|
||||
PermissionJobsUpdate,
|
||||
PermissionJobsDelete,
|
||||
},
|
||||
Description: "Create, read, update, and delete ML jobs",
|
||||
},
|
||||
"data_access": {
|
||||
Name: "Data Access",
|
||||
Permissions: []string{PermissionDataRead, PermissionDataCreate, PermissionDataUpdate, PermissionDataDelete},
|
||||
Name: "Data Access",
|
||||
Permissions: []string{
|
||||
PermissionDataRead,
|
||||
PermissionDataCreate,
|
||||
PermissionDataUpdate,
|
||||
PermissionDataDelete,
|
||||
},
|
||||
Description: "Access and manage datasets",
|
||||
},
|
||||
"readonly": {
|
||||
Name: "Read Only",
|
||||
Permissions: []string{PermissionJobsRead, PermissionDataRead, PermissionModelsRead, PermissionSystemMetrics},
|
||||
Name: "Read Only",
|
||||
Permissions: []string{
|
||||
PermissionJobsRead,
|
||||
PermissionDataRead,
|
||||
PermissionModelsRead,
|
||||
PermissionSystemMetrics,
|
||||
},
|
||||
Description: "View-only access to system resources",
|
||||
},
|
||||
"system_admin": {
|
||||
Name: "System Administration",
|
||||
Permissions: []string{PermissionSystemConfig, PermissionSystemLogs, PermissionSystemUsers, PermissionSystemMetrics},
|
||||
Name: "System Administration",
|
||||
Permissions: []string{
|
||||
PermissionSystemConfig,
|
||||
PermissionSystemLogs,
|
||||
PermissionSystemUsers,
|
||||
PermissionSystemMetrics,
|
||||
},
|
||||
Description: "System configuration and user management",
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ const (
|
|||
RedisTaskStatusPrefix = "ml:status:"
|
||||
RedisDatasetPrefix = "ml:dataset:"
|
||||
RedisWorkerHeartbeat = "ml:workers:heartbeat"
|
||||
RedisWorkerPrewarmKey = "ml:workers:prewarm:"
|
||||
)
|
||||
|
||||
// Task status constants
|
||||
|
|
|
|||
83
internal/config/security.go
Normal file
83
internal/config/security.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityConfig holds security-related configuration
|
||||
type SecurityConfig struct {
|
||||
// AllowedOrigins lists the allowed origins for WebSocket connections
|
||||
// Empty list defaults to localhost-only in production mode
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
|
||||
// ProductionMode enables strict security checks
|
||||
ProductionMode bool `yaml:"production_mode"`
|
||||
|
||||
// APIKeyRotationDays is the number of days before API keys should be rotated
|
||||
APIKeyRotationDays int `yaml:"api_key_rotation_days"`
|
||||
|
||||
// AuditLogging configuration
|
||||
AuditLogging AuditLoggingConfig `yaml:"audit_logging"`
|
||||
|
||||
// IPWhitelist for additional connection filtering
|
||||
IPWhitelist []string `yaml:"ip_whitelist"`
|
||||
}
|
||||
|
||||
// AuditLoggingConfig holds audit logging configuration
|
||||
type AuditLoggingConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
LogPath string `yaml:"log_path"`
|
||||
}
|
||||
|
||||
// MonitoringConfig holds monitoring-related configuration
|
||||
type MonitoringConfig struct {
|
||||
Prometheus PrometheusConfig `yaml:"prometheus"`
|
||||
HealthChecks HealthChecksConfig `yaml:"health_checks"`
|
||||
}
|
||||
|
||||
// PrometheusConfig holds Prometheus metrics configuration
|
||||
type PrometheusConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Port int `yaml:"port"`
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
// HealthChecksConfig holds health check configuration
|
||||
type HealthChecksConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Interval time.Duration `yaml:"interval"`
|
||||
}
|
||||
|
||||
// Validate validates the security configuration
|
||||
func (s *SecurityConfig) Validate() error {
|
||||
if s.ProductionMode {
|
||||
if len(s.AllowedOrigins) == 0 {
|
||||
return fmt.Errorf("production_mode requires at least one allowed_origin")
|
||||
}
|
||||
}
|
||||
|
||||
if s.APIKeyRotationDays < 0 {
|
||||
return fmt.Errorf("api_key_rotation_days must be positive")
|
||||
}
|
||||
|
||||
if s.AuditLogging.Enabled && s.AuditLogging.LogPath == "" {
|
||||
return fmt.Errorf("audit_logging enabled but log_path not set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates the monitoring configuration
|
||||
func (m *MonitoringConfig) Validate() error {
|
||||
if m.Prometheus.Enabled {
|
||||
if m.Prometheus.Port <= 0 || m.Prometheus.Port > 65535 {
|
||||
return fmt.Errorf("prometheus port must be between 1 and 65535")
|
||||
}
|
||||
if m.Prometheus.Path == "" {
|
||||
m.Prometheus.Path = "/metrics" // Default
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -15,13 +15,21 @@ func int64Max(a, b int64) int64 {
|
|||
|
||||
// Metrics tracks various performance counters and statistics.
|
||||
type Metrics struct {
|
||||
TasksProcessed atomic.Int64
|
||||
TasksFailed atomic.Int64
|
||||
DataFetchTime atomic.Int64 // Total nanoseconds
|
||||
ExecutionTime atomic.Int64
|
||||
DataTransferred atomic.Int64 // Total bytes
|
||||
ActiveTasks atomic.Int64
|
||||
QueuedTasks atomic.Int64
|
||||
TasksProcessed atomic.Int64
|
||||
TasksFailed atomic.Int64
|
||||
DataFetchTime atomic.Int64 // Total nanoseconds
|
||||
ExecutionTime atomic.Int64
|
||||
DataTransferred atomic.Int64 // Total bytes
|
||||
ActiveTasks atomic.Int64
|
||||
QueuedTasks atomic.Int64
|
||||
PrewarmEnvHit atomic.Int64
|
||||
PrewarmEnvMiss atomic.Int64
|
||||
PrewarmEnvBuilt atomic.Int64
|
||||
PrewarmEnvTime atomic.Int64 // Total nanoseconds
|
||||
PrewarmSnapshotHit atomic.Int64
|
||||
PrewarmSnapshotMiss atomic.Int64
|
||||
PrewarmSnapshotBuilt atomic.Int64
|
||||
PrewarmSnapshotTime atomic.Int64 // Total nanoseconds
|
||||
}
|
||||
|
||||
// RecordTaskSuccess records successful task completion with duration.
|
||||
|
|
@ -53,6 +61,32 @@ func (m *Metrics) RecordDataTransfer(bytes int64, duration time.Duration) {
|
|||
m.DataFetchTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmEnvHit() {
|
||||
m.PrewarmEnvHit.Add(1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmEnvMiss() {
|
||||
m.PrewarmEnvMiss.Add(1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmEnvBuilt(duration time.Duration) {
|
||||
m.PrewarmEnvBuilt.Add(1)
|
||||
m.PrewarmEnvTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmSnapshotHit() {
|
||||
m.PrewarmSnapshotHit.Add(1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmSnapshotMiss() {
|
||||
m.PrewarmSnapshotMiss.Add(1)
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordPrewarmSnapshotBuilt(duration time.Duration) {
|
||||
m.PrewarmSnapshotBuilt.Add(1)
|
||||
m.PrewarmSnapshotTime.Add(duration.Nanoseconds())
|
||||
}
|
||||
|
||||
// SetQueuedTasks sets the number of queued tasks.
|
||||
func (m *Metrics) SetQueuedTasks(count int64) {
|
||||
m.QueuedTasks.Store(count)
|
||||
|
|
@ -66,13 +100,21 @@ func (m *Metrics) GetStats() map[string]any {
|
|||
dataFetchTime := m.DataFetchTime.Load()
|
||||
|
||||
return map[string]any{
|
||||
"tasks_processed": processed,
|
||||
"tasks_failed": failed,
|
||||
"active_tasks": m.ActiveTasks.Load(),
|
||||
"queued_tasks": m.QueuedTasks.Load(),
|
||||
"success_rate": float64(processed-failed) / float64(int64Max(processed, 1)),
|
||||
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)),
|
||||
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
|
||||
"avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)),
|
||||
"tasks_processed": processed,
|
||||
"tasks_failed": failed,
|
||||
"active_tasks": m.ActiveTasks.Load(),
|
||||
"queued_tasks": m.QueuedTasks.Load(),
|
||||
"success_rate": float64(processed-failed) / float64(int64Max(processed, 1)),
|
||||
"avg_exec_time": time.Duration(m.ExecutionTime.Load() / int64Max(processed, 1)),
|
||||
"data_transferred_gb": float64(dataTransferred) / (1024 * 1024 * 1024),
|
||||
"avg_fetch_time": time.Duration(dataFetchTime / int64Max(processed, 1)),
|
||||
"prewarm_env_hit": m.PrewarmEnvHit.Load(),
|
||||
"prewarm_env_miss": m.PrewarmEnvMiss.Load(),
|
||||
"prewarm_env_built": m.PrewarmEnvBuilt.Load(),
|
||||
"prewarm_env_time": time.Duration(m.PrewarmEnvTime.Load()),
|
||||
"prewarm_snapshot_hit": m.PrewarmSnapshotHit.Load(),
|
||||
"prewarm_snapshot_miss": m.PrewarmSnapshotMiss.Load(),
|
||||
"prewarm_snapshot_built": m.PrewarmSnapshotBuilt.Load(),
|
||||
"prewarm_snapshot_time": time.Duration(m.PrewarmSnapshotTime.Load()),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@ package middleware
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -26,7 +29,11 @@ type RateLimitOptions struct {
|
|||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware instance.
|
||||
func NewSecurityMiddleware(authConfig *auth.Config, jwtSecret string, rlOpts *RateLimitOptions) *SecurityMiddleware {
|
||||
func NewSecurityMiddleware(
|
||||
authConfig *auth.Config,
|
||||
jwtSecret string,
|
||||
rlOpts *RateLimitOptions,
|
||||
) *SecurityMiddleware {
|
||||
sm := &SecurityMiddleware{
|
||||
authConfig: authConfig,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
|
|
@ -62,21 +69,23 @@ func (sm *SecurityMiddleware) RateLimit(next http.Handler) http.Handler {
|
|||
// APIKeyAuth provides API key authentication middleware.
|
||||
func (sm *SecurityMiddleware) APIKeyAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
||||
|
||||
// Validate API key using auth config
|
||||
if sm.authConfig == nil {
|
||||
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
||||
// If authentication is not configured or disabled, allow all requests.
|
||||
// This keeps local/dev environments functional without requiring API keys.
|
||||
if sm.authConfig == nil || !sm.authConfig.Enabled {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := sm.authConfig.ValidateAPIKey(apiKey)
|
||||
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
||||
|
||||
// Validate API key using auth config
|
||||
user, err := sm.authConfig.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
ctx := auth.WithUserContext(r.Context(), user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -103,21 +112,54 @@ func SecurityHeaders(next http.Handler) http.Handler {
|
|||
|
||||
// IPWhitelist provides IP whitelist middleware.
|
||||
func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler) http.Handler {
|
||||
parsedAddrs := make([]netip.Addr, 0, len(allowedIPs))
|
||||
parsedPrefixes := make([]netip.Prefix, 0, len(allowedIPs))
|
||||
for _, raw := range allowedIPs {
|
||||
val := strings.TrimSpace(raw)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(val, "/") {
|
||||
p, err := netip.ParsePrefix(val)
|
||||
if err != nil {
|
||||
log.Printf("SECURITY: invalid ip whitelist cidr ignored: %q: %v", val, err)
|
||||
continue
|
||||
}
|
||||
parsedPrefixes = append(parsedPrefixes, p)
|
||||
continue
|
||||
}
|
||||
a, err := netip.ParseAddr(val)
|
||||
if err != nil {
|
||||
log.Printf("SECURITY: invalid ip whitelist addr ignored: %q: %v", val, err)
|
||||
continue
|
||||
}
|
||||
parsedAddrs = append(parsedAddrs, a)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if len(parsedAddrs) == 0 && len(parsedPrefixes) == 0 {
|
||||
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
clientIPStr := getClientIP(r)
|
||||
addr, err := parseClientIP(clientIPStr)
|
||||
if err != nil {
|
||||
http.Error(w, "IP not whitelisted", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client IP is in whitelist
|
||||
allowed := false
|
||||
for _, ip := range allowedIPs {
|
||||
if strings.Contains(ip, "/") {
|
||||
// CIDR notation - would need proper IP net parsing
|
||||
if strings.HasPrefix(clientIP, strings.Split(ip, "/")[0]) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if clientIP == ip {
|
||||
for _, a := range parsedAddrs {
|
||||
if a == addr {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
for _, p := range parsedPrefixes {
|
||||
if p.Contains(addr) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
|
|
@ -134,41 +176,46 @@ func (sm *SecurityMiddleware) IPWhitelist(allowedIPs []string) func(http.Handler
|
|||
}
|
||||
}
|
||||
|
||||
// CORS middleware with restrictive defaults
|
||||
func CORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Only allow specific origins in production
|
||||
allowedOrigins := []string{
|
||||
"https://ml-experiments.example.com",
|
||||
"https://app.example.com",
|
||||
// CORS middleware with configured allowed origins
|
||||
func CORS(allowedOrigins []string) func(http.Handler) http.Handler {
|
||||
allowed := make([]string, 0, len(allowedOrigins))
|
||||
for _, o := range allowedOrigins {
|
||||
v := strings.TrimSpace(o)
|
||||
if v != "" {
|
||||
allowed = append(allowed, v)
|
||||
}
|
||||
}
|
||||
|
||||
isAllowed := false
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed {
|
||||
isAllowed = true
|
||||
break
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
isAllowed := false
|
||||
for _, a := range allowed {
|
||||
if a == "*" || origin == a {
|
||||
isAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isAllowed {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isAllowed {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestTimeout provides request timeout middleware.
|
||||
|
|
@ -253,12 +300,28 @@ func getClientIP(r *http.Request) string {
|
|||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func parseClientIP(raw string) (netip.Addr, error) {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return netip.Addr{}, fmt.Errorf("empty ip")
|
||||
}
|
||||
// Try host:port parsing first (covers IPv4 and bracketed IPv6)
|
||||
if host, _, err := net.SplitHostPort(s); err == nil {
|
||||
s = host
|
||||
}
|
||||
// Trim brackets for IPv6 literals
|
||||
s = strings.TrimPrefix(s, "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
return netip.ParseAddr(s)
|
||||
}
|
||||
|
||||
// Response writer wrapper to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
|
@ -273,5 +336,11 @@ func (rw *responseWriter) WriteHeader(code int) {
|
|||
func logSecurityEvent(event map[string]interface{}) {
|
||||
// Implementation would send to security monitoring system
|
||||
// For now, just log (in production, use proper logging)
|
||||
log.Printf("SECURITY AUDIT: %s %s %s %v", event["client_ip"], event["method"], event["path"], event["status"])
|
||||
log.Printf(
|
||||
"SECURITY AUDIT: %s %s %s %v",
|
||||
event["client_ip"],
|
||||
event["method"],
|
||||
event["path"],
|
||||
event["status"],
|
||||
)
|
||||
}
|
||||
|
|
|
|||
295
internal/prommetrics/prometheus.go
Normal file
295
internal/prommetrics/prometheus.go
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
package prommetrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Metrics holds all Prometheus metrics for the application
|
||||
type Metrics struct {
|
||||
// WebSocket metrics
|
||||
wsConnections *prometheus.GaugeVec
|
||||
wsMessages *prometheus.CounterVec
|
||||
wsDuration *prometheus.HistogramVec
|
||||
wsErrors *prometheus.CounterVec
|
||||
|
||||
// Job queue metrics
|
||||
jobsQueued prometheus.Counter
|
||||
jobsCompleted *prometheus.CounterVec
|
||||
jobsActive prometheus.Gauge
|
||||
jobDuration *prometheus.HistogramVec
|
||||
queueLength prometheus.Gauge
|
||||
|
||||
// Jupyter metrics
|
||||
jupyterServices *prometheus.GaugeVec
|
||||
jupyterOps *prometheus.CounterVec
|
||||
|
||||
// HTTP metrics
|
||||
httpRequests *prometheus.CounterVec
|
||||
httpDuration *prometheus.HistogramVec
|
||||
|
||||
// Prewarm metrics
|
||||
prewarmSnapshotHit prometheus.Counter
|
||||
prewarmSnapshotMiss prometheus.Counter
|
||||
prewarmSnapshotBuilt prometheus.Counter
|
||||
prewarmSnapshotTime prometheus.Histogram
|
||||
|
||||
registry *prometheus.Registry
|
||||
}
|
||||
|
||||
// New creates a new Prometheus Metrics instance
|
||||
func New() *Metrics {
|
||||
m := &Metrics{
|
||||
registry: prometheus.NewRegistry(),
|
||||
}
|
||||
|
||||
m.initMetrics()
|
||||
return m
|
||||
}
|
||||
|
||||
// initMetrics initializes all Prometheus metrics
|
||||
func (m *Metrics) initMetrics() {
|
||||
// WebSocket metrics
|
||||
m.wsConnections = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "fetchml_websocket_connections",
|
||||
Help: "Number of active WebSocket connections",
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
m.wsMessages = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_websocket_messages_total",
|
||||
Help: "Total number of WebSocket messages",
|
||||
},
|
||||
[]string{"opcode", "status"},
|
||||
)
|
||||
|
||||
m.wsDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "fetchml_websocket_duration_seconds",
|
||||
Help: "WebSocket message processing duration",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"opcode"},
|
||||
)
|
||||
|
||||
m.wsErrors = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_websocket_errors_total",
|
||||
Help: "Total number of WebSocket errors",
|
||||
},
|
||||
[]string{"type"},
|
||||
)
|
||||
|
||||
// Job queue metrics
|
||||
m.jobsQueued = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_jobs_queued_total",
|
||||
Help: "Total number of jobs queued",
|
||||
},
|
||||
)
|
||||
|
||||
m.jobsCompleted = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_jobs_completed_total",
|
||||
Help: "Total number of completed jobs",
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
m.jobsActive = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "fetchml_jobs_active",
|
||||
Help: "Number of currently active jobs",
|
||||
},
|
||||
)
|
||||
|
||||
m.jobDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "fetchml_job_duration_seconds",
|
||||
Help: "Job execution duration",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 300, 600, 1800, 3600},
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
m.queueLength = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "fetchml_queue_length",
|
||||
Help: "Current job queue length",
|
||||
},
|
||||
)
|
||||
|
||||
// Jupyter metrics
|
||||
m.jupyterServices = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "fetchml_jupyter_services",
|
||||
Help: "Number of Jupyter services",
|
||||
},
|
||||
[]string{"status"},
|
||||
)
|
||||
|
||||
m.jupyterOps = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_jupyter_operations_total",
|
||||
Help: "Total number of Jupyter operations",
|
||||
},
|
||||
[]string{"operation", "status"},
|
||||
)
|
||||
|
||||
// HTTP metrics
|
||||
m.httpRequests = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_http_requests_total",
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"method", "endpoint", "status"},
|
||||
)
|
||||
|
||||
m.httpDuration = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "fetchml_http_duration_seconds",
|
||||
Help: "HTTP request duration",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"method", "endpoint"},
|
||||
)
|
||||
|
||||
// Prewarm metrics
|
||||
m.prewarmSnapshotHit = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_prewarm_snapshot_hit_total",
|
||||
Help: "Total number of prewarmed snapshot hits (snapshots found in .prewarm/)",
|
||||
},
|
||||
)
|
||||
|
||||
m.prewarmSnapshotMiss = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_prewarm_snapshot_miss_total",
|
||||
Help: "Total number of prewarmed snapshot misses (snapshots not found in .prewarm/)",
|
||||
},
|
||||
)
|
||||
|
||||
m.prewarmSnapshotBuilt = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "fetchml_prewarm_snapshot_built_total",
|
||||
Help: "Total number of snapshots prewarmed into .prewarm/",
|
||||
},
|
||||
)
|
||||
|
||||
m.prewarmSnapshotTime = prometheus.NewHistogram(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "fetchml_prewarm_snapshot_duration_seconds",
|
||||
Help: "Time spent prewarming snapshots",
|
||||
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 30, 60, 120},
|
||||
},
|
||||
)
|
||||
|
||||
// Register all metrics
|
||||
m.registry.MustRegister(
|
||||
m.wsConnections,
|
||||
m.wsMessages,
|
||||
m.wsDuration,
|
||||
m.wsErrors,
|
||||
m.jobsQueued,
|
||||
m.jobsCompleted,
|
||||
m.jobsActive,
|
||||
m.jobDuration,
|
||||
m.queueLength,
|
||||
m.jupyterServices,
|
||||
m.jupyterOps,
|
||||
m.httpRequests,
|
||||
m.httpDuration,
|
||||
m.prewarmSnapshotHit,
|
||||
m.prewarmSnapshotMiss,
|
||||
m.prewarmSnapshotBuilt,
|
||||
m.prewarmSnapshotTime,
|
||||
)
|
||||
}
|
||||
|
||||
// Handler returns the Prometheus HTTP handler
|
||||
func (m *Metrics) Handler() http.Handler {
|
||||
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
|
||||
}
|
||||
|
||||
// WebSocket metrics methods
|
||||
func (m *Metrics) IncWSConnections(status string) {
|
||||
m.wsConnections.WithLabelValues(status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) DecWSConnections(status string) {
|
||||
m.wsConnections.WithLabelValues(status).Dec()
|
||||
}
|
||||
|
||||
func (m *Metrics) IncWSMessages(opcode, status string) {
|
||||
m.wsMessages.WithLabelValues(opcode, status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveWSDuration(opcode string, duration time.Duration) {
|
||||
m.wsDuration.WithLabelValues(opcode).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) IncWSErrors(errType string) {
|
||||
m.wsErrors.WithLabelValues(errType).Inc()
|
||||
}
|
||||
|
||||
// Job queue metrics methods
|
||||
func (m *Metrics) IncJobsQueued() {
|
||||
m.jobsQueued.Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) IncJobsCompleted(status string) {
|
||||
m.jobsCompleted.WithLabelValues(status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) SetJobsActive(count float64) {
|
||||
m.jobsActive.Set(count)
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveJobDuration(status string, duration time.Duration) {
|
||||
m.jobDuration.WithLabelValues(status).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
func (m *Metrics) SetQueueLength(length float64) {
|
||||
m.queueLength.Set(length)
|
||||
}
|
||||
|
||||
// Jupyter metrics methods
|
||||
func (m *Metrics) SetJupyterServices(status string, count float64) {
|
||||
m.jupyterServices.WithLabelValues(status).Set(count)
|
||||
}
|
||||
|
||||
func (m *Metrics) IncJupyterOps(operation, status string) {
|
||||
m.jupyterOps.WithLabelValues(operation, status).Inc()
|
||||
}
|
||||
|
||||
// HTTP metrics methods
|
||||
func (m *Metrics) IncHTTPRequests(method, endpoint, status string) {
|
||||
m.httpRequests.WithLabelValues(method, endpoint, status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveHTTPDuration(method, endpoint string, duration time.Duration) {
|
||||
m.httpDuration.WithLabelValues(method, endpoint).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
// Prewarm metrics methods
|
||||
func (m *Metrics) IncPrewarmSnapshotHit() {
|
||||
m.prewarmSnapshotHit.Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) IncPrewarmSnapshotMiss() {
|
||||
m.prewarmSnapshotMiss.Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) IncPrewarmSnapshotBuilt() {
|
||||
m.prewarmSnapshotBuilt.Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) ObservePrewarmSnapshotDuration(duration time.Duration) {
|
||||
m.prewarmSnapshotTime.Observe(duration.Seconds())
|
||||
}
|
||||
Loading…
Reference in a new issue