diff --git a/internal/api/adapter.go b/internal/api/adapter.go index e33976e..e78fcc6 100644 --- a/internal/api/adapter.go +++ b/internal/api/adapter.go @@ -104,11 +104,7 @@ func (a *HandlerAdapter) DeleteV1JupyterServicesServiceId(ctx echo.Context, serv }) } // TODO: Implement when StopServiceHTTP is available - return ctx.JSON(501, map[string]any{ - "error": "Not implemented", - "code": "NOT_IMPLEMENTED", - "message": "Jupyter service stop not yet implemented via REST API", - }) + return toHTTPHandler(a.jupyterHandler.StopServiceHTTP)(ctx) } // GetV1Queue returns queue status diff --git a/internal/api/factory.go b/internal/api/factory.go index e9313e9..c65c0b7 100644 --- a/internal/api/factory.go +++ b/internal/api/factory.go @@ -44,6 +44,11 @@ func (s *Server) initializeComponents() error { return err } + // Initialize default lab group (if configured) + if err := s.initDefaultLabGroup(); err != nil { + return err + } + // Initialize security s.initSecurity() @@ -185,6 +190,31 @@ func (s *Server) initDatabaseSchema() error { return nil } +// initDefaultLabGroup creates the auto-provisioned default lab group if configured. +// Reads DEFAULT_LAB_GROUP env var and creates the group if it doesn't exist. +// If DEFAULT_LAB_GROUP is not set, this is a silent no-op (intentional - groups are optional). +func (s *Server) initDefaultLabGroup() error { + if s.db == nil { + return nil + } + + groupName := os.Getenv("DEFAULT_LAB_GROUP") + if groupName == "" { + return nil // No default lab group configured + } + + groupID, err := s.db.GetOrCreateDefaultLabGroup("system") + if err != nil { + s.logger.Error("failed to initialize default lab group", "error", err) + return err + } + + if groupID != "" { + s.logger.Info("default lab group initialized", "group", groupName, "id", groupID) + } + return nil +} + // initSecurity initializes security middleware func (s *Server) initSecurity() { authConfig := s.config.BuildAuthConfig() diff --git a/internal/api/groups/handlers.go b/internal/api/groups/handlers.go new file mode 100644 index 0000000..d6a7eac --- /dev/null +++ b/internal/api/groups/handlers.go @@ -0,0 +1,692 @@ +// Package groups provides HTTP handlers for lab group management +package groups + +import ( + "encoding/binary" + "encoding/json" + "net/http" + "strconv" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// Handler provides group management HTTP handlers +type Handler struct { + db *storage.DB + logger *logging.Logger +} + +// NewHandler creates a new groups handler +func NewHandler(db *storage.DB, logger *logging.Logger) *Handler { + return &Handler{ + db: db, + logger: logger, + } +} + +// CreateGroupRequest represents a request to create a new group +type CreateGroupRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` +} + +// CreateGroup handles POST /api/groups +func (h *Handler) CreateGroup(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Only admins can create groups + if !user.Admin { + http.Error(w, "forbidden: admin required", http.StatusForbidden) + return + } + + var req CreateGroupRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.Name == "" { + http.Error(w, "name is required", http.StatusBadRequest) + return + } + + group, err := h.db.CreateGroup(req.Name, req.Description, user.Name) + if err != nil { + h.logger.Error("failed to create group", "error", err) + http.Error(w, "failed to create group", http.StatusInternalServerError) + return + } + + h.logger.Info("group created", "group_id", group.ID, "name", group.Name, "created_by", user.Name) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(group); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// ListGroups handles GET /api/groups +func (h *Handler) ListGroups(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + groups, err := h.db.ListGroupsForUser(user.Name) + if err != nil { + h.logger.Error("failed to list groups", "error", err) + http.Error(w, "failed to list groups", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "groups": groups, + "count": len(groups), + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// CreateInvitation handles POST /api/groups/{id}/invitations +func (h *Handler) CreateInvitation(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + groupID := r.PathValue("id") + if groupID == "" { + http.Error(w, "group id required", http.StatusBadRequest) + return + } + + // Check if user is group admin + isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID) + if err != nil || !isAdmin { + http.Error(w, "forbidden: group admin required", http.StatusForbidden) + return + } + + var req struct { + InvitedUserID string `json:"invited_user_id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.InvitedUserID == "" { + http.Error(w, "invited_user_id is required", http.StatusBadRequest) + return + } + + invitation, err := h.db.CreateGroupInvitation(groupID, req.InvitedUserID, user.Name) + if err != nil { + h.logger.Error("failed to create invitation", "error", err) + http.Error(w, "failed to create invitation", http.StatusInternalServerError) + return + } + + h.logger.Info("invitation created", "invitation_id", invitation.ID, "group_id", groupID, "invited_user", req.InvitedUserID) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(invitation); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// ListInvitations handles GET /api/invitations +func (h *Handler) ListInvitations(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + invitations, err := h.db.ListPendingInvitationsForUser(user.Name) + if err != nil { + h.logger.Error("failed to list invitations", "error", err) + http.Error(w, "failed to list invitations", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "invitations": invitations, + "count": len(invitations), + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// AcceptInvitation handles POST /api/invitations/{id}/accept +func (h *Handler) AcceptInvitation(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + invitationID := r.PathValue("id") + if invitationID == "" { + http.Error(w, "invitation id required", http.StatusBadRequest) + return + } + + // Verify invitation belongs to this user and is pending + invitation, err := h.db.GetInvitation(invitationID) + if err != nil { + http.Error(w, "invitation not found", http.StatusNotFound) + return + } + + if invitation.InvitedUserID != user.Name { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if invitation.Status != "pending" { + http.Error(w, "invitation already processed", http.StatusConflict) + return + } + + // Check if expired (7 days default) + if invitation.ExpiresAt != nil && time.Now().After(*invitation.ExpiresAt) { + http.Error(w, "invitation expired", http.StatusGone) + return + } + + if err := h.db.AcceptInvitation(invitationID, user.Name); err != nil { + h.logger.Error("failed to accept invitation", "error", err) + http.Error(w, "failed to accept invitation", http.StatusInternalServerError) + return + } + + h.logger.Info("invitation accepted", "invitation_id", invitationID, "user", user.Name, "group_id", invitation.GroupID) + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]string{ + "status": "accepted", + "group_id": invitation.GroupID, + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// DeclineInvitation handles POST /api/invitations/{id}/decline +func (h *Handler) DeclineInvitation(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + invitationID := r.PathValue("id") + if invitationID == "" { + http.Error(w, "invitation id required", http.StatusBadRequest) + return + } + + invitation, err := h.db.GetInvitation(invitationID) + if err != nil { + http.Error(w, "invitation not found", http.StatusNotFound) + return + } + + if invitation.InvitedUserID != user.Name { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + if err := h.db.DeclineInvitation(invitationID, user.Name); err != nil { + h.logger.Error("failed to decline invitation", "error", err) + http.Error(w, "failed to decline invitation", http.StatusInternalServerError) + return + } + + h.logger.Info("invitation declined", "invitation_id", invitationID, "user", user.Name) + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]string{ + "status": "declined", + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// RemoveMember handles DELETE /api/groups/{id}/members/{user} +func (h *Handler) RemoveMember(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + groupID := r.PathValue("id") + memberID := r.PathValue("user") + if groupID == "" || memberID == "" { + http.Error(w, "group id and user id required", http.StatusBadRequest) + return + } + + // Check if user is group admin + isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID) + if err != nil || !isAdmin { + http.Error(w, "forbidden: group admin required", http.StatusForbidden) + return + } + + // Cannot remove yourself (use leave group endpoint instead) + if memberID == user.Name { + http.Error(w, "cannot remove yourself; use leave group endpoint", http.StatusBadRequest) + return + } + + if err := h.db.RemoveGroupMember(groupID, memberID); err != nil { + h.logger.Error("failed to remove member", "error", err) + http.Error(w, "failed to remove member", http.StatusInternalServerError) + return + } + + h.logger.Info("member removed", "group_id", groupID, "member", memberID, "removed_by", user.Name) + + w.WriteHeader(http.StatusNoContent) +} + +// ListGroupTasks handles GET /api/groups/{id}/tasks +func (h *Handler) ListGroupTasks(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + groupID := r.PathValue("id") + if groupID == "" { + http.Error(w, "group id required", http.StatusBadRequest) + return + } + + // Check if user is member of group + isMember, err := h.db.IsGroupMember(user.Name, groupID) + if err != nil || !isMember { + http.Error(w, "forbidden: group membership required", http.StatusForbidden) + return + } + + // Parse pagination options + limit := 100 + if l := r.URL.Query().Get("limit"); l != "" { + if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 100 { + limit = parsed + } + } + cursor := r.URL.Query().Get("cursor") + + opts := storage.ListTasksOptions{ + Limit: limit, + Cursor: cursor, + } + + tasks, nextCursor, err := h.db.ListTasksForGroup(groupID, opts) + if err != nil { + h.logger.Error("failed to list group tasks", "error", err) + http.Error(w, "failed to list tasks", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "tasks": tasks, + "next_cursor": nextCursor, + "count": len(tasks), + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// sendErrorPacket sends an error response packet to the client +func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { + err := map[string]interface{}{ + "error": true, + "code": code, + "message": message, + "details": details, + } + return conn.WriteJSON(err) +} + +// sendSuccessPacket sends a success response packet +func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error { + return conn.WriteJSON(data) +} + +// HandleCreateGroup handles WebSocket group creation. +// Protocol: [api_key_hash:16][name_len:1][name:var][desc_len:2][desc:var] +func (h *Handler) HandleCreateGroup(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+1+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + nameLen := int(payload[offset]) + offset++ + if nameLen <= 0 || len(payload) < offset+nameLen+2 { + return h.sendErrorPacket(conn, 0x01, "invalid name length", "") + } + name := string(payload[offset : offset+nameLen]) + offset += nameLen + + descLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if descLen < 0 || len(payload) < offset+descLen { + return h.sendErrorPacket(conn, 0x01, "invalid description length", "") + } + description := string(payload[offset : offset+descLen]) + + // Only admins can create groups + if !user.Admin { + return h.sendErrorPacket(conn, 0x03, "forbidden: admin required", "") + } + + if name == "" { + return h.sendErrorPacket(conn, 0x01, "name is required", "") + } + + group, err := h.db.CreateGroup(name, description, user.Name) + if err != nil { + h.logger.Error("failed to create group", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to create group", err.Error()) + } + + h.logger.Info("group created", "group_id", group.ID, "name", group.Name, "created_by", user.Name) + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "group": group, + }) +} + +// HandleListGroups handles WebSocket group listing. +// Protocol: [api_key_hash:16] +func (h *Handler) HandleListGroups(conn *websocket.Conn, payload []byte, user *auth.User) error { + groups, err := h.db.ListGroupsForUser(user.Name) + if err != nil { + h.logger.Error("failed to list groups", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to list groups", err.Error()) + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "groups": groups, + "count": len(groups), + }) +} + +// HandleCreateInvitation handles WebSocket invitation creation. +// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][user_id_len:2][user_id:var] +func (h *Handler) HandleCreateInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+2+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if groupIDLen <= 0 || len(payload) < offset+groupIDLen+2 { + return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "") + } + groupID := string(payload[offset : offset+groupIDLen]) + offset += groupIDLen + + userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if userIDLen <= 0 || len(payload) < offset+userIDLen { + return h.sendErrorPacket(conn, 0x01, "invalid user_id length", "") + } + invitedUserID := string(payload[offset : offset+userIDLen]) + + // Check if user is group admin + isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID) + if err != nil || !isAdmin { + return h.sendErrorPacket(conn, 0x03, "forbidden: group admin required", "") + } + + invitation, err := h.db.CreateGroupInvitation(groupID, invitedUserID, user.Name) + if err != nil { + h.logger.Error("failed to create invitation", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to create invitation", err.Error()) + } + + h.logger.Info("invitation created", "invitation_id", invitation.ID, "group_id", groupID, "invited_user", invitedUserID) + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "invitation": invitation, + }) +} + +// HandleListInvitations handles WebSocket invitation listing. +// Protocol: [api_key_hash:16] +func (h *Handler) HandleListInvitations(conn *websocket.Conn, payload []byte, user *auth.User) error { + invitations, err := h.db.ListPendingInvitationsForUser(user.Name) + if err != nil { + h.logger.Error("failed to list invitations", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to list invitations", err.Error()) + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "invitations": invitations, + "count": len(invitations), + }) +} + +// HandleAcceptInvitation handles WebSocket invitation acceptance. +// Protocol: [api_key_hash:16][invitation_id_len:2][invitation_id:var] +func (h *Handler) HandleAcceptInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + invIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if invIDLen <= 0 || len(payload) < offset+invIDLen { + return h.sendErrorPacket(conn, 0x01, "invalid invitation_id length", "") + } + invitationID := string(payload[offset : offset+invIDLen]) + + // Verify invitation belongs to this user and is pending + invitation, err := h.db.GetInvitation(invitationID) + if err != nil { + return h.sendErrorPacket(conn, 0x04, "invitation not found", "") + } + + if invitation.InvitedUserID != user.Name { + return h.sendErrorPacket(conn, 0x03, "forbidden", "") + } + + if invitation.Status != "pending" { + return h.sendErrorPacket(conn, 0x05, "invitation already processed", "") + } + + // Check if expired (7 days default) + if invitation.ExpiresAt != nil && time.Now().After(*invitation.ExpiresAt) { + return h.sendErrorPacket(conn, 0x14, "invitation expired", "") + } + + if err := h.db.AcceptInvitation(invitationID, user.Name); err != nil { + h.logger.Error("failed to accept invitation", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to accept invitation", err.Error()) + } + + h.logger.Info("invitation accepted", "invitation_id", invitationID, "user", user.Name, "group_id", invitation.GroupID) + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "status": "accepted", + "group_id": invitation.GroupID, + }) +} + +// HandleDeclineInvitation handles WebSocket invitation decline. +// Protocol: [api_key_hash:16][invitation_id_len:2][invitation_id:var] +func (h *Handler) HandleDeclineInvitation(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + invIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if invIDLen <= 0 || len(payload) < offset+invIDLen { + return h.sendErrorPacket(conn, 0x01, "invalid invitation_id length", "") + } + invitationID := string(payload[offset : offset+invIDLen]) + + invitation, err := h.db.GetInvitation(invitationID) + if err != nil { + return h.sendErrorPacket(conn, 0x04, "invitation not found", "") + } + + if invitation.InvitedUserID != user.Name { + return h.sendErrorPacket(conn, 0x03, "forbidden", "") + } + + if err := h.db.DeclineInvitation(invitationID, user.Name); err != nil { + h.logger.Error("failed to decline invitation", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to decline invitation", err.Error()) + } + + h.logger.Info("invitation declined", "invitation_id", invitationID, "user", user.Name) + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "status": "declined", + }) +} + +// HandleRemoveMember handles WebSocket member removal. +// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][user_id_len:2][user_id:var] +func (h *Handler) HandleRemoveMember(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+2+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if groupIDLen <= 0 || len(payload) < offset+groupIDLen+2 { + return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "") + } + groupID := string(payload[offset : offset+groupIDLen]) + offset += groupIDLen + + userIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if userIDLen <= 0 || len(payload) < offset+userIDLen { + return h.sendErrorPacket(conn, 0x01, "invalid user_id length", "") + } + memberID := string(payload[offset : offset+userIDLen]) + + // Check if user is group admin + isAdmin, err := h.db.IsGroupAdmin(user.Name, groupID) + if err != nil || !isAdmin { + return h.sendErrorPacket(conn, 0x03, "forbidden: group admin required", "") + } + + // Cannot remove yourself + if memberID == user.Name { + return h.sendErrorPacket(conn, 0x01, "cannot remove yourself; use leave group endpoint", "") + } + + if err := h.db.RemoveGroupMember(groupID, memberID); err != nil { + h.logger.Error("failed to remove member", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to remove member", err.Error()) + } + + h.logger.Info("member removed", "group_id", groupID, "member", memberID, "removed_by", user.Name) + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "message": "Member removed", + }) +} + +// HandleListGroupTasks handles WebSocket group task listing. +// Protocol: [api_key_hash:16][group_id_len:2][group_id:var][limit:1][cursor_len:2][cursor:var] +func (h *Handler) HandleListGroupTasks(conn *websocket.Conn, payload []byte, user *auth.User) error { + if len(payload) < 16+2+1+2 { + return h.sendErrorPacket(conn, 0x01, "payload too short", "") + } + + offset := 16 + + groupIDLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + if groupIDLen <= 0 || len(payload) < offset+groupIDLen+1+2 { + return h.sendErrorPacket(conn, 0x01, "invalid group_id length", "") + } + groupID := string(payload[offset : offset+groupIDLen]) + offset += groupIDLen + + // Check if user is member of group + isMember, err := h.db.IsGroupMember(user.Name, groupID) + if err != nil || !isMember { + return h.sendErrorPacket(conn, 0x03, "forbidden: group membership required", "") + } + + limit := int(payload[offset]) + offset++ + if limit <= 0 || limit > 100 { + limit = 100 + } + + cursorLen := int(binary.BigEndian.Uint16(payload[offset : offset+2])) + offset += 2 + var cursor string + if cursorLen > 0 { + if len(payload) < offset+cursorLen { + return h.sendErrorPacket(conn, 0x01, "invalid cursor length", "") + } + cursor = string(payload[offset : offset+cursorLen]) + } + + opts := storage.ListTasksOptions{ + Limit: limit, + Cursor: cursor, + } + + tasks, nextCursor, err := h.db.ListTasksForGroup(groupID, opts) + if err != nil { + h.logger.Error("failed to list group tasks", "error", err) + return h.sendErrorPacket(conn, 0x11, "failed to list tasks", err.Error()) + } + + return h.sendSuccessPacket(conn, map[string]interface{}{ + "success": true, + "tasks": tasks, + "next_cursor": nextCursor, + "count": len(tasks), + }) +} diff --git a/internal/api/jobs/handlers.go b/internal/api/jobs/handlers.go index fdac92e..bcd564d 100644 --- a/internal/api/jobs/handlers.go +++ b/internal/api/jobs/handlers.go @@ -2,7 +2,6 @@ package jobs import ( - "context" "encoding/binary" "encoding/json" "net/http" @@ -22,14 +21,11 @@ import ( // Handler provides job-related WebSocket handlers type Handler struct { - expManager *experiment.Manager - logger *logging.Logger - queue queue.Backend - db *storage.DB - authConfig *auth.Config - privacyEnforcer interface { // NEW: Privacy enforcement interface - CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error) - } + expManager *experiment.Manager + logger *logging.Logger + queue queue.Backend + db *storage.DB + authConfig *auth.Config } // NewHandler creates a new jobs handler @@ -39,17 +35,13 @@ func NewHandler( queue queue.Backend, db *storage.DB, authConfig *auth.Config, - privacyEnforcer interface { // NEW - can be nil - CanAccess(ctx context.Context, user *auth.User, owner string, level string, team string) (bool, error) - }, ) *Handler { return &Handler{ - expManager: expManager, - logger: logger, - queue: queue, - db: db, - authConfig: authConfig, - privacyEnforcer: privacyEnforcer, + expManager: expManager, + logger: logger, + queue: queue, + db: db, + authConfig: authConfig, } } @@ -75,7 +67,7 @@ const ( // sendErrorPacket sends an error response packet to the client func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, details string) error { - err := map[string]interface{}{ + err := map[string]any{ "error": true, "code": code, "message": message, @@ -85,7 +77,7 @@ func (h *Handler) sendErrorPacket(conn *websocket.Conn, code byte, message, deta } // sendSuccessPacket sends a success response packet -func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]interface{}) error { +func (h *Handler) sendSuccessPacket(conn *websocket.Conn, data map[string]any) error { return conn.WriteJSON(data) } @@ -340,7 +332,7 @@ func (h *Handler) HandleListJobs(conn *websocket.Conn, user *auth.User) error { jobPaths := storage.NewJobPaths(base) - jobs := []map[string]interface{}{} + jobs := []map[string]any{} // Scan all job directories for _, bucket := range []string{"running", "pending", "finished", "failed"} { @@ -411,9 +403,9 @@ func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Reques h.logger.Info("getting experiment history", "experiment", experimentID, "all_users", allUsers) // Placeholder response - response := map[string]interface{}{ + response := map[string]any{ "experiment_id": experimentID, - "history": []map[string]interface{}{ + "history": []map[string]any{ { "timestamp": time.Now().UTC(), "event": "run_started", @@ -440,7 +432,9 @@ func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Reques } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode compare response", "error", err) + } } // ListAllJobsHTTP handles GET /api/jobs?all_users=true for team collaboration @@ -455,7 +449,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) { } jobPaths := storage.NewJobPaths(base) - jobs := []map[string]interface{}{} + jobs := []map[string]any{} // Scan all job directories for _, bucket := range []string{"running", "pending", "finished", "failed"} { @@ -500,7 +494,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) { } } - response := map[string]interface{}{ + response := map[string]any{ "success": true, "jobs": jobs, "count": len(jobs), @@ -508,5 +502,7 @@ func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode jobs list", "error", err) + } } diff --git a/internal/api/jupyter/handlers.go b/internal/api/jupyter/handlers.go index a179d2b..713181c 100644 --- a/internal/api/jupyter/handlers.go +++ b/internal/api/jupyter/handlers.go @@ -254,3 +254,7 @@ func (h *Handler) ListServicesHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) StartServiceHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not implemented", http.StatusNotImplemented) } + +func (h *Handler) StopServiceHTTP(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Not mplementated", http.StatusNotImplemented) +} diff --git a/internal/api/routes.go b/internal/api/routes.go index 82f852f..15ea872 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -6,6 +6,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/api/audit" "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/groups" "github.com/jfraeys/fetch_ml/internal/api/jobs" "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/plugins" @@ -46,7 +47,6 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { s.taskQueue, s.db, s.config.BuildAuthConfig(), - nil, ) // Experiment history endpoint: GET /api/experiments/:id/history @@ -173,7 +173,6 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) { s.taskQueue, s.db, s.config.BuildAuthConfig(), - nil, // privacyEnforcer - not enabled for now ) // Create jupyter handler @@ -190,6 +189,9 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) { s.config.DataDir, ) + // Create groups handler + groupsHandler := groups.NewHandler(s.db, s.logger) + wsHandler := ws.NewHandler( s.config.BuildAuthConfig(), s.logger, @@ -203,6 +205,7 @@ func (s *Server) registerWebSocketRoutes(mux *http.ServeMux) { jobsHandler, jupyterHandler, datasetsHandler, + groupsHandler, ) mux.Handle("/ws", wsHandler) diff --git a/internal/api/scheduler/handlers.go b/internal/api/scheduler/handlers.go index e6eb7f9..ce7bbf5 100644 --- a/internal/api/scheduler/handlers.go +++ b/internal/api/scheduler/handlers.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "strconv" "time" "github.com/jfraeys/fetch_ml/internal/auth" @@ -124,7 +123,9 @@ func (h *APIHandler) GetV1SchedulerStatus(w http.ResponseWriter, r *http.Request } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) + if err := json.NewEncoder(w).Encode(status); err != nil { + h.logger.Warn("failed to encode status", "error", err) + } } // GetV1SchedulerWorkers handles GET /v1/scheduler/workers @@ -159,7 +160,9 @@ func (h *APIHandler) GetV1SchedulerWorkers(w http.ResponseWriter, r *http.Reques } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(workers) + if err := json.NewEncoder(w).Encode(workers); err != nil { + h.logger.Warn("failed to encode workers", "error", err) + } } // GetV1SchedulerWorkersWorkerID handles GET /v1/scheduler/workers/{workerId} @@ -200,7 +203,9 @@ func (h *APIHandler) GetV1SchedulerWorkersWorkerID(w http.ResponseWriter, r *htt } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(worker) + if err := json.NewEncoder(w).Encode(worker); err != nil { + h.logger.Warn("failed to encode worker", "error", err) + } } // DeleteV1SchedulerWorkersWorkerID handles DELETE /v1/scheduler/workers/{workerId} @@ -247,7 +252,9 @@ func (h *APIHandler) GetV1SchedulerReservations(w http.ResponseWriter, r *http.R reservations := []ReservationInfo{} w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(reservations) + if err := json.NewEncoder(w).Encode(reservations); err != nil { + h.logger.Warn("failed to encode reservations", "error", err) + } } // PostV1SchedulerReservations handles POST /v1/scheduler/reservations @@ -295,7 +302,9 @@ func (h *APIHandler) PostV1SchedulerReservations(w http.ResponseWriter, r *http. w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(reservation) + if err := json.NewEncoder(w).Encode(reservation); err != nil { + h.logger.Warn("failed to encode reservation", "error", err) + } } // PatchV1SchedulerJobsJobIDPriority handles PATCH /v1/scheduler/jobs/{jobId}/priority @@ -342,7 +351,9 @@ func (h *APIHandler) PatchV1SchedulerJobsJobIDPriority(w http.ResponseWriter, r } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Warn("failed to encode priority response", "error", err) + } } // GetV1SchedulerStatusStream handles GET /v1/scheduler/status/stream (SSE) @@ -465,16 +476,3 @@ func (h *APIHandler) checkPermission(user *auth.User, permission string) bool { // Check specific permission return user.HasPermission(permission) } - -// parseIntQueryParam parses an integer query parameter -func parseIntQueryParam(r *http.Request, name string, defaultVal int) int { - str := r.URL.Query().Get(name) - if str == "" { - return defaultVal - } - val, err := strconv.Atoi(str) - if err != nil { - return defaultVal - } - return val -} diff --git a/internal/api/tokens/handlers.go b/internal/api/tokens/handlers.go new file mode 100644 index 0000000..0996b77 --- /dev/null +++ b/internal/api/tokens/handlers.go @@ -0,0 +1,197 @@ +// Package tokens provides HTTP handlers for share token management +package tokens + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/jfraeys/fetch_ml/internal/auth" + "github.com/jfraeys/fetch_ml/internal/logging" + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// Handler provides share token management HTTP handlers +type Handler struct { + db *storage.DB + logger *logging.Logger +} + +// NewHandler creates a new share tokens handler +func NewHandler(db *storage.DB, logger *logging.Logger) *Handler { + return &Handler{ + db: db, + logger: logger, + } +} + +// CreateTokenRequest represents a request to create a new share token +type CreateTokenRequest struct { + TaskID *string `json:"task_id,omitempty"` + ExperimentID *string `json:"experiment_id,omitempty"` + ExpiresIn *int `json:"expires_in_days,omitempty"` // Number of days until expiry + MaxAccesses *int `json:"max_accesses,omitempty"` // Max number of accesses allowed +} + +// CreateTokenResponse represents the response with the created token +type CreateTokenResponse struct { + Token string `json:"token"` + ShareLink string `json:"share_link"` +} + +// CreateShareToken handles POST /api/tokens +// Creates a new share token for a task or experiment +func (h *Handler) CreateShareToken(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check if user has permission to share tasks + if !user.HasPermission(auth.PermissionTasksShare) { + http.Error(w, "forbidden: insufficient permissions to share tasks", http.StatusForbidden) + return + } + + var req CreateTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + // Must specify either task_id or experiment_id, not both + if (req.TaskID == nil && req.ExperimentID == nil) || (req.TaskID != nil && req.ExperimentID != nil) { + http.Error(w, "must specify exactly one of task_id or experiment_id", http.StatusBadRequest) + return + } + + // Build options + opts := auth.ShareTokenOptions{} + if req.ExpiresIn != nil { + expiresAt := time.Now().Add(time.Duration(*req.ExpiresIn) * 24 * time.Hour) + opts.ExpiresAt = &expiresAt + } + if req.MaxAccesses != nil { + opts.MaxAccesses = req.MaxAccesses + } + + // Generate token + token, err := auth.GenerateShareToken(h.db, req.TaskID, req.ExperimentID, user.Name, opts) + if err != nil { + h.logger.Error("failed to generate share token", "error", err) + http.Error(w, "failed to create share token", http.StatusInternalServerError) + return + } + + // Build share link + var path string + if req.TaskID != nil { + path = "/api/tasks/" + *req.TaskID + } else { + path = "/api/experiments/" + *req.ExperimentID + } + shareLink := auth.BuildShareLink("", path, token) + + h.logger.Info("share token created", + "token", token[:8]+"...", + "created_by", user.Name, + "task_id", req.TaskID, + "experiment_id", req.ExperimentID, + ) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(CreateTokenResponse{ + Token: token, + ShareLink: shareLink, + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// ListShareTokens handles GET /api/tokens +// Lists share tokens for a task or experiment +func (h *Handler) ListShareTokens(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + taskID := r.URL.Query().Get("task_id") + experimentID := r.URL.Query().Get("experiment_id") + + // Must specify either task_id or experiment_id + if taskID == "" && experimentID == "" { + http.Error(w, "must specify task_id or experiment_id", http.StatusBadRequest) + return + } + + var tokens []*storage.ShareToken + var err error + + if taskID != "" { + tokens, err = h.db.ListShareTokensForTask(taskID) + } else { + tokens, err = h.db.ListShareTokensForExperiment(experimentID) + } + + if err != nil { + h.logger.Error("failed to list share tokens", "error", err) + http.Error(w, "failed to list share tokens", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "tokens": tokens, + "count": len(tokens), + }); err != nil { + h.logger.Error("failed to encode response", "error", err) + } +} + +// RevokeShareToken handles DELETE /api/tokens/:token +// Revokes a share token +func (h *Handler) RevokeShareToken(w http.ResponseWriter, r *http.Request) { + user := auth.GetUserFromContext(r.Context()) + if user == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + token := r.PathValue("token") + if token == "" { + http.Error(w, "token required", http.StatusBadRequest) + return + } + + // Get token details to check ownership + t, err := h.db.GetShareToken(token) + if err != nil { + h.logger.Error("failed to get share token", "error", err) + http.Error(w, "token not found", http.StatusNotFound) + return + } + if t == nil { + http.Error(w, "token not found", http.StatusNotFound) + return + } + + // Only the creator or an admin can revoke + if t.CreatedBy != user.Name && !user.Admin { + http.Error(w, "forbidden: not token owner or admin", http.StatusForbidden) + return + } + + if err := h.db.DeleteShareToken(token); err != nil { + h.logger.Error("failed to delete share token", "error", err) + http.Error(w, "failed to revoke token", http.StatusInternalServerError) + return + } + + h.logger.Info("share token revoked", "token", token[:8]+"...", "revoked_by", user.Name) + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index 3de15da..2937578 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -27,6 +27,7 @@ import ( "github.com/jfraeys/fetch_ml/internal/storage" "github.com/jfraeys/fetch_ml/internal/api/datasets" + "github.com/jfraeys/fetch_ml/internal/api/groups" "github.com/jfraeys/fetch_ml/internal/api/jobs" jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" ) @@ -75,6 +76,16 @@ const ( OpcodeQueryJob = 0x23 OpcodeQueryRunInfo = 0x28 + // Group management opcodes + OpcodeCreateGroup = 0x50 + OpcodeListGroups = 0x51 + OpcodeCreateInvitation = 0x52 + OpcodeListInvitations = 0x53 + OpcodeAcceptInvitation = 0x54 + OpcodeDeclineInvitation = 0x55 + OpcodeRemoveMember = 0x56 + OpcodeListGroupTasks = 0x57 + // OpcodeCompareRuns = 0x30 OpcodeFindRuns = 0x31 @@ -147,6 +158,7 @@ type Handler struct { authConfig *auth.Config jobsHandler *jobs.Handler jupyterHandler *jupyterj.Handler + groupsHandler *groups.Handler // NEW: groups handler for WebSocket upgrader websocket.Upgrader dataDir string clientsMu sync.RWMutex @@ -166,6 +178,7 @@ func NewHandler( jobsHandler *jobs.Handler, jupyterHandler *jupyterj.Handler, datasetsHandler *datasets.Handler, + groupsHandler *groups.Handler, ) *Handler { upgrader := createUpgrader(securityCfg) @@ -183,6 +196,7 @@ func NewHandler( jobsHandler: jobsHandler, jupyterHandler: jupyterHandler, datasetsHandler: datasetsHandler, + groupsHandler: groupsHandler, clients: make(map[*Client]bool), } } @@ -339,6 +353,22 @@ func (h *Handler) handleMessage(conn *websocket.Conn, payload []byte) error { return h.handleSetRunOutcome(conn, payload) case OpcodeQueryRunInfo: return h.handleQueryRunInfo(conn, payload) + case OpcodeCreateGroup: + return h.handleCreateGroup(conn, payload) + case OpcodeListGroups: + return h.handleListGroups(conn, payload) + case OpcodeCreateInvitation: + return h.handleCreateInvitation(conn, payload) + case OpcodeListInvitations: + return h.handleListInvitations(conn, payload) + case OpcodeAcceptInvitation: + return h.handleAcceptInvitation(conn, payload) + case OpcodeDeclineInvitation: + return h.handleDeclineInvitation(conn, payload) + case OpcodeRemoveMember: + return h.handleRemoveMember(conn, payload) + case OpcodeListGroupTasks: + return h.handleListGroupTasks(conn, payload) default: return h.sendErrorPacket(conn, ErrorCodeInvalidRequest, "unknown opcode", string(opcode)) } @@ -429,6 +459,54 @@ func (h *Handler) handleListJupyter(conn *websocket.Conn, payload []byte) error }) } +func (h *Handler) handleCreateGroup(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleCreateGroup(conn, payload, user) + }) +} + +func (h *Handler) handleListGroups(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleListGroups(conn, payload, user) + }) +} + +func (h *Handler) handleCreateInvitation(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleCreateInvitation(conn, payload, user) + }) +} + +func (h *Handler) handleListInvitations(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleListInvitations(conn, payload, user) + }) +} + +func (h *Handler) handleAcceptInvitation(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleAcceptInvitation(conn, payload, user) + }) +} + +func (h *Handler) handleDeclineInvitation(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleDeclineInvitation(conn, payload, user) + }) +} + +func (h *Handler) handleRemoveMember(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleRemoveMember(conn, payload, user) + }) +} + +func (h *Handler) handleListGroupTasks(conn *websocket.Conn, payload []byte) error { + return h.withAuth(conn, payload, func(user *auth.User) error { + return h.groupsHandler.HandleListGroupTasks(conn, payload, user) + }) +} + func (h *Handler) handleLogMetric(conn *websocket.Conn, payload []byte) error { // Parse payload: [api_key_hash:16][metric_name_len:1][metric_name:var][value:8] if len(payload) < 16+1+8 { @@ -590,9 +668,15 @@ func selectDependencyManifest(filesPath string) (string, error) { for _, name := range []string{ "requirements.txt", "package.json", "Cargo.toml", "go.mod", "pom.xml", "build.gradle", } { - if _, err := os.Stat(filepath.Join(filesPath, name)); err == nil { + //nolint:gosec // G703: name is from hardcoded slice, not user input + fullPath := filepath.Join(filesPath, name) + _, err := os.Stat(fullPath) + if err == nil { return name, nil } + if !os.IsNotExist(err) { + return "", fmt.Errorf("error checking %s: %w", name, err) + } } return "", fmt.Errorf("no dependency manifest found") } diff --git a/internal/auth/api_key.go b/internal/auth/api_key.go index 9d8afbf..1eadfc4 100644 --- a/internal/auth/api_key.go +++ b/internal/auth/api_key.go @@ -19,6 +19,7 @@ type User struct { Name string `json:"name"` Roles []string `json:"roles"` Admin bool `json:"admin"` + Groups []string `json:"groups"` // NEW: lab groups ("ml-lab", "nlp-group") } // ExtractAPIKeyFromRequest extracts an API key from the standard headers. diff --git a/internal/auth/database.go b/internal/auth/database.go index 69d305a..1c36342 100644 --- a/internal/auth/database.go +++ b/internal/auth/database.go @@ -194,7 +194,11 @@ func (s *DatabaseAuthStore) ListUsers(ctx context.Context) ([]APIKeyRecord, erro if err != nil { return nil, fmt.Errorf("failed to query users: %w", err) } - defer func() { _ = rows.Close() }() + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() var users []APIKeyRecord for rows.Next() { diff --git a/internal/auth/flags.go b/internal/auth/flags.go index 277906c..88dce52 100644 --- a/internal/auth/flags.go +++ b/internal/auth/flags.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "log/slog" "os" "strings" @@ -52,7 +53,8 @@ func GetAPIKeyFromSources(flags *Flags) string { if readErr == nil { return strings.TrimSpace(string(contents)) } - log.Printf("Warning: Could not read API key file %s: %v", flags.APIKeyFile, readErr) + // Use structured logging to prevent log injection + slog.Warn("Could not read API key file", "file", flags.APIKeyFile, "error", readErr) } // 3. Environment variable @@ -66,7 +68,8 @@ func GetAPIKeyFromSources(flags *Flags) string { if err == nil { return strings.TrimSpace(string(content)) } - log.Printf("Warning: Could not read API key file %s: %v", fileKey, err) + // Use structured logging to prevent log injection + slog.Warn("Could not read API key file", "file", fileKey, "error", err) } return "" @@ -84,7 +87,8 @@ func ValidateFlags(flags *Flags) error { return fmt.Errorf("api key file not found: %s", flags.APIKeyFile) } if err := CheckConfigFilePermissions(flags.APIKeyFile); err != nil { - log.Printf("Warning: %v", err) + // Use structured logging to prevent log injection + slog.Warn("API key file permission issue", "error", err) } } @@ -96,7 +100,8 @@ func ValidateFlags(flags *Flags) error { // Check file permissions if err := CheckConfigFilePermissions(flags.ConfigFile); err != nil { - log.Printf("Warning: %v", err) + // Use structured logging to prevent log injection + slog.Warn("Config file permission issue", "error", err) } } diff --git a/internal/auth/permissions.go b/internal/auth/permissions.go index 63149b3..ac6a33c 100644 --- a/internal/auth/permissions.go +++ b/internal/auth/permissions.go @@ -3,6 +3,8 @@ package auth import ( "fmt" "strings" + + "github.com/jfraeys/fetch_ml/internal/domain" ) // Permission constants for type safety @@ -31,6 +33,13 @@ const ( PermissionSystemLogs = "system:logs" PermissionSystemUsers = "system:users" + // Task sharing permissions + PermissionTasksReadOwn = "tasks:read:own" + PermissionTasksReadLab = "tasks:read:lab" + PermissionTasksReadAll = "tasks:read:all" + PermissionTasksShare = "tasks:share" + PermissionTasksClone = "tasks:clone" + // Wildcard permission PermissionAll = "*" ) @@ -210,3 +219,59 @@ func (u *User) GetEffectivePermissions() []string { return permissions } + +// TaskAccessChecker defines the interface for checking task access permissions. +// Implemented by storage.DB to avoid circular imports. +type TaskAccessChecker interface { + TaskIsSharedWithUser(taskID, userID string) bool + UserSharesGroupWithTask(userID, taskID string) bool + TaskAllowsPublicClone(taskID string) bool + UserRoleInTaskGroups(userID, taskID string) string // NEW: returns highest role (admin/member/viewer) +} + +// CanAccessTask checks if an authenticated user can access a specific task. +// For unauthenticated (token-based) access, use CanAccessWithToken() instead. +// NOTE: This method makes 1-2 DB calls per invocation. Never call it in a loop +// over a task list — use ListTasksForUser() instead, which resolves all access +// in a single query. +func (u *User) CanAccessTask(task *domain.Task, db TaskAccessChecker) bool { + if task.UserID == u.Name || task.CreatedBy == u.Name { + return true + } + if u.Admin { + return true + } + // Explicit share — expiry checked in SQL + if db.TaskIsSharedWithUser(task.ID, u.Name) { + return true + } + switch task.Visibility { + case domain.VisibilityPrivate: + return false + case domain.VisibilityLab: + return db.UserSharesGroupWithTask(u.Name, task.ID) + case domain.VisibilityInstitution, domain.VisibilityOpen: + // open tasks are accessible to authenticated users too, not only token holders + return u.HasPermission(PermissionTasksReadLab) + } + return false +} + +// CanCloneTask checks if a user can clone (reproduce) a task. +// Clone is a distinct action from read. Viewing a task does not imply +// permission to reproduce it under a different account. +func (u *User) CanCloneTask(task *domain.Task, db TaskAccessChecker) bool { + if !u.CanAccessTask(task, db) { + return false + } + if task.Visibility == domain.VisibilityOpen { + return db.TaskAllowsPublicClone(task.ID) + } + // NEW: viewer-role members can see but not clone + // role == "" means access is via explicit task_shares (not group) - allow clone + role := db.UserRoleInTaskGroups(u.Name, task.ID) + if role == "viewer" { + return false + } + return u.HasPermission(PermissionTasksClone) +} diff --git a/internal/auth/tokens.go b/internal/auth/tokens.go new file mode 100644 index 0000000..430b0a8 --- /dev/null +++ b/internal/auth/tokens.go @@ -0,0 +1,64 @@ +package auth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/jfraeys/fetch_ml/internal/storage" +) + +// ShareTokenOptions provides options for creating share tokens +type ShareTokenOptions struct { + ExpiresAt *time.Time // nil = never expires + MaxAccesses *int // nil = unlimited +} + +// GenerateShareToken creates a new cryptographically random share token. +// Returns the token string that can be shared with others. +func GenerateShareToken(db *storage.DB, taskID, experimentID *string, createdBy string, opts ShareTokenOptions) (string, error) { + // Generate 32 bytes of random data + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("failed to generate random token: %w", err) + } + + // Encode as base64url (URL-safe, no padding) + token := base64.RawURLEncoding.EncodeToString(raw) + + // Store in database + if err := db.CreateShareToken(token, taskID, experimentID, createdBy, opts.ExpiresAt, opts.MaxAccesses); err != nil { + return "", fmt.Errorf("failed to store share token: %w", err) + } + + return token, nil +} + +// CanAccessWithToken checks if a token grants access to a task or experiment. +// This is called for unauthenticated access via signed share links. +// Returns true if the token is valid, not expired, and within access limits. +func CanAccessWithToken(db *storage.DB, token string, taskID, experimentID *string) bool { + t, err := db.ValidateShareToken(token, taskID, experimentID) + if err != nil { + return false + } + if t == nil { + return false + } + + // Increment access count + if err := db.IncrementTokenAccessCount(token); err != nil { + // Log error but don't fail the request - access was already validated + // This is a best-effort operation for tracking purposes + fmt.Printf("WARNING: failed to increment token access count: %v\n", err) + } + + return true +} + +// BuildShareLink constructs a shareable URL from a token. +// The path should be the base API path (e.g., "/api/tasks/123" or "/api/experiments/456"). +func BuildShareLink(baseURL, path, token string) string { + return fmt.Sprintf("%s%s?token=%s", baseURL, path, token) +}