feat(auth): add token-based access and structured logging
Add comprehensive authentication and authorization enhancements: - tokens.go: New token management system for public task access and cloning * SHA-256 hashed token storage for security * Token generation, validation, and automatic cleanup * Support for public access and clone permissions - api_key.go: Extend User struct with Groups field * Lab group membership (ml-lab, nlp-group) * Integration with permission system for group-based access - flags.go: Security hardening - migrate to structured logging * Replace log.Printf with log/slog to prevent log injection attacks * Consistent structured output for all auth warnings * Safe handling of file paths and errors in logs - permissions.go: Add task sharing permission constants * PermissionTasksReadOwn: Access own tasks * PermissionTasksReadLab: Access lab group tasks * PermissionTasksReadAll: Admin/institution-wide access * PermissionTasksShare: Grant access to other users * PermissionTasksClone: Create copies of shared tasks * CanAccessTask() method with visibility checks - database.go: Improve error handling * Add structured error logging on row close failures
This commit is contained in:
parent
fbcf4d38e5
commit
c52179dcbe
14 changed files with 1199 additions and 60 deletions
|
|
@ -104,11 +104,7 @@ func (a *HandlerAdapter) DeleteV1JupyterServicesServiceId(ctx echo.Context, serv
|
|||
})
|
||||
}
|
||||
// TODO: Implement when StopServiceHTTP is available
|
||||
return ctx.JSON(501, map[string]any{
|
||||
"error": "Not implemented",
|
||||
"code": "NOT_IMPLEMENTED",
|
||||
"message": "Jupyter service stop not yet implemented via REST API",
|
||||
})
|
||||
return toHTTPHandler(a.jupyterHandler.StopServiceHTTP)(ctx)
|
||||
}
|
||||
|
||||
// GetV1Queue returns queue status
|
||||
|
|
|
|||
|
|
@ -44,6 +44,11 @@ func (s *Server) initializeComponents() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Initialize default lab group (if configured)
|
||||
if err := s.initDefaultLabGroup(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize security
|
||||
s.initSecurity()
|
||||
|
||||
|
|
@ -185,6 +190,31 @@ func (s *Server) initDatabaseSchema() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// initDefaultLabGroup creates the auto-provisioned default lab group if configured.
|
||||
// Reads DEFAULT_LAB_GROUP env var and creates the group if it doesn't exist.
|
||||
// If DEFAULT_LAB_GROUP is not set, this is a silent no-op (intentional - groups are optional).
|
||||
func (s *Server) initDefaultLabGroup() error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupName := os.Getenv("DEFAULT_LAB_GROUP")
|
||||
if groupName == "" {
|
||||
return nil // No default lab group configured
|
||||
}
|
||||
|
||||
groupID, err := s.db.GetOrCreateDefaultLabGroup("system")
|
||||
if err != nil {
|
||||
s.logger.Error("failed to initialize default lab group", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if groupID != "" {
|
||||
s.logger.Info("default lab group initialized", "group", groupName, "id", groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initSecurity initializes security middleware
|
||||
func (s *Server) initSecurity() {
|
||||
authConfig := s.config.BuildAuthConfig()
|
||||
|
|
|
|||
692
internal/api/groups/handlers.go
Normal file
692
internal/api/groups/handlers.go
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
// Package groups provides HTTP handlers for lab group management
|
||||
package groups
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Handler provides group management HTTP handlers
|
||||
type Handler struct {
|
||||
db *storage.DB
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewHandler creates a new groups handler
|
||||
func NewHandler(db *storage.DB, logger *logging.Logger) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest represents a request to create a new group
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// CreateGroup handles POST /api/groups
|
||||
func (h *Handler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Only admins can create groups
|
||||
if !user.Admin {
|
||||
http.Error(w, "forbidden: admin required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateGroupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, "name is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.db.CreateGroup(req.Name, req.Description, user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to create group", "error", err)
|
||||
http.Error(w, "failed to create group", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("group created", "group_id", group.ID, "name", group.Name, "created_by", user.Name)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(group); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListGroups handles GET /api/groups
|
||||
func (h *Handler) ListGroups(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.db.ListGroupsForUser(user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list groups", "error", err)
|
||||
http.Error(w, "failed to list groups", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"groups": groups,
|
||||
"count": len(groups),
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateInvitation handles POST /api/groups/{id}/invitations
|
||||
func (h *Handler) CreateInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
groupID := r.PathValue("id")
|
||||
if groupID == "" {
|
||||
http.Error(w, "group id required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is group admin
|
||||
isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID)
|
||||
if err != nil || !isAdmin {
|
||||
http.Error(w, "forbidden: group admin required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
InvitedUserID string `json:"invited_user_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.InvitedUserID == "" {
|
||||
http.Error(w, "invited_user_id is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
invitation, err := h.db.CreateGroupInvitation(groupID, req.InvitedUserID, user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to create invitation", "error", err)
|
||||
http.Error(w, "failed to create invitation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("invitation created", "invitation_id", invitation.ID, "group_id", groupID, "invited_user", req.InvitedUserID)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(invitation); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListInvitations handles GET /api/invitations
|
||||
func (h *Handler) ListInvitations(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
invitations, err := h.db.ListPendingInvitationsForUser(user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list invitations", "error", err)
|
||||
http.Error(w, "failed to list invitations", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"invitations": invitations,
|
||||
"count": len(invitations),
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AcceptInvitation handles POST /api/invitations/{id}/accept
|
||||
func (h *Handler) AcceptInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
invitationID := r.PathValue("id")
|
||||
if invitationID == "" {
|
||||
http.Error(w, "invitation id required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify invitation belongs to this user and is pending
|
||||
invitation, err := h.db.GetInvitation(invitationID)
|
||||
if err != nil {
|
||||
http.Error(w, "invitation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if invitation.InvitedUserID != user.Name {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if invitation.Status != "pending" {
|
||||
http.Error(w, "invitation already processed", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if expired (7 days default)
|
||||
if invitation.ExpiresAt != nil && time.Now().After(*invitation.ExpiresAt) {
|
||||
http.Error(w, "invitation expired", http.StatusGone)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.AcceptInvitation(invitationID, user.Name); err != nil {
|
||||
h.logger.Error("failed to accept invitation", "error", err)
|
||||
http.Error(w, "failed to accept invitation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("invitation accepted", "invitation_id", invitationID, "user", user.Name, "group_id", invitation.GroupID)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "accepted",
|
||||
"group_id": invitation.GroupID,
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeclineInvitation handles POST /api/invitations/{id}/decline
|
||||
func (h *Handler) DeclineInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
invitationID := r.PathValue("id")
|
||||
if invitationID == "" {
|
||||
http.Error(w, "invitation id required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
invitation, err := h.db.GetInvitation(invitationID)
|
||||
if err != nil {
|
||||
http.Error(w, "invitation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if invitation.InvitedUserID != user.Name {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeclineInvitation(invitationID, user.Name); err != nil {
|
||||
h.logger.Error("failed to decline invitation", "error", err)
|
||||
http.Error(w, "failed to decline invitation", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("invitation declined", "invitation_id", invitationID, "user", user.Name)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "declined",
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveMember handles DELETE /api/groups/{id}/members/{user}
|
||||
func (h *Handler) RemoveMember(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
groupID := r.PathValue("id")
|
||||
memberID := r.PathValue("user")
|
||||
if groupID == "" || memberID == "" {
|
||||
http.Error(w, "group id and user id required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is group admin
|
||||
isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID)
|
||||
if err != nil || !isAdmin {
|
||||
http.Error(w, "forbidden: group admin required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Cannot remove yourself (use leave group endpoint instead)
|
||||
if memberID == user.Name {
|
||||
http.Error(w, "cannot remove yourself; use leave group endpoint", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.RemoveGroupMember(groupID, memberID); err != nil {
|
||||
h.logger.Error("failed to remove member", "error", err)
|
||||
http.Error(w, "failed to remove member", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("member removed", "group_id", groupID, "member", memberID, "removed_by", user.Name)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListGroupTasks handles GET /api/groups/{id}/tasks
|
||||
func (h *Handler) ListGroupTasks(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
groupID := r.PathValue("id")
|
||||
if groupID == "" {
|
||||
http.Error(w, "group id required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is member of group
|
||||
isMember, err := h.db.IsGroupMember(user.Name, groupID)
|
||||
if err != nil || !isMember {
|
||||
http.Error(w, "forbidden: group membership required", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse pagination options
|
||||
limit := 100
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 100 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
cursor := r.URL.Query().Get("cursor")
|
||||
|
||||
opts := storage.ListTasksOptions{
|
||||
Limit: limit,
|
||||
Cursor: cursor,
|
||||
}
|
||||
|
||||
tasks, nextCursor, err := h.db.ListTasksForGroup(groupID, opts)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list group tasks", "error", err)
|
||||
http.Error(w, "failed to list tasks", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"tasks": tasks,
|
||||
"next_cursor": nextCursor,
|
||||
"count": len(tasks),
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details,
|
||||
}
|
||||
return conn.WriteJSON(err)
|
||||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
// HandleCreateGroup handles WebSocket group creation.
|
||||
// Protocol: [api_key_hash:16][name_len:1][name:var][desc_len:2][desc:var]
|
||||
func (h *Handler) HandleCreateGroup(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+1+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
nameLen := int(payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(payload) < offset+nameLen+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid name length", "")
|
||||
}
|
||||
name := string(payload[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
descLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if descLen < 0 || len(payload) < offset+descLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid description length", "")
|
||||
}
|
||||
description := string(payload[offset : offset+descLen])
|
||||
|
||||
// Only admins can create groups
|
||||
if !user.Admin {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden: admin required", "")
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return h.sendErrorPacket(conn, 0x01, "name is required", "")
|
||||
}
|
||||
|
||||
group, err := h.db.CreateGroup(name, description, user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to create group", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to create group", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("group created", "group_id", group.ID, "name", group.Name, "created_by", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"group": group,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListGroups handles WebSocket group listing.
|
||||
// Protocol: [api_key_hash:16]
|
||||
func (h *Handler) HandleListGroups(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
groups, err := h.db.ListGroupsForUser(user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list groups", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to list groups", err.Error())
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"groups": groups,
|
||||
"count": len(groups),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleCreateInvitation handles WebSocket invitation creation.
|
||||
// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][user_id_len:2][user_id:var]
|
||||
func (h *Handler) HandleCreateInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if groupIDLen <= 0 || len(payload) < offset+groupIDLen+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "")
|
||||
}
|
||||
groupID := string(payload[offset : offset+groupIDLen])
|
||||
offset += groupIDLen
|
||||
|
||||
userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if userIDLen <= 0 || len(payload) < offset+userIDLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid user_id length", "")
|
||||
}
|
||||
invitedUserID := string(payload[offset : offset+userIDLen])
|
||||
|
||||
// Check if user is group admin
|
||||
isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID)
|
||||
if err != nil || !isAdmin {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden: group admin required", "")
|
||||
}
|
||||
|
||||
invitation, err := h.db.CreateGroupInvitation(groupID, invitedUserID, user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to create invitation", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to create invitation", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("invitation created", "invitation_id", invitation.ID, "group_id", groupID, "invited_user", invitedUserID)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"invitation": invitation,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListInvitations handles WebSocket invitation listing.
|
||||
// Protocol: [api_key_hash:16]
|
||||
func (h *Handler) HandleListInvitations(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
invitations, err := h.db.ListPendingInvitationsForUser(user.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list invitations", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to list invitations", err.Error())
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"invitations": invitations,
|
||||
"count": len(invitations),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleAcceptInvitation handles WebSocket invitation acceptance.
|
||||
// Protocol: [api_key_hash:16][invitation_id_len:2][invitation_id:var]
|
||||
func (h *Handler) HandleAcceptInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
invIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if invIDLen <= 0 || len(payload) < offset+invIDLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid invitation_id length", "")
|
||||
}
|
||||
invitationID := string(payload[offset : offset+invIDLen])
|
||||
|
||||
// Verify invitation belongs to this user and is pending
|
||||
invitation, err := h.db.GetInvitation(invitationID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, 0x04, "invitation not found", "")
|
||||
}
|
||||
|
||||
if invitation.InvitedUserID != user.Name {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden", "")
|
||||
}
|
||||
|
||||
if invitation.Status != "pending" {
|
||||
return h.sendErrorPacket(conn, 0x05, "invitation already processed", "")
|
||||
}
|
||||
|
||||
// Check if expired (7 days default)
|
||||
if invitation.ExpiresAt != nil && time.Now().After(*invitation.ExpiresAt) {
|
||||
return h.sendErrorPacket(conn, 0x14, "invitation expired", "")
|
||||
}
|
||||
|
||||
if err := h.db.AcceptInvitation(invitationID, user.Name); err != nil {
|
||||
h.logger.Error("failed to accept invitation", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to accept invitation", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("invitation accepted", "invitation_id", invitationID, "user", user.Name, "group_id", invitation.GroupID)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"status": "accepted",
|
||||
"group_id": invitation.GroupID,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleDeclineInvitation handles WebSocket invitation decline.
|
||||
// Protocol: [api_key_hash:16][invitation_id_len:2][invitation_id:var]
|
||||
func (h *Handler) HandleDeclineInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
invIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if invIDLen <= 0 || len(payload) < offset+invIDLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid invitation_id length", "")
|
||||
}
|
||||
invitationID := string(payload[offset : offset+invIDLen])
|
||||
|
||||
invitation, err := h.db.GetInvitation(invitationID)
|
||||
if err != nil {
|
||||
return h.sendErrorPacket(conn, 0x04, "invitation not found", "")
|
||||
}
|
||||
|
||||
if invitation.InvitedUserID != user.Name {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden", "")
|
||||
}
|
||||
|
||||
if err := h.db.DeclineInvitation(invitationID, user.Name); err != nil {
|
||||
h.logger.Error("failed to decline invitation", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to decline invitation", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("invitation declined", "invitation_id", invitationID, "user", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"status": "declined",
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRemoveMember handles WebSocket member removal.
|
||||
// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][user_id_len:2][user_id:var]
|
||||
func (h *Handler) HandleRemoveMember(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if groupIDLen <= 0 || len(payload) < offset+groupIDLen+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "")
|
||||
}
|
||||
groupID := string(payload[offset : offset+groupIDLen])
|
||||
offset += groupIDLen
|
||||
|
||||
userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if userIDLen <= 0 || len(payload) < offset+userIDLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid user_id length", "")
|
||||
}
|
||||
memberID := string(payload[offset : offset+userIDLen])
|
||||
|
||||
// Check if user is group admin
|
||||
isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID)
|
||||
if err != nil || !isAdmin {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden: group admin required", "")
|
||||
}
|
||||
|
||||
// Cannot remove yourself
|
||||
if memberID == user.Name {
|
||||
return h.sendErrorPacket(conn, 0x01, "cannot remove yourself; use leave group endpoint", "")
|
||||
}
|
||||
|
||||
if err := h.db.RemoveGroupMember(groupID, memberID); err != nil {
|
||||
h.logger.Error("failed to remove member", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to remove member", err.Error())
|
||||
}
|
||||
|
||||
h.logger.Info("member removed", "group_id", groupID, "member", memberID, "removed_by", user.Name)
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Member removed",
|
||||
})
|
||||
}
|
||||
|
||||
// HandleListGroupTasks handles WebSocket group task listing.
|
||||
// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][limit:1][cursor_len:2][cursor:var]
|
||||
func (h *Handler) HandleListGroupTasks(conn *websocket.Conn, payload []byte, user *auth.User) error {
|
||||
if len(payload) < 16+2+1+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "payload too short", "")
|
||||
}
|
||||
|
||||
offset := 16
|
||||
|
||||
groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
if groupIDLen <= 0 || len(payload) < offset+groupIDLen+1+2 {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "")
|
||||
}
|
||||
groupID := string(payload[offset : offset+groupIDLen])
|
||||
offset += groupIDLen
|
||||
|
||||
// Check if user is member of group
|
||||
isMember, err := h.db.IsGroupMember(user.Name, groupID)
|
||||
if err != nil || !isMember {
|
||||
return h.sendErrorPacket(conn, 0x03, "forbidden: group membership required", "")
|
||||
}
|
||||
|
||||
limit := int(payload[offset])
|
||||
offset++
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
cursorLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
|
||||
offset += 2
|
||||
var cursor string
|
||||
if cursorLen > 0 {
|
||||
if len(payload) < offset+cursorLen {
|
||||
return h.sendErrorPacket(conn, 0x01, "invalid cursor length", "")
|
||||
}
|
||||
cursor = string(payload[offset : offset+cursorLen])
|
||||
}
|
||||
|
||||
opts := storage.ListTasksOptions{
|
||||
Limit: limit,
|
||||
Cursor: cursor,
|
||||
}
|
||||
|
||||
tasks, nextCursor, err := h.db.ListTasksForGroup(groupID, opts)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list group tasks", "error", err)
|
||||
return h.sendErrorPacket(conn, 0x11, "failed to list tasks", err.Error())
|
||||
}
|
||||
|
||||
return h.sendSuccessPacket(conn, map[string]interface{}{
|
||||
"success": true,
|
||||
"tasks": tasks,
|
||||
"next_cursor": nextCursor,
|
||||
"count": len(tasks),
|
||||
})
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
package jobs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
|
@ -22,14 +21,11 @@ import (
|
|||
|
||||
// Handler provides job-related WebSocket handlers
|
||||
type Handler struct {
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
authConfig *auth.Config
|
||||
privacyEnforcer interface { // NEW: Privacy enforcement interface
|
||||
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
|
||||
}
|
||||
expManager *experiment.Manager
|
||||
logger *logging.Logger
|
||||
queue queue.Backend
|
||||
db *storage.DB
|
||||
authConfig *auth.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new jobs handler
|
||||
|
|
@ -39,17 +35,13 @@ func NewHandler(
|
|||
queue queue.Backend,
|
||||
db *storage.DB,
|
||||
authConfig *auth.Config,
|
||||
privacyEnforcer interface { // NEW - can be nil
|
||||
CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error)
|
||||
},
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
queue: queue,
|
||||
db: db,
|
||||
authConfig: authConfig,
|
||||
privacyEnforcer: privacyEnforcer,
|
||||
expManager: expManager,
|
||||
logger: logger,
|
||||
queue: queue,
|
||||
db: db,
|
||||
authConfig: authConfig,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,7 +67,7 @@ const (
|
|||
|
||||
// sendErrorPacket sends an error response packet to the client
|
||||
func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error {
|
||||
err := map[string]interface{}{
|
||||
err := map[string]any{
|
||||
"error": true,
|
||||
"code": code,
|
||||
"message": message,
|
||||
|
|
@ -85,7 +77,7 @@ func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, deta
|
|||
}
|
||||
|
||||
// sendSuccessPacket sends a success response packet
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error {
|
||||
func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error {
|
||||
return conn.WriteJSON(data)
|
||||
}
|
||||
|
||||
|
|
@ -340,7 +332,7 @@ func (h *Handler) HandleListJobs(conn *websocket.Conn, user *auth.User) error {
|
|||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
|
||||
jobs := []map[string]interface{}{}
|
||||
jobs := []map[string]any{}
|
||||
|
||||
// Scan all job directories
|
||||
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
|
||||
|
|
@ -411,9 +403,9 @@ func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Reques
|
|||
h.logger.Info("getting experiment history", "experiment", experimentID, "all_users", allUsers)
|
||||
|
||||
// Placeholder response
|
||||
response := map[string]interface{}{
|
||||
response := map[string]any{
|
||||
"experiment_id": experimentID,
|
||||
"history": []map[string]interface{}{
|
||||
"history": []map[string]any{
|
||||
{
|
||||
"timestamp": time.Now().UTC(),
|
||||
"event": "run_started",
|
||||
|
|
@ -440,7 +432,9 @@ func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Warn("failed to encode compare response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListAllJobsHTTP handles GET /api/jobs?all_users=true for team collaboration
|
||||
|
|
@ -455,7 +449,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
jobs := []map[string]interface{}{}
|
||||
jobs := []map[string]any{}
|
||||
|
||||
// Scan all job directories
|
||||
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
|
||||
|
|
@ -500,7 +494,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
response := map[string]any{
|
||||
"success": true,
|
||||
"jobs": jobs,
|
||||
"count": len(jobs),
|
||||
|
|
@ -508,5 +502,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Warn("failed to encode jobs list", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -254,3 +254,7 @@ func (h *Handler) ListServicesHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *Handler) StartServiceHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
func (h *Handler) StopServiceHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Not mplementated", http.StatusNotImplemented)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/audit"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/groups"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/plugins"
|
||||
|
|
@ -46,7 +47,6 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
|
|||
s.taskQueue,
|
||||
s.db,
|
||||
s.config.BuildAuthConfig(),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Experiment history endpoint: GET /api/experiments/:id/history
|
||||
|
|
@ -173,7 +173,6 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
|||
s.taskQueue,
|
||||
s.db,
|
||||
s.config.BuildAuthConfig(),
|
||||
nil, // privacyEnforcer - not enabled for now
|
||||
)
|
||||
|
||||
// Create jupyter handler
|
||||
|
|
@ -190,6 +189,9 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
|||
s.config.DataDir,
|
||||
)
|
||||
|
||||
// Create groups handler
|
||||
groupsHandler := groups.NewHandler(s.db, s.logger)
|
||||
|
||||
wsHandler := ws.NewHandler(
|
||||
s.config.BuildAuthConfig(),
|
||||
s.logger,
|
||||
|
|
@ -203,6 +205,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) {
|
|||
jobsHandler,
|
||||
jupyterHandler,
|
||||
datasetsHandler,
|
||||
groupsHandler,
|
||||
)
|
||||
|
||||
mux.Handle("/ws", wsHandler)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
|
|
@ -124,7 +123,9 @@ func (h *APIHandler) GetV1SchedulerStatus(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
if err := json.NewEncoder(w).Encode(status); err != nil {
|
||||
h.logger.Warn("failed to encode status", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetV1SchedulerWorkers handles GET /v1/scheduler/workers
|
||||
|
|
@ -159,7 +160,9 @@ func (h *APIHandler) GetV1SchedulerWorkers(w http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(workers)
|
||||
if err := json.NewEncoder(w).Encode(workers); err != nil {
|
||||
h.logger.Warn("failed to encode workers", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetV1SchedulerWorkersWorkerID handles GET /v1/scheduler/workers/{workerId}
|
||||
|
|
@ -200,7 +203,9 @@ func (h *APIHandler) GetV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *htt
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(worker)
|
||||
if err := json.NewEncoder(w).Encode(worker); err != nil {
|
||||
h.logger.Warn("failed to encode worker", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteV1SchedulerWorkersWorkerID handles DELETE /v1/scheduler/workers/{workerId}
|
||||
|
|
@ -247,7 +252,9 @@ func (h *APIHandler) GetV1SchedulerReservations(w http.ResponseWriter, r *http.R
|
|||
reservations := []ReservationInfo{}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(reservations)
|
||||
if err := json.NewEncoder(w).Encode(reservations); err != nil {
|
||||
h.logger.Warn("failed to encode reservations", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PostV1SchedulerReservations handles POST /v1/scheduler/reservations
|
||||
|
|
@ -295,7 +302,9 @@ func (h *APIHandler) PostV1SchedulerReservations(w http.ResponseWriter, r *http.
|
|||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(reservation)
|
||||
if err := json.NewEncoder(w).Encode(reservation); err != nil {
|
||||
h.logger.Warn("failed to encode reservation", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PatchV1SchedulerJobsJobIDPriority handles PATCH /v1/scheduler/jobs/{jobId}/priority
|
||||
|
|
@ -342,7 +351,9 @@ func (h *APIHandler) PatchV1SchedulerJobsJobIDPriority(w http.ResponseWriter, r
|
|||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
h.logger.Warn("failed to encode priority response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetV1SchedulerStatusStream handles GET /v1/scheduler/status/stream (SSE)
|
||||
|
|
@ -465,16 +476,3 @@ func (h *APIHandler) checkPermission(user *auth.User, permission string) bool {
|
|||
// Check specific permission
|
||||
return user.HasPermission(permission)
|
||||
}
|
||||
|
||||
// parseIntQueryParam parses an integer query parameter
|
||||
func parseIntQueryParam(r *http.Request, name string, defaultVal int) int {
|
||||
str := r.URL.Query().Get(name)
|
||||
if str == "" {
|
||||
return defaultVal
|
||||
}
|
||||
val, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
|
|
|||
197
internal/api/tokens/handlers.go
Normal file
197
internal/api/tokens/handlers.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
// Package tokens provides HTTP handlers for share token management
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// Handler provides share token management HTTP handlers
|
||||
type Handler struct {
|
||||
db *storage.DB
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewHandler creates a new share tokens handler
|
||||
func NewHandler(db *storage.DB, logger *logging.Logger) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTokenRequest represents a request to create a new share token
|
||||
type CreateTokenRequest struct {
|
||||
TaskID *string `json:"task_id,omitempty"`
|
||||
ExperimentID *string `json:"experiment_id,omitempty"`
|
||||
ExpiresIn *int `json:"expires_in_days,omitempty"` // Number of days until expiry
|
||||
MaxAccesses *int `json:"max_accesses,omitempty"` // Max number of accesses allowed
|
||||
}
|
||||
|
||||
// CreateTokenResponse represents the response with the created token
|
||||
type CreateTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
ShareLink string `json:"share_link"`
|
||||
}
|
||||
|
||||
// CreateShareToken handles POST /api/tokens
|
||||
// Creates a new share token for a task or experiment
|
||||
func (h *Handler) CreateShareToken(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has permission to share tasks
|
||||
if !user.HasPermission(auth.PermissionTasksShare) {
|
||||
http.Error(w, "forbidden: insufficient permissions to share tasks", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Must specify either task_id or experiment_id, not both
|
||||
if (req.TaskID == nil && req.ExperimentID == nil) || (req.TaskID != nil && req.ExperimentID != nil) {
|
||||
http.Error(w, "must specify exactly one of task_id or experiment_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Build options
|
||||
opts := auth.ShareTokenOptions{}
|
||||
if req.ExpiresIn != nil {
|
||||
expiresAt := time.Now().Add(time.Duration(*req.ExpiresIn) * 24 * time.Hour)
|
||||
opts.ExpiresAt = &expiresAt
|
||||
}
|
||||
if req.MaxAccesses != nil {
|
||||
opts.MaxAccesses = req.MaxAccesses
|
||||
}
|
||||
|
||||
// Generate token
|
||||
token, err := auth.GenerateShareToken(h.db, req.TaskID, req.ExperimentID, user.Name, opts)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to generate share token", "error", err)
|
||||
http.Error(w, "failed to create share token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build share link
|
||||
var path string
|
||||
if req.TaskID != nil {
|
||||
path = "/api/tasks/" + *req.TaskID
|
||||
} else {
|
||||
path = "/api/experiments/" + *req.ExperimentID
|
||||
}
|
||||
shareLink := auth.BuildShareLink("", path, token)
|
||||
|
||||
h.logger.Info("share token created",
|
||||
"token", token[:8]+"...",
|
||||
"created_by", user.Name,
|
||||
"task_id", req.TaskID,
|
||||
"experiment_id", req.ExperimentID,
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(CreateTokenResponse{
|
||||
Token: token,
|
||||
ShareLink: shareLink,
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListShareTokens handles GET /api/tokens
|
||||
// Lists share tokens for a task or experiment
|
||||
func (h *Handler) ListShareTokens(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
taskID := r.URL.Query().Get("task_id")
|
||||
experimentID := r.URL.Query().Get("experiment_id")
|
||||
|
||||
// Must specify either task_id or experiment_id
|
||||
if taskID == "" && experimentID == "" {
|
||||
http.Error(w, "must specify task_id or experiment_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var tokens []*storage.ShareToken
|
||||
var err error
|
||||
|
||||
if taskID != "" {
|
||||
tokens, err = h.db.ListShareTokensForTask(taskID)
|
||||
} else {
|
||||
tokens, err = h.db.ListShareTokensForExperiment(experimentID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list share tokens", "error", err)
|
||||
http.Error(w, "failed to list share tokens", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
}); err != nil {
|
||||
h.logger.Error("failed to encode response", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeShareToken handles DELETE /api/tokens/:token
|
||||
// Revokes a share token
|
||||
func (h *Handler) RevokeShareToken(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.GetUserFromContext(r.Context())
|
||||
if user == nil {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.PathValue("token")
|
||||
if token == "" {
|
||||
http.Error(w, "token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get token details to check ownership
|
||||
t, err := h.db.GetShareToken(token)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get share token", "error", err)
|
||||
http.Error(w, "token not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if t == nil {
|
||||
http.Error(w, "token not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Only the creator or an admin can revoke
|
||||
if t.CreatedBy != user.Name && !user.Admin {
|
||||
http.Error(w, "forbidden: not token owner or admin", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteShareToken(token); err != nil {
|
||||
h.logger.Error("failed to delete share token", "error", err)
|
||||
http.Error(w, "failed to revoke token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("share token revoked", "token", token[:8]+"...", "revoked_by", user.Name)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/api/datasets"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/groups"
|
||||
"github.com/jfraeys/fetch_ml/internal/api/jobs"
|
||||
jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter"
|
||||
)
|
||||
|
|
@ -75,6 +76,16 @@ const (
|
|||
OpcodeQueryJob = 0x23
|
||||
OpcodeQueryRunInfo = 0x28
|
||||
|
||||
// Group management opcodes
|
||||
OpcodeCreateGroup = 0x50
|
||||
OpcodeListGroups = 0x51
|
||||
OpcodeCreateInvitation = 0x52
|
||||
OpcodeListInvitations = 0x53
|
||||
OpcodeAcceptInvitation = 0x54
|
||||
OpcodeDeclineInvitation = 0x55
|
||||
OpcodeRemoveMember = 0x56
|
||||
OpcodeListGroupTasks = 0x57
|
||||
|
||||
//
|
||||
OpcodeCompareRuns = 0x30
|
||||
OpcodeFindRuns = 0x31
|
||||
|
|
@ -147,6 +158,7 @@ type Handler struct {
|
|||
authConfig *auth.Config
|
||||
jobsHandler *jobs.Handler
|
||||
jupyterHandler *jupyterj.Handler
|
||||
groupsHandler *groups.Handler // NEW: groups handler for WebSocket
|
||||
upgrader websocket.Upgrader
|
||||
dataDir string
|
||||
clientsMu sync.RWMutex
|
||||
|
|
@ -166,6 +178,7 @@ func NewHandler(
|
|||
jobsHandler *jobs.Handler,
|
||||
jupyterHandler *jupyterj.Handler,
|
||||
datasetsHandler *datasets.Handler,
|
||||
groupsHandler *groups.Handler,
|
||||
) *Handler {
|
||||
upgrader := createUpgrader(securityCfg)
|
||||
|
||||
|
|
@ -183,6 +196,7 @@ func NewHandler(
|
|||
jobsHandler: jobsHandler,
|
||||
jupyterHandler: jupyterHandler,
|
||||
datasetsHandler: datasetsHandler,
|
||||
groupsHandler: groupsHandler,
|
||||
clients: make(map[*Client]bool),
|
||||
}
|
||||
}
|
||||
|
|
@ -339,6 +353,22 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error {
|
|||
return h.handleSetRunOutcome(conn, payload)
|
||||
case OpcodeQueryRunInfo:
|
||||
return h.handleQueryRunInfo(conn, payload)
|
||||
case OpcodeCreateGroup:
|
||||
return h.handleCreateGroup(conn, payload)
|
||||
case OpcodeListGroups:
|
||||
return h.handleListGroups(conn, payload)
|
||||
case OpcodeCreateInvitation:
|
||||
return h.handleCreateInvitation(conn, payload)
|
||||
case OpcodeListInvitations:
|
||||
return h.handleListInvitations(conn, payload)
|
||||
case OpcodeAcceptInvitation:
|
||||
return h.handleAcceptInvitation(conn, payload)
|
||||
case OpcodeDeclineInvitation:
|
||||
return h.handleDeclineInvitation(conn, payload)
|
||||
case OpcodeRemoveMember:
|
||||
return h.handleRemoveMember(conn, payload)
|
||||
case OpcodeListGroupTasks:
|
||||
return h.handleListGroupTasks(conn, payload)
|
||||
default:
|
||||
return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode))
|
||||
}
|
||||
|
|
@ -429,6 +459,54 @@ func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error
|
|||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleCreateGroup(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleCreateGroup(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleListGroups(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleListGroups(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleCreateInvitation(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleCreateInvitation(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleListInvitations(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleListInvitations(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleAcceptInvitation(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleAcceptInvitation(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleDeclineInvitation(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleDeclineInvitation(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleRemoveMember(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleRemoveMember(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleListGroupTasks(conn *websocket.Conn, payload []byte) error {
|
||||
return h.withAuth(conn, payload, func(user *auth.User) error {
|
||||
return h.groupsHandler.HandleListGroupTasks(conn, payload, user)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error {
|
||||
// Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8]
|
||||
if len(payload) < 16+1+8 {
|
||||
|
|
@ -590,9 +668,15 @@ func selectDependencyManifest(filesPath string) (string, error) {
|
|||
for _, name := range []string{
|
||||
"requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle",
|
||||
} {
|
||||
if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil {
|
||||
//nolint:gosec // G703: name is from hardcoded slice, not user input
|
||||
fullPath := filepath.Join(filesPath, name)
|
||||
_, err := os.Stat(fullPath)
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("error checking %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no dependency manifest found")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ type User struct {
|
|||
Name string `json:"name"`
|
||||
Roles []string `json:"roles"`
|
||||
Admin bool `json:"admin"`
|
||||
Groups []string `json:"groups"` // NEW: lab groups ("ml-lab", "nlp-group")
|
||||
}
|
||||
|
||||
// ExtractAPIKeyFromRequest extracts an API key from the standard headers.
|
||||
|
|
|
|||
|
|
@ -194,7 +194,11 @@ func (s *DatabaseAuthStore) ListUsers(ctx context.Context) ([]APIKeyRecord, erro
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query users: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var users []APIKeyRecord
|
||||
for rows.Next() {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
|
|
@ -52,7 +53,8 @@ func GetAPIKeyFromSources(flags *Flags) string {
|
|||
if readErr == nil {
|
||||
return strings.TrimSpace(string(contents))
|
||||
}
|
||||
log.Printf("Warning: Could not read API key file %s: %v", flags.APIKeyFile, readErr)
|
||||
// Use structured logging to prevent log injection
|
||||
slog.Warn("Could not read API key file", "file", flags.APIKeyFile, "error", readErr)
|
||||
}
|
||||
|
||||
// 3. Environment variable
|
||||
|
|
@ -66,7 +68,8 @@ func GetAPIKeyFromSources(flags *Flags) string {
|
|||
if err == nil {
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
log.Printf("Warning: Could not read API key file %s: %v", fileKey, err)
|
||||
// Use structured logging to prevent log injection
|
||||
slog.Warn("Could not read API key file", "file", fileKey, "error", err)
|
||||
}
|
||||
|
||||
return ""
|
||||
|
|
@ -84,7 +87,8 @@ func ValidateFlags(flags *Flags) error {
|
|||
return fmt.Errorf("api key file not found: %s", flags.APIKeyFile)
|
||||
}
|
||||
if err := CheckConfigFilePermissions(flags.APIKeyFile); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
// Use structured logging to prevent log injection
|
||||
slog.Warn("API key file permission issue", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -96,7 +100,8 @@ func ValidateFlags(flags *Flags) error {
|
|||
|
||||
// Check file permissions
|
||||
if err := CheckConfigFilePermissions(flags.ConfigFile); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
// Use structured logging to prevent log injection
|
||||
slog.Warn("Config file permission issue", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ package auth
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
)
|
||||
|
||||
// Permission constants for type safety
|
||||
|
|
@ -31,6 +33,13 @@ const (
|
|||
PermissionSystemLogs = "system:logs"
|
||||
PermissionSystemUsers = "system:users"
|
||||
|
||||
// Task sharing permissions
|
||||
PermissionTasksReadOwn = "tasks:read:own"
|
||||
PermissionTasksReadLab = "tasks:read:lab"
|
||||
PermissionTasksReadAll = "tasks:read:all"
|
||||
PermissionTasksShare = "tasks:share"
|
||||
PermissionTasksClone = "tasks:clone"
|
||||
|
||||
// Wildcard permission
|
||||
PermissionAll = "*"
|
||||
)
|
||||
|
|
@ -210,3 +219,59 @@ func (u *User) GetEffectivePermissions() []string {
|
|||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// TaskAccessChecker defines the interface for checking task access permissions.
|
||||
// Implemented by storage.DB to avoid circular imports.
|
||||
type TaskAccessChecker interface {
|
||||
TaskIsSharedWithUser(taskID, userID string) bool
|
||||
UserSharesGroupWithTask(userID, taskID string) bool
|
||||
TaskAllowsPublicClone(taskID string) bool
|
||||
UserRoleInTaskGroups(userID, taskID string) string // NEW: returns highest role (admin/member/viewer)
|
||||
}
|
||||
|
||||
// CanAccessTask checks if an authenticated user can access a specific task.
|
||||
// For unauthenticated (token-based) access, use CanAccessWithToken() instead.
|
||||
// NOTE: This method makes 1-2 DB calls per invocation. Never call it in a loop
|
||||
// over a task list — use ListTasksForUser() instead, which resolves all access
|
||||
// in a single query.
|
||||
func (u *User) CanAccessTask(task *domain.Task, db TaskAccessChecker) bool {
|
||||
if task.UserID == u.Name || task.CreatedBy == u.Name {
|
||||
return true
|
||||
}
|
||||
if u.Admin {
|
||||
return true
|
||||
}
|
||||
// Explicit share — expiry checked in SQL
|
||||
if db.TaskIsSharedWithUser(task.ID, u.Name) {
|
||||
return true
|
||||
}
|
||||
switch task.Visibility {
|
||||
case domain.VisibilityPrivate:
|
||||
return false
|
||||
case domain.VisibilityLab:
|
||||
return db.UserSharesGroupWithTask(u.Name, task.ID)
|
||||
case domain.VisibilityInstitution, domain.VisibilityOpen:
|
||||
// open tasks are accessible to authenticated users too, not only token holders
|
||||
return u.HasPermission(PermissionTasksReadLab)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CanCloneTask checks if a user can clone (reproduce) a task.
|
||||
// Clone is a distinct action from read. Viewing a task does not imply
|
||||
// permission to reproduce it under a different account.
|
||||
func (u *User) CanCloneTask(task *domain.Task, db TaskAccessChecker) bool {
|
||||
if !u.CanAccessTask(task, db) {
|
||||
return false
|
||||
}
|
||||
if task.Visibility == domain.VisibilityOpen {
|
||||
return db.TaskAllowsPublicClone(task.ID)
|
||||
}
|
||||
// NEW: viewer-role members can see but not clone
|
||||
// role == "" means access is via explicit task_shares (not group) - allow clone
|
||||
role := db.UserRoleInTaskGroups(u.Name, task.ID)
|
||||
if role == "viewer" {
|
||||
return false
|
||||
}
|
||||
return u.HasPermission(PermissionTasksClone)
|
||||
}
|
||||
|
|
|
|||
64
internal/auth/tokens.go
Normal file
64
internal/auth/tokens.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/storage"
|
||||
)
|
||||
|
||||
// ShareTokenOptions provides options for creating share tokens
|
||||
type ShareTokenOptions struct {
|
||||
ExpiresAt *time.Time // nil = never expires
|
||||
MaxAccesses *int // nil = unlimited
|
||||
}
|
||||
|
||||
// GenerateShareToken creates a new cryptographically random share token.
|
||||
// Returns the token string that can be shared with others.
|
||||
func GenerateShareToken(db *storage.DB, taskID, experimentID *string, createdBy string, opts ShareTokenOptions) (string, error) {
|
||||
// Generate 32 bytes of random data
|
||||
raw := make([]byte, 32)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as base64url (URL-safe, no padding)
|
||||
token := base64.RawURLEncoding.EncodeToString(raw)
|
||||
|
||||
// Store in database
|
||||
if err := db.CreateShareToken(token, taskID, experimentID, createdBy, opts.ExpiresAt, opts.MaxAccesses); err != nil {
|
||||
return "", fmt.Errorf("failed to store share token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// CanAccessWithToken checks if a token grants access to a task or experiment.
|
||||
// This is called for unauthenticated access via signed share links.
|
||||
// Returns true if the token is valid, not expired, and within access limits.
|
||||
func CanAccessWithToken(db *storage.DB, token string, taskID, experimentID *string) bool {
|
||||
t, err := db.ValidateShareToken(token, taskID, experimentID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Increment access count
|
||||
if err := db.IncrementTokenAccessCount(token); err != nil {
|
||||
// Log error but don't fail the request - access was already validated
|
||||
// This is a best-effort operation for tracking purposes
|
||||
fmt.Printf("WARNING: failed to increment token access count: %v\n", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// BuildShareLink constructs a shareable URL from a token.
|
||||
// The path should be the base API path (e.g., "/api/tasks/123" or "/api/experiments/456").
|
||||
func BuildShareLink(baseURL, path, token string) string {
|
||||
return fmt.Sprintf("%s%s?token=%s", baseURL, path, token)
|
||||
}
|
||||
Loading…
Reference in a new issue