fetch_ml/internal/storage/db_groups.go
Jeremie Fraeys fbcf4d38e5
feat(storage): add groups, tasks, tokens, and audit database schemas
Add comprehensive database storage layer for new features:

- db_groups.go: Lab group management with members, roles (admin/member/viewer),
  and group-based task visibility queries

- db_tasks.go: Task visibility system (private/lab/institution/open),
  task sharing with expiry, public clone tokens, and optimized
  ListTasksForUser() for access control

- db_tokens.go: Secure token management for public task access and cloning,
  with SHA-256 hashed token storage and automatic cleanup

- db_audit.go: Audit log persistence with checkpoint chains, tamper
  detection, and log rotation support

- schema_sqlite.sql: Updated schema with:
  - groups, group_members tables
  - tasks.visibility enum, task_shares with expiry
  - access_tokens table with hashed tokens
  - audit_logs, audit_checkpoints tables
  - indexes for all foreign keys and query patterns

- db_experiments.go: Add CascadeVisibilityToTasks() for propagating
  visibility changes from experiments to associated tasks
2026-03-08 12:48:42 -04:00

541 lines
17 KiB
Go

package storage
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"time"
)
// Group represents a lab group
type Group struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
CreatedAt time.Time `json:"created_at"`
CreatedBy string `json:"created_by"`
}
// GroupInvitation represents a pending invitation
type GroupInvitation struct {
ID string `json:"id"`
GroupID string `json:"group_id"`
InvitedUserID string `json:"invited_user_id"`
InvitedBy string `json:"invited_by"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
// CreateGroup creates a new lab group
func (db *DB) CreateGroup(name, description, createdBy string) (*Group, error) {
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
}
_, err := db.conn.ExecContext(context.Background(), query, id, name, description, createdBy)
if err != nil {
return nil, fmt.Errorf("failed to create group: %w", err)
}
// Add creator as admin
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
}
_, err = db.conn.ExecContext(context.Background(), query, id, createdBy)
if err != nil {
return nil, fmt.Errorf("failed to add creator as admin: %w", err)
}
return &Group{
ID: id,
Name: name,
Description: description,
CreatedBy: createdBy,
CreatedAt: time.Now(),
}, nil
}
// ListGroupsForUser returns all groups the user is a member of
func (db *DB) ListGroupsForUser(userID string) ([]Group, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT g.id, g.name, g.description, g.created_at, g.created_by
FROM groups g
JOIN group_members gm ON gm.group_id = g.id
WHERE gm.user_id = ?
ORDER BY g.created_at DESC`
} else {
query = `
SELECT g.id, g.name, g.description, g.created_at, g.created_by
FROM groups g
JOIN group_members gm ON gm.group_id = g.id
WHERE gm.user_id = $1
ORDER BY g.created_at DESC`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID)
if err != nil {
return nil, fmt.Errorf("failed to list groups: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
var groups []Group
for rows.Next() {
var g Group
var createdAt sql.NullTime
err := rows.Scan(&g.ID, &g.Name, &g.Description, &createdAt, &g.CreatedBy)
if err != nil {
return nil, fmt.Errorf("failed to scan group: %w", err)
}
if createdAt.Valid {
g.CreatedAt = createdAt.Time
}
groups = append(groups, g)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating groups: %w", err)
}
return groups, nil
}
// IsGroupAdmin checks if user is an admin of the group
func (db *DB) IsGroupAdmin(userID, groupID string) (bool, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ? AND role = 'admin'`
} else {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2 AND role = 'admin'`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
if err != nil {
return false, fmt.Errorf("failed to check admin status: %w", err)
}
return count > 0, nil
}
// IsGroupMember checks if user is a member of the group
func (db *DB) IsGroupMember(userID, groupID string) (bool, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ?`
} else {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
if err != nil {
return false, fmt.Errorf("failed to check membership: %w", err)
}
return count > 0, nil
}
// CreateGroupInvitation creates a new group invitation
func (db *DB) CreateGroupInvitation(groupID, invitedUserID, invitedBy string) (*GroupInvitation, error) {
id := fmt.Sprintf("inv_%d", time.Now().UnixNano())
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES (?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES ($1, $2, $3, $4, $5)`
}
_, err := db.conn.ExecContext(context.Background(), query, id, groupID, invitedUserID, invitedBy, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to create invitation: %w", err)
}
return &GroupInvitation{
ID: id,
GroupID: groupID,
InvitedUserID: invitedUserID,
InvitedBy: invitedBy,
Status: "pending",
CreatedAt: time.Now(),
ExpiresAt: &expiresAt,
}, nil
}
// GetInvitation retrieves an invitation by ID
func (db *DB) GetInvitation(invitationID string) (*GroupInvitation, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = ?`
} else {
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = $1`
}
var inv GroupInvitation
var createdAt, expiresAt sql.NullTime
err := db.conn.QueryRowContext(context.Background(), query, invitationID).Scan(
&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invitation not found")
}
return nil, fmt.Errorf("failed to get invitation: %w", err)
}
if createdAt.Valid {
inv.CreatedAt = createdAt.Time
}
if expiresAt.Valid {
inv.ExpiresAt = &expiresAt.Time
}
return &inv, nil
}
// ListPendingInvitationsForUser returns pending invitations for a user
func (db *DB) ListPendingInvitationsForUser(userID string) ([]GroupInvitation, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
FROM group_invitations
WHERE invited_user_id = ? AND status = 'pending'
ORDER BY created_at DESC`
} else {
query = `
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
FROM group_invitations
WHERE invited_user_id = $1 AND status = 'pending'
ORDER BY created_at DESC`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID)
if err != nil {
return nil, fmt.Errorf("failed to list invitations: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
var invitations []GroupInvitation
for rows.Next() {
var inv GroupInvitation
var createdAt, expiresAt sql.NullTime
err := rows.Scan(&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to scan invitation: %w", err)
}
if createdAt.Valid {
inv.CreatedAt = createdAt.Time
}
if expiresAt.Valid {
inv.ExpiresAt = &expiresAt.Time
}
invitations = append(invitations, inv)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating invitations: %w", err)
}
return invitations, nil
}
// AcceptInvitation accepts a group invitation
func (db *DB) AcceptInvitation(invitationID, userID string) error {
tx, err := db.conn.BeginTx(context.Background(), nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Printf("ERROR: failed to rollback transaction: %v", err)
}
}()
// Get invitation details
var groupID string
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT group_id FROM group_invitations WHERE id = ? AND invited_user_id = ? AND status = 'pending'`
} else {
query = `SELECT group_id FROM group_invitations WHERE id = $1 AND invited_user_id = $2 AND status = 'pending'`
}
err = tx.QueryRowContext(context.Background(), query, invitationID, userID).Scan(&groupID)
if err != nil {
return fmt.Errorf("invitation not found or already processed: %w", err)
}
// Add user to group
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'member')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'member')`
}
_, err = tx.ExecContext(context.Background(), query, groupID, userID)
if err != nil {
return fmt.Errorf("failed to add member: %w", err)
}
// Update invitation status
if db.dbType == DBTypeSQLite {
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = ?`
} else {
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = $1`
}
_, err = tx.ExecContext(context.Background(), query, invitationID)
if err != nil {
return fmt.Errorf("failed to update invitation: %w", err)
}
return tx.Commit()
}
// DeclineInvitation declines a group invitation
func (db *DB) DeclineInvitation(invitationID, userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE group_invitations SET status = 'declined' WHERE id = ? AND invited_user_id = ?`
} else {
query = `UPDATE group_invitations SET status = 'declined' WHERE id = $1 AND invited_user_id = $2`
}
_, err := db.conn.ExecContext(context.Background(), query, invitationID, userID)
if err != nil {
return fmt.Errorf("failed to decline invitation: %w", err)
}
return nil
}
// RemoveGroupMember removes a member from a group
func (db *DB) RemoveGroupMember(groupID, userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM group_members WHERE group_id = ? AND user_id = ?`
} else {
query = `DELETE FROM group_members WHERE group_id = $1 AND user_id = $2`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID)
if err != nil {
return fmt.Errorf("failed to remove member: %w", err)
}
return nil
}
// UserRoleInTaskGroups returns the highest role (admin > member > viewer) the user
// holds across all groups associated with the task. Returns empty string if no access.
func (db *DB) UserRoleInTaskGroups(userID, taskID string) string {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT gm.role FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = ? AND tga.task_id = ?
ORDER BY CASE gm.role
WHEN 'admin' THEN 1
WHEN 'member' THEN 2
WHEN 'viewer' THEN 3
END
LIMIT 1`
} else {
query = `
SELECT gm.role FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = $1 AND tga.task_id = $2
ORDER BY CASE gm.role
WHEN 'admin' THEN 1
WHEN 'member' THEN 2
WHEN 'viewer' THEN 3
END
LIMIT 1`
}
var role string
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&role)
if err != nil {
return ""
}
return role
}
// GetOrCreateDefaultLabGroup creates the auto-provisioned lab group if it doesn't exist.
// Returns the group ID. If no DEFAULT_LAB_GROUP env var is set, returns empty string.
func (db *DB) GetOrCreateDefaultLabGroup(createdBy string) (string, error) {
groupName := os.Getenv("DEFAULT_LAB_GROUP")
if groupName == "" {
return "", nil
}
// Use transaction to prevent race conditions during concurrent startup
tx, err := db.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Printf("ERROR: failed to rollback transaction: %v", err)
}
}()
// Check if group exists
var groupID string
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id FROM groups WHERE name = ?`
} else {
query = `SELECT id FROM groups WHERE name = $1`
}
err = tx.QueryRowContext(context.Background(), query, groupName).Scan(&groupID)
if err == nil {
if commitErr := tx.Commit(); commitErr != nil {
return "", fmt.Errorf("failed to commit transaction: %w", commitErr)
}
return groupID, nil
}
if err != sql.ErrNoRows {
return "", fmt.Errorf("failed to check for default lab group: %w", err)
}
// Create the group within the transaction
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
if db.dbType == DBTypeSQLite {
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
}
_, err = tx.ExecContext(context.Background(), query, id, groupName, "Auto-provisioned default lab group", createdBy)
if err != nil {
return "", fmt.Errorf("failed to create default lab group: %w", err)
}
// Add creator as admin within the transaction
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
}
_, err = tx.ExecContext(context.Background(), query, id, createdBy)
if err != nil {
return "", fmt.Errorf("failed to add creator as admin: %w", err)
}
if err := tx.Commit(); err != nil {
return "", fmt.Errorf("failed to commit transaction: %w", err)
}
return id, nil
}
// EnsureUserInGroup adds a user to a group if not already a member.
// Default role is 'member'. Returns nil if already a member.
func (db *DB) EnsureUserInGroup(groupID, userID string, role string) error {
if role == "" {
role = "member"
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO group_members (group_id, user_id, role) VALUES (?, ?, ?)`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, $3)
ON CONFLICT (group_id, user_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID, role)
if err != nil {
return fmt.Errorf("failed to add user to group: %w", err)
}
return nil
}
// EnsureAllUsersGroup creates the 'all-users' system group if it doesn't exist.
// This group is used for institution visibility.
func (db *DB) EnsureAllUsersGroup() (string, error) {
const groupID = "all-users"
const groupName = "all-users"
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)
ON CONFLICT (id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, groupName,
"System group: all authenticated users", "system")
if err != nil {
return "", fmt.Errorf("failed to ensure all-users group: %w", err)
}
return groupID, nil
}
// EnsureUserInAllUsersGroup adds a user to the 'all-users' system group.
func (db *DB) EnsureUserInAllUsersGroup(userID string) error {
allUsersGroupID, err := db.EnsureAllUsersGroup()
if err != nil {
return err
}
return db.EnsureUserInGroup(allUsersGroupID, userID, "member")
}
// ProvisionUserOnFirstLogin adds a new user to the default groups:
// 1. The 'all-users' system group (for institution visibility)
// 2. The DEFAULT_LAB_GROUP if configured (for lab visibility)
// This should be called when a user first authenticates/logs in.
func (db *DB) ProvisionUserOnFirstLogin(userID string) error {
// Add to all-users system group first
if err := db.EnsureUserInAllUsersGroup(userID); err != nil {
return fmt.Errorf("failed to add user to all-users group: %w", err)
}
// Add to default lab group if configured
defaultGroupID, err := db.GetOrCreateDefaultLabGroup("system")
if err != nil {
return fmt.Errorf("failed to get/create default lab group: %w", err)
}
if defaultGroupID != "" {
if err := db.EnsureUserInGroup(defaultGroupID, userID, "member"); err != nil {
return fmt.Errorf("failed to add user to default lab group: %w", err)
}
}
return nil
}
// DeprovisionUser removes a user from all groups when they are deactivated/deleted.
// This prevents deactivated users from retaining institution-visibility access.
func (db *DB) DeprovisionUser(userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM group_members WHERE user_id = ?`
} else {
query = `DELETE FROM group_members WHERE user_id = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, userID)
if err != nil {
return fmt.Errorf("failed to remove user from groups: %w", err)
}
return nil
}