package storage import ( "context" "database/sql" "fmt" "log" "os" "time" ) // Group represents a lab group type Group struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` CreatedAt time.Time `json:"created_at"` CreatedBy string `json:"created_by"` } // GroupInvitation represents a pending invitation type GroupInvitation struct { ID string `json:"id"` GroupID string `json:"group_id"` InvitedUserID string `json:"invited_user_id"` InvitedBy string `json:"invited_by"` Status string `json:"status"` CreatedAt time.Time `json:"created_at"` ExpiresAt *time.Time `json:"expires_at,omitempty"` } // CreateGroup creates a new lab group func (db *DB) CreateGroup(name, description, createdBy string) (*Group, error) { id := fmt.Sprintf("group_%d", time.Now().UnixNano()) var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` } else { query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)` } _, err := db.conn.ExecContext(context.Background(), query, id, name, description, createdBy) if err != nil { return nil, fmt.Errorf("failed to create group: %w", err) } // Add creator as admin if db.dbType == DBTypeSQLite { query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')` } else { query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')` } _, err = db.conn.ExecContext(context.Background(), query, id, createdBy) if err != nil { return nil, fmt.Errorf("failed to add creator as admin: %w", err) } return &Group{ ID: id, Name: name, Description: description, CreatedBy: createdBy, CreatedAt: time.Now(), }, nil } // ListGroupsForUser returns all groups the user is a member of func (db *DB) ListGroupsForUser(userID string) ([]Group, error) { var query string if db.dbType == DBTypeSQLite { query = ` SELECT g.id, g.name, g.description, g.created_at, g.created_by FROM groups g JOIN group_members gm ON gm.group_id = g.id WHERE gm.user_id = ? ORDER BY g.created_at DESC` } else { query = ` SELECT g.id, g.name, g.description, g.created_at, g.created_by FROM groups g JOIN group_members gm ON gm.group_id = g.id WHERE gm.user_id = $1 ORDER BY g.created_at DESC` } rows, err := db.conn.QueryContext(context.Background(), query, userID) if err != nil { return nil, fmt.Errorf("failed to list groups: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("ERROR: failed to close rows: %v", err) } }() var groups []Group for rows.Next() { var g Group var createdAt sql.NullTime err := rows.Scan(&g.ID, &g.Name, &g.Description, &createdAt, &g.CreatedBy) if err != nil { return nil, fmt.Errorf("failed to scan group: %w", err) } if createdAt.Valid { g.CreatedAt = createdAt.Time } groups = append(groups, g) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating groups: %w", err) } return groups, nil } // IsGroupAdmin checks if user is an admin of the group func (db *DB) IsGroupAdmin(userID, groupID string) (bool, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ? AND role = 'admin'` } else { query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2 AND role = 'admin'` } var count int err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count) if err != nil { return false, fmt.Errorf("failed to check admin status: %w", err) } return count > 0, nil } // IsGroupMember checks if user is a member of the group func (db *DB) IsGroupMember(userID, groupID string) (bool, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ?` } else { query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2` } var count int err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count) if err != nil { return false, fmt.Errorf("failed to check membership: %w", err) } return count > 0, nil } // CreateGroupInvitation creates a new group invitation func (db *DB) CreateGroupInvitation(groupID, invitedUserID, invitedBy string) (*GroupInvitation, error) { id := fmt.Sprintf("inv_%d", time.Now().UnixNano()) expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES (?, ?, ?, ?, ?)` } else { query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES ($1, $2, $3, $4, $5)` } _, err := db.conn.ExecContext(context.Background(), query, id, groupID, invitedUserID, invitedBy, expiresAt) if err != nil { return nil, fmt.Errorf("failed to create invitation: %w", err) } return &GroupInvitation{ ID: id, GroupID: groupID, InvitedUserID: invitedUserID, InvitedBy: invitedBy, Status: "pending", CreatedAt: time.Now(), ExpiresAt: &expiresAt, }, nil } // GetInvitation retrieves an invitation by ID func (db *DB) GetInvitation(invitationID string) (*GroupInvitation, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = ?` } else { query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = $1` } var inv GroupInvitation var createdAt, expiresAt sql.NullTime err := db.conn.QueryRowContext(context.Background(), query, invitationID).Scan( &inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("invitation not found") } return nil, fmt.Errorf("failed to get invitation: %w", err) } if createdAt.Valid { inv.CreatedAt = createdAt.Time } if expiresAt.Valid { inv.ExpiresAt = &expiresAt.Time } return &inv, nil } // ListPendingInvitationsForUser returns pending invitations for a user func (db *DB) ListPendingInvitationsForUser(userID string) ([]GroupInvitation, error) { var query string if db.dbType == DBTypeSQLite { query = ` SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE invited_user_id = ? AND status = 'pending' ORDER BY created_at DESC` } else { query = ` SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE invited_user_id = $1 AND status = 'pending' ORDER BY created_at DESC` } rows, err := db.conn.QueryContext(context.Background(), query, userID) if err != nil { return nil, fmt.Errorf("failed to list invitations: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("ERROR: failed to close rows: %v", err) } }() var invitations []GroupInvitation for rows.Next() { var inv GroupInvitation var createdAt, expiresAt sql.NullTime err := rows.Scan(&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt) if err != nil { return nil, fmt.Errorf("failed to scan invitation: %w", err) } if createdAt.Valid { inv.CreatedAt = createdAt.Time } if expiresAt.Valid { inv.ExpiresAt = &expiresAt.Time } invitations = append(invitations, inv) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating invitations: %w", err) } return invitations, nil } // AcceptInvitation accepts a group invitation func (db *DB) AcceptInvitation(invitationID, userID string) error { tx, err := db.conn.BeginTx(context.Background(), nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { log.Printf("ERROR: failed to rollback transaction: %v", err) } }() // Get invitation details var groupID string var query string if db.dbType == DBTypeSQLite { query = `SELECT group_id FROM group_invitations WHERE id = ? AND invited_user_id = ? AND status = 'pending'` } else { query = `SELECT group_id FROM group_invitations WHERE id = $1 AND invited_user_id = $2 AND status = 'pending'` } err = tx.QueryRowContext(context.Background(), query, invitationID, userID).Scan(&groupID) if err != nil { return fmt.Errorf("invitation not found or already processed: %w", err) } // Add user to group if db.dbType == DBTypeSQLite { query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'member')` } else { query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'member')` } _, err = tx.ExecContext(context.Background(), query, groupID, userID) if err != nil { return fmt.Errorf("failed to add member: %w", err) } // Update invitation status if db.dbType == DBTypeSQLite { query = `UPDATE group_invitations SET status = 'accepted' WHERE id = ?` } else { query = `UPDATE group_invitations SET status = 'accepted' WHERE id = $1` } _, err = tx.ExecContext(context.Background(), query, invitationID) if err != nil { return fmt.Errorf("failed to update invitation: %w", err) } return tx.Commit() } // DeclineInvitation declines a group invitation func (db *DB) DeclineInvitation(invitationID, userID string) error { var query string if db.dbType == DBTypeSQLite { query = `UPDATE group_invitations SET status = 'declined' WHERE id = ? AND invited_user_id = ?` } else { query = `UPDATE group_invitations SET status = 'declined' WHERE id = $1 AND invited_user_id = $2` } _, err := db.conn.ExecContext(context.Background(), query, invitationID, userID) if err != nil { return fmt.Errorf("failed to decline invitation: %w", err) } return nil } // RemoveGroupMember removes a member from a group func (db *DB) RemoveGroupMember(groupID, userID string) error { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM group_members WHERE group_id = ? AND user_id = ?` } else { query = `DELETE FROM group_members WHERE group_id = $1 AND user_id = $2` } _, err := db.conn.ExecContext(context.Background(), query, groupID, userID) if err != nil { return fmt.Errorf("failed to remove member: %w", err) } return nil } // UserRoleInTaskGroups returns the highest role (admin > member > viewer) the user // holds across all groups associated with the task. Returns empty string if no access. func (db *DB) UserRoleInTaskGroups(userID, taskID string) string { var query string if db.dbType == DBTypeSQLite { query = ` SELECT gm.role FROM group_members gm JOIN task_group_access tga ON tga.group_id = gm.group_id WHERE gm.user_id = ? AND tga.task_id = ? ORDER BY CASE gm.role WHEN 'admin' THEN 1 WHEN 'member' THEN 2 WHEN 'viewer' THEN 3 END LIMIT 1` } else { query = ` SELECT gm.role FROM group_members gm JOIN task_group_access tga ON tga.group_id = gm.group_id WHERE gm.user_id = $1 AND tga.task_id = $2 ORDER BY CASE gm.role WHEN 'admin' THEN 1 WHEN 'member' THEN 2 WHEN 'viewer' THEN 3 END LIMIT 1` } var role string err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&role) if err != nil { return "" } return role } // GetOrCreateDefaultLabGroup creates the auto-provisioned lab group if it doesn't exist. // Returns the group ID. If no DEFAULT_LAB_GROUP env var is set, returns empty string. func (db *DB) GetOrCreateDefaultLabGroup(createdBy string) (string, error) { groupName := os.Getenv("DEFAULT_LAB_GROUP") if groupName == "" { return "", nil } // Use transaction to prevent race conditions during concurrent startup tx, err := db.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) if err != nil { return "", fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { log.Printf("ERROR: failed to rollback transaction: %v", err) } }() // Check if group exists var groupID string var query string if db.dbType == DBTypeSQLite { query = `SELECT id FROM groups WHERE name = ?` } else { query = `SELECT id FROM groups WHERE name = $1` } err = tx.QueryRowContext(context.Background(), query, groupName).Scan(&groupID) if err == nil { if commitErr := tx.Commit(); commitErr != nil { return "", fmt.Errorf("failed to commit transaction: %w", commitErr) } return groupID, nil } if err != sql.ErrNoRows { return "", fmt.Errorf("failed to check for default lab group: %w", err) } // Create the group within the transaction id := fmt.Sprintf("group_%d", time.Now().UnixNano()) if db.dbType == DBTypeSQLite { query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` } else { query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)` } _, err = tx.ExecContext(context.Background(), query, id, groupName, "Auto-provisioned default lab group", createdBy) if err != nil { return "", fmt.Errorf("failed to create default lab group: %w", err) } // Add creator as admin within the transaction if db.dbType == DBTypeSQLite { query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')` } else { query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')` } _, err = tx.ExecContext(context.Background(), query, id, createdBy) if err != nil { return "", fmt.Errorf("failed to add creator as admin: %w", err) } if err := tx.Commit(); err != nil { return "", fmt.Errorf("failed to commit transaction: %w", err) } return id, nil } // EnsureUserInGroup adds a user to a group if not already a member. // Default role is 'member'. Returns nil if already a member. func (db *DB) EnsureUserInGroup(groupID, userID string, role string) error { if role == "" { role = "member" } var query string if db.dbType == DBTypeSQLite { query = `INSERT OR IGNORE INTO group_members (group_id, user_id, role) VALUES (?, ?, ?)` } else { query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, $3) ON CONFLICT (group_id, user_id) DO NOTHING` } _, err := db.conn.ExecContext(context.Background(), query, groupID, userID, role) if err != nil { return fmt.Errorf("failed to add user to group: %w", err) } return nil } // EnsureAllUsersGroup creates the 'all-users' system group if it doesn't exist. // This group is used for institution visibility. func (db *DB) EnsureAllUsersGroup() (string, error) { const groupID = "all-users" const groupName = "all-users" var query string if db.dbType == DBTypeSQLite { query = `INSERT OR IGNORE INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` } else { query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4) ON CONFLICT (id) DO NOTHING` } _, err := db.conn.ExecContext(context.Background(), query, groupID, groupName, "System group: all authenticated users", "system") if err != nil { return "", fmt.Errorf("failed to ensure all-users group: %w", err) } return groupID, nil } // EnsureUserInAllUsersGroup adds a user to the 'all-users' system group. func (db *DB) EnsureUserInAllUsersGroup(userID string) error { allUsersGroupID, err := db.EnsureAllUsersGroup() if err != nil { return err } return db.EnsureUserInGroup(allUsersGroupID, userID, "member") } // ProvisionUserOnFirstLogin adds a new user to the default groups: // 1. The 'all-users' system group (for institution visibility) // 2. The DEFAULT_LAB_GROUP if configured (for lab visibility) // This should be called when a user first authenticates/logs in. func (db *DB) ProvisionUserOnFirstLogin(userID string) error { // Add to all-users system group first if err := db.EnsureUserInAllUsersGroup(userID); err != nil { return fmt.Errorf("failed to add user to all-users group: %w", err) } // Add to default lab group if configured defaultGroupID, err := db.GetOrCreateDefaultLabGroup("system") if err != nil { return fmt.Errorf("failed to get/create default lab group: %w", err) } if defaultGroupID != "" { if err := db.EnsureUserInGroup(defaultGroupID, userID, "member"); err != nil { return fmt.Errorf("failed to add user to default lab group: %w", err) } } return nil } // DeprovisionUser removes a user from all groups when they are deactivated/deleted. // This prevents deactivated users from retaining institution-visibility access. func (db *DB) DeprovisionUser(userID string) error { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM group_members WHERE user_id = ?` } else { query = `DELETE FROM group_members WHERE user_id = $1` } _, err := db.conn.ExecContext(context.Background(), query, userID) if err != nil { return fmt.Errorf("failed to remove user from groups: %w", err) } return nil }