Add comprehensive database storage layer for new features: - db_groups.go: Lab group management with members, roles (admin/member/viewer), and group-based task visibility queries - db_tasks.go: Task visibility system (private/lab/institution/open), task sharing with expiry, public clone tokens, and optimized ListTasksForUser() for access control - db_tokens.go: Secure token management for public task access and cloning, with SHA-256 hashed token storage and automatic cleanup - db_audit.go: Audit log persistence with checkpoint chains, tamper detection, and log rotation support - schema_sqlite.sql: Updated schema with: - groups, group_members tables - tasks.visibility enum, task_shares with expiry - access_tokens table with hashed tokens - audit_logs, audit_checkpoints tables - indexes for all foreign keys and query patterns - db_experiments.go: Add CascadeVisibilityToTasks() for propagating visibility changes from experiments to associated tasks
498 lines
14 KiB
Go
498 lines
14 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// TaskIsSharedWithUser returns true if an active, non-expired explicit share exists.
|
|
func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT COUNT(*) FROM task_shares
|
|
WHERE task_id = ? AND user_id = ?
|
|
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
|
|
} else {
|
|
query = `SELECT COUNT(*) FROM task_shares
|
|
WHERE task_id = $1 AND user_id = $2
|
|
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
|
|
}
|
|
var count int
|
|
err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return count > 0
|
|
}
|
|
|
|
// UserSharesGroupWithTask returns true if the user is a live member of any group
|
|
// associated with the task via task_group_access.
|
|
func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT COUNT(*) FROM group_members gm
|
|
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
|
WHERE gm.user_id = ? AND tga.task_id = ?`
|
|
} else {
|
|
query = `SELECT COUNT(*) FROM group_members gm
|
|
JOIN task_group_access tga ON tga.group_id = gm.group_id
|
|
WHERE gm.user_id = $1 AND tga.task_id = $2`
|
|
}
|
|
var count int
|
|
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return count > 0
|
|
}
|
|
|
|
// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled.
|
|
// For now, this checks if the task exists and is visible as 'open'.
|
|
// The allow_public_clone flag would be stored in the task metadata or a separate column.
|
|
func (db *DB) TaskAllowsPublicClone(taskID string) bool {
|
|
// TODO: Implement proper allow_public_clone check
|
|
// For now, return false as a safe default
|
|
return false
|
|
}
|
|
|
|
// AssociateTaskWithGroup records that a task is shared with a specific group.
|
|
// Called at submit time for lab-visibility tasks.
|
|
func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)`
|
|
} else {
|
|
query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2)
|
|
ON CONFLICT (task_id, group_id) DO NOTHING`
|
|
}
|
|
_, err := db.conn.ExecContext(context.Background(), query, taskID, groupID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to associate task with group: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CountOpenTasksForUserToday returns the number of open tasks created by user today.
|
|
func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT COUNT(*) FROM jobs
|
|
WHERE user_id = ?
|
|
AND visibility = 'open'
|
|
AND created_at >= date('now')`
|
|
} else {
|
|
query = `SELECT COUNT(*) FROM jobs
|
|
WHERE user_id = $1
|
|
AND visibility = 'open'
|
|
AND created_at >= CURRENT_DATE`
|
|
}
|
|
var count int
|
|
err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to count open tasks: %w", err)
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// ListTasksOptions provides pagination options for task listing.
|
|
type ListTasksOptions struct {
|
|
Limit int // required; enforced max 100
|
|
Cursor string // (created_at, id) pair encoded as opaque string; empty = first page
|
|
}
|
|
|
|
// ListTasksForUser returns tasks visible to a specific user.
|
|
// Uses UNION-based SQL for efficient index usage with cursor pagination.
|
|
func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) {
|
|
// Enforce max limit
|
|
if opts.Limit <= 0 || opts.Limit > 100 {
|
|
opts.Limit = 100
|
|
}
|
|
|
|
if isAdmin {
|
|
return db.listTasksPaginated("SELECT * FROM jobs", opts)
|
|
}
|
|
|
|
// Parse cursor if provided
|
|
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
|
|
|
// UNION approach: one index-able branch per access reason.
|
|
// Explicit column list instead of j.* avoids issues in SQLite strict mode
|
|
// and makes schema expectations explicit across all branches.
|
|
cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at`
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = fmt.Sprintf(`
|
|
-- Owner
|
|
SELECT %s FROM jobs j
|
|
WHERE j.user_id = ?
|
|
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
|
|
|
UNION
|
|
|
|
-- Explicit share (expiry checked inline)
|
|
SELECT %s FROM jobs j
|
|
JOIN task_shares ts ON ts.task_id = j.id
|
|
WHERE ts.user_id = ?
|
|
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
|
|
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
|
|
|
UNION
|
|
|
|
-- Lab group membership (live, not snapshotted)
|
|
SELECT %s FROM jobs j
|
|
JOIN task_group_access tga ON tga.task_id = j.id
|
|
JOIN group_members gm ON gm.group_id = tga.group_id
|
|
WHERE gm.user_id = ?
|
|
AND j.visibility = 'lab'
|
|
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
|
|
|
UNION
|
|
|
|
-- Institution or open (all authenticated users)
|
|
SELECT %s FROM jobs j
|
|
WHERE j.visibility IN ('institution', 'open')
|
|
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
|
|
|
ORDER BY created_at DESC, id DESC
|
|
LIMIT ?
|
|
`, cols, cols, cols, cols)
|
|
} else {
|
|
// PostgreSQL version with $N placeholders
|
|
query = fmt.Sprintf(`
|
|
-- Owner
|
|
SELECT %s FROM jobs j
|
|
WHERE j.user_id = $1
|
|
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
|
|
|
|
UNION
|
|
|
|
-- Explicit share (expiry checked inline)
|
|
SELECT %s FROM jobs j
|
|
JOIN task_shares ts ON ts.task_id = j.id
|
|
WHERE ts.user_id = $4
|
|
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
|
|
AND ($5 = '' OR (j.created_at::text || j.id) < $6)
|
|
|
|
UNION
|
|
|
|
-- Lab group membership (live, not snapshotted)
|
|
SELECT %s FROM jobs j
|
|
JOIN task_group_access tga ON tga.task_id = j.id
|
|
JOIN group_members gm ON gm.group_id = tga.group_id
|
|
WHERE gm.user_id = $7
|
|
AND j.visibility = 'lab'
|
|
AND ($8 = '' OR (j.created_at::text || j.id) < $9)
|
|
|
|
UNION
|
|
|
|
-- Institution or open (all authenticated users)
|
|
SELECT %s FROM jobs j
|
|
WHERE j.visibility IN ('institution', 'open')
|
|
AND ($10 = '' OR (j.created_at::text || j.id) < $11)
|
|
|
|
ORDER BY created_at DESC, id DESC
|
|
LIMIT $12
|
|
`, cols, cols, cols, cols)
|
|
}
|
|
|
|
// Fetch limit+1 rows; if len(rows) > limit, there is a next page
|
|
fetchLimit := opts.Limit + 1
|
|
|
|
var rows *sql.Rows
|
|
var err error
|
|
if db.dbType == DBTypeSQLite {
|
|
rows, err = db.conn.QueryContext(context.Background(), query,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
fetchLimit,
|
|
)
|
|
} else {
|
|
rows, err = db.conn.QueryContext(context.Background(), query,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
cursorCreatedAt, cursorCreatedAt+cursorID,
|
|
fetchLimit,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var jobs []*Job
|
|
for rows.Next() {
|
|
job := &Job{}
|
|
var createdAt, updatedAt sql.NullTime
|
|
var startedAt, endedAt sql.NullTime
|
|
var datasetsJSON, metadataJSON []byte
|
|
|
|
err := rows.Scan(
|
|
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
|
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
|
&createdAt, &updatedAt, &startedAt, &endedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
|
}
|
|
|
|
if createdAt.Valid {
|
|
job.CreatedAt = createdAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
job.UpdatedAt = updatedAt.Time
|
|
}
|
|
if startedAt.Valid {
|
|
job.StartedAt = &startedAt.Time
|
|
}
|
|
if endedAt.Valid {
|
|
job.EndedAt = &endedAt.Time
|
|
}
|
|
|
|
if len(datasetsJSON) > 0 {
|
|
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
|
}
|
|
if len(metadataJSON) > 0 {
|
|
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
|
}
|
|
|
|
jobs = append(jobs, job)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
|
}
|
|
|
|
// Determine next cursor
|
|
nextCursor := ""
|
|
if len(jobs) > opts.Limit {
|
|
// More rows available, encode cursor from the last row we actually want to return
|
|
lastJob := jobs[opts.Limit-1]
|
|
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
|
// Trim to requested limit
|
|
jobs = jobs[:opts.Limit]
|
|
}
|
|
|
|
return jobs, nextCursor, nil
|
|
}
|
|
|
|
// ListTasksForGroup returns tasks associated with a group
|
|
func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) {
|
|
if opts.Limit <= 0 || opts.Limit > 100 {
|
|
opts.Limit = 100
|
|
}
|
|
|
|
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `
|
|
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
|
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
|
FROM jobs j
|
|
JOIN task_group_access tga ON tga.task_id = j.id
|
|
WHERE tga.group_id = ?
|
|
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
|
|
ORDER BY j.created_at DESC, j.id DESC
|
|
LIMIT ?`
|
|
} else {
|
|
query = `
|
|
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
|
|
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
|
|
FROM jobs j
|
|
JOIN task_group_access tga ON tga.task_id = j.id
|
|
WHERE tga.group_id = $1
|
|
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
|
|
ORDER BY j.created_at DESC, j.id DESC
|
|
LIMIT $4`
|
|
}
|
|
|
|
fetchLimit := opts.Limit + 1
|
|
|
|
var rows *sql.Rows
|
|
var err error
|
|
if db.dbType == DBTypeSQLite {
|
|
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
|
} else {
|
|
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
|
}
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to list group tasks: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var jobs []*Job
|
|
for rows.Next() {
|
|
job := &Job{}
|
|
var createdAt, updatedAt sql.NullTime
|
|
var startedAt, endedAt sql.NullTime
|
|
var datasetsJSON, metadataJSON []byte
|
|
|
|
err := rows.Scan(
|
|
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
|
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
|
&createdAt, &updatedAt, &startedAt, &endedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
|
}
|
|
|
|
if createdAt.Valid {
|
|
job.CreatedAt = createdAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
job.UpdatedAt = updatedAt.Time
|
|
}
|
|
if startedAt.Valid {
|
|
job.StartedAt = &startedAt.Time
|
|
}
|
|
if endedAt.Valid {
|
|
job.EndedAt = &endedAt.Time
|
|
}
|
|
|
|
if len(datasetsJSON) > 0 {
|
|
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
|
}
|
|
if len(metadataJSON) > 0 {
|
|
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
|
}
|
|
|
|
jobs = append(jobs, job)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
|
}
|
|
|
|
nextCursor := ""
|
|
if len(jobs) > opts.Limit {
|
|
lastJob := jobs[opts.Limit-1]
|
|
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
|
jobs = jobs[:opts.Limit]
|
|
}
|
|
|
|
return jobs, nextCursor, nil
|
|
}
|
|
|
|
// listTasksPaginated is a helper for admin queries that lists all tasks.
|
|
func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) {
|
|
if opts.Limit <= 0 || opts.Limit > 100 {
|
|
opts.Limit = 100
|
|
}
|
|
|
|
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
|
|
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = baseQuery + `
|
|
WHERE (? = '' OR (datetime(created_at) || id) < ?)
|
|
ORDER BY created_at DESC, id DESC
|
|
LIMIT ?`
|
|
} else {
|
|
query = baseQuery + `
|
|
WHERE ($1 = '' OR (created_at::text || id) < $2)
|
|
ORDER BY created_at DESC, id DESC
|
|
LIMIT $3`
|
|
}
|
|
|
|
fetchLimit := opts.Limit + 1
|
|
|
|
var rows *sql.Rows
|
|
var err error
|
|
if db.dbType == DBTypeSQLite {
|
|
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
|
} else {
|
|
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
|
|
}
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var jobs []*Job
|
|
for rows.Next() {
|
|
job := &Job{}
|
|
var createdAt, updatedAt sql.NullTime
|
|
var startedAt, endedAt sql.NullTime
|
|
var datasetsJSON, metadataJSON []byte
|
|
|
|
err := rows.Scan(
|
|
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
|
|
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
|
|
&createdAt, &updatedAt, &startedAt, &endedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("failed to scan job: %w", err)
|
|
}
|
|
|
|
if createdAt.Valid {
|
|
job.CreatedAt = createdAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
job.UpdatedAt = updatedAt.Time
|
|
}
|
|
if startedAt.Valid {
|
|
job.StartedAt = &startedAt.Time
|
|
}
|
|
if endedAt.Valid {
|
|
job.EndedAt = &endedAt.Time
|
|
}
|
|
|
|
if len(datasetsJSON) > 0 {
|
|
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
|
|
}
|
|
if len(metadataJSON) > 0 {
|
|
_ = json.Unmarshal(metadataJSON, &job.Metadata)
|
|
}
|
|
|
|
jobs = append(jobs, job)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
|
|
}
|
|
|
|
nextCursor := ""
|
|
if len(jobs) > opts.Limit {
|
|
lastJob := jobs[opts.Limit-1]
|
|
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
|
|
jobs = jobs[:opts.Limit]
|
|
}
|
|
|
|
return jobs, nextCursor, nil
|
|
}
|
|
|
|
// encodeCursor creates an opaque cursor string from created_at and id.
|
|
// Format: base64url(created_at_RFC3339 + "|" + task_id)
|
|
func encodeCursor(createdAt time.Time, id string) string {
|
|
payload := createdAt.UTC().Format(time.RFC3339) + "|" + id
|
|
return base64.RawURLEncoding.EncodeToString([]byte(payload))
|
|
}
|
|
|
|
// decodeCursor extracts created_at and id from an opaque cursor.
|
|
// Returns empty strings if cursor is invalid or empty.
|
|
func decodeCursor(cursor string) (createdAt, id string, err error) {
|
|
if cursor == "" {
|
|
return "", "", nil
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(cursor)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
parts := strings.SplitN(string(payload), "|", 2)
|
|
if len(parts) != 2 {
|
|
return "", "", fmt.Errorf("invalid cursor format")
|
|
}
|
|
|
|
return parts[0], parts[1], nil
|
|
}
|