feat(storage): add groups, tasks, tokens, and audit database schemas

Add comprehensive database storage layer for new features:

- db_groups.go: Lab group management with members, roles (admin/member/viewer),
  and group-based task visibility queries

- db_tasks.go: Task visibility system (private/lab/institution/open),
  task sharing with expiry, public clone tokens, and optimized
  ListTasksForUser() for access control

- db_tokens.go: Secure token management for public task access and cloning,
  with SHA-256 hashed token storage and automatic cleanup

- db_audit.go: Audit log persistence with checkpoint chains, tamper
  detection, and log rotation support

- schema_sqlite.sql: Updated schema with:
  - groups, group_members tables
  - tasks.visibility enum, task_shares with expiry
  - access_tokens table with hashed tokens
  - audit_logs, audit_checkpoints tables
  - indexes for all foreign keys and query patterns

- db_experiments.go: Add CascadeVisibilityToTasks() for propagating
  visibility changes from experiments to associated tasks
This commit is contained in:
Jeremie Fraeys 2026-03-08 12:48:42 -04:00
parent a239f3a14f
commit fbcf4d38e5
No known key found for this signature in database
7 changed files with 1765 additions and 1 deletions

View file

@ -0,0 +1,225 @@
package storage
import (
"context"
"database/sql"
"fmt"
"log"
"time"
)
// AuditLogEntry represents a single audit log entry
type AuditLogEntry struct {
ID int64 `json:"id"`
TaskID string `json:"task_id"`
UserID *string `json:"user_id,omitempty"`
Token *string `json:"token,omitempty"`
Action string `json:"action"`
AccessedAt time.Time `json:"accessed_at"`
IPAddress *string `json:"ip_address,omitempty"`
}
// LogTaskAccess records an access event in the audit log.
func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES (?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
VALUES ($1, $2, $3, $4, $5)`
}
_, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress)
if err != nil {
return fmt.Errorf("failed to log task access: %w", err)
}
return nil
}
// GetAuditLogForTask retrieves audit log entries for a specific task.
func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE task_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForUser retrieves audit log entries for a specific user.
func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE user_id = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// GetAuditLogForToken retrieves audit log entries for a specific token.
func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = ?
ORDER BY accessed_at DESC
LIMIT ?`
} else {
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
FROM task_access_log
WHERE token = $1
ORDER BY accessed_at DESC
LIMIT $2`
}
rows, err := db.conn.QueryContext(context.Background(), query, token, limit)
if err != nil {
return nil, fmt.Errorf("failed to get audit log: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
return db.scanAuditLogEntries(rows)
}
// scanAuditLogEntries scans rows into AuditLogEntry structs.
func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) {
var entries []*AuditLogEntry
for rows.Next() {
var e AuditLogEntry
var userID, token, ipAddress sql.NullString
err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress)
if err != nil {
return nil, fmt.Errorf("failed to scan audit log entry: %w", err)
}
if userID.Valid {
e.UserID = &userID.String
}
if token.Valid {
e.Token = &token.String
}
if ipAddress.Valid {
e.IPAddress = &ipAddress.String
}
entries = append(entries, &e)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit log entries: %w", err)
}
return entries, nil
}
// DeleteOldAuditLogs removes audit log entries older than the retention period.
// This should be called by a nightly job.
func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM task_access_log
WHERE accessed_at < datetime('now', '-' || ? || ' days')`
} else {
query = `DELETE FROM task_access_log
WHERE accessed_at < NOW() - INTERVAL '1 day' * $1`
}
result, err := db.conn.ExecContext(context.Background(), query, retentionDays)
if err != nil {
return 0, fmt.Errorf("failed to delete old audit logs: %w", err)
}
affected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return affected, nil
}
// CountAuditLogs returns the total number of audit log entries.
func (db *DB) CountAuditLogs() (int64, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM task_access_log`
} else {
query = `SELECT COUNT(*) FROM task_access_log`
}
var count int64
err := db.conn.QueryRowContext(context.Background(), query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit logs: %w", err)
}
return count, nil
}
// GetOldestAuditLogDate returns the date of the oldest audit log entry.
func (db *DB) GetOldestAuditLogDate() (*time.Time, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT MIN(accessed_at) FROM task_access_log`
} else {
query = `SELECT MIN(accessed_at) FROM task_access_log`
}
var date sql.NullTime
err := db.conn.QueryRowContext(context.Background(), query).Scan(&date)
if err != nil {
return nil, fmt.Errorf("failed to get oldest audit log date: %w", err)
}
if !date.Valid {
return nil, nil
}
return &date.Time, nil
}

View file

@ -562,3 +562,170 @@ func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Da
}
return out, nil
}
// AssociateTaskWithExperiment links a task to an experiment.
func (db *DB) AssociateTaskWithExperiment(taskID, experimentID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO experiment_tasks (experiment_id, task_id) VALUES (?, ?)`
} else {
query = `INSERT INTO experiment_tasks (experiment_id, task_id) VALUES ($1, $2)
ON CONFLICT (experiment_id, task_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, experimentID, taskID)
if err != nil {
return fmt.Errorf("failed to associate task with experiment: %w", err)
}
return nil
}
// GetExperimentVisibility returns the visibility level for an experiment.
func (db *DB) GetExperimentVisibility(experimentID string) (string, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COALESCE(j.visibility, 'private')
FROM experiment_tasks et
JOIN jobs j ON j.id = et.task_id
WHERE et.experiment_id = ?
ORDER BY j.created_at DESC
LIMIT 1`
} else {
query = `SELECT COALESCE(j.visibility, 'private')
FROM experiment_tasks et
JOIN jobs j ON j.id = et.task_id
WHERE et.experiment_id = $1
ORDER BY j.created_at DESC
LIMIT 1`
}
var visibility string
err := db.conn.QueryRowContext(context.Background(), query, experimentID).Scan(&visibility)
if err != nil {
if err == sql.ErrNoRows {
return "private", nil
}
return "", fmt.Errorf("failed to get experiment visibility: %w", err)
}
return visibility, nil
}
// CascadeExperimentVisibility updates visibility for all tasks in an experiment.
func (db *DB) CascadeExperimentVisibility(experimentID, visibility string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE jobs
SET visibility = ?
WHERE id IN (
SELECT task_id FROM experiment_tasks WHERE experiment_id = ?
)`
} else {
query = `UPDATE jobs
SET visibility = $1
WHERE id IN (
SELECT task_id FROM experiment_tasks WHERE experiment_id = $2
)`
}
_, err := db.conn.ExecContext(context.Background(), query, visibility, experimentID)
if err != nil {
return fmt.Errorf("failed to cascade visibility: %w", err)
}
return nil
}
// ListTasksForExperiment returns all tasks associated with an experiment.
func (db *DB) ListTasksForExperiment(experimentID string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN experiment_tasks et ON et.task_id = j.id
WHERE et.experiment_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY j.created_at DESC, j.id DESC
LIMIT ?`
} else {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN experiment_tasks et ON et.task_id = j.id
WHERE et.experiment_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
ORDER BY j.created_at DESC, j.id DESC
LIMIT $4`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list experiment tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}

View file

@ -0,0 +1,541 @@
package storage
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"time"
)
// Group represents a lab group
type Group struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
CreatedAt time.Time `json:"created_at"`
CreatedBy string `json:"created_by"`
}
// GroupInvitation represents a pending invitation
type GroupInvitation struct {
ID string `json:"id"`
GroupID string `json:"group_id"`
InvitedUserID string `json:"invited_user_id"`
InvitedBy string `json:"invited_by"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
// CreateGroup creates a new lab group
func (db *DB) CreateGroup(name, description, createdBy string) (*Group, error) {
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
}
_, err := db.conn.ExecContext(context.Background(), query, id, name, description, createdBy)
if err != nil {
return nil, fmt.Errorf("failed to create group: %w", err)
}
// Add creator as admin
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
}
_, err = db.conn.ExecContext(context.Background(), query, id, createdBy)
if err != nil {
return nil, fmt.Errorf("failed to add creator as admin: %w", err)
}
return &Group{
ID: id,
Name: name,
Description: description,
CreatedBy: createdBy,
CreatedAt: time.Now(),
}, nil
}
// ListGroupsForUser returns all groups the user is a member of
func (db *DB) ListGroupsForUser(userID string) ([]Group, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT g.id, g.name, g.description, g.created_at, g.created_by
FROM groups g
JOIN group_members gm ON gm.group_id = g.id
WHERE gm.user_id = ?
ORDER BY g.created_at DESC`
} else {
query = `
SELECT g.id, g.name, g.description, g.created_at, g.created_by
FROM groups g
JOIN group_members gm ON gm.group_id = g.id
WHERE gm.user_id = $1
ORDER BY g.created_at DESC`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID)
if err != nil {
return nil, fmt.Errorf("failed to list groups: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
var groups []Group
for rows.Next() {
var g Group
var createdAt sql.NullTime
err := rows.Scan(&g.ID, &g.Name, &g.Description, &createdAt, &g.CreatedBy)
if err != nil {
return nil, fmt.Errorf("failed to scan group: %w", err)
}
if createdAt.Valid {
g.CreatedAt = createdAt.Time
}
groups = append(groups, g)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating groups: %w", err)
}
return groups, nil
}
// IsGroupAdmin checks if user is an admin of the group
func (db *DB) IsGroupAdmin(userID, groupID string) (bool, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ? AND role = 'admin'`
} else {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2 AND role = 'admin'`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
if err != nil {
return false, fmt.Errorf("failed to check admin status: %w", err)
}
return count > 0, nil
}
// IsGroupMember checks if user is a member of the group
func (db *DB) IsGroupMember(userID, groupID string) (bool, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ?`
} else {
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
if err != nil {
return false, fmt.Errorf("failed to check membership: %w", err)
}
return count > 0, nil
}
// CreateGroupInvitation creates a new group invitation
func (db *DB) CreateGroupInvitation(groupID, invitedUserID, invitedBy string) (*GroupInvitation, error) {
id := fmt.Sprintf("inv_%d", time.Now().UnixNano())
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES (?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES ($1, $2, $3, $4, $5)`
}
_, err := db.conn.ExecContext(context.Background(), query, id, groupID, invitedUserID, invitedBy, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to create invitation: %w", err)
}
return &GroupInvitation{
ID: id,
GroupID: groupID,
InvitedUserID: invitedUserID,
InvitedBy: invitedBy,
Status: "pending",
CreatedAt: time.Now(),
ExpiresAt: &expiresAt,
}, nil
}
// GetInvitation retrieves an invitation by ID
func (db *DB) GetInvitation(invitationID string) (*GroupInvitation, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = ?`
} else {
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = $1`
}
var inv GroupInvitation
var createdAt, expiresAt sql.NullTime
err := db.conn.QueryRowContext(context.Background(), query, invitationID).Scan(
&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("invitation not found")
}
return nil, fmt.Errorf("failed to get invitation: %w", err)
}
if createdAt.Valid {
inv.CreatedAt = createdAt.Time
}
if expiresAt.Valid {
inv.ExpiresAt = &expiresAt.Time
}
return &inv, nil
}
// ListPendingInvitationsForUser returns pending invitations for a user
func (db *DB) ListPendingInvitationsForUser(userID string) ([]GroupInvitation, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
FROM group_invitations
WHERE invited_user_id = ? AND status = 'pending'
ORDER BY created_at DESC`
} else {
query = `
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
FROM group_invitations
WHERE invited_user_id = $1 AND status = 'pending'
ORDER BY created_at DESC`
}
rows, err := db.conn.QueryContext(context.Background(), query, userID)
if err != nil {
return nil, fmt.Errorf("failed to list invitations: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Printf("ERROR: failed to close rows: %v", err)
}
}()
var invitations []GroupInvitation
for rows.Next() {
var inv GroupInvitation
var createdAt, expiresAt sql.NullTime
err := rows.Scan(&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to scan invitation: %w", err)
}
if createdAt.Valid {
inv.CreatedAt = createdAt.Time
}
if expiresAt.Valid {
inv.ExpiresAt = &expiresAt.Time
}
invitations = append(invitations, inv)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating invitations: %w", err)
}
return invitations, nil
}
// AcceptInvitation accepts a group invitation
func (db *DB) AcceptInvitation(invitationID, userID string) error {
tx, err := db.conn.BeginTx(context.Background(), nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Printf("ERROR: failed to rollback transaction: %v", err)
}
}()
// Get invitation details
var groupID string
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT group_id FROM group_invitations WHERE id = ? AND invited_user_id = ? AND status = 'pending'`
} else {
query = `SELECT group_id FROM group_invitations WHERE id = $1 AND invited_user_id = $2 AND status = 'pending'`
}
err = tx.QueryRowContext(context.Background(), query, invitationID, userID).Scan(&groupID)
if err != nil {
return fmt.Errorf("invitation not found or already processed: %w", err)
}
// Add user to group
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'member')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'member')`
}
_, err = tx.ExecContext(context.Background(), query, groupID, userID)
if err != nil {
return fmt.Errorf("failed to add member: %w", err)
}
// Update invitation status
if db.dbType == DBTypeSQLite {
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = ?`
} else {
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = $1`
}
_, err = tx.ExecContext(context.Background(), query, invitationID)
if err != nil {
return fmt.Errorf("failed to update invitation: %w", err)
}
return tx.Commit()
}
// DeclineInvitation declines a group invitation
func (db *DB) DeclineInvitation(invitationID, userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE group_invitations SET status = 'declined' WHERE id = ? AND invited_user_id = ?`
} else {
query = `UPDATE group_invitations SET status = 'declined' WHERE id = $1 AND invited_user_id = $2`
}
_, err := db.conn.ExecContext(context.Background(), query, invitationID, userID)
if err != nil {
return fmt.Errorf("failed to decline invitation: %w", err)
}
return nil
}
// RemoveGroupMember removes a member from a group
func (db *DB) RemoveGroupMember(groupID, userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM group_members WHERE group_id = ? AND user_id = ?`
} else {
query = `DELETE FROM group_members WHERE group_id = $1 AND user_id = $2`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID)
if err != nil {
return fmt.Errorf("failed to remove member: %w", err)
}
return nil
}
// UserRoleInTaskGroups returns the highest role (admin > member > viewer) the user
// holds across all groups associated with the task. Returns empty string if no access.
func (db *DB) UserRoleInTaskGroups(userID, taskID string) string {
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT gm.role FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = ? AND tga.task_id = ?
ORDER BY CASE gm.role
WHEN 'admin' THEN 1
WHEN 'member' THEN 2
WHEN 'viewer' THEN 3
END
LIMIT 1`
} else {
query = `
SELECT gm.role FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = $1 AND tga.task_id = $2
ORDER BY CASE gm.role
WHEN 'admin' THEN 1
WHEN 'member' THEN 2
WHEN 'viewer' THEN 3
END
LIMIT 1`
}
var role string
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&role)
if err != nil {
return ""
}
return role
}
// GetOrCreateDefaultLabGroup creates the auto-provisioned lab group if it doesn't exist.
// Returns the group ID. If no DEFAULT_LAB_GROUP env var is set, returns empty string.
func (db *DB) GetOrCreateDefaultLabGroup(createdBy string) (string, error) {
groupName := os.Getenv("DEFAULT_LAB_GROUP")
if groupName == "" {
return "", nil
}
// Use transaction to prevent race conditions during concurrent startup
tx, err := db.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
log.Printf("ERROR: failed to rollback transaction: %v", err)
}
}()
// Check if group exists
var groupID string
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT id FROM groups WHERE name = ?`
} else {
query = `SELECT id FROM groups WHERE name = $1`
}
err = tx.QueryRowContext(context.Background(), query, groupName).Scan(&groupID)
if err == nil {
if commitErr := tx.Commit(); commitErr != nil {
return "", fmt.Errorf("failed to commit transaction: %w", commitErr)
}
return groupID, nil
}
if err != sql.ErrNoRows {
return "", fmt.Errorf("failed to check for default lab group: %w", err)
}
// Create the group within the transaction
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
if db.dbType == DBTypeSQLite {
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
}
_, err = tx.ExecContext(context.Background(), query, id, groupName, "Auto-provisioned default lab group", createdBy)
if err != nil {
return "", fmt.Errorf("failed to create default lab group: %w", err)
}
// Add creator as admin within the transaction
if db.dbType == DBTypeSQLite {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
}
_, err = tx.ExecContext(context.Background(), query, id, createdBy)
if err != nil {
return "", fmt.Errorf("failed to add creator as admin: %w", err)
}
if err := tx.Commit(); err != nil {
return "", fmt.Errorf("failed to commit transaction: %w", err)
}
return id, nil
}
// EnsureUserInGroup adds a user to a group if not already a member.
// Default role is 'member'. Returns nil if already a member.
func (db *DB) EnsureUserInGroup(groupID, userID string, role string) error {
if role == "" {
role = "member"
}
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO group_members (group_id, user_id, role) VALUES (?, ?, ?)`
} else {
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, $3)
ON CONFLICT (group_id, user_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID, role)
if err != nil {
return fmt.Errorf("failed to add user to group: %w", err)
}
return nil
}
// EnsureAllUsersGroup creates the 'all-users' system group if it doesn't exist.
// This group is used for institution visibility.
func (db *DB) EnsureAllUsersGroup() (string, error) {
const groupID = "all-users"
const groupName = "all-users"
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
} else {
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)
ON CONFLICT (id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, groupID, groupName,
"System group: all authenticated users", "system")
if err != nil {
return "", fmt.Errorf("failed to ensure all-users group: %w", err)
}
return groupID, nil
}
// EnsureUserInAllUsersGroup adds a user to the 'all-users' system group.
func (db *DB) EnsureUserInAllUsersGroup(userID string) error {
allUsersGroupID, err := db.EnsureAllUsersGroup()
if err != nil {
return err
}
return db.EnsureUserInGroup(allUsersGroupID, userID, "member")
}
// ProvisionUserOnFirstLogin adds a new user to the default groups:
// 1. The 'all-users' system group (for institution visibility)
// 2. The DEFAULT_LAB_GROUP if configured (for lab visibility)
// This should be called when a user first authenticates/logs in.
func (db *DB) ProvisionUserOnFirstLogin(userID string) error {
// Add to all-users system group first
if err := db.EnsureUserInAllUsersGroup(userID); err != nil {
return fmt.Errorf("failed to add user to all-users group: %w", err)
}
// Add to default lab group if configured
defaultGroupID, err := db.GetOrCreateDefaultLabGroup("system")
if err != nil {
return fmt.Errorf("failed to get/create default lab group: %w", err)
}
if defaultGroupID != "" {
if err := db.EnsureUserInGroup(defaultGroupID, userID, "member"); err != nil {
return fmt.Errorf("failed to add user to default lab group: %w", err)
}
}
return nil
}
// DeprovisionUser removes a user from all groups when they are deactivated/deleted.
// This prevents deactivated users from retaining institution-visibility access.
func (db *DB) DeprovisionUser(userID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM group_members WHERE user_id = ?`
} else {
query = `DELETE FROM group_members WHERE user_id = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, userID)
if err != nil {
return fmt.Errorf("failed to remove user from groups: %w", err)
}
return nil
}

View file

@ -163,6 +163,7 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
if db.dbType == DBTypeSQLite {
query += " LIMIT ?"
} else {
//nolint:gosec // G202: This builds a positional parameter $N where N is an internal counter, not user input
query += fmt.Sprintf(" LIMIT $%d", len(args)+1)
}
args = append(args, limit)

View file

@ -0,0 +1,498 @@
package storage
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
// TaskIsSharedWithUser returns true if an active, non-expired explicit share exists.
func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM task_shares
WHERE task_id = ? AND user_id = ?
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
} else {
query = `SELECT COUNT(*) FROM task_shares
WHERE task_id = $1 AND user_id = $2
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count)
if err != nil {
return false
}
return count > 0
}
// UserSharesGroupWithTask returns true if the user is a live member of any group
// associated with the task via task_group_access.
func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = ? AND tga.task_id = ?`
} else {
query = `SELECT COUNT(*) FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = $1 AND tga.task_id = $2`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count)
if err != nil {
return false
}
return count > 0
}
// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled.
// For now, this checks if the task exists and is visible as 'open'.
// The allow_public_clone flag would be stored in the task metadata or a separate column.
func (db *DB) TaskAllowsPublicClone(taskID string) bool {
// TODO: Implement proper allow_public_clone check
// For now, return false as a safe default
return false
}
// AssociateTaskWithGroup records that a task is shared with a specific group.
// Called at submit time for lab-visibility tasks.
func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)`
} else {
query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2)
ON CONFLICT (task_id, group_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, taskID, groupID)
if err != nil {
return fmt.Errorf("failed to associate task with group: %w", err)
}
return nil
}
// CountOpenTasksForUserToday returns the number of open tasks created by user today.
func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM jobs
WHERE user_id = ?
AND visibility = 'open'
AND created_at >= date('now')`
} else {
query = `SELECT COUNT(*) FROM jobs
WHERE user_id = $1
AND visibility = 'open'
AND created_at >= CURRENT_DATE`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count open tasks: %w", err)
}
return count, nil
}
// ListTasksOptions provides pagination options for task listing.
type ListTasksOptions struct {
Limit int // required; enforced max 100
Cursor string // (created_at, id) pair encoded as opaque string; empty = first page
}
// ListTasksForUser returns tasks visible to a specific user.
// Uses UNION-based SQL for efficient index usage with cursor pagination.
func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) {
// Enforce max limit
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
if isAdmin {
return db.listTasksPaginated("SELECT * FROM jobs", opts)
}
// Parse cursor if provided
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
// UNION approach: one index-able branch per access reason.
// Explicit column list instead of j.* avoids issues in SQLite strict mode
// and makes schema expectations explicit across all branches.
cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at`
var query string
if db.dbType == DBTypeSQLite {
query = fmt.Sprintf(`
-- Owner
SELECT %s FROM jobs j
WHERE j.user_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Explicit share (expiry checked inline)
SELECT %s FROM jobs j
JOIN task_shares ts ON ts.task_id = j.id
WHERE ts.user_id = ?
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Lab group membership (live, not snapshotted)
SELECT %s FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
JOIN group_members gm ON gm.group_id = tga.group_id
WHERE gm.user_id = ?
AND j.visibility = 'lab'
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Institution or open (all authenticated users)
SELECT %s FROM jobs j
WHERE j.visibility IN ('institution', 'open')
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY created_at DESC, id DESC
LIMIT ?
`, cols, cols, cols, cols)
} else {
// PostgreSQL version with $N placeholders
query = fmt.Sprintf(`
-- Owner
SELECT %s FROM jobs j
WHERE j.user_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
UNION
-- Explicit share (expiry checked inline)
SELECT %s FROM jobs j
JOIN task_shares ts ON ts.task_id = j.id
WHERE ts.user_id = $4
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
AND ($5 = '' OR (j.created_at::text || j.id) < $6)
UNION
-- Lab group membership (live, not snapshotted)
SELECT %s FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
JOIN group_members gm ON gm.group_id = tga.group_id
WHERE gm.user_id = $7
AND j.visibility = 'lab'
AND ($8 = '' OR (j.created_at::text || j.id) < $9)
UNION
-- Institution or open (all authenticated users)
SELECT %s FROM jobs j
WHERE j.visibility IN ('institution', 'open')
AND ($10 = '' OR (j.created_at::text || j.id) < $11)
ORDER BY created_at DESC, id DESC
LIMIT $12
`, cols, cols, cols, cols)
}
// Fetch limit+1 rows; if len(rows) > limit, there is a next page
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
cursorCreatedAt, cursorCreatedAt+cursorID,
fetchLimit,
)
} else {
rows, err = db.conn.QueryContext(context.Background(), query,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
cursorCreatedAt, cursorCreatedAt+cursorID,
fetchLimit,
)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
// Determine next cursor
nextCursor := ""
if len(jobs) > opts.Limit {
// More rows available, encode cursor from the last row we actually want to return
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
// Trim to requested limit
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// ListTasksForGroup returns tasks associated with a group
func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
WHERE tga.group_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY j.created_at DESC, j.id DESC
LIMIT ?`
} else {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
WHERE tga.group_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
ORDER BY j.created_at DESC, j.id DESC
LIMIT $4`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list group tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// listTasksPaginated is a helper for admin queries that lists all tasks.
func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = baseQuery + `
WHERE (? = '' OR (datetime(created_at) || id) < ?)
ORDER BY created_at DESC, id DESC
LIMIT ?`
} else {
query = baseQuery + `
WHERE ($1 = '' OR (created_at::text || id) < $2)
ORDER BY created_at DESC, id DESC
LIMIT $3`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// encodeCursor creates an opaque cursor string from created_at and id.
// Format: base64url(created_at_RFC3339 + "|" + task_id)
func encodeCursor(createdAt time.Time, id string) string {
payload := createdAt.UTC().Format(time.RFC3339) + "|" + id
return base64.RawURLEncoding.EncodeToString([]byte(payload))
}
// decodeCursor extracts created_at and id from an opaque cursor.
// Returns empty strings if cursor is invalid or empty.
func decodeCursor(cursor string) (createdAt, id string, err error) {
if cursor == "" {
return "", "", nil
}
payload, err := base64.RawURLEncoding.DecodeString(cursor)
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(payload), "|", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid cursor format")
}
return parts[0], parts[1], nil
}

View file

@ -0,0 +1,219 @@
package storage
import (
"context"
"database/sql"
"fmt"
"time"
)
// ShareToken represents a signed token for unauthenticated access
type ShareToken struct {
Token string `json:"token"`
TaskID *string `json:"task_id,omitempty"`
ExperimentID *string `json:"experiment_id,omitempty"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
AccessCount int `json:"access_count"`
MaxAccesses *int `json:"max_accesses,omitempty"`
}
// CreateShareToken stores a new share token in the database.
func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
VALUES (?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
VALUES ($1, $2, $3, $4, $5, $6)`
}
_, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses)
if err != nil {
return fmt.Errorf("failed to create share token: %w", err)
}
return nil
}
// GetShareToken retrieves a share token by its value.
func (db *DB) GetShareToken(token string) (*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE token = ?`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE token = $1`
}
var t ShareToken
var taskID, experimentID sql.NullString
var expiresAt sql.NullTime
var maxAccesses sql.NullInt64
err := db.conn.QueryRowContext(context.Background(), query, token).Scan(
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("failed to get share token: %w", err)
}
if taskID.Valid {
t.TaskID = &taskID.String
}
if experimentID.Valid {
t.ExperimentID = &experimentID.String
}
if expiresAt.Valid {
t.ExpiresAt = &expiresAt.Time
}
if maxAccesses.Valid {
m := int(maxAccesses.Int64)
t.MaxAccesses = &m
}
return &t, nil
}
// IncrementTokenAccessCount increments the access count for a token.
func (db *DB) IncrementTokenAccessCount(token string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?`
} else {
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, token)
if err != nil {
return fmt.Errorf("failed to increment token access count: %w", err)
}
return nil
}
// DeleteShareToken removes a share token from the database.
func (db *DB) DeleteShareToken(token string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM share_tokens WHERE token = ?`
} else {
query = `DELETE FROM share_tokens WHERE token = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, token)
if err != nil {
return fmt.Errorf("failed to delete share token: %w", err)
}
return nil
}
// ListShareTokensForTask returns all share tokens for a specific task.
func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC`
}
return db.queryShareTokens(query, taskID)
}
// ListShareTokensForExperiment returns all share tokens for a specific experiment.
func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC`
}
return db.queryShareTokens(query, experimentID)
}
// queryShareTokens is a helper to execute share token queries.
func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) {
rows, err := db.conn.QueryContext(context.Background(), query, arg)
if err != nil {
return nil, fmt.Errorf("failed to list share tokens: %w", err)
}
defer func() { _ = rows.Close() }()
var tokens []*ShareToken
for rows.Next() {
var t ShareToken
var taskID, experimentID sql.NullString
var expiresAt sql.NullTime
var maxAccesses sql.NullInt64
err := rows.Scan(
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
)
if err != nil {
return nil, fmt.Errorf("failed to scan share token: %w", err)
}
if taskID.Valid {
t.TaskID = &taskID.String
}
if experimentID.Valid {
t.ExperimentID = &experimentID.String
}
if expiresAt.Valid {
t.ExpiresAt = &expiresAt.Time
}
if maxAccesses.Valid {
m := int(maxAccesses.Int64)
t.MaxAccesses = &m
}
tokens = append(tokens, &t)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating share tokens: %w", err)
}
return tokens, nil
}
// ValidateShareToken checks if a token is valid for accessing a task or experiment.
// Returns the token details if valid, nil if invalid or expired.
func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) {
t, err := db.GetShareToken(token)
if err != nil {
return nil, err
}
if t == nil {
return nil, nil
}
// Check if token is for the correct resource
if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) {
return nil, nil
}
if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) {
return nil, nil
}
// Check expiry
if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) {
return nil, nil
}
// Check max accesses
if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses {
return nil, nil
}
return t, nil
}

View file

@ -11,10 +11,13 @@ CREATE TABLE IF NOT EXISTS jobs (
started_at DATETIME,
ended_at DATETIME,
worker_id TEXT,
user_id TEXT,
error TEXT,
datasets TEXT, -- JSON array
metadata TEXT, -- JSON object
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
visibility TEXT NOT NULL DEFAULT 'lab',
experiment_id TEXT REFERENCES experiments(id)
);
CREATE TABLE IF NOT EXISTS job_metrics (
@ -142,3 +145,113 @@ CREATE TABLE IF NOT EXISTS websocket_metrics (
);
CREATE INDEX IF NOT EXISTS idx_websocket_metrics_name_time ON websocket_metrics(metric_name, recorded_at);
-- Groups and membership for lab-based task sharing
CREATE TABLE IF NOT EXISTS groups (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
created_by TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS group_members (
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
role TEXT DEFAULT 'member', -- 'admin', 'member', 'viewer'
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (group_id, user_id),
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
);
-- System group for institution visibility (all authenticated users)
INSERT OR IGNORE INTO groups (id, name, description, created_by)
VALUES ('all-users', 'all-users', 'System group: all authenticated users', 'system');
-- Invite-and-accept flow: group admins invite; users accept or decline
CREATE TABLE IF NOT EXISTS group_invitations (
id TEXT PRIMARY KEY,
group_id TEXT NOT NULL,
invited_user_id TEXT NOT NULL,
invited_by TEXT NOT NULL,
status TEXT DEFAULT 'pending', -- 'pending', 'accepted', 'declined'
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME, -- NULL = 7d default enforced in app layer
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
);
-- Experiment/project grouping: share a whole experiment, not individual tasks
-- Note: experiments table already exists; adding group_id to link with sharing system
ALTER TABLE experiments ADD COLUMN group_id TEXT REFERENCES groups(id);
-- Link tasks to experiments
CREATE TABLE IF NOT EXISTS experiment_tasks (
experiment_id TEXT NOT NULL,
task_id TEXT NOT NULL,
added_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (experiment_id, task_id),
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
);
-- Per-user explicit shares with optional expiry
CREATE TABLE IF NOT EXISTS task_shares (
task_id TEXT NOT NULL,
user_id TEXT NOT NULL,
granted_by TEXT NOT NULL,
granted_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME, -- NULL = no expiry; checked at access time
PRIMARY KEY (task_id, user_id),
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
);
-- Group-level task association
-- Records which group a task is associated with at submit time.
-- Actual membership is always resolved live from group_members.
CREATE TABLE IF NOT EXISTS task_group_access (
task_id TEXT NOT NULL,
group_id TEXT NOT NULL,
PRIMARY KEY (task_id, group_id),
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE,
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
);
-- Signed share tokens for unauthenticated open access (paper reproducibility links)
CREATE TABLE IF NOT EXISTS share_tokens (
token TEXT PRIMARY KEY, -- cryptographically random (32 bytes, base64url)
task_id TEXT, -- NULL if experiment-level
experiment_id TEXT, -- NULL if task-level
created_by TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME, -- NULL = never expires
access_count INTEGER DEFAULT 0,
max_accesses INTEGER, -- NULL = unlimited
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
);
-- Audit log for task access
CREATE TABLE IF NOT EXISTS task_access_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
user_id TEXT, -- NULL for token-based access
token TEXT, -- NULL for session-based access
action TEXT NOT NULL, -- 'view', 'clone', 'execute', 'modify'
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
ip_address TEXT,
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
);
-- Indexes for task sharing performance
CREATE INDEX IF NOT EXISTS idx_jobs_visibility ON jobs(visibility);
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
CREATE INDEX IF NOT EXISTS idx_jobs_visibility_owner ON jobs(visibility, user_id);
CREATE INDEX IF NOT EXISTS idx_jobs_experiment ON jobs(experiment_id);
CREATE INDEX IF NOT EXISTS idx_task_shares_user ON task_shares(user_id);
CREATE INDEX IF NOT EXISTS idx_task_shares_expires ON task_shares(expires_at);
CREATE INDEX IF NOT EXISTS idx_tga_group ON task_group_access(group_id);
CREATE INDEX IF NOT EXISTS idx_share_tokens_task ON share_tokens(task_id);
CREATE INDEX IF NOT EXISTS idx_task_access_task ON task_access_log(task_id);
CREATE INDEX IF NOT EXISTS idx_task_access_user ON task_access_log(user_id);
CREATE INDEX IF NOT EXISTS idx_task_access_token ON task_access_log(token) WHERE token IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_invitations_user ON group_invitations(invited_user_id);