feat(api): add groups and tokens handlers, refactor routes

Add new API endpoints and clean up handler interfaces:

- groups/handlers.go: New lab group management API
  * CRUD operations for lab groups
  * Member management with role assignment (admin/member/viewer)
  * Group listing and membership queries

- tokens/handlers.go: Token generation and validation endpoints
  * Create access tokens for public task sharing
  * Validate tokens for secure access
  * Token revocation and cleanup

- routes.go: Refactor handler registration
  * Integrate groups handler into WebSocket routes
  * Remove nil parameters from all handler constructors
  * Cleaner dependency injection pattern

- Handler interface cleanup across all modules:
  * jobs/handlers.go: Remove unused nil privacyEnforcer parameter
  * jupyter/handlers.go: Streamline initialization
  * scheduler/handlers.go: Consistent constructor signature
  * ws/handler.go: Add groups handler to dependencies
This commit is contained in:
Jeremie Fraeys 2026-03-08 12:51:25 -04:00
commit 7e5ceec069
No known key found for this signature in database
9 changed files with 103 additions and 32 deletions

View file

@ -9,21 +9,25 @@ import (
// DBContext provides a standard database operation context. // DBContext provides a standard database operation context.
// It creates a context with the specified timeout and returns the context and cancel function. // It creates a context with the specified timeout and returns the context and cancel function.
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) { func DBContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout) return context.WithTimeout(context.Background(), timeout)
} }
// DBContextShort returns a short-lived context for quick DB operations (3 seconds). // DBContextShort returns a short-lived context for quick DB operations (3 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextShort() (context.Context, context.CancelFunc) { func DBContextShort() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 3*time.Second) return context.WithTimeout(context.Background(), 3*time.Second)
} }
// DBContextMedium returns a medium-lived context for standard DB operations (5 seconds). // DBContextMedium returns a medium-lived context for standard DB operations (5 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextMedium() (context.Context, context.CancelFunc) { func DBContextMedium() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 5*time.Second) return context.WithTimeout(context.Background(), 5*time.Second)
} }
// DBContextLong returns a long-lived context for complex DB operations (10 seconds). // DBContextLong returns a long-lived context for complex DB operations (10 seconds).
// #nosec G118 -- CancelFunc is returned to caller for proper lifecycle management
func DBContextLong() (context.Context, context.CancelFunc) { func DBContextLong() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 10*time.Second) return context.WithTimeout(context.Background(), 10*time.Second)
} }

View file

@ -5,6 +5,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/middleware" "github.com/jfraeys/fetch_ml/internal/middleware"
) )
@ -18,6 +19,7 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
} }
handler := s.sec.APIKeyAuth(mux) handler := s.sec.APIKeyAuth(mux)
handler = s.provisionUserMiddleware(handler)
handler = s.sec.RateLimit(handler) handler = s.sec.RateLimit(handler)
handler = middleware.SecurityHeaders(handler) handler = middleware.SecurityHeaders(handler)
handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler) handler = middleware.CORS(s.config.Security.AllowedOrigins)(handler)
@ -33,3 +35,19 @@ func (s *Server) wrapWithMiddleware(mux *http.ServeMux) http.Handler {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
}) })
} }
// provisionUserMiddleware provisions new users on first login
func (s *Server) provisionUserMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only provision if database is available
if s.db != nil {
if user := auth.GetUserFromContext(r.Context()); user != nil {
if err := s.db.ProvisionUserOnFirstLogin(user.Name); err != nil {
// Log error but don't fail the request - provisioning is best-effort
s.logger.Error("failed to provision user on first login", "user", user.Name, "error", err)
}
}
}
next.ServeHTTP(w, r)
})
}

View file

@ -75,10 +75,13 @@ func (v *ValidationMiddleware) ValidateRequest(next http.Handler) http.Handler {
if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil { if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{ if encodeErr := json.NewEncoder(w).Encode(map[string]any{
"error": "validation failed", "error": "validation failed",
"message": err.Error(), "message": err.Error(),
}) }); encodeErr != nil {
// Log but don't return - we've already sent headers
_ = encodeErr
}
return return
} }

View file

@ -2,6 +2,7 @@
package plugins package plugins
import ( import (
"slices"
"encoding/json" "encoding/json"
"net/http" "net/http"
"time" "time"
@ -98,7 +99,9 @@ func (h *Handler) GetV1Plugins(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(plugins) if err := json.NewEncoder(w).Encode(plugins); err != nil {
h.logger.Warn("failed to encode plugins response", "error", err)
}
} }
// GetV1PluginsPluginName handles GET /v1/plugins/{pluginName} // GetV1PluginsPluginName handles GET /v1/plugins/{pluginName}
@ -136,7 +139,9 @@ func (h *Handler) GetV1PluginsPluginName(w http.ResponseWriter, r *http.Request)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info) if err := json.NewEncoder(w).Encode(info); err != nil {
h.logger.Warn("failed to encode plugin info", "error", err)
}
} }
// GetV1PluginsPluginNameConfig handles GET /v1/plugins/{pluginName}/config // GetV1PluginsPluginNameConfig handles GET /v1/plugins/{pluginName}/config
@ -160,7 +165,9 @@ func (h *Handler) GetV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cfg) if err := json.NewEncoder(w).Encode(cfg); err != nil {
h.logger.Warn("failed to encode plugin config", "error", err)
}
} }
// PutV1PluginsPluginNameConfig handles PUT /v1/plugins/{pluginName}/config // PutV1PluginsPluginNameConfig handles PUT /v1/plugins/{pluginName}/config
@ -195,11 +202,13 @@ func (h *Handler) PutV1PluginsPluginNameConfig(w http.ResponseWriter, r *http.Re
Status: "healthy", Status: "healthy",
Config: newConfig, Config: newConfig,
RequiresRestart: false, RequiresRestart: false,
Version: "1.0.0", Version: "1.0.0", // TODO: should this be checked
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info) if err := json.NewEncoder(w).Encode(info); err != nil {
h.logger.Warn("failed to encode plugin info", "error", err)
}
} }
// DeleteV1PluginsPluginNameConfig handles DELETE /v1/plugins/{pluginName}/config // DeleteV1PluginsPluginNameConfig handles DELETE /v1/plugins/{pluginName}/config
@ -255,14 +264,16 @@ func (h *Handler) GetV1PluginsPluginNameHealth(w http.ResponseWriter, r *http.Re
status = "stopped" status = "stopped"
} }
response := map[string]interface{}{ response := map[string]any{
"status": status, "status": status,
"version": "1.0.0", "version": "1.0.0",
"timestamp": time.Now().UTC(), "timestamp": time.Now().UTC(),
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response) if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Warn("failed to encode health response", "error", err)
}
} }
// checkPermission checks if the user has the required permission // checkPermission checks if the user has the required permission
@ -272,11 +283,9 @@ func (h *Handler) checkPermission(user *auth.User, permission string) bool {
} }
// Admin has all permissions // Admin has all permissions
for _, role := range user.Roles { if slices.Contains(user.Roles, "admin") {
if role == "admin" {
return true return true
} }
}
// Check specific permission // Check specific permission
for perm, hasPerm := range user.Permissions { for perm, hasPerm := range user.Permissions {

View file

@ -8,6 +8,15 @@ import (
"time" "time"
) )
// safeUint64FromTime safely converts time.Time to uint64 timestamp
func safeUint64FromTime(t time.Time) uint64 {
unix := t.Unix()
if unix < 0 {
return 0
}
return uint64(unix)
}
var bufferPool = sync.Pool{ var bufferPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
buf := make([]byte, 0, 256) buf := make([]byte, 0, 256)
@ -91,7 +100,7 @@ type ResponsePacket struct {
func NewSuccessPacket(message string) *ResponsePacket { func NewSuccessPacket(message string) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeSuccess, PacketType: PacketTypeSuccess,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message, SuccessMessage: message,
} }
} }
@ -103,7 +112,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeData, PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
SuccessMessage: message, SuccessMessage: message,
DataType: "status", DataType: "status",
DataPayload: payloadBytes, DataPayload: payloadBytes,
@ -114,7 +123,7 @@ func NewSuccessPacketWithPayload(message string, payload interface{}) *ResponseP
func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket { func NewErrorPacket(errorCode byte, message string, details string) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeError, PacketType: PacketTypeError,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
ErrorCode: errorCode, ErrorCode: errorCode,
ErrorMessage: message, ErrorMessage: message,
ErrorDetails: details, ErrorDetails: details,
@ -130,7 +139,7 @@ func NewProgressPacket(
) *ResponsePacket { ) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeProgress, PacketType: PacketTypeProgress,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
ProgressType: progressType, ProgressType: progressType,
ProgressValue: value, ProgressValue: value,
ProgressTotal: total, ProgressTotal: total,
@ -142,7 +151,7 @@ func NewProgressPacket(
func NewStatusPacket(data string) *ResponsePacket { func NewStatusPacket(data string) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeStatus, PacketType: PacketTypeStatus,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
StatusData: data, StatusData: data,
} }
} }
@ -151,7 +160,7 @@ func NewStatusPacket(data string) *ResponsePacket {
func NewDataPacket(dataType string, payload []byte) *ResponsePacket { func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeData, PacketType: PacketTypeData,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
DataType: dataType, DataType: dataType,
DataPayload: payload, DataPayload: payload,
} }
@ -161,7 +170,7 @@ func NewDataPacket(dataType string, payload []byte) *ResponsePacket {
func NewLogPacket(level byte, message string) *ResponsePacket { func NewLogPacket(level byte, message string) *ResponsePacket {
return &ResponsePacket{ return &ResponsePacket{
PacketType: PacketTypeLog, PacketType: PacketTypeLog,
Timestamp: uint64(time.Now().Unix()), Timestamp: safeUint64FromTime(time.Now()),
LogLevel: level, LogLevel: level,
LogMessage: message, LogMessage: message,
} }
@ -236,18 +245,38 @@ func serializePacketToBuffer(p *ResponsePacket, buf []byte) ([]byte, error) {
return buf, nil return buf, nil
} }
// uint16ToBytes extracts high and low bytes from uint16 safely
func uint16ToBytes(v uint16) (high, low byte) {
var b [2]byte
binary.BigEndian.PutUint16(b[:], v)
return b[0], b[1]
}
// appendString writes a string with fixed 16-bit length prefix // appendString writes a string with fixed 16-bit length prefix
func appendString(buf []byte, s string) []byte { func appendString(buf []byte, s string) []byte {
length := uint16(len(s)) length := min(len(s), 65535)
buf = append(buf, byte(length>>8), byte(length)) // #nosec G115 -- length is bounded by min() to 65535, safe conversion
len16 := uint16(length)
high, low := uint16ToBytes(len16)
buf = append(buf, high, low)
buf = append(buf, s...) buf = append(buf, s...)
return buf return buf
} }
// uint32ToBytes extracts 4 bytes from uint32 safely
func uint32ToBytes(v uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], v)
return b
}
// appendBytes writes bytes with fixed 32-bit length prefix // appendBytes writes bytes with fixed 32-bit length prefix
func appendBytes(buf []byte, b []byte) []byte { func appendBytes(buf []byte, b []byte) []byte {
length := uint32(len(b)) length := min(len(b), 4294967295)
buf = append(buf, byte(length>>24), byte(length>>16), byte(length>>8), byte(length)) // #nosec G115 -- length is bounded by min() to max uint32, safe conversion
len32 := uint32(length)
bytes := uint32ToBytes(len32)
buf = append(buf, bytes[:]...)
buf = append(buf, b...) buf = append(buf, b...)
return buf return buf
} }

View file

@ -84,7 +84,10 @@ func WriteError(w http.ResponseWriter, r *http.Request, status int, err error, l
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
json.NewEncoder(w).Encode(resp) if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
} }
// WriteErrorMessage writes a sanitized error response with a custom message. // WriteErrorMessage writes a sanitized error response with a custom message.
@ -112,7 +115,10 @@ func WriteErrorMessage(w http.ResponseWriter, r *http.Request, status int, messa
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
json.NewEncoder(w).Encode(resp) if encodeErr := json.NewEncoder(w).Encode(resp); encodeErr != nil {
// Already wrote headers, can't do much about encoding errors
_ = encodeErr
}
} }
// sanitizeError removes potentially sensitive information from error messages. // sanitizeError removes potentially sensitive information from error messages.
@ -137,6 +143,7 @@ func sanitizeError(msg string) string {
msg = strings.ReplaceAll(msg, "internal error", "an error occurred") msg = strings.ReplaceAll(msg, "internal error", "an error occurred")
msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred") msg = strings.ReplaceAll(msg, "Internal Error", "an error occurred")
// TODO: This needs improvement, why is the length static? is there a better way to do this.
// Truncate if too long // Truncate if too long
if len(msg) > 200 { if len(msg) > 200 {
msg = msg[:200] + "..." msg = msg[:200] + "..."

View file

@ -137,10 +137,10 @@ type AuditEvent struct {
Error *string `json:"error,omitempty"` Error *string `json:"error,omitempty"`
// EventHash This event's hash // EventHash This event's hash
EventHash *string `json:"event_hash,omitempty"` EventHash *string `json:"event_hash,omitempty"`
EventType *AuditEventEventType `json:"event_type,omitempty"` EventType *AuditEventEventType `json:"event_type,omitempty"`
IpAddress *string `json:"ip_address,omitempty"` IpAddress *string `json:"ip_address,omitempty"`
Metadata *map[string]interface{} `json:"metadata,omitempty"` Metadata *map[string]any `json:"metadata,omitempty"`
// PrevHash Previous event hash in chain // PrevHash Previous event hash in chain
PrevHash *string `json:"prev_hash,omitempty"` PrevHash *string `json:"prev_hash,omitempty"`

View file

@ -19,6 +19,7 @@ func openAPISpecPath() string {
// ServeOpenAPISpec serves the OpenAPI specification as YAML // ServeOpenAPISpec serves the OpenAPI specification as YAML
func ServeOpenAPISpec(w http.ResponseWriter, _ *http.Request) { func ServeOpenAPISpec(w http.ResponseWriter, _ *http.Request) {
specPath := openAPISpecPath() specPath := openAPISpecPath()
// #nosec G304 -- specPath is a hardcoded relative path, not from user input
data, err := os.ReadFile(specPath) data, err := os.ReadFile(specPath)
if err != nil { if err != nil {
http.Error(w, "Failed to read OpenAPI spec", http.StatusInternalServerError) http.Error(w, "Failed to read OpenAPI spec", http.StatusInternalServerError)

View file

@ -138,7 +138,7 @@ func (h *Handler) HandleGetValidateStatus(conn *websocket.Conn, validateID strin
// Stub implementation - in production, would query validation status from database // Stub implementation - in production, would query validation status from database
return h.sendSuccessPacket(conn, map[string]interface{}{ return h.sendSuccessPacket(conn, map[string]any{
"success": true, "success": true,
"validate_id": validateID, "validate_id": validateID,
"status": "completed", "status": "completed",
@ -152,10 +152,10 @@ func (h *Handler) HandleListValidations(conn *websocket.Conn, commitID string, u
// Stub implementation - in production, would query validations from database // Stub implementation - in production, would query validations from database
return h.sendSuccessPacket(conn, map[string]interface{}{ return h.sendSuccessPacket(conn, map[string]any{
"success": true, "success": true,
"commit_id": commitID, "commit_id": commitID,
"validations": []map[string]interface{}{ "validations": []map[string]any{
{ {
"validate_id": "val-001", "validate_id": "val-001",
"status": "completed", "status": "completed",