refactor: Phase 5 cleanup - Remove original ws_*.go files
Removed original monolithic WebSocket handler files after extracting to focused packages: Deleted: - ws_jobs.go (1,365 lines) → Extracted to api/jobs/handlers.go - ws_jupyter.go (512 lines) → Extracted to api/jupyter/handlers.go - ws_validate.go (523 lines) → Extracted to api/validate/handlers.go - ws_handler.go (379 lines) → Extracted to api/ws/handler.go - ws_datasets.go (174 lines) - Functionality not migrated - ws_tls_auth.go (101 lines) - Functionality not migrated Updated: - routes.go - Changed NewWSHandler → ws.NewHandler Lines deleted: ~3,000+ lines from monolithic files Build status: Compiles successfully
This commit is contained in:
parent
f0ffbb4a3d
commit
d9c5750ed8
7 changed files with 2 additions and 3053 deletions
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/prommetrics"
|
||||
)
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
|||
|
||||
// Register WebSocket handler with security config and audit logger
|
||||
securityCfg := getSecurityConfig(s.config)
|
||||
wsHandler := NewWSHandler(
|
||||
wsHandler := ws.NewHandler(
|
||||
s.config.BuildAuthConfig(),
|
||||
s.logger,
|
||||
s.expManager,
|
||||
|
|
|
|||
|
|
@ -1,173 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleDatasetList(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.ListDatasets(ctx, 0)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to list datasets", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(datasets)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetRegister(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetRegister)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsCreate, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
urlLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if urlLen <= 0 || len(payload) < offset+urlLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url length", "")
|
||||
}
|
||||
urlStr := string(payload[offset : offset+urlLen])
|
||||
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "dataset name required", "")
|
||||
}
|
||||
if u, err := url.Parse(urlStr); err != nil || u.Scheme == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset url", "")
|
||||
}
|
||||
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
if err := h.db.UpsertDataset(ctx, &storage.Dataset{Name: name, URL: urlStr}); err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to register dataset", err.Error())
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket("Dataset registered"))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetInfo(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid dataset name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
ds, err := h.db.GetDataset(ctx, name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceNotFound, "Dataset not found", "")
|
||||
}
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to get dataset", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ds)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("dataset", data))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleDatasetSearch(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetSearch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermDatasetsRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requireDB(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
termLen := int(payload[offset])
|
||||
offset++
|
||||
if termLen < 0 || len(payload) < offset+termLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid search term length", "")
|
||||
}
|
||||
term := string(payload[offset : offset+termLen])
|
||||
term = strings.TrimSpace(term)
|
||||
|
||||
ctx, cancel := helpers.DBContextShort()
|
||||
defer cancel()
|
||||
|
||||
datasets, err := h.db.SearchDatasets(ctx, term, 0)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to search datasets", err.Error())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(datasets)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServerOverloaded,
|
||||
"Failed to serialize response",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("datasets", data))
|
||||
}
|
||||
|
|
@ -1,379 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Opcodes for binary WebSocket protocol
|
||||
const (
|
||||
OpcodeQueueJob = 0x01
|
||||
OpcodeStatusRequest = 0x02
|
||||
OpcodeCancelJob = 0x03
|
||||
OpcodePrune = 0x04
|
||||
OpcodeDatasetList = 0x06
|
||||
OpcodeDatasetRegister = 0x07
|
||||
OpcodeDatasetInfo = 0x08
|
||||
OpcodeDatasetSearch = 0x09
|
||||
OpcodeLogMetric = 0x0A
|
||||
OpcodeGetExperiment = 0x0B
|
||||
OpcodeQueueJobWithTracking = 0x0C
|
||||
OpcodeQueueJobWithSnapshot = 0x17
|
||||
OpcodeQueueJobWithArgs = 0x1A
|
||||
OpcodeQueueJobWithNote = 0x1B
|
||||
OpcodeAnnotateRun = 0x1C
|
||||
OpcodeSetRunNarrative = 0x1D
|
||||
OpcodeStartJupyter = 0x0D
|
||||
OpcodeStopJupyter = 0x0E
|
||||
OpcodeRemoveJupyter = 0x18
|
||||
OpcodeRestoreJupyter = 0x19
|
||||
OpcodeListJupyter = 0x0F
|
||||
OpcodeListJupyterPackages = 0x1E
|
||||
OpcodeValidateRequest = 0x16
|
||||
|
||||
// Logs opcodes
|
||||
OpcodeGetLogs = 0x20
|
||||
OpcodeStreamLogs = 0x21
|
||||
)
|
||||
|
||||
// createUpgrader creates a WebSocket upgrader with the given security configuration
|
||||
func createUpgrader(securityCfg *config.SecurityConfig) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true // Allow same-origin requests
|
||||
}
|
||||
|
||||
// Production mode: strict checking against allowed origins
|
||||
if securityCfg != nil && securityCfg.ProductionMode {
|
||||
for _, allowed := range securityCfg.AllowedOrigins {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false // Reject if not in allowed list
|
||||
}
|
||||
|
||||
// Development mode: allow localhost and local network origins
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
host := parsedOrigin.Host
|
||||
return strings.HasSuffix(host, ":8080") ||
|
||||
strings.HasPrefix(host, "localhost:") ||
|
||||
strings.HasPrefix(host, "127.0.0.1:") ||
|
||||
strings.HasPrefix(host, "192.168.") ||
|
||||
strings.HasPrefix(host, "10.") ||
|
||||
strings.HasPrefix(host, "172.")
|
||||
},
|
||||
// Performance optimizations
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
ReadBufferSize: 16 * 1024,
|
||||
WriteBufferSize: 16 * 1024,
|
||||
EnableCompression: true,
|
||||
}
|
||||
}
|
||||
|
||||
// WSHandler handles WebSocket connections for the API.
|
||||
type WSHandler struct {
|
||||
authConfig *auth.Config
|
||||
logger *logging.Logger
|
||||
expManager *experiment.Manager
|
||||
dataDir string
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
jupyterServiceMgr *jupyter.ServiceManager
|
||||
securityConfig *config.SecurityConfig
|
||||
auditLogger *audit.Logger
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewWSHandler creates a new WebSocket handler.
|
||||
func NewWSHandler(
|
||||
authConfig *auth.Config,
|
||||
logger *logging.Logger,
|
||||
expManager *experiment.Manager,
|
||||
dataDir string,
|
||||
taskQueue queue.Backend,
|
||||
db *storage.DB,
|
||||
jupyterServiceMgr *jupyter.ServiceManager,
|
||||
securityConfig *config.SecurityConfig,
|
||||
auditLogger *audit.Logger,
|
||||
) *WSHandler {
|
||||
return &WSHandler{
|
||||
authConfig: authConfig,
|
||||
logger: logger.Component(logging.EnsureTrace(context.Background()), "ws-handler"),
|
||||
expManager: expManager,
|
||||
dataDir: dataDir,
|
||||
queue: taskQueue,
|
||||
db: db,
|
||||
jupyterServiceMgr: jupyterServiceMgr,
|
||||
securityConfig: securityConfig,
|
||||
auditLogger: auditLogger,
|
||||
upgrader: createUpgrader(securityConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// enableLowLatencyTCP disables Nagle's algorithm to reduce latency for small packets.
|
||||
func enableLowLatencyTCP(conn *websocket.Conn, logger *logging.Logger) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tcpConn, ok := conn.UnderlyingConn().(*net.TCPConn); ok {
|
||||
if err := tcpConn.SetNoDelay(true); err != nil {
|
||||
logger.Warn("failed to enable tcp no delay", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Add security headers
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
if r.TLS != nil {
|
||||
// Only set HSTS if using HTTPS
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
// Check API key before upgrading WebSocket
|
||||
apiKey := auth.ExtractAPIKeyFromRequest(r)
|
||||
clientIP := r.RemoteAddr
|
||||
|
||||
// Validate API key if authentication is enabled
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
prefixLen := len(apiKey)
|
||||
if prefixLen > 8 {
|
||||
prefixLen = 8
|
||||
}
|
||||
h.logger.Info(
|
||||
"websocket auth attempt",
|
||||
"api_key_length",
|
||||
len(apiKey),
|
||||
"api_key_prefix",
|
||||
apiKey[:prefixLen],
|
||||
)
|
||||
|
||||
userID, err := h.authConfig.ValidateAPIKey(apiKey)
|
||||
if err != nil {
|
||||
h.logger.Warn("websocket authentication failed", "error", err)
|
||||
|
||||
// Audit log failed authentication
|
||||
if h.auditLogger != nil {
|
||||
h.auditLogger.LogAuthAttempt(apiKey[:prefixLen], clientIP, false, err.Error())
|
||||
}
|
||||
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h.logger.Info("websocket authentication succeeded")
|
||||
|
||||
// Audit log successful authentication
|
||||
if h.auditLogger != nil && userID != nil {
|
||||
h.auditLogger.LogAuthAttempt(userID.Name, clientIP, true, "")
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := h.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
conn.EnableWriteCompression(true)
|
||||
if err := conn.SetCompressionLevel(flate.BestSpeed); err != nil {
|
||||
h.logger.Warn("failed to set websocket compression level", "error", err)
|
||||
}
|
||||
enableLowLatencyTCP(conn, h.logger)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
h.logger.Info("websocket connection established", "remote", r.RemoteAddr)
|
||||
|
||||
for {
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(
|
||||
err,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseAbnormalClosure,
|
||||
) {
|
||||
h.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
h.logger.Warn("received non-binary message")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := h.handleMessage(conn, message); err != nil {
|
||||
h.logger.Error("message handling error", "error", err)
|
||||
// Send structured error response so CLI clients can parse it.
|
||||
// (Raw fallback bytes cause client-side InvalidPacket errors.)
|
||||
_ = h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "message handling error", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleMessage(conn *websocket.Conn, message []byte) error {
|
||||
if len(message) < 1 {
|
||||
return fmt.Errorf("message too short")
|
||||
}
|
||||
|
||||
opcode := message[0]
|
||||
payload := message[1:]
|
||||
|
||||
switch opcode {
|
||||
case OpcodeQueueJob:
|
||||
return h.handleQueueJob(conn, payload)
|
||||
case OpcodeQueueJobWithTracking:
|
||||
return h.handleQueueJobWithTracking(conn, payload)
|
||||
case OpcodeQueueJobWithSnapshot:
|
||||
return h.handleQueueJobWithSnapshot(conn, payload)
|
||||
case OpcodeQueueJobWithArgs:
|
||||
return h.handleQueueJobWithArgs(conn, payload)
|
||||
case OpcodeQueueJobWithNote:
|
||||
return h.handleQueueJobWithNote(conn, payload)
|
||||
case OpcodeAnnotateRun:
|
||||
return h.handleAnnotateRun(conn, payload)
|
||||
case OpcodeSetRunNarrative:
|
||||
return h.handleSetRunNarrative(conn, payload)
|
||||
case OpcodeStatusRequest:
|
||||
return h.handleStatusRequest(conn, payload)
|
||||
case OpcodeCancelJob:
|
||||
return h.handleCancelJob(conn, payload)
|
||||
case OpcodePrune:
|
||||
return h.handlePrune(conn, payload)
|
||||
case OpcodeDatasetList:
|
||||
return h.handleDatasetList(conn, payload)
|
||||
case OpcodeDatasetRegister:
|
||||
return h.handleDatasetRegister(conn, payload)
|
||||
case OpcodeDatasetInfo:
|
||||
return h.handleDatasetInfo(conn, payload)
|
||||
case OpcodeDatasetSearch:
|
||||
return h.handleDatasetSearch(conn, payload)
|
||||
case OpcodeLogMetric:
|
||||
return h.handleLogMetric(conn, payload)
|
||||
case OpcodeGetExperiment:
|
||||
return h.handleGetExperiment(conn, payload)
|
||||
case OpcodeStartJupyter:
|
||||
return h.handleStartJupyter(conn, payload)
|
||||
case OpcodeStopJupyter:
|
||||
return h.handleStopJupyter(conn, payload)
|
||||
case OpcodeRemoveJupyter:
|
||||
return h.handleRemoveJupyter(conn, payload)
|
||||
case OpcodeRestoreJupyter:
|
||||
return h.handleRestoreJupyter(conn, payload)
|
||||
case OpcodeListJupyter:
|
||||
return h.handleListJupyter(conn, payload)
|
||||
case OpcodeListJupyterPackages:
|
||||
return h.handleListJupyterPackages(conn, payload)
|
||||
case OpcodeValidateRequest:
|
||||
return h.handleValidateRequest(conn, payload)
|
||||
case OpcodeGetLogs:
|
||||
return h.handleGetLogs(conn, payload)
|
||||
case OpcodeStreamLogs:
|
||||
return h.handleStreamLogs(conn, payload)
|
||||
default:
|
||||
return fmt.Errorf("unknown opcode: 0x%02x", opcode)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthHandler is a handler function that receives an authenticated user
|
||||
type AuthHandler func(conn *websocket.Conn, payload []byte, user *auth.User) error
|
||||
|
||||
// authenticate validates the API key from raw payload and returns the user
|
||||
func (h *WSHandler) authenticate(conn *websocket.Conn, payload []byte, minLen int) (*auth.User, error) {
|
||||
if len(payload) < minLen {
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// authenticateWithHash validates a pre-extracted API key hash
|
||||
func (h *WSHandler) authenticateWithHash(conn *websocket.Conn, apiKeyHash []byte) (*auth.User, error) {
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
h.logger.Error("invalid api key", "error", err)
|
||||
return nil, h.sendErrorPacket(conn, ErrorCodeAuthenticationFailed, "Invalid API key", err.Error())
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// requirePermission checks if the user has the required permission
|
||||
func (h *WSHandler) requirePermission(
|
||||
user *auth.User,
|
||||
permission string,
|
||||
conn *websocket.Conn,
|
||||
) error {
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission(permission) {
|
||||
h.logger.Error("insufficient permissions", "user", user.Name, "required", permission)
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
fmt.Sprintf("Insufficient permissions: %s", permission),
|
||||
"",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// requireDB checks if the database is configured
|
||||
func (h *WSHandler) requireDB(conn *websocket.Conn) error {
|
||||
if h.db == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Database not configured", "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,512 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
// JupyterTaskErrorCode returns the error code for a Jupyter task.
|
||||
// This is kept for backward compatibility and delegates to the helper.
|
||||
func JupyterTaskErrorCode(t *queue.Task) byte {
|
||||
mapper := helpers.NewTaskErrorMapper()
|
||||
return byte(mapper.MapJupyterError(t))
|
||||
}
|
||||
|
||||
type jupyterTaskOutput struct {
|
||||
Type string `json:"type"`
|
||||
Service json.RawMessage `json:"service,omitempty"`
|
||||
Services json.RawMessage `json:"services,omitempty"`
|
||||
Packages json.RawMessage `json:"packages,omitempty"`
|
||||
RestorePath string `json:"restore_path,omitempty"`
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleRestoreJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, 18)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermJupyterManage, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset := ProtocolAPIKeyHashLen
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if len(payload) < offset+nameLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionRestore,
|
||||
jupyterNameKey: strings.TrimSpace(name),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-restore-%s", strings.TrimSpace(name))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter restore", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter restore", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter restore", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to restore Jupyter workspace", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Restored Jupyter workspace '%s'", strings.TrimSpace(name))
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out != "" {
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
if strings.TrimSpace(payloadOut.RestorePath) != "" {
|
||||
msg = fmt.Sprintf("Restored Jupyter workspace '%s' to %s", strings.TrimSpace(name), strings.TrimSpace(payloadOut.RestorePath))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
||||
}
|
||||
|
||||
type jupyterServiceView struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
const (
|
||||
jupyterTaskTypeKey = "task_type"
|
||||
jupyterTaskTypeValue = "jupyter"
|
||||
|
||||
jupyterTaskActionKey = "jupyter_action"
|
||||
jupyterActionStart = "start"
|
||||
jupyterActionStop = "stop"
|
||||
jupyterActionRemove = "remove"
|
||||
jupyterActionRestore = "restore"
|
||||
jupyterActionList = "list"
|
||||
jupyterActionListPkgs = "list_packages"
|
||||
|
||||
jupyterNameKey = "jupyter_name"
|
||||
jupyterWorkspaceKey = "jupyter_workspace"
|
||||
jupyterServiceIDKey = "jupyter_service_id"
|
||||
)
|
||||
|
||||
func (h *WSHandler) handleListJupyterPackages(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "list jupyter packages payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:read") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
name, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "missing jupyter name", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionListPkgs,
|
||||
jupyterNameKey: name,
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-packages-%s", name)
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter packages list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter packages list", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter packages list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter packages", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
payload := payloadOut.Packages
|
||||
if len(payload) == 0 {
|
||||
payload = []byte("[]")
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", payload))
|
||||
}
|
||||
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_packages", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
|
||||
func (h *WSHandler) enqueueJupyterTask(userName, jobName string, meta map[string]string) (string, error) {
|
||||
if h.queue == nil {
|
||||
return "", fmt.Errorf("task queue not configured")
|
||||
}
|
||||
if err := container.ValidateJobName(jobName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(userName) == "" {
|
||||
return "", fmt.Errorf("missing user")
|
||||
}
|
||||
if meta == nil {
|
||||
meta = make(map[string]string)
|
||||
}
|
||||
meta[jupyterTaskTypeKey] = jupyterTaskTypeValue
|
||||
|
||||
taskID := uuid.New().String()
|
||||
task := &queue.Task{
|
||||
ID: taskID,
|
||||
JobName: jobName,
|
||||
Args: "",
|
||||
Status: "queued",
|
||||
Priority: 100, // high priority; interactive request
|
||||
CreatedAt: time.Now(),
|
||||
UserID: userName,
|
||||
Username: userName,
|
||||
CreatedBy: userName,
|
||||
Metadata: meta,
|
||||
}
|
||||
if err := h.queue.AddTask(task); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (h *WSHandler) waitForTask(taskID string, timeout time.Duration) (*queue.Task, error) {
|
||||
if h.queue == nil {
|
||||
return nil, fmt.Errorf("task queue not configured")
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("timed out waiting for worker")
|
||||
}
|
||||
t, err := h.queue.GetTask(taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if t.Status == "completed" || t.Status == "failed" || t.Status == "cancelled" {
|
||||
return t, nil
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStartJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol:
|
||||
// [api_key_hash:16][name][workspace][password]
|
||||
if len(payload) < 21 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "start jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
// Verify API key
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid name length", "")
|
||||
}
|
||||
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
|
||||
if len(payload) < offset+workspaceLen+1 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid workspace length", "")
|
||||
}
|
||||
|
||||
workspace := string(payload[offset : offset+workspaceLen])
|
||||
offset += workspaceLen
|
||||
|
||||
passwordLen := int(payload[offset])
|
||||
offset++
|
||||
|
||||
if len(payload) < offset+passwordLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid password length", "")
|
||||
}
|
||||
|
||||
// Password is parsed but not used in StartRequest
|
||||
// offset += passwordLen (already advanced during parsing)
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionStart,
|
||||
jupyterNameKey: strings.TrimSpace(name),
|
||||
jupyterWorkspaceKey: strings.TrimSpace(workspace),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-%s", strings.TrimSpace(name))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter task", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter task", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter start", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
h.logger.Error("jupyter task failed", "error", result.Error)
|
||||
details := strings.TrimSpace(result.Error)
|
||||
lower := strings.ToLower(details)
|
||||
if strings.Contains(lower, "already exists") || strings.Contains(lower, "already in use") {
|
||||
return h.sendErrorPacket(conn, ErrorCodeResourceAlreadyExists, "Jupyter workspace already exists", details)
|
||||
}
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to start Jupyter service", details)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Started Jupyter service '%s'", strings.TrimSpace(name))
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out != "" {
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil && len(payloadOut.Service) > 0 {
|
||||
var svc jupyterServiceView
|
||||
if err := json.Unmarshal(payloadOut.Service, &svc); err == nil {
|
||||
if strings.TrimSpace(svc.URL) != "" {
|
||||
msg = fmt.Sprintf("Started Jupyter service '%s' at %s", strings.TrimSpace(name), strings.TrimSpace(svc.URL))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(msg))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleStopJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "stop jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
serviceID, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionStop,
|
||||
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-stop-%s", strings.TrimSpace(serviceID))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter stop", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter stop", "")
|
||||
}
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter stop", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to stop Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Stopped Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleRemoveJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][service_id_len:1][service_id:var]
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "remove jupyter payload too short", "")
|
||||
}
|
||||
|
||||
apiKeyHash := payload[:16]
|
||||
|
||||
p := helpers.NewPayloadParser(payload, 16)
|
||||
serviceID, err := p.ParseLengthPrefixedString()
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid service id length", "")
|
||||
}
|
||||
|
||||
// Optional: purge flag (1 byte). Default false for trash-first behavior.
|
||||
purge := false
|
||||
if p.HasRemaining() {
|
||||
purgeByte, _ := p.ParseByte()
|
||||
purge = purgeByte == 0x01
|
||||
}
|
||||
|
||||
if h.authConfig != nil && h.authConfig.Enabled {
|
||||
if err := h.verifyAPIKeyHash(apiKeyHash); err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Authentication failed",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if user != nil && !user.HasPermission("jupyter:manage") {
|
||||
return h.sendErrorPacket(conn, ErrorCodePermissionDenied, "Insufficient permissions", "")
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionRemove,
|
||||
jupyterServiceIDKey: strings.TrimSpace(serviceID),
|
||||
"jupyter_purge": fmt.Sprintf("%t", purge),
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-remove-%s", strings.TrimSpace(serviceID))
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter remove", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter remove", "")
|
||||
}
|
||||
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter remove", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to remove Jupyter service", strings.TrimSpace(result.Error))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewSuccessPacket(fmt.Sprintf("Removed Jupyter service %s", serviceID)))
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleListJupyter(conn *websocket.Conn, payload []byte) error {
|
||||
user, err := h.authenticate(conn, payload, ProtocolMinDatasetList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.requirePermission(user, PermJupyterRead, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := map[string]string{
|
||||
jupyterTaskActionKey: jupyterActionList,
|
||||
}
|
||||
jobName := fmt.Sprintf("jupyter-list-%s", user.Name)
|
||||
taskID, err := h.enqueueJupyterTask(user.Name, jobName, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to enqueue jupyter list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeDatabaseError, "Failed to enqueue Jupyter list", "")
|
||||
}
|
||||
result, err := h.waitForTask(taskID, 2*time.Minute)
|
||||
if err != nil {
|
||||
h.logger.Error("failed waiting for jupyter list", "error", err)
|
||||
return h.sendErrorPacket(conn, ErrorCodeTimeout, "Timed out waiting for worker", "")
|
||||
}
|
||||
if result.Status != "completed" {
|
||||
return h.sendErrorPacket(conn, JupyterTaskErrorCode(result), "Failed to list Jupyter services", strings.TrimSpace(result.Error))
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(result.Output)
|
||||
if out == "" {
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
var payloadOut jupyterTaskOutput
|
||||
if err := json.Unmarshal([]byte(out), &payloadOut); err == nil {
|
||||
payload := payloadOut.Services
|
||||
if len(payload) == 0 {
|
||||
payload = []byte("[]")
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", payload))
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("jupyter_services", helpers.MarshalJSONOrEmpty([]any{})))
|
||||
}
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// SetupTLSConfig creates TLS configuration for WebSocket server
|
||||
func SetupTLSConfig(certFile, keyFile string, host string) (*http.Server, error) {
|
||||
var server *http.Server
|
||||
|
||||
if certFile != "" && keyFile != "" {
|
||||
// Use provided certificates
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if host != "" {
|
||||
// Use Let's Encrypt with autocert
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(host),
|
||||
Cache: autocert.DirCache("/var/www/.cache"),
|
||||
}
|
||||
|
||||
server = &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
|
||||
TLSConfig: certManager.TLSConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// verifyAPIKeyHash verifies the provided binary hash against stored API keys
|
||||
func (h *WSHandler) verifyAPIKeyHash(hash []byte) error {
|
||||
if h.authConfig == nil || !h.authConfig.Enabled {
|
||||
return nil // No auth required
|
||||
}
|
||||
_, err := h.authConfig.ValidateAPIKeyHash(hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid api key")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet
|
||||
func (h *WSHandler) sendErrorPacket(
|
||||
conn *websocket.Conn,
|
||||
errorCode byte,
|
||||
message string,
|
||||
details string,
|
||||
) error {
|
||||
packet := NewErrorPacket(errorCode, message, details)
|
||||
return h.sendResponsePacket(conn, packet)
|
||||
}
|
||||
|
||||
// sendResponsePacket sends a structured response packet
|
||||
func (h *WSHandler) sendResponsePacket(conn *websocket.Conn, packet *ResponsePacket) error {
|
||||
data, err := packet.Serialize()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to serialize response packet", "error", err)
|
||||
// Fallback to simple error response
|
||||
return conn.WriteMessage(websocket.BinaryMessage, []byte{0xFF, 0x00})
|
||||
}
|
||||
|
||||
return conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
func (h *WSHandler) validateWSUser(apiKeyHash []byte) (*auth.User, error) {
|
||||
if h.authConfig != nil {
|
||||
user, err := h.authConfig.ValidateAPIKeyHash(apiKeyHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return &auth.User{
|
||||
Name: "default",
|
||||
Admin: true,
|
||||
Roles: []string{"admin"},
|
||||
Permissions: map[string]bool{
|
||||
"*": true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,523 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/helpers"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/manifest"
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
type validateCheck struct {
|
||||
OK bool `json:"ok"`
|
||||
Expected string `json:"expected,omitempty"`
|
||||
Actual string `json:"actual,omitempty"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type validateReport struct {
|
||||
OK bool `json:"ok"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Checks map[string]validateCheck `json:"checks"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
TS string `json:"ts"`
|
||||
}
|
||||
|
||||
func shouldRequireRunManifest(task *queue.Task) bool {
|
||||
if task == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(strings.TrimSpace(task.Status))
|
||||
switch s {
|
||||
case "running", "completed", "failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func expectedRunManifestBucketForStatus(status string) (string, bool) {
|
||||
s := strings.ToLower(strings.TrimSpace(status))
|
||||
switch s {
|
||||
case "queued", "pending":
|
||||
return "pending", true
|
||||
case "running":
|
||||
return "running", true
|
||||
case "completed", "finished":
|
||||
return "finished", true
|
||||
case "failed":
|
||||
return "failed", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func findRunManifestDir(basePath string, jobName string) (string, string, bool) {
|
||||
if strings.TrimSpace(basePath) == "" || strings.TrimSpace(jobName) == "" {
|
||||
return "", "", false
|
||||
}
|
||||
jobPaths := storage.NewJobPaths(basePath)
|
||||
typedRoots := []struct {
|
||||
bucket string
|
||||
root string
|
||||
}{
|
||||
{bucket: "running", root: jobPaths.RunningPath()},
|
||||
{bucket: "pending", root: jobPaths.PendingPath()},
|
||||
{bucket: "finished", root: jobPaths.FinishedPath()},
|
||||
{bucket: "failed", root: jobPaths.FailedPath()},
|
||||
}
|
||||
for _, item := range typedRoots {
|
||||
root := item.root
|
||||
dir := filepath.Join(root, jobName)
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
if _, err := os.Stat(manifest.ManifestPath(dir)); err == nil {
|
||||
return dir, item.bucket, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func validateResourcesForTask(task *queue.Task) (validateCheck, []string) {
|
||||
if task == nil {
|
||||
return validateCheck{OK: false, Details: "task is nil"}, []string{"missing task"}
|
||||
}
|
||||
|
||||
if task.CPU < 0 {
|
||||
chk := validateCheck{OK: false, Details: "cpu must be >= 0"}
|
||||
return chk, []string{"invalid cpu request"}
|
||||
}
|
||||
if task.MemoryGB < 0 {
|
||||
chk := validateCheck{OK: false, Details: "memory_gb must be >= 0"}
|
||||
return chk, []string{"invalid memory request"}
|
||||
}
|
||||
if task.GPU < 0 {
|
||||
chk := validateCheck{OK: false, Details: "gpu must be >= 0"}
|
||||
return chk, []string{"invalid gpu request"}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(task.GPUMemory) != "" {
|
||||
s := strings.TrimSpace(task.GPUMemory)
|
||||
if strings.HasSuffix(s, "%") {
|
||||
v := strings.TrimSuffix(s, "%")
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||
if err != nil || f <= 0 || f > 100 {
|
||||
details := "gpu_memory must be a percentage in (0,100]"
|
||||
chk := validateCheck{OK: false, Details: details}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
} else {
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil || f <= 0 || f > 1 {
|
||||
chk := validateCheck{OK: false, Details: "gpu_memory must be a fraction in (0,1]"}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if task.GPU == 0 && strings.TrimSpace(task.GPUMemory) != "" {
|
||||
chk := validateCheck{OK: false, Details: "gpu_memory requires gpu > 0"}
|
||||
return chk, []string{"invalid gpu_memory"}
|
||||
}
|
||||
|
||||
return validateCheck{OK: true}, nil
|
||||
}
|
||||
|
||||
func (h *WSHandler) handleValidateRequest(conn *websocket.Conn, payload []byte) error {
|
||||
// Protocol: [api_key_hash:16][target_type:1][id_len:1][id:var]
|
||||
// target_type: 0=commit_id (20 bytes), 1=task_id (string)
|
||||
// TODO(context): Add a versioned validate protocol once we need more target types/fields.
|
||||
if len(payload) < 18 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "validate request payload too short", "")
|
||||
}
|
||||
apiKeyHash := payload[:16]
|
||||
targetType := payload[16]
|
||||
idLen := int(payload[17])
|
||||
if idLen < 1 || len(payload) < 18+idLen {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate id length", "")
|
||||
}
|
||||
idBytes := payload[18 : 18+idLen]
|
||||
|
||||
// Validate API key and user
|
||||
user, err := h.validateWSUser(apiKeyHash)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeAuthenticationFailed,
|
||||
"Invalid API key",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
if h.authConfig != nil && h.authConfig.Enabled && !user.HasPermission("jobs:read") {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"Insufficient permissions to validate jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
if h.expManager == nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeServiceUnavailable,
|
||||
"Experiment manager not available",
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
var task *queue.Task
|
||||
commitID := ""
|
||||
switch targetType {
|
||||
case 0:
|
||||
if len(idBytes) != 20 {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "commit_id must be 20 bytes", "")
|
||||
}
|
||||
commitID = fmt.Sprintf("%x", idBytes)
|
||||
case 1:
|
||||
taskID := string(idBytes)
|
||||
if h.queue == nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeServiceUnavailable, "Task queue not available", "")
|
||||
}
|
||||
t, err := h.queue.GetTask(taskID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Task not found", err.Error())
|
||||
}
|
||||
task = t
|
||||
if h.authConfig != nil &&
|
||||
h.authConfig.Enabled &&
|
||||
!user.Admin &&
|
||||
task.UserID != user.Name &&
|
||||
task.CreatedBy != user.Name {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodePermissionDenied,
|
||||
"You can only validate your own jobs",
|
||||
"",
|
||||
)
|
||||
}
|
||||
if task.Metadata == nil || task.Metadata["commit_id"] == "" {
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "Task missing commit_id", "")
|
||||
}
|
||||
commitID = task.Metadata["commit_id"]
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "invalid validate target_type", "")
|
||||
}
|
||||
|
||||
r := validateReport{
|
||||
OK: true,
|
||||
TS: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Checks: map[string]validateCheck{},
|
||||
}
|
||||
if task != nil {
|
||||
r.TaskID = task.ID
|
||||
}
|
||||
if commitID != "" {
|
||||
r.CommitID = commitID
|
||||
}
|
||||
|
||||
// Validate commit id format
|
||||
if ok, errMsg := helpers.ValidateCommitIDFormat(commitID); !ok {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, errMsg)
|
||||
}
|
||||
|
||||
// Experiment manifest integrity
|
||||
// TODO(context): Extend report to include per-file diff list on mismatch (bounded output).
|
||||
if r.OK {
|
||||
if ok, details := helpers.ValidateExperimentManifest(h.expManager, commitID); !ok {
|
||||
r.OK = false
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: false, Details: details}
|
||||
r.Errors = append(r.Errors, "experiment manifest validation failed")
|
||||
} else {
|
||||
r.Checks["experiment_manifest"] = validateCheck{OK: true}
|
||||
}
|
||||
}
|
||||
|
||||
// Deps manifest presence + hash
|
||||
// TODO(context): Allow client to declare which dependency manifest is authoritative.
|
||||
filesPath := h.expManager.GetFilesPath(commitID)
|
||||
depName, depCheck, depErrs := helpers.ValidateDepsManifest(h.expManager, commitID)
|
||||
if depErrs != nil {
|
||||
r.OK = false
|
||||
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
||||
r.Errors = append(r.Errors, depErrs...)
|
||||
} else {
|
||||
r.Checks["deps_manifest"] = validateCheck(depCheck)
|
||||
}
|
||||
|
||||
// Compare against expected task metadata if available.
|
||||
if task != nil {
|
||||
resCheck, resErrs := validateResourcesForTask(task)
|
||||
r.Checks["resources"] = resCheck
|
||||
if !resCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, resErrs...)
|
||||
}
|
||||
|
||||
// Run manifest checks: best-effort for queued tasks, required for running/completed/failed.
|
||||
if err := container.ValidateJobName(task.JobName); err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid job name")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "invalid job name"}
|
||||
} else if base := strings.TrimSpace(h.expManager.BasePath()); base == "" {
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing api base_path; cannot validate run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "missing api base_path; cannot validate run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "missing api base_path"}
|
||||
}
|
||||
} else {
|
||||
manifestDir, manifestBucket, found := findRunManifestDir(base, task.JobName)
|
||||
if !found {
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest not found")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "run manifest not found (job may not have started)")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "run manifest not found"}
|
||||
}
|
||||
} else if rm, err := manifest.LoadFromDir(manifestDir); err != nil || rm == nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "unable to read run manifest")
|
||||
r.Checks["run_manifest"] = validateCheck{OK: false, Details: "unable to read run manifest"}
|
||||
} else {
|
||||
r.Checks["run_manifest"] = validateCheck{OK: true}
|
||||
|
||||
expectedBucket, ok := expectedRunManifestBucketForStatus(task.Status)
|
||||
if ok {
|
||||
if expectedBucket != manifestBucket {
|
||||
msg := "run manifest location mismatch"
|
||||
chk := validateCheck{OK: false, Expected: expectedBucket, Actual: manifestBucket}
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, msg)
|
||||
r.Checks["run_manifest_location"] = chk
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, msg)
|
||||
r.Checks["run_manifest_location"] = chk
|
||||
}
|
||||
} else {
|
||||
r.Checks["run_manifest_location"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: expectedBucket,
|
||||
Actual: manifestBucket,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate task ID using helper
|
||||
taskIDCheck := helpers.ValidateTaskIDMatch(rm, task.ID)
|
||||
r.Checks["run_manifest_task_id"] = validateCheck(taskIDCheck)
|
||||
if !taskIDCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest task_id mismatch")
|
||||
}
|
||||
|
||||
// Validate commit ID using helper
|
||||
commitCheck := helpers.ValidateCommitIDMatch(rm.CommitID, task.Metadata["commit_id"])
|
||||
r.Checks["run_manifest_commit_id"] = validateCheck(commitCheck)
|
||||
if !commitCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest commit_id mismatch")
|
||||
}
|
||||
|
||||
// Validate deps provenance using helper
|
||||
depWantName := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
||||
depWantSHA := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
||||
depGotName := strings.TrimSpace(rm.DepsManifestName)
|
||||
depGotSHA := strings.TrimSpace(rm.DepsManifestSHA)
|
||||
depsCheck := helpers.ValidateDepsProvenance(depWantName, depWantSHA, depGotName, depGotSHA)
|
||||
r.Checks["run_manifest_deps"] = validateCheck(depsCheck)
|
||||
if !depsCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest deps provenance mismatch")
|
||||
}
|
||||
|
||||
// Validate snapshot using helpers
|
||||
if strings.TrimSpace(task.SnapshotID) != "" {
|
||||
snapWantID := strings.TrimSpace(task.SnapshotID)
|
||||
snapWantSHA := strings.TrimSpace(task.Metadata["snapshot_sha256"])
|
||||
snapGotID := strings.TrimSpace(rm.SnapshotID)
|
||||
snapGotSHA := strings.TrimSpace(rm.SnapshotSHA256)
|
||||
|
||||
snapIDCheck := helpers.ValidateSnapshotID(snapWantID, snapGotID)
|
||||
r.Checks["run_manifest_snapshot_id"] = validateCheck(snapIDCheck)
|
||||
if !snapIDCheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_id mismatch")
|
||||
}
|
||||
|
||||
snapSHACheck := helpers.ValidateSnapshotSHA(snapWantSHA, snapGotSHA)
|
||||
r.Checks["run_manifest_snapshot_sha256"] = validateCheck(snapSHACheck)
|
||||
if !snapSHACheck.OK {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest snapshot_sha256 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate lifecycle using helper
|
||||
lifecycleOK, details := helpers.ValidateRunManifestLifecycle(rm, task.Status)
|
||||
if lifecycleOK {
|
||||
r.Checks["run_manifest_lifecycle"] = validateCheck{OK: true}
|
||||
} else {
|
||||
chk := validateCheck{OK: false, Details: details}
|
||||
if shouldRequireRunManifest(task) {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "run manifest lifecycle invalid")
|
||||
r.Checks["run_manifest_lifecycle"] = chk
|
||||
} else {
|
||||
r.Warnings = append(r.Warnings, "run manifest lifecycle invalid")
|
||||
r.Checks["run_manifest_lifecycle"] = chk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
want := strings.TrimSpace(task.Metadata["experiment_manifest_overall_sha"])
|
||||
cur := ""
|
||||
if man, err := h.expManager.ReadManifest(commitID); err == nil && man != nil {
|
||||
cur = man.OverallSHA
|
||||
}
|
||||
if want == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing expected experiment_manifest_overall_sha")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Actual: cur}
|
||||
} else if cur == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "unable to read current experiment manifest overall sha")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want}
|
||||
} else if want != cur {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "experiment manifest overall sha mismatch")
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: false, Expected: want, Actual: cur}
|
||||
} else {
|
||||
r.Checks["expected_manifest_overall_sha"] = validateCheck{OK: true, Expected: want, Actual: cur}
|
||||
}
|
||||
|
||||
wantDep := strings.TrimSpace(task.Metadata["deps_manifest_name"])
|
||||
wantDepSha := strings.TrimSpace(task.Metadata["deps_manifest_sha256"])
|
||||
if wantDep == "" || wantDepSha == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing expected deps manifest provenance")
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{OK: false}
|
||||
} else if depName != "" {
|
||||
sha, _ := helpers.FileSHA256Hex(filepath.Join(filesPath, depName))
|
||||
ok := (wantDep == depName && wantDepSha == sha)
|
||||
if !ok {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "deps manifest provenance mismatch")
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{
|
||||
OK: false,
|
||||
Expected: wantDep + ":" + wantDepSha,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
} else {
|
||||
r.Checks["expected_deps_manifest"] = validateCheck{
|
||||
OK: true,
|
||||
Expected: wantDep + ":" + wantDepSha,
|
||||
Actual: depName + ":" + sha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot/dataset checks require dataDir.
|
||||
// TODO(context): Support snapshot stores beyond local filesystem (e.g. S3).
|
||||
// TODO(context): Validate snapshots by digest.
|
||||
if task.SnapshotID != "" {
|
||||
if h.dataDir == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "api server data_dir not configured; cannot validate snapshot")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Details: "missing api data_dir"}
|
||||
} else {
|
||||
wantSnap, nerr := worker.NormalizeSHA256ChecksumHex(task.Metadata["snapshot_sha256"])
|
||||
if nerr != nil || wantSnap == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "missing/invalid snapshot_sha256")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false}
|
||||
} else {
|
||||
curSnap, err := worker.DirOverallSHA256Hex(
|
||||
filepath.Join(h.dataDir, "snapshots", task.SnapshotID),
|
||||
)
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "snapshot hash computation failed")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Details: err.Error()}
|
||||
} else if curSnap != wantSnap {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "snapshot checksum mismatch")
|
||||
r.Checks["snapshot"] = validateCheck{OK: false, Expected: wantSnap, Actual: curSnap}
|
||||
} else {
|
||||
r.Checks["snapshot"] = validateCheck{OK: true, Expected: wantSnap, Actual: curSnap}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(task.DatasetSpecs) > 0 {
|
||||
// TODO(context): Add dataset URI fetch/verification.
|
||||
// TODO(context): Currently only validates local materialized datasets.
|
||||
for _, ds := range task.DatasetSpecs {
|
||||
if ds.Checksum == "" {
|
||||
continue
|
||||
}
|
||||
key := "dataset:" + ds.Name
|
||||
if h.dataDir == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(
|
||||
r.Errors,
|
||||
"api server data_dir not configured; cannot validate dataset checksums",
|
||||
)
|
||||
r.Checks[key] = validateCheck{OK: false, Details: "missing api data_dir"}
|
||||
continue
|
||||
}
|
||||
wantDS, nerr := worker.NormalizeSHA256ChecksumHex(ds.Checksum)
|
||||
if nerr != nil || wantDS == "" {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "invalid dataset checksum format")
|
||||
r.Checks[key] = validateCheck{OK: false, Details: "invalid checksum"}
|
||||
continue
|
||||
}
|
||||
curDS, err := worker.DirOverallSHA256Hex(filepath.Join(h.dataDir, ds.Name))
|
||||
if err != nil {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "dataset checksum computation failed")
|
||||
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Details: err.Error()}
|
||||
continue
|
||||
}
|
||||
if curDS != wantDS {
|
||||
r.OK = false
|
||||
r.Errors = append(r.Errors, "dataset checksum mismatch")
|
||||
r.Checks[key] = validateCheck{OK: false, Expected: wantDS, Actual: curDS}
|
||||
continue
|
||||
}
|
||||
r.Checks[key] = validateCheck{OK: true, Expected: wantDS, Actual: curDS}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(
|
||||
conn,
|
||||
ErrorCodeUnknownError,
|
||||
"failed to serialize validate report",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
return h.sendResponsePacket(conn, NewDataPacket("validate", payloadBytes))
|
||||
}
|
||||
Loading…
Reference in a new issue