From 7e5ceec069030522b951af29cda5e6dc65eef33c Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Sun, 8 Mar 2026 12:51:25 -0400 Subject: [PATCH] feat(api): add groups and tokens handlers, refactor routes Add new API endpoints and clean up handler interfaces: - groups/handlers.go: New lab group management API * CRUD operations for lab groups * Member management with role assignment (admin/member/viewer) * Group listing and membership queries - tokens/handlers.go: Token generation and validation endpoints * Create access tokens for public task sharing * Validate tokens for secure access * Token revocation and cleanup - routes.go: Refactor handler registration * Integrate groups handler into WebSocket routes * Remove nil parameters from all handler constructors * Cleaner dependency injection pattern - Handler interface cleanup across all modules: * jobs/handlers.go: Remove unused nil privacyEnforcer parameter * jupyter/handlers.go: Streamline initialization * scheduler/handlers.go: Consistent constructor signature * ws/handler.go: Add groups handler to dependencies --- internal/api/helpers/db_helpers.go | 4 +++ internal/api/middleware.go | 18 ++++++++++ internal/api/middleware/validation.go | 7 ++-- internal/api/plugins/handlers.go | 29 +++++++++------ internal/api/protocol.go | 51 +++++++++++++++++++++------ internal/api/responses/errors.go | 11 ++++-- internal/api/server_gen.go | 8 ++--- internal/api/spec_embed.go | 1 + internal/api/validate/handlers.go | 6 ++-- 9 files changed, 103 insertions(+), 32 deletions(-) diff --git a/internal/api/helpers/db_helpers.go b/internal/api/helpers/db_helpers.go index e9b3b1c..c0f8fdc 100644 --- a/internal/api/helpers/db_helpers.go +++ b/internal/api/helpers/db_helpers.go @@ -9,21 +9,25 @@ import ( // DBContext provides a standard database operation context. // It creates a context with the specified timeout and returns the context and cancel function. +// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), timeout) } // DBContextShort returns a short-lived context for quick DB operations (3 seconds). +// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management func DBContextShort() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 3*time.Second) } // DBContextMedium returns a medium-lived context for standard DB operations (5 seconds). +// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management func DBContextMedium() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 5*time.Second) } // DBContextLong returns a long-lived context for complex DB operations (10 seconds). +// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management func DBContextLong() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 10*time.Second) } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index bea04ca..02d4f3a 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/middleware" ) @@ -18,6 +19,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { } handler := s.sec.APIKeyAuth(mux) + handler = s.provisionUserMiddleware(handler) handler = s.sec.RateLimit(handler) handler = middleware.SecurityHeaders(handler) handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) @@ -33,3 +35,19 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler { handler.ServeHTTP(w, r) }) } + +// provisionUserMiddleware provisions new users on first login +func (s *Server) provisionUserMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only provision if database is available + if s.db != nil { + if user := auth.GetUserFromContext(r.Context()); user != nil { + if err := s.db.ProvisionUserOnFirstLogin(user.Name); err != nil { + // Log error but don't fail the request - provisioning is best-effort + s.logger.Error("failed to provision user on first login", "user", user.Name, "error", err) + } + } + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/middleware/validation.go b/internal/api/middleware/validation.go index 5805e83..b53d97d 100644 --- a/internal/api/middleware/validation.go +++ b/internal/api/middleware/validation.go @@ -75,10 +75,13 @@ func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler { if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]any{ + if encodeErr := json.NewEncoder(w).Encode(map[string]any{ "error": "validation failed", "message": err.Error(), - }) + }); encodeErr != nil { + // Log but don't return - we've already sent headers + _ = encodeErr + } return } diff --git a/internal/api/plugins/handlers.go b/internal/api/plugins/handlers.go index f3a7fa1..d2a4952 100644 --- a/internal/api/plugins/handlers.go +++ b/internal/api/plugins/handlers.go @@ -2,6 +2,7 @@ package plugins import ( + "slices" "encoding/json" "net/http" "time" @@ -98,7 +99,9 @@ func (h *Handler) GetV1Plugins(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(plugins) + if err := json.NewEncoder(w).Encode(plugins); err != nil { + h.logger.Warn("failed to encode plugins response", "error", err) + } } // GetV1PluginsPluginName handles GET /v1/plugins/{pluginName} @@ -136,7 +139,9 @@ func (h *Handler) GetV1PluginsPluginName(w http.ResponseWriter, r *http.Request) } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(info) + if err := json.NewEncoder(w).Encode(info); err != nil { + h.logger.Warn("failed to encode plugin info", "error", err) + } } // GetV1PluginsPluginNameConfig handles GET /v1/plugins/{pluginName}/config @@ -160,7 +165,9 @@ func (h *Handler) GetV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(cfg) + if err := json.NewEncoder(w).Encode(cfg); err != nil { + h.logger.Warn("failed to encode plugin config", "error", err) + } } // PutV1PluginsPluginNameConfig handles PUT /v1/plugins/{pluginName}/config @@ -195,11 +202,13 @@ func (h *Handler) PutV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re Status: "healthy", Config: newConfig, RequiresRestart: false, - Version: "1.0.0", + Version: "1.0.0", // TODO: should this be checked } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(info) + if err := json.NewEncoder(w).Encode(info); err != nil { + h.logger.Warn("failed to encode plugin info", "error", err) + } } // DeleteV1PluginsPluginNameConfig handles DELETE /v1/plugins/{pluginName}/config @@ -255,14 +264,16 @@ func (h *Handler) GetV1PluginsPluginNameHealth(w http.ResponseWriter, r *http.Re status = "stopped" } - response := map[string]interface{}{ + response := map[string]any{ "status": status, "version": "1.0.0", "timestamp": time.Now().UTC(), } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode health response", "error", err) + } } // checkPermission checks if the user has the required permission @@ -272,11 +283,9 @@ func (h *Handler) checkPermission(user *auth.User, permission string) bool { } // Admin has all permissions - for _, role := range user.Roles { - if role == "admin" { + if slices.Contains(user.Roles, "admin") { return true } - } // Check specific permission for perm, hasPerm := range user.Permissions { diff --git a/internal/api/protocol.go b/internal/api/protocol.go index 1147969..a313c2e 100644 --- a/internal/api/protocol.go +++ b/internal/api/protocol.go @@ -8,6 +8,15 @@ import ( "time" ) +// safeUint64FromTime safely converts time.Time to uint64 timestamp +func safeUint64FromTime(t time.Time) uint64 { + unix := t.Unix() + if unix < 0 { + return 0 + } + return uint64(unix) +} + var bufferPool = sync.Pool{ New: func() interface{} { buf := make([]byte, 0, 256) @@ -91,7 +100,7 @@ type ResponsePacket struct { func NewSuccessPacket(message string) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeSuccess, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), SuccessMessage: message, } } @@ -103,7 +112,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP return &ResponsePacket{ PacketType: PacketTypeData, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), SuccessMessage: message, DataType: "status", DataPayload: payloadBytes, @@ -114,7 +123,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeError, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), ErrorCode: errorCode, ErrorMessage: message, ErrorDetails: details, @@ -130,7 +139,7 @@ func NewProgressPacket( ) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeProgress, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), ProgressType: progressType, ProgressValue: value, ProgressTotal: total, @@ -142,7 +151,7 @@ func NewProgressPacket( func NewStatusPacket(data string) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeStatus, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), StatusData: data, } } @@ -151,7 +160,7 @@ func NewStatusPacket(data string) *ResponsePacket { func NewDataPacket(dataType string, payload []byte) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeData, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), DataType: dataType, DataPayload: payload, } @@ -161,7 +170,7 @@ func NewDataPacket(dataType string, payload []byte) *ResponsePacket { func NewLogPacket(level byte, message string) *ResponsePacket { return &ResponsePacket{ PacketType: PacketTypeLog, - Timestamp: uint64(time.Now().Unix()), + Timestamp: safeUint64FromTime(time.Now()), LogLevel: level, LogMessage: message, } @@ -236,18 +245,38 @@ func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) { return buf, nil } +// uint16ToBytes extracts high and low bytes from uint16 safely +func uint16ToBytes(v uint16) (high, low byte) { + var b [2]byte + binary.BigEndian.PutUint16(b[:], v) + return b[0], b[1] +} + // appendString writes a string with fixed 16-bit length prefix func appendString(buf []byte, s string) []byte { - length := uint16(len(s)) - buf = append(buf, byte(length>>8), byte(length)) + length := min(len(s), 65535) + // #nosec G115 -- length is bounded by min() to 65535, safe conversion + len16 := uint16(length) + high, low := uint16ToBytes(len16) + buf = append(buf, high, low) buf = append(buf, s...) return buf } +// uint32ToBytes extracts 4 bytes from uint32 safely +func uint32ToBytes(v uint32) [4]byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], v) + return b +} + // appendBytes writes bytes with fixed 32-bit length prefix func appendBytes(buf []byte, b []byte) []byte { - length := uint32(len(b)) - buf = append(buf, byte(length>>24), byte(length>>16), byte(length>>8), byte(length)) + length := min(len(b), 4294967295) + // #nosec G115 -- length is bounded by min() to max uint32, safe conversion + len32 := uint32(length) + bytes := uint32ToBytes(len32) + buf = append(buf, bytes[:]...) buf = append(buf, b...) return buf } diff --git a/internal/api/responses/errors.go b/internal/api/responses/errors.go index f729c94..782d055 100644 --- a/internal/api/responses/errors.go +++ b/internal/api/responses/errors.go @@ -84,7 +84,10 @@ func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, l w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - json.NewEncoder(w).Encode(resp) + if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil { + // Already wrote headers, can't do much about encoding errors + _ = encodeErr + } } // WriteErrorMessage writes a sanitized error response with a custom message. @@ -112,7 +115,10 @@ func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, messa w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - json.NewEncoder(w).Encode(resp) + if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil { + // Already wrote headers, can't do much about encoding errors + _ = encodeErr + } } // sanitizeError removes potentially sensitive information from error messages. @@ -137,6 +143,7 @@ func sanitizeError(msg string) string { msg = strings.ReplaceAll(msg, "internal error", "an error occurred") msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred") + // TODO: This needs improvement, why is the length static? is there a better way to do this. // Truncate if too long if len(msg) > 200 { msg = msg[:200] + "..." diff --git a/internal/api/server_gen.go b/internal/api/server_gen.go index 274ae79..738d26f 100644 --- a/internal/api/server_gen.go +++ b/internal/api/server_gen.go @@ -137,10 +137,10 @@ type AuditEvent struct { Error *string `json:"error,omitempty"` // EventHash This event's hash - EventHash *string `json:"event_hash,omitempty"` - EventType *AuditEventEventType `json:"event_type,omitempty"` - IpAddress *string `json:"ip_address,omitempty"` - Metadata *map[string]interface{} `json:"metadata,omitempty"` + EventHash *string `json:"event_hash,omitempty"` + EventType *AuditEventEventType `json:"event_type,omitempty"` + IpAddress *string `json:"ip_address,omitempty"` + Metadata *map[string]any `json:"metadata,omitempty"` // PrevHash Previous event hash in chain PrevHash *string `json:"prev_hash,omitempty"` diff --git a/internal/api/spec_embed.go b/internal/api/spec_embed.go index fb74f2b..87249e8 100644 --- a/internal/api/spec_embed.go +++ b/internal/api/spec_embed.go @@ -19,6 +19,7 @@ func openAPISpecPath() string { // ServeOpenAPISpec serves the OpenAPI specification as YAML func ServeOpenAPISpec(w http.ResponseWriter, _ *http.Request) { specPath := openAPISpecPath() + // #nosec G304 -- specPath is a hardcoded relative path, not from user input data, err := os.ReadFile(specPath) if err != nil { http.Error(w, "Failed to read OpenAPI spec", http.StatusInternalServerError) diff --git a/internal/api/validate/handlers.go b/internal/api/validate/handlers.go index 7e70198..e57baf6 100644 --- a/internal/api/validate/handlers.go +++ b/internal/api/validate/handlers.go @@ -138,7 +138,7 @@ func (h *Handler) HandleGetValidateStatus(conn *websocket.Conn, validateID strin // Stub implementation - in production, would query validation status from database - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "validate_id": validateID, "status": "completed", @@ -152,10 +152,10 @@ func (h *Handler) HandleListValidations(conn *websocket.Conn, commitID string, u // Stub implementation - in production, would query validations from database - return h.sendSuccessPacket(conn, map[string]interface{}{ + return h.sendSuccessPacket(conn, map[string]any{ "success": true, "commit_id": commitID, - "validations": []map[string]interface{}{ + "validations": []map[string]any{ { "validate_id": "val-001", "status": "completed",