fetch_ml/internal/storage/db_tokens.go
Jeremie Fraeys fbcf4d38e5
feat(storage): add groups, tasks, tokens, and audit database schemas
Add comprehensive database storage layer for new features:

- db_groups.go: Lab group management with members, roles (admin/member/viewer),
  and group-based task visibility queries

- db_tasks.go: Task visibility system (private/lab/institution/open),
  task sharing with expiry, public clone tokens, and optimized
  ListTasksForUser() for access control

- db_tokens.go: Secure token management for public task access and cloning,
  with SHA-256 hashed token storage and automatic cleanup

- db_audit.go: Audit log persistence with checkpoint chains, tamper
  detection, and log rotation support

- schema_sqlite.sql: Updated schema with:
  - groups, group_members tables
  - tasks.visibility enum, task_shares with expiry
  - access_tokens table with hashed tokens
  - audit_logs, audit_checkpoints tables
  - indexes for all foreign keys and query patterns

- db_experiments.go: Add CascadeVisibilityToTasks() for propagating
  visibility changes from experiments to associated tasks
2026-03-08 12:48:42 -04:00

219 lines
6.6 KiB
Go

package storage
import (
"context"
"database/sql"
"fmt"
"time"
)
// ShareToken represents a signed token for unauthenticated access
type ShareToken struct {
Token string `json:"token"`
TaskID *string `json:"task_id,omitempty"`
ExperimentID *string `json:"experiment_id,omitempty"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
AccessCount int `json:"access_count"`
MaxAccesses *int `json:"max_accesses,omitempty"`
}
// CreateShareToken stores a new share token in the database.
func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
VALUES (?, ?, ?, ?, ?, ?)`
} else {
query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses)
VALUES ($1, $2, $3, $4, $5, $6)`
}
_, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses)
if err != nil {
return fmt.Errorf("failed to create share token: %w", err)
}
return nil
}
// GetShareToken retrieves a share token by its value.
func (db *DB) GetShareToken(token string) (*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE token = ?`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE token = $1`
}
var t ShareToken
var taskID, experimentID sql.NullString
var expiresAt sql.NullTime
var maxAccesses sql.NullInt64
err := db.conn.QueryRowContext(context.Background(), query, token).Scan(
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("failed to get share token: %w", err)
}
if taskID.Valid {
t.TaskID = &taskID.String
}
if experimentID.Valid {
t.ExperimentID = &experimentID.String
}
if expiresAt.Valid {
t.ExpiresAt = &expiresAt.Time
}
if maxAccesses.Valid {
m := int(maxAccesses.Int64)
t.MaxAccesses = &m
}
return &t, nil
}
// IncrementTokenAccessCount increments the access count for a token.
func (db *DB) IncrementTokenAccessCount(token string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?`
} else {
query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, token)
if err != nil {
return fmt.Errorf("failed to increment token access count: %w", err)
}
return nil
}
// DeleteShareToken removes a share token from the database.
func (db *DB) DeleteShareToken(token string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `DELETE FROM share_tokens WHERE token = ?`
} else {
query = `DELETE FROM share_tokens WHERE token = $1`
}
_, err := db.conn.ExecContext(context.Background(), query, token)
if err != nil {
return fmt.Errorf("failed to delete share token: %w", err)
}
return nil
}
// ListShareTokensForTask returns all share tokens for a specific task.
func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC`
}
return db.queryShareTokens(query, taskID)
}
// ListShareTokensForExperiment returns all share tokens for a specific experiment.
func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC`
} else {
query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses
FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC`
}
return db.queryShareTokens(query, experimentID)
}
// queryShareTokens is a helper to execute share token queries.
func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) {
rows, err := db.conn.QueryContext(context.Background(), query, arg)
if err != nil {
return nil, fmt.Errorf("failed to list share tokens: %w", err)
}
defer func() { _ = rows.Close() }()
var tokens []*ShareToken
for rows.Next() {
var t ShareToken
var taskID, experimentID sql.NullString
var expiresAt sql.NullTime
var maxAccesses sql.NullInt64
err := rows.Scan(
&t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses,
)
if err != nil {
return nil, fmt.Errorf("failed to scan share token: %w", err)
}
if taskID.Valid {
t.TaskID = &taskID.String
}
if experimentID.Valid {
t.ExperimentID = &experimentID.String
}
if expiresAt.Valid {
t.ExpiresAt = &expiresAt.Time
}
if maxAccesses.Valid {
m := int(maxAccesses.Int64)
t.MaxAccesses = &m
}
tokens = append(tokens, &t)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating share tokens: %w", err)
}
return tokens, nil
}
// ValidateShareToken checks if a token is valid for accessing a task or experiment.
// Returns the token details if valid, nil if invalid or expired.
func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) {
t, err := db.GetShareToken(token)
if err != nil {
return nil, err
}
if t == nil {
return nil, nil
}
// Check if token is for the correct resource
if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) {
return nil, nil
}
if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) {
return nil, nil
}
// Check expiry
if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) {
return nil, nil
}
// Check max accesses
if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses {
return nil, nil
}
return t, nil
}