Add comprehensive database storage layer for new features: - db_groups.go: Lab group management with members, roles (admin/member/viewer), and group-based task visibility queries - db_tasks.go: Task visibility system (private/lab/institution/open), task sharing with expiry, public clone tokens, and optimized ListTasksForUser() for access control - db_tokens.go: Secure token management for public task access and cloning, with SHA-256 hashed token storage and automatic cleanup - db_audit.go: Audit log persistence with checkpoint chains, tamper detection, and log rotation support - schema_sqlite.sql: Updated schema with: - groups, group_members tables - tasks.visibility enum, task_shares with expiry - access_tokens table with hashed tokens - audit_logs, audit_checkpoints tables - indexes for all foreign keys and query patterns - db_experiments.go: Add CascadeVisibilityToTasks() for propagating visibility changes from experiments to associated tasks
219 lines
6.6 KiB
Go
219 lines
6.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
// ShareToken represents a signed token for unauthenticated access
|
|
type ShareToken struct {
|
|
Token string `json:"token"`
|
|
TaskID *string `json:"task_id,omitempty"`
|
|
ExperimentID *string `json:"experiment_id,omitempty"`
|
|
CreatedBy string `json:"created_by"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
|
AccessCount int `json:"access_count"`
|
|
MaxAccesses *int `json:"max_accesses,omitempty"`
|
|
}
|
|
|
|
// CreateShareToken stores a new share token in the database.
|
|
func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
|
|
VALUES (?, ?, ?, ?, ?, ?)`
|
|
} else {
|
|
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
|
|
VALUES ($1, $2, $3, $4, $5, $6)`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create share token: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetShareToken retrieves a share token by its value.
|
|
func (db *DB) GetShareToken(token string) (*ShareToken, error) {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE token = ?`
|
|
} else {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE token = $1`
|
|
}
|
|
|
|
var t ShareToken
|
|
var taskID, experimentID sql.NullString
|
|
var expiresAt sql.NullTime
|
|
var maxAccesses sql.NullInt64
|
|
|
|
err := db.conn.QueryRowContext(context.Background(), query, token).Scan(
|
|
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get share token: %w", err)
|
|
}
|
|
|
|
if taskID.Valid {
|
|
t.TaskID = &taskID.String
|
|
}
|
|
if experimentID.Valid {
|
|
t.ExperimentID = &experimentID.String
|
|
}
|
|
if expiresAt.Valid {
|
|
t.ExpiresAt = &expiresAt.Time
|
|
}
|
|
if maxAccesses.Valid {
|
|
m := int(maxAccesses.Int64)
|
|
t.MaxAccesses = &m
|
|
}
|
|
|
|
return &t, nil
|
|
}
|
|
|
|
// IncrementTokenAccessCount increments the access count for a token.
|
|
func (db *DB) IncrementTokenAccessCount(token string) error {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?`
|
|
} else {
|
|
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(context.Background(), query, token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to increment token access count: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteShareToken removes a share token from the database.
|
|
func (db *DB) DeleteShareToken(token string) error {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `DELETE FROM share_tokens WHERE token = ?`
|
|
} else {
|
|
query = `DELETE FROM share_tokens WHERE token = $1`
|
|
}
|
|
|
|
_, err := db.conn.ExecContext(context.Background(), query, token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete share token: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListShareTokensForTask returns all share tokens for a specific task.
|
|
func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC`
|
|
} else {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC`
|
|
}
|
|
|
|
return db.queryShareTokens(query, taskID)
|
|
}
|
|
|
|
// ListShareTokensForExperiment returns all share tokens for a specific experiment.
|
|
func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) {
|
|
var query string
|
|
if db.dbType == DBTypeSQLite {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC`
|
|
} else {
|
|
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
|
|
FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC`
|
|
}
|
|
|
|
return db.queryShareTokens(query, experimentID)
|
|
}
|
|
|
|
// queryShareTokens is a helper to execute share token queries.
|
|
func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) {
|
|
rows, err := db.conn.QueryContext(context.Background(), query, arg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list share tokens: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var tokens []*ShareToken
|
|
for rows.Next() {
|
|
var t ShareToken
|
|
var taskID, experimentID sql.NullString
|
|
var expiresAt sql.NullTime
|
|
var maxAccesses sql.NullInt64
|
|
|
|
err := rows.Scan(
|
|
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan share token: %w", err)
|
|
}
|
|
|
|
if taskID.Valid {
|
|
t.TaskID = &taskID.String
|
|
}
|
|
if experimentID.Valid {
|
|
t.ExperimentID = &experimentID.String
|
|
}
|
|
if expiresAt.Valid {
|
|
t.ExpiresAt = &expiresAt.Time
|
|
}
|
|
if maxAccesses.Valid {
|
|
m := int(maxAccesses.Int64)
|
|
t.MaxAccesses = &m
|
|
}
|
|
|
|
tokens = append(tokens, &t)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating share tokens: %w", err)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// ValidateShareToken checks if a token is valid for accessing a task or experiment.
|
|
// Returns the token details if valid, nil if invalid or expired.
|
|
func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) {
|
|
t, err := db.GetShareToken(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if t == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
// Check if token is for the correct resource
|
|
if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) {
|
|
return nil, nil
|
|
}
|
|
if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) {
|
|
return nil, nil
|
|
}
|
|
|
|
// Check expiry
|
|
if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) {
|
|
return nil, nil
|
|
}
|
|
|
|
// Check max accesses
|
|
if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses {
|
|
return nil, nil
|
|
}
|
|
|
|
return t, nil
|
|
}
|