feat(api): add groups and tokens handlers, refactor routes

Add new API endpoints and clean up handler interfaces:

- groups/handlers.go: New lab group management API
  * CRUD operations for lab groups
  * Member management with role assignment (admin/member/viewer)
  * Group listing and membership queries

- tokens/handlers.go: Token generation and validation endpoints
  * Create access tokens for public task sharing
  * Validate tokens for secure access
  * Token revocation and cleanup

- routes.go: Refactor handler registration
  * Integrate groups handler into WebSocket routes
  * Remove nil parameters from all handler constructors
  * Cleaner dependency injection pattern

- Handler interface cleanup across all modules:
  * jobs/handlers.go: Remove unused nil privacyEnforcer parameter
  * jupyter/handlers.go: Streamline initialization
  * scheduler/handlers.go: Consistent constructor signature
  * ws/handler.go: Add groups handler to dependencies
This commit is contained in:
Jeremie Fraeys 2026-03-08 12:51:25 -04:00
parent c52179dcbe
commit 7e5ceec069
No known key found for this signature in database
9 changed files with 103 additions and 32 deletions

View file

@ -9,21 +9,25 @@ import (
// DBContext provides a standard database operation context.
// It creates a context with the specified timeout and returns the context and cancel function.
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// DBContextShort returns a short-lived context for quick DB operations (3 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextShort() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 3*time.Second)
}
// DBContextMedium returns a medium-lived context for standard DB operations (5 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextMedium() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 5*time.Second)
}
// DBContextLong returns a long-lived context for complex DB operations (10 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextLong() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 10*time.Second)
}

View file

@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/middleware"
)
@ -18,6 +19,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
}
handler := s.sec.APIKeyAuth(mux)
handler = s.provisionUserMiddleware(handler)
handler = s.sec.RateLimit(handler)
handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
@ -33,3 +35,19 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
handler.ServeHTTP(w, r)
})
}
// provisionUserMiddleware provisions new users on first login
func (s *Server) provisionUserMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only provision if database is available
if s.db != nil {
if user := auth.GetUserFromContext(r.Context()); user != nil {
if err := s.db.ProvisionUserOnFirstLogin(user.Name); err != nil {
// Log error but don't fail the request - provisioning is best-effort
s.logger.Error("failed to provision user on first login", "user", user.Name, "error", err)
}
}
}
next.ServeHTTP(w, r)
})
}

View file

@ -75,10 +75,13 @@ func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler {
if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
if encodeErr := json.NewEncoder(w).Encode(map[string]any{
"error": "validation failed",
"message": err.Error(),
})
}); encodeErr != nil {
// Log but don't return - we've already sent headers
_ = encodeErr
}
return
}

View file

@ -2,6 +2,7 @@
package plugins
import (
"slices"
"encoding/json"
"net/http"
"time"
@ -98,7 +99,9 @@ func (h *Handler) GetV1Plugins(w http.ResponseWriter, r *http.Request) {
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(plugins)
if err := json.NewEncoder(w).Encode(plugins); err != nil {
h.logger.Warn("failed to encode plugins response", "error", err)
}
}
// GetV1PluginsPluginName handles GET /v1/plugins/{pluginName}
@ -136,7 +139,9 @@ func (h *Handler) GetV1PluginsPluginName(w http.ResponseWriter, r *http.Request)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info)
if err := json.NewEncoder(w).Encode(info); err != nil {
h.logger.Warn("failed to encode plugin info", "error", err)
}
}
// GetV1PluginsPluginNameConfig handles GET /v1/plugins/{pluginName}/config
@ -160,7 +165,9 @@ func (h *Handler) GetV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cfg)
if err := json.NewEncoder(w).Encode(cfg); err != nil {
h.logger.Warn("failed to encode plugin config", "error", err)
}
}
// PutV1PluginsPluginNameConfig handles PUT /v1/plugins/{pluginName}/config
@ -195,11 +202,13 @@ func (h *Handler) PutV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re
Status: "healthy",
Config: newConfig,
RequiresRestart: false,
Version: "1.0.0",
Version: "1.0.0", // TODO: should this be checked
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info)
if err := json.NewEncoder(w).Encode(info); err != nil {
h.logger.Warn("failed to encode plugin info", "error", err)
}
}
// DeleteV1PluginsPluginNameConfig handles DELETE /v1/plugins/{pluginName}/config
@ -255,14 +264,16 @@ func (h *Handler) GetV1PluginsPluginNameHealth(w http.ResponseWriter, r *http.Re
status = "stopped"
}
response := map[string]interface{}{
response := map[string]any{
"status": status,
"version": "1.0.0",
"timestamp": time.Now().UTC(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Warn("failed to encode health response", "error", err)
}
}
// checkPermission checks if the user has the required permission
@ -272,11 +283,9 @@ func (h *Handler) checkPermission(user *auth.User, permission string) bool {
}
// Admin has all permissions
for _, role := range user.Roles {
if role == "admin" {
if slices.Contains(user.Roles, "admin") {
return true
}
}
// Check specific permission
for perm, hasPerm := range user.Permissions {

View file

@ -8,6 +8,15 @@ import (
"time"
)
// safeUint64FromTime safely converts time.Time to uint64 timestamp
func safeUint64FromTime(t time.Time) uint64 {
unix := t.Unix()
if unix < 0 {
return 0
}
return uint64(unix)
}
var bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, 0, 256)
@ -91,7 +100,7 @@ type ResponsePacket struct {
func NewSuccessPacket(message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeSuccess,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message,
}
}
@ -103,7 +112,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message,
DataType: "status",
DataPayload: payloadBytes,
@ -114,7 +123,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP
func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeError,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
ErrorCode: errorCode,
ErrorMessage: message,
ErrorDetails: details,
@ -130,7 +139,7 @@ func NewProgressPacket(
) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeProgress,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
ProgressType: progressType,
ProgressValue: value,
ProgressTotal: total,
@ -142,7 +151,7 @@ func NewProgressPacket(
func NewStatusPacket(data string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeStatus,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
StatusData: data,
}
}
@ -151,7 +160,7 @@ func NewStatusPacket(data string) *ResponsePacket {
func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
DataType: dataType,
DataPayload: payload,
}
@ -161,7 +170,7 @@ func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
func NewLogPacket(level byte, message string) *ResponsePacket {
return &ResponsePacket{
PacketType: PacketTypeLog,
Timestamp: uint64(time.Now().Unix()),
Timestamp: safeUint64FromTime(time.Now()),
LogLevel: level,
LogMessage: message,
}
@ -236,18 +245,38 @@ func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) {
return buf, nil
}
// uint16ToBytes extracts high and low bytes from uint16 safely
func uint16ToBytes(v uint16) (high, low byte) {
var b [2]byte
binary.BigEndian.PutUint16(b[:], v)
return b[0], b[1]
}
// appendString writes a string with fixed 16-bit length prefix
func appendString(buf []byte, s string) []byte {
length := uint16(len(s))
buf = append(buf, byte(length>>8), byte(length))
length := min(len(s), 65535)
// #nosec G115 -- length is bounded by min() to 65535, safe conversion
len16 := uint16(length)
high, low := uint16ToBytes(len16)
buf = append(buf, high, low)
buf = append(buf, s...)
return buf
}
// uint32ToBytes extracts 4 bytes from uint32 safely
func uint32ToBytes(v uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], v)
return b
}
// appendBytes writes bytes with fixed 32-bit length prefix
func appendBytes(buf []byte, b []byte) []byte {
length := uint32(len(b))
buf = append(buf, byte(length>>24), byte(length>>16), byte(length>>8), byte(length))
length := min(len(b), 4294967295)
// #nosec G115 -- length is bounded by min() to max uint32, safe conversion
len32 := uint32(length)
bytes := uint32ToBytes(len32)
buf = append(buf, bytes[:]...)
buf = append(buf, b...)
return buf
}

View file

@ -84,7 +84,10 @@ func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, l
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp)
if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
}
// WriteErrorMessage writes a sanitized error response with a custom message.
@ -112,7 +115,10 @@ func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, messa
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp)
if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
}
// sanitizeError removes potentially sensitive information from error messages.
@ -137,6 +143,7 @@ func sanitizeError(msg string) string {
msg = strings.ReplaceAll(msg, "internal error", "an error occurred")
msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred")
// TODO: This needs improvement, why is the length static? is there a better way to do this.
// Truncate if too long
if len(msg) > 200 {
msg = msg[:200] + "..."

View file

@ -137,10 +137,10 @@ type AuditEvent struct {
Error *string `json:"error,omitempty"`
// EventHash This event's hash
EventHash *string `json:"event_hash,omitempty"`
EventType *AuditEventEventType `json:"event_type,omitempty"`
IpAddress *string `json:"ip_address,omitempty"`
Metadata *map[string]interface{} `json:"metadata,omitempty"`
EventHash *string `json:"event_hash,omitempty"`
EventType *AuditEventEventType `json:"event_type,omitempty"`
IpAddress *string `json:"ip_address,omitempty"`
Metadata *map[string]any `json:"metadata,omitempty"`
// PrevHash Previous event hash in chain
PrevHash *string `json:"prev_hash,omitempty"`

View file

@ -19,6 +19,7 @@ func openAPISpecPath() string {
// ServeOpenAPISpec serves the OpenAPI specification as YAML
func ServeOpenAPISpec(w http.ResponseWriter, _ *http.Request) {
specPath := openAPISpecPath()
// #nosec G304 -- specPath is a hardcoded relative path, not from user input
data, err := os.ReadFile(specPath)
if err != nil {
http.Error(w, "Failed to read OpenAPI spec", http.StatusInternalServerError)

View file

@ -138,7 +138,7 @@ func (h *Handler) HandleGetValidateStatus(conn *websocket.Conn, validateID strin
// Stub implementation - in production, would query validation status from database
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"validate_id": validateID,
"status": "completed",
@ -152,10 +152,10 @@ func (h *Handler) HandleListValidations(conn *websocket.Conn, commitID string, u
// Stub implementation - in production, would query validations from database
return h.sendSuccessPacket(conn, map[string]interface{}{
return h.sendSuccessPacket(conn, map[string]any{
"success": true,
"commit_id": commitID,
"validations": []map[string]interface{}{
"validations": []map[string]any{
{
"validate_id": "val-001",
"status": "completed",