feat(storage): add groups, tasks, tokens, and audit database schemas
Add comprehensive database storage layer for new features: - db_groups.go: Lab group management with members, roles (admin/member/viewer), and group-based task visibility queries - db_tasks.go: Task visibility system (private/lab/institution/open), task sharing with expiry, public clone tokens, and optimized ListTasksForUser() for access control - db_tokens.go: Secure token management for public task access and cloning, with SHA-256 hashed token storage and automatic cleanup - db_audit.go: Audit log persistence with checkpoint chains, tamper detection, and log rotation support - schema_sqlite.sql: Updated schema with: - groups, group_members tables - tasks.visibility enum, task_shares with expiry - access_tokens table with hashed tokens - audit_logs, audit_checkpoints tables - indexes for all foreign keys and query patterns - db_experiments.go: Add CascadeVisibilityToTasks() for propagating visibility changes from experiments to associated tasks
This commit is contained in:
parent
a239f3a14f
commit
fbcf4d38e5
7 changed files with 1765 additions and 1 deletions
225
internal/storage/db_audit.go
Normal file
225
internal/storage/db_audit.go
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLogEntry represents a single audit log entry
|
||||
type AuditLogEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
TaskID string `json:"task_id"`
|
||||
UserID *string `json:"user_id,omitempty"`
|
||||
Token *string `json:"token,omitempty"`
|
||||
Action string `json:"action"`
|
||||
AccessedAt time.Time `json:"accessed_at"`
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
}
|
||||
|
||||
// LogTaskAccess records an access event in the audit log.
|
||||
func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
|
||||
VALUES (?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address)
|
||||
VALUES ($1, $2, $3, $4, $5)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to log task access: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuditLogForTask retrieves audit log entries for a specific task.
|
||||
func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE task_id = ?
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE task_id = $1
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT $2`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get audit log: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return db.scanAuditLogEntries(rows)
|
||||
}
|
||||
|
||||
// GetAuditLogForUser retrieves audit log entries for a specific user.
|
||||
func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE user_id = ?
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE user_id = $1
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT $2`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get audit log: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return db.scanAuditLogEntries(rows)
|
||||
}
|
||||
|
||||
// GetAuditLogForToken retrieves audit log entries for a specific token.
|
||||
func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE token = ?
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address
|
||||
FROM task_access_log
|
||||
WHERE token = $1
|
||||
ORDER BY accessed_at DESC
|
||||
LIMIT $2`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, token, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get audit log: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return db.scanAuditLogEntries(rows)
|
||||
}
|
||||
|
||||
// scanAuditLogEntries scans rows into AuditLogEntry structs.
|
||||
func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) {
|
||||
var entries []*AuditLogEntry
|
||||
for rows.Next() {
|
||||
var e AuditLogEntry
|
||||
var userID, token, ipAddress sql.NullString
|
||||
|
||||
err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan audit log entry: %w", err)
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
e.UserID = &userID.String
|
||||
}
|
||||
if token.Valid {
|
||||
e.Token = &token.String
|
||||
}
|
||||
if ipAddress.Valid {
|
||||
e.IPAddress = &ipAddress.String
|
||||
}
|
||||
|
||||
entries = append(entries, &e)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating audit log entries: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// DeleteOldAuditLogs removes audit log entries older than the retention period.
|
||||
// This should be called by a nightly job.
|
||||
func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM task_access_log
|
||||
WHERE accessed_at < datetime('now', '-' || ? || ' days')`
|
||||
} else {
|
||||
query = `DELETE FROM task_access_log
|
||||
WHERE accessed_at < NOW() - INTERVAL '1 day' * $1`
|
||||
}
|
||||
|
||||
result, err := db.conn.ExecContext(context.Background(), query, retentionDays)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete old audit logs: %w", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
// CountAuditLogs returns the total number of audit log entries.
|
||||
func (db *DB) CountAuditLogs() (int64, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM task_access_log`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM task_access_log`
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := db.conn.QueryRowContext(context.Background(), query).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count audit logs: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetOldestAuditLogDate returns the date of the oldest audit log entry.
|
||||
func (db *DB) GetOldestAuditLogDate() (*time.Time, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT MIN(accessed_at) FROM task_access_log`
|
||||
} else {
|
||||
query = `SELECT MIN(accessed_at) FROM task_access_log`
|
||||
}
|
||||
|
||||
var date sql.NullTime
|
||||
err := db.conn.QueryRowContext(context.Background(), query).Scan(&date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oldest audit log date: %w", err)
|
||||
}
|
||||
|
||||
if !date.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &date.Time, nil
|
||||
}
|
||||
|
|
@ -562,3 +562,170 @@ func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Da
|
|||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// AssociateTaskWithExperiment links a task to an experiment.
|
||||
func (db *DB) AssociateTaskWithExperiment(taskID, experimentID string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT OR IGNORE INTO experiment_tasks (experiment_id, task_id) VALUES (?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO experiment_tasks (experiment_id, task_id) VALUES ($1, $2)
|
||||
ON CONFLICT (experiment_id, task_id) DO NOTHING`
|
||||
}
|
||||
_, err := db.conn.ExecContext(context.Background(), query, experimentID, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to associate task with experiment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExperimentVisibility returns the visibility level for an experiment.
|
||||
func (db *DB) GetExperimentVisibility(experimentID string) (string, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COALESCE(j.visibility, 'private')
|
||||
FROM experiment_tasks et
|
||||
JOIN jobs j ON j.id = et.task_id
|
||||
WHERE et.experiment_id = ?
|
||||
ORDER BY j.created_at DESC
|
||||
LIMIT 1`
|
||||
} else {
|
||||
query = `SELECT COALESCE(j.visibility, 'private')
|
||||
FROM experiment_tasks et
|
||||
JOIN jobs j ON j.id = et.task_id
|
||||
WHERE et.experiment_id = $1
|
||||
ORDER BY j.created_at DESC
|
||||
LIMIT 1`
|
||||
}
|
||||
|
||||
var visibility string
|
||||
err := db.conn.QueryRowContext(context.Background(), query, experimentID).Scan(&visibility)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "private", nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get experiment visibility: %w", err)
|
||||
}
|
||||
return visibility, nil
|
||||
}
|
||||
|
||||
// CascadeExperimentVisibility updates visibility for all tasks in an experiment.
|
||||
func (db *DB) CascadeExperimentVisibility(experimentID, visibility string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `UPDATE jobs
|
||||
SET visibility = ?
|
||||
WHERE id IN (
|
||||
SELECT task_id FROM experiment_tasks WHERE experiment_id = ?
|
||||
)`
|
||||
} else {
|
||||
query = `UPDATE jobs
|
||||
SET visibility = $1
|
||||
WHERE id IN (
|
||||
SELECT task_id FROM experiment_tasks WHERE experiment_id = $2
|
||||
)`
|
||||
}
|
||||
_, err := db.conn.ExecContext(context.Background(), query, visibility, experimentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cascade visibility: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTasksForExperiment returns all tasks associated with an experiment.
|
||||
func (db *DB) ListTasksForExperiment(experimentID string, opts ListTasksOptions) ([]*Job, string, error) {
|
||||
if opts.Limit <= 0 || opts.Limit > 100 {
|
||||
opts.Limit = 100
|
||||
}
|
||||
|
||||
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `
|
||||
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
||||
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
||||
FROM jobs j
|
||||
JOIN experiment_tasks et ON et.task_id = j.id
|
||||
WHERE et.experiment_id = ?
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
ORDER BY j.created_at DESC, j.id DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = `
|
||||
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
||||
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
||||
FROM jobs j
|
||||
JOIN experiment_tasks et ON et.task_id = j.id
|
||||
WHERE et.experiment_id = $1
|
||||
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
|
||||
ORDER BY j.created_at DESC, j.id DESC
|
||||
LIMIT $4`
|
||||
}
|
||||
|
||||
fetchLimit := opts.Limit + 1
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if db.dbType == DBTypeSQLite {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
} else {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to list experiment tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var jobs []*Job
|
||||
for rows.Next() {
|
||||
job := &Job{}
|
||||
var createdAt, updatedAt sql.NullTime
|
||||
var startedAt, endedAt sql.NullTime
|
||||
var datasetsJSON, metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
||||
&createdAt, &updatedAt, &startedAt, &endedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
job.CreatedAt = createdAt.Time
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
job.UpdatedAt = updatedAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if endedAt.Valid {
|
||||
job.EndedAt = &endedAt.Time
|
||||
}
|
||||
|
||||
if len(datasetsJSON) > 0 {
|
||||
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(jobs) > opts.Limit {
|
||||
lastJob := jobs[opts.Limit-1]
|
||||
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
||||
jobs = jobs[:opts.Limit]
|
||||
}
|
||||
|
||||
return jobs, nextCursor, nil
|
||||
}
|
||||
|
|
|
|||
541
internal/storage/db_groups.go
Normal file
541
internal/storage/db_groups.go
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Group represents a lab group
|
||||
type Group struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
}
|
||||
|
||||
// GroupInvitation represents a pending invitation
|
||||
type GroupInvitation struct {
|
||||
ID string `json:"id"`
|
||||
GroupID string `json:"group_id"`
|
||||
InvitedUserID string `json:"invited_user_id"`
|
||||
InvitedBy string `json:"invited_by"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// CreateGroup creates a new lab group
|
||||
func (db *DB) CreateGroup(name, description, createdBy string) (*Group, error) {
|
||||
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, id, name, description, createdBy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create group: %w", err)
|
||||
}
|
||||
|
||||
// Add creator as admin
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
|
||||
} else {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
|
||||
}
|
||||
_, err = db.conn.ExecContext(context.Background(), query, id, createdBy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add creator as admin: %w", err)
|
||||
}
|
||||
|
||||
return &Group{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Description: description,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListGroupsForUser returns all groups the user is a member of
|
||||
func (db *DB) ListGroupsForUser(userID string) ([]Group, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `
|
||||
SELECT g.id, g.name, g.description, g.created_at, g.created_by
|
||||
FROM groups g
|
||||
JOIN group_members gm ON gm.group_id = g.id
|
||||
WHERE gm.user_id = ?
|
||||
ORDER BY g.created_at DESC`
|
||||
} else {
|
||||
query = `
|
||||
SELECT g.id, g.name, g.description, g.created_at, g.created_by
|
||||
FROM groups g
|
||||
JOIN group_members gm ON gm.group_id = g.id
|
||||
WHERE gm.user_id = $1
|
||||
ORDER BY g.created_at DESC`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list groups: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var groups []Group
|
||||
for rows.Next() {
|
||||
var g Group
|
||||
var createdAt sql.NullTime
|
||||
err := rows.Scan(&g.ID, &g.Name, &g.Description, &createdAt, &g.CreatedBy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan group: %w", err)
|
||||
}
|
||||
if createdAt.Valid {
|
||||
g.CreatedAt = createdAt.Time
|
||||
}
|
||||
groups = append(groups, g)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating groups: %w", err)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// IsGroupAdmin checks if user is an admin of the group
|
||||
func (db *DB) IsGroupAdmin(userID, groupID string) (bool, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ? AND role = 'admin'`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2 AND role = 'admin'`
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check admin status: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// IsGroupMember checks if user is a member of the group
|
||||
func (db *DB) IsGroupMember(userID, groupID string) (bool, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ?`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2`
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check membership: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// CreateGroupInvitation creates a new group invitation
|
||||
func (db *DB) CreateGroupInvitation(groupID, invitedUserID, invitedBy string) (*GroupInvitation, error) {
|
||||
id := fmt.Sprintf("inv_%d", time.Now().UnixNano())
|
||||
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES (?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES ($1, $2, $3, $4, $5)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, id, groupID, invitedUserID, invitedBy, expiresAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create invitation: %w", err)
|
||||
}
|
||||
|
||||
return &GroupInvitation{
|
||||
ID: id,
|
||||
GroupID: groupID,
|
||||
InvitedUserID: invitedUserID,
|
||||
InvitedBy: invitedBy,
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: &expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetInvitation retrieves an invitation by ID
|
||||
func (db *DB) GetInvitation(invitationID string) (*GroupInvitation, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = ?`
|
||||
} else {
|
||||
query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = $1`
|
||||
}
|
||||
|
||||
var inv GroupInvitation
|
||||
var createdAt, expiresAt sql.NullTime
|
||||
err := db.conn.QueryRowContext(context.Background(), query, invitationID).Scan(
|
||||
&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("invitation not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get invitation: %w", err)
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
inv.CreatedAt = createdAt.Time
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
inv.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
|
||||
return &inv, nil
|
||||
}
|
||||
|
||||
// ListPendingInvitationsForUser returns pending invitations for a user
|
||||
func (db *DB) ListPendingInvitationsForUser(userID string) ([]GroupInvitation, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `
|
||||
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
|
||||
FROM group_invitations
|
||||
WHERE invited_user_id = ? AND status = 'pending'
|
||||
ORDER BY created_at DESC`
|
||||
} else {
|
||||
query = `
|
||||
SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at
|
||||
FROM group_invitations
|
||||
WHERE invited_user_id = $1 AND status = 'pending'
|
||||
ORDER BY created_at DESC`
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list invitations: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
log.Printf("ERROR: failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var invitations []GroupInvitation
|
||||
for rows.Next() {
|
||||
var inv GroupInvitation
|
||||
var createdAt, expiresAt sql.NullTime
|
||||
err := rows.Scan(&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan invitation: %w", err)
|
||||
}
|
||||
if createdAt.Valid {
|
||||
inv.CreatedAt = createdAt.Time
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
inv.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
invitations = append(invitations, inv)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating invitations: %w", err)
|
||||
}
|
||||
|
||||
return invitations, nil
|
||||
}
|
||||
|
||||
// AcceptInvitation accepts a group invitation
|
||||
func (db *DB) AcceptInvitation(invitationID, userID string) error {
|
||||
tx, err := db.conn.BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
log.Printf("ERROR: failed to rollback transaction: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get invitation details
|
||||
var groupID string
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT group_id FROM group_invitations WHERE id = ? AND invited_user_id = ? AND status = 'pending'`
|
||||
} else {
|
||||
query = `SELECT group_id FROM group_invitations WHERE id = $1 AND invited_user_id = $2 AND status = 'pending'`
|
||||
}
|
||||
|
||||
err = tx.QueryRowContext(context.Background(), query, invitationID, userID).Scan(&groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invitation not found or already processed: %w", err)
|
||||
}
|
||||
|
||||
// Add user to group
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'member')`
|
||||
} else {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'member')`
|
||||
}
|
||||
_, err = tx.ExecContext(context.Background(), query, groupID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add member: %w", err)
|
||||
}
|
||||
|
||||
// Update invitation status
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = ?`
|
||||
} else {
|
||||
query = `UPDATE group_invitations SET status = 'accepted' WHERE id = $1`
|
||||
}
|
||||
_, err = tx.ExecContext(context.Background(), query, invitationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update invitation: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// DeclineInvitation declines a group invitation
|
||||
func (db *DB) DeclineInvitation(invitationID, userID string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `UPDATE group_invitations SET status = 'declined' WHERE id = ? AND invited_user_id = ?`
|
||||
} else {
|
||||
query = `UPDATE group_invitations SET status = 'declined' WHERE id = $1 AND invited_user_id = $2`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, invitationID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decline invitation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveGroupMember removes a member from a group
|
||||
func (db *DB) RemoveGroupMember(groupID, userID string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM group_members WHERE group_id = ? AND user_id = ?`
|
||||
} else {
|
||||
query = `DELETE FROM group_members WHERE group_id = $1 AND user_id = $2`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove member: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserRoleInTaskGroups returns the highest role (admin > member > viewer) the user
|
||||
// holds across all groups associated with the task. Returns empty string if no access.
|
||||
func (db *DB) UserRoleInTaskGroups(userID, taskID string) string {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `
|
||||
SELECT gm.role FROM group_members gm
|
||||
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
||||
WHERE gm.user_id = ? AND tga.task_id = ?
|
||||
ORDER BY CASE gm.role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'member' THEN 2
|
||||
WHEN 'viewer' THEN 3
|
||||
END
|
||||
LIMIT 1`
|
||||
} else {
|
||||
query = `
|
||||
SELECT gm.role FROM group_members gm
|
||||
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
||||
WHERE gm.user_id = $1 AND tga.task_id = $2
|
||||
ORDER BY CASE gm.role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'member' THEN 2
|
||||
WHEN 'viewer' THEN 3
|
||||
END
|
||||
LIMIT 1`
|
||||
}
|
||||
|
||||
var role string
|
||||
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&role)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return role
|
||||
}
|
||||
|
||||
// GetOrCreateDefaultLabGroup creates the auto-provisioned lab group if it doesn't exist.
|
||||
// Returns the group ID. If no DEFAULT_LAB_GROUP env var is set, returns empty string.
|
||||
func (db *DB) GetOrCreateDefaultLabGroup(createdBy string) (string, error) {
|
||||
groupName := os.Getenv("DEFAULT_LAB_GROUP")
|
||||
if groupName == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Use transaction to prevent race conditions during concurrent startup
|
||||
tx, err := db.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
||||
log.Printf("ERROR: failed to rollback transaction: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if group exists
|
||||
var groupID string
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id FROM groups WHERE name = ?`
|
||||
} else {
|
||||
query = `SELECT id FROM groups WHERE name = $1`
|
||||
}
|
||||
err = tx.QueryRowContext(context.Background(), query, groupName).Scan(&groupID)
|
||||
if err == nil {
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
return "", fmt.Errorf("failed to commit transaction: %w", commitErr)
|
||||
}
|
||||
return groupID, nil
|
||||
}
|
||||
if err != sql.ErrNoRows {
|
||||
return "", fmt.Errorf("failed to check for default lab group: %w", err)
|
||||
}
|
||||
|
||||
// Create the group within the transaction
|
||||
id := fmt.Sprintf("group_%d", time.Now().UnixNano())
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)`
|
||||
}
|
||||
_, err = tx.ExecContext(context.Background(), query, id, groupName, "Auto-provisioned default lab group", createdBy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create default lab group: %w", err)
|
||||
}
|
||||
|
||||
// Add creator as admin within the transaction
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')`
|
||||
} else {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')`
|
||||
}
|
||||
_, err = tx.ExecContext(context.Background(), query, id, createdBy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to add creator as admin: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// EnsureUserInGroup adds a user to a group if not already a member.
|
||||
// Default role is 'member'. Returns nil if already a member.
|
||||
func (db *DB) EnsureUserInGroup(groupID, userID string, role string) error {
|
||||
if role == "" {
|
||||
role = "member"
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT OR IGNORE INTO group_members (group_id, user_id, role) VALUES (?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (group_id, user_id) DO NOTHING`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, groupID, userID, role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureAllUsersGroup creates the 'all-users' system group if it doesn't exist.
|
||||
// This group is used for institution visibility.
|
||||
func (db *DB) EnsureAllUsersGroup() (string, error) {
|
||||
const groupID = "all-users"
|
||||
const groupName = "all-users"
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT OR IGNORE INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (id) DO NOTHING`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, groupID, groupName,
|
||||
"System group: all authenticated users", "system")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to ensure all-users group: %w", err)
|
||||
}
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// EnsureUserInAllUsersGroup adds a user to the 'all-users' system group.
|
||||
func (db *DB) EnsureUserInAllUsersGroup(userID string) error {
|
||||
allUsersGroupID, err := db.EnsureAllUsersGroup()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.EnsureUserInGroup(allUsersGroupID, userID, "member")
|
||||
}
|
||||
|
||||
// ProvisionUserOnFirstLogin adds a new user to the default groups:
|
||||
// 1. The 'all-users' system group (for institution visibility)
|
||||
// 2. The DEFAULT_LAB_GROUP if configured (for lab visibility)
|
||||
// This should be called when a user first authenticates/logs in.
|
||||
func (db *DB) ProvisionUserOnFirstLogin(userID string) error {
|
||||
// Add to all-users system group first
|
||||
if err := db.EnsureUserInAllUsersGroup(userID); err != nil {
|
||||
return fmt.Errorf("failed to add user to all-users group: %w", err)
|
||||
}
|
||||
|
||||
// Add to default lab group if configured
|
||||
defaultGroupID, err := db.GetOrCreateDefaultLabGroup("system")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get/create default lab group: %w", err)
|
||||
}
|
||||
if defaultGroupID != "" {
|
||||
if err := db.EnsureUserInGroup(defaultGroupID, userID, "member"); err != nil {
|
||||
return fmt.Errorf("failed to add user to default lab group: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeprovisionUser removes a user from all groups when they are deactivated/deleted.
|
||||
// This prevents deactivated users from retaining institution-visibility access.
|
||||
func (db *DB) DeprovisionUser(userID string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM group_members WHERE user_id = ?`
|
||||
} else {
|
||||
query = `DELETE FROM group_members WHERE user_id = $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove user from groups: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -163,6 +163,7 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) {
|
|||
if db.dbType == DBTypeSQLite {
|
||||
query += " LIMIT ?"
|
||||
} else {
|
||||
//nolint:gosec // G202: This builds a positional parameter $N where N is an internal counter, not user input
|
||||
query += fmt.Sprintf(" LIMIT $%d", len(args)+1)
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
|
|
|||
498
internal/storage/db_tasks.go
Normal file
498
internal/storage/db_tasks.go
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskIsSharedWithUser returns true if an active, non-expired explicit share exists.
|
||||
func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM task_shares
|
||||
WHERE task_id = ? AND user_id = ?
|
||||
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM task_shares
|
||||
WHERE task_id = $1 AND user_id = $2
|
||||
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
|
||||
}
|
||||
var count int
|
||||
err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// UserSharesGroupWithTask returns true if the user is a live member of any group
|
||||
// associated with the task via task_group_access.
|
||||
func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM group_members gm
|
||||
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
||||
WHERE gm.user_id = ? AND tga.task_id = ?`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM group_members gm
|
||||
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
||||
WHERE gm.user_id = $1 AND tga.task_id = $2`
|
||||
}
|
||||
var count int
|
||||
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled.
|
||||
// For now, this checks if the task exists and is visible as 'open'.
|
||||
// The allow_public_clone flag would be stored in the task metadata or a separate column.
|
||||
func (db *DB) TaskAllowsPublicClone(taskID string) bool {
|
||||
// TODO: Implement proper allow_public_clone check
|
||||
// For now, return false as a safe default
|
||||
return false
|
||||
}
|
||||
|
||||
// AssociateTaskWithGroup records that a task is shared with a specific group.
|
||||
// Called at submit time for lab-visibility tasks.
|
||||
func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2)
|
||||
ON CONFLICT (task_id, group_id) DO NOTHING`
|
||||
}
|
||||
_, err := db.conn.ExecContext(context.Background(), query, taskID, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to associate task with group: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountOpenTasksForUserToday returns the number of open tasks created by user today.
|
||||
func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT COUNT(*) FROM jobs
|
||||
WHERE user_id = ?
|
||||
AND visibility = 'open'
|
||||
AND created_at >= date('now')`
|
||||
} else {
|
||||
query = `SELECT COUNT(*) FROM jobs
|
||||
WHERE user_id = $1
|
||||
AND visibility = 'open'
|
||||
AND created_at >= CURRENT_DATE`
|
||||
}
|
||||
var count int
|
||||
err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count open tasks: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ListTasksOptions provides pagination options for task listing.
|
||||
type ListTasksOptions struct {
|
||||
Limit int // required; enforced max 100
|
||||
Cursor string // (created_at, id) pair encoded as opaque string; empty = first page
|
||||
}
|
||||
|
||||
// ListTasksForUser returns tasks visible to a specific user.
|
||||
// Uses UNION-based SQL for efficient index usage with cursor pagination.
|
||||
func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) {
|
||||
// Enforce max limit
|
||||
if opts.Limit <= 0 || opts.Limit > 100 {
|
||||
opts.Limit = 100
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
return db.listTasksPaginated("SELECT * FROM jobs", opts)
|
||||
}
|
||||
|
||||
// Parse cursor if provided
|
||||
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
||||
|
||||
// UNION approach: one index-able branch per access reason.
|
||||
// Explicit column list instead of j.* avoids issues in SQLite strict mode
|
||||
// and makes schema expectations explicit across all branches.
|
||||
cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at`
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = fmt.Sprintf(`
|
||||
-- Owner
|
||||
SELECT %s FROM jobs j
|
||||
WHERE j.user_id = ?
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
|
||||
UNION
|
||||
|
||||
-- Explicit share (expiry checked inline)
|
||||
SELECT %s FROM jobs j
|
||||
JOIN task_shares ts ON ts.task_id = j.id
|
||||
WHERE ts.user_id = ?
|
||||
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
|
||||
UNION
|
||||
|
||||
-- Lab group membership (live, not snapshotted)
|
||||
SELECT %s FROM jobs j
|
||||
JOIN task_group_access tga ON tga.task_id = j.id
|
||||
JOIN group_members gm ON gm.group_id = tga.group_id
|
||||
WHERE gm.user_id = ?
|
||||
AND j.visibility = 'lab'
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
|
||||
UNION
|
||||
|
||||
-- Institution or open (all authenticated users)
|
||||
SELECT %s FROM jobs j
|
||||
WHERE j.visibility IN ('institution', 'open')
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?
|
||||
`, cols, cols, cols, cols)
|
||||
} else {
|
||||
// PostgreSQL version with $N placeholders
|
||||
query = fmt.Sprintf(`
|
||||
-- Owner
|
||||
SELECT %s FROM jobs j
|
||||
WHERE j.user_id = $1
|
||||
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
|
||||
|
||||
UNION
|
||||
|
||||
-- Explicit share (expiry checked inline)
|
||||
SELECT %s FROM jobs j
|
||||
JOIN task_shares ts ON ts.task_id = j.id
|
||||
WHERE ts.user_id = $4
|
||||
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
|
||||
AND ($5 = '' OR (j.created_at::text || j.id) < $6)
|
||||
|
||||
UNION
|
||||
|
||||
-- Lab group membership (live, not snapshotted)
|
||||
SELECT %s FROM jobs j
|
||||
JOIN task_group_access tga ON tga.task_id = j.id
|
||||
JOIN group_members gm ON gm.group_id = tga.group_id
|
||||
WHERE gm.user_id = $7
|
||||
AND j.visibility = 'lab'
|
||||
AND ($8 = '' OR (j.created_at::text || j.id) < $9)
|
||||
|
||||
UNION
|
||||
|
||||
-- Institution or open (all authenticated users)
|
||||
SELECT %s FROM jobs j
|
||||
WHERE j.visibility IN ('institution', 'open')
|
||||
AND ($10 = '' OR (j.created_at::text || j.id) < $11)
|
||||
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $12
|
||||
`, cols, cols, cols, cols)
|
||||
}
|
||||
|
||||
// Fetch limit+1 rows; if len(rows) > limit, there is a next page
|
||||
fetchLimit := opts.Limit + 1
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if db.dbType == DBTypeSQLite {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
fetchLimit,
|
||||
)
|
||||
} else {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
cursorCreatedAt, cursorCreatedAt+cursorID,
|
||||
fetchLimit,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var jobs []*Job
|
||||
for rows.Next() {
|
||||
job := &Job{}
|
||||
var createdAt, updatedAt sql.NullTime
|
||||
var startedAt, endedAt sql.NullTime
|
||||
var datasetsJSON, metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
||||
&createdAt, &updatedAt, &startedAt, &endedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
job.CreatedAt = createdAt.Time
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
job.UpdatedAt = updatedAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if endedAt.Valid {
|
||||
job.EndedAt = &endedAt.Time
|
||||
}
|
||||
|
||||
if len(datasetsJSON) > 0 {
|
||||
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
||||
}
|
||||
|
||||
// Determine next cursor
|
||||
nextCursor := ""
|
||||
if len(jobs) > opts.Limit {
|
||||
// More rows available, encode cursor from the last row we actually want to return
|
||||
lastJob := jobs[opts.Limit-1]
|
||||
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
||||
// Trim to requested limit
|
||||
jobs = jobs[:opts.Limit]
|
||||
}
|
||||
|
||||
return jobs, nextCursor, nil
|
||||
}
|
||||
|
||||
// ListTasksForGroup returns tasks associated with a group
|
||||
func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) {
|
||||
if opts.Limit <= 0 || opts.Limit > 100 {
|
||||
opts.Limit = 100
|
||||
}
|
||||
|
||||
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `
|
||||
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
||||
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
||||
FROM jobs j
|
||||
JOIN task_group_access tga ON tga.task_id = j.id
|
||||
WHERE tga.group_id = ?
|
||||
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
||||
ORDER BY j.created_at DESC, j.id DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = `
|
||||
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
||||
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
||||
FROM jobs j
|
||||
JOIN task_group_access tga ON tga.task_id = j.id
|
||||
WHERE tga.group_id = $1
|
||||
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
|
||||
ORDER BY j.created_at DESC, j.id DESC
|
||||
LIMIT $4`
|
||||
}
|
||||
|
||||
fetchLimit := opts.Limit + 1
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if db.dbType == DBTypeSQLite {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
} else {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to list group tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var jobs []*Job
|
||||
for rows.Next() {
|
||||
job := &Job{}
|
||||
var createdAt, updatedAt sql.NullTime
|
||||
var startedAt, endedAt sql.NullTime
|
||||
var datasetsJSON, metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
||||
&createdAt, &updatedAt, &startedAt, &endedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
job.CreatedAt = createdAt.Time
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
job.UpdatedAt = updatedAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if endedAt.Valid {
|
||||
job.EndedAt = &endedAt.Time
|
||||
}
|
||||
|
||||
if len(datasetsJSON) > 0 {
|
||||
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(jobs) > opts.Limit {
|
||||
lastJob := jobs[opts.Limit-1]
|
||||
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
||||
jobs = jobs[:opts.Limit]
|
||||
}
|
||||
|
||||
return jobs, nextCursor, nil
|
||||
}
|
||||
|
||||
// listTasksPaginated is a helper for admin queries that lists all tasks.
|
||||
func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) {
|
||||
if opts.Limit <= 0 || opts.Limit > 100 {
|
||||
opts.Limit = 100
|
||||
}
|
||||
|
||||
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = baseQuery + `
|
||||
WHERE (? = '' OR (datetime(created_at) || id) < ?)
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT ?`
|
||||
} else {
|
||||
query = baseQuery + `
|
||||
WHERE ($1 = '' OR (created_at::text || id) < $2)
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $3`
|
||||
}
|
||||
|
||||
fetchLimit := opts.Limit + 1
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if db.dbType == DBTypeSQLite {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
} else {
|
||||
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var jobs []*Job
|
||||
for rows.Next() {
|
||||
job := &Job{}
|
||||
var createdAt, updatedAt sql.NullTime
|
||||
var startedAt, endedAt sql.NullTime
|
||||
var datasetsJSON, metadataJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
||||
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
||||
&createdAt, &updatedAt, &startedAt, &endedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
job.CreatedAt = createdAt.Time
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
job.UpdatedAt = updatedAt.Time
|
||||
}
|
||||
if startedAt.Valid {
|
||||
job.StartedAt = &startedAt.Time
|
||||
}
|
||||
if endedAt.Valid {
|
||||
job.EndedAt = &endedAt.Time
|
||||
}
|
||||
|
||||
if len(datasetsJSON) > 0 {
|
||||
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
||||
}
|
||||
|
||||
nextCursor := ""
|
||||
if len(jobs) > opts.Limit {
|
||||
lastJob := jobs[opts.Limit-1]
|
||||
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
||||
jobs = jobs[:opts.Limit]
|
||||
}
|
||||
|
||||
return jobs, nextCursor, nil
|
||||
}
|
||||
|
||||
// encodeCursor creates an opaque cursor string from created_at and id.
|
||||
// Format: base64url(created_at_RFC3339 + "|" + task_id)
|
||||
func encodeCursor(createdAt time.Time, id string) string {
|
||||
payload := createdAt.UTC().Format(time.RFC3339) + "|" + id
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
}
|
||||
|
||||
// decodeCursor extracts created_at and id from an opaque cursor.
|
||||
// Returns empty strings if cursor is invalid or empty.
|
||||
func decodeCursor(cursor string) (createdAt, id string, err error) {
|
||||
if cursor == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(cursor)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(payload), "|", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("invalid cursor format")
|
||||
}
|
||||
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
219
internal/storage/db_tokens.go
Normal file
219
internal/storage/db_tokens.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShareToken represents a signed token for unauthenticated access
|
||||
type ShareToken struct {
|
||||
Token string `json:"token"`
|
||||
TaskID *string `json:"task_id,omitempty"`
|
||||
ExperimentID *string `json:"experiment_id,omitempty"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
AccessCount int `json:"access_count"`
|
||||
MaxAccesses *int `json:"max_accesses,omitempty"`
|
||||
}
|
||||
|
||||
// CreateShareToken stores a new share token in the database.
|
||||
func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create share token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetShareToken retrieves a share token by its value.
|
||||
func (db *DB) GetShareToken(token string) (*ShareToken, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE token = ?`
|
||||
} else {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE token = $1`
|
||||
}
|
||||
|
||||
var t ShareToken
|
||||
var taskID, experimentID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var maxAccesses sql.NullInt64
|
||||
|
||||
err := db.conn.QueryRowContext(context.Background(), query, token).Scan(
|
||||
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get share token: %w", err)
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
t.TaskID = &taskID.String
|
||||
}
|
||||
if experimentID.Valid {
|
||||
t.ExperimentID = &experimentID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
t.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
if maxAccesses.Valid {
|
||||
m := int(maxAccesses.Int64)
|
||||
t.MaxAccesses = &m
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// IncrementTokenAccessCount increments the access count for a token.
|
||||
func (db *DB) IncrementTokenAccessCount(token string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?`
|
||||
} else {
|
||||
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment token access count: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteShareToken removes a share token from the database.
|
||||
func (db *DB) DeleteShareToken(token string) error {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `DELETE FROM share_tokens WHERE token = ?`
|
||||
} else {
|
||||
query = `DELETE FROM share_tokens WHERE token = $1`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(context.Background(), query, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete share token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListShareTokensForTask returns all share tokens for a specific task.
|
||||
func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC`
|
||||
} else {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC`
|
||||
}
|
||||
|
||||
return db.queryShareTokens(query, taskID)
|
||||
}
|
||||
|
||||
// ListShareTokensForExperiment returns all share tokens for a specific experiment.
|
||||
func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) {
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC`
|
||||
} else {
|
||||
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
||||
FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC`
|
||||
}
|
||||
|
||||
return db.queryShareTokens(query, experimentID)
|
||||
}
|
||||
|
||||
// queryShareTokens is a helper to execute share token queries.
|
||||
func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) {
|
||||
rows, err := db.conn.QueryContext(context.Background(), query, arg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list share tokens: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var tokens []*ShareToken
|
||||
for rows.Next() {
|
||||
var t ShareToken
|
||||
var taskID, experimentID sql.NullString
|
||||
var expiresAt sql.NullTime
|
||||
var maxAccesses sql.NullInt64
|
||||
|
||||
err := rows.Scan(
|
||||
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan share token: %w", err)
|
||||
}
|
||||
|
||||
if taskID.Valid {
|
||||
t.TaskID = &taskID.String
|
||||
}
|
||||
if experimentID.Valid {
|
||||
t.ExperimentID = &experimentID.String
|
||||
}
|
||||
if expiresAt.Valid {
|
||||
t.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
if maxAccesses.Valid {
|
||||
m := int(maxAccesses.Int64)
|
||||
t.MaxAccesses = &m
|
||||
}
|
||||
|
||||
tokens = append(tokens, &t)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating share tokens: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ValidateShareToken checks if a token is valid for accessing a task or experiment.
|
||||
// Returns the token details if valid, nil if invalid or expired.
|
||||
func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) {
|
||||
t, err := db.GetShareToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if token is for the correct resource
|
||||
if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) {
|
||||
return nil, nil
|
||||
}
|
||||
if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check expiry
|
||||
if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check max accesses
|
||||
if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
|
@ -11,10 +11,13 @@ CREATE TABLE IF NOT EXISTS jobs (
|
|||
started_at DATETIME,
|
||||
ended_at DATETIME,
|
||||
worker_id TEXT,
|
||||
user_id TEXT,
|
||||
error TEXT,
|
||||
datasets TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
visibility TEXT NOT NULL DEFAULT 'lab',
|
||||
experiment_id TEXT REFERENCES experiments(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS job_metrics (
|
||||
|
|
@ -142,3 +145,113 @@ CREATE TABLE IF NOT EXISTS websocket_metrics (
|
|||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_websocket_metrics_name_time ON websocket_metrics(metric_name, recorded_at);
|
||||
|
||||
-- Groups and membership for lab-based task sharing
|
||||
CREATE TABLE IF NOT EXISTS groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
created_by TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS group_members (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
role TEXT DEFAULT 'member', -- 'admin', 'member', 'viewer'
|
||||
joined_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (group_id, user_id),
|
||||
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- System group for institution visibility (all authenticated users)
|
||||
INSERT OR IGNORE INTO groups (id, name, description, created_by)
|
||||
VALUES ('all-users', 'all-users', 'System group: all authenticated users', 'system');
|
||||
|
||||
-- Invite-and-accept flow: group admins invite; users accept or decline
|
||||
CREATE TABLE IF NOT EXISTS group_invitations (
|
||||
id TEXT PRIMARY KEY,
|
||||
group_id TEXT NOT NULL,
|
||||
invited_user_id TEXT NOT NULL,
|
||||
invited_by TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'pending', -- 'pending', 'accepted', 'declined'
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME, -- NULL = 7d default enforced in app layer
|
||||
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Experiment/project grouping: share a whole experiment, not individual tasks
|
||||
-- Note: experiments table already exists; adding group_id to link with sharing system
|
||||
ALTER TABLE experiments ADD COLUMN group_id TEXT REFERENCES groups(id);
|
||||
|
||||
-- Link tasks to experiments
|
||||
CREATE TABLE IF NOT EXISTS experiment_tasks (
|
||||
experiment_id TEXT NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
added_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (experiment_id, task_id),
|
||||
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Per-user explicit shares with optional expiry
|
||||
CREATE TABLE IF NOT EXISTS task_shares (
|
||||
task_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
granted_by TEXT NOT NULL,
|
||||
granted_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME, -- NULL = no expiry; checked at access time
|
||||
PRIMARY KEY (task_id, user_id),
|
||||
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Group-level task association
|
||||
-- Records which group a task is associated with at submit time.
|
||||
-- Actual membership is always resolved live from group_members.
|
||||
CREATE TABLE IF NOT EXISTS task_group_access (
|
||||
task_id TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
PRIMARY KEY (task_id, group_id),
|
||||
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Signed share tokens for unauthenticated open access (paper reproducibility links)
|
||||
CREATE TABLE IF NOT EXISTS share_tokens (
|
||||
token TEXT PRIMARY KEY, -- cryptographically random (32 bytes, base64url)
|
||||
task_id TEXT, -- NULL if experiment-level
|
||||
experiment_id TEXT, -- NULL if task-level
|
||||
created_by TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME, -- NULL = never expires
|
||||
access_count INTEGER DEFAULT 0,
|
||||
max_accesses INTEGER, -- NULL = unlimited
|
||||
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Audit log for task access
|
||||
CREATE TABLE IF NOT EXISTS task_access_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
user_id TEXT, -- NULL for token-based access
|
||||
token TEXT, -- NULL for session-based access
|
||||
action TEXT NOT NULL, -- 'view', 'clone', 'execute', 'modify'
|
||||
accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Indexes for task sharing performance
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_visibility ON jobs(visibility);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_visibility_owner ON jobs(visibility, user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_jobs_experiment ON jobs(experiment_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_shares_user ON task_shares(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_shares_expires ON task_shares(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_tga_group ON task_group_access(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_share_tokens_task ON share_tokens(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_access_task ON task_access_log(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_access_user ON task_access_log(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_task_access_token ON task_access_log(token) WHERE token IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_invitations_user ON group_invitations(invited_user_id);
|
||||
|
|
|
|||
Loading…
Reference in a new issue