fetch_ml/internal/storage/db_tasks.go
Jeremie Fraeys 2b1ef10514
test(chaos): add worker disconnect chaos test and queue improvements
Chaos testing:
- Add worker_disconnect_chaos_test.go for network partition resilience
- Test scheduler hub recovery and job reassignment scenarios

Queue layer updates:
- event_store.go: add event sourcing for queue operations
- native_queue.go: extend native queue with batch operations and indexing
2026-03-12 12:08:21 -04:00

509 lines
14 KiB
Go

package storage
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
// TaskIsSharedWithUser returns true if an active, non-expired explicit share exists.
func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM task_shares
WHERE task_id = ? AND user_id = ?
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
} else {
query = `SELECT COUNT(*) FROM task_shares
WHERE task_id = $1 AND user_id = $2
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count)
if err != nil {
return false
}
return count > 0
}
// UserSharesGroupWithTask returns true if the user is a live member of any group
// associated with the task via task_group_access.
func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = ? AND tga.task_id = ?`
} else {
query = `SELECT COUNT(*) FROM group_members gm
JOIN task_group_access tga ON tga.group_id = gm.group_id
WHERE gm.user_id = $1 AND tga.task_id = $2`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count)
if err != nil {
return false
}
return count > 0
}
// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled.
// Checks if the task visibility is 'open' which allows public cloning.
func (db *DB) TaskAllowsPublicClone(taskID string) bool {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT visibility FROM jobs WHERE id = ?`
} else {
query = `SELECT visibility FROM jobs WHERE id = $1`
}
var visibility string
err := db.conn.QueryRowContext(context.Background(), query, taskID).Scan(&visibility)
if err != nil {
return false
}
// Only 'open' visibility allows public clone
return visibility == "open"
}
// AssociateTaskWithGroup records that a task is shared with a specific group.
// Called at submit time for lab-visibility tasks.
func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error {
var query string
if db.dbType == DBTypeSQLite {
query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)`
} else {
query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2)
ON CONFLICT (task_id, group_id) DO NOTHING`
}
_, err := db.conn.ExecContext(context.Background(), query, taskID, groupID)
if err != nil {
return fmt.Errorf("failed to associate task with group: %w", err)
}
return nil
}
// CountOpenTasksForUserToday returns the number of open tasks created by user today.
func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) {
var query string
if db.dbType == DBTypeSQLite {
query = `SELECT COUNT(*) FROM jobs
WHERE user_id = ?
AND visibility = 'open'
AND created_at >= date('now')`
} else {
query = `SELECT COUNT(*) FROM jobs
WHERE user_id = $1
AND visibility = 'open'
AND created_at >= CURRENT_DATE`
}
var count int
err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count open tasks: %w", err)
}
return count, nil
}
// ListTasksOptions provides pagination options for task listing.
type ListTasksOptions struct {
Limit int // required; enforced max 100
Cursor string // (created_at, id) pair encoded as opaque string; empty = first page
}
// ListTasksForUser returns tasks visible to a specific user.
// Uses UNION-based SQL for efficient index usage with cursor pagination.
func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) {
// Enforce max limit
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
if isAdmin {
return db.listTasksPaginated("SELECT * FROM jobs", opts)
}
// Parse cursor if provided
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
// UNION approach: one index-able branch per access reason.
// Explicit column list instead of j.* avoids issues in SQLite strict mode
// and makes schema expectations explicit across all branches.
cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at`
var query string
if db.dbType == DBTypeSQLite {
query = fmt.Sprintf(`
-- Owner
SELECT %s FROM jobs j
WHERE j.user_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Explicit share (expiry checked inline)
SELECT %s FROM jobs j
JOIN task_shares ts ON ts.task_id = j.id
WHERE ts.user_id = ?
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Lab group membership (live, not snapshotted)
SELECT %s FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
JOIN group_members gm ON gm.group_id = tga.group_id
WHERE gm.user_id = ?
AND j.visibility = 'lab'
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
UNION
-- Institution or open (all authenticated users)
SELECT %s FROM jobs j
WHERE j.visibility IN ('institution', 'open')
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY created_at DESC, id DESC
LIMIT ?
`, cols, cols, cols, cols)
} else {
// PostgreSQL version with $N placeholders
query = fmt.Sprintf(`
-- Owner
SELECT %s FROM jobs j
WHERE j.user_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
UNION
-- Explicit share (expiry checked inline)
SELECT %s FROM jobs j
JOIN task_shares ts ON ts.task_id = j.id
WHERE ts.user_id = $4
AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP)
AND ($5 = '' OR (j.created_at::text || j.id) < $6)
UNION
-- Lab group membership (live, not snapshotted)
SELECT %s FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
JOIN group_members gm ON gm.group_id = tga.group_id
WHERE gm.user_id = $7
AND j.visibility = 'lab'
AND ($8 = '' OR (j.created_at::text || j.id) < $9)
UNION
-- Institution or open (all authenticated users)
SELECT %s FROM jobs j
WHERE j.visibility IN ('institution', 'open')
AND ($10 = '' OR (j.created_at::text || j.id) < $11)
ORDER BY created_at DESC, id DESC
LIMIT $12
`, cols, cols, cols, cols)
}
// Fetch limit+1 rows; if len(rows) > limit, there is a next page
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
cursorCreatedAt, cursorCreatedAt+cursorID,
fetchLimit,
)
} else {
rows, err = db.conn.QueryContext(context.Background(), query,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
userID, cursorCreatedAt, cursorCreatedAt+cursorID,
cursorCreatedAt, cursorCreatedAt+cursorID,
fetchLimit,
)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
// Determine next cursor
nextCursor := ""
if len(jobs) > opts.Limit {
// More rows available, encode cursor from the last row we actually want to return
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
// Trim to requested limit
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// ListTasksForGroup returns tasks associated with a group
func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
WHERE tga.group_id = ?
AND (? = '' OR (datetime(j.created_at) || j.id) < ?)
ORDER BY j.created_at DESC, j.id DESC
LIMIT ?`
} else {
query = `
SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata,
j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at
FROM jobs j
JOIN task_group_access tga ON tga.task_id = j.id
WHERE tga.group_id = $1
AND ($2 = '' OR (j.created_at::text || j.id) < $3)
ORDER BY j.created_at DESC, j.id DESC
LIMIT $4`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list group tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// listTasksPaginated is a helper for admin queries that lists all tasks.
func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) {
if opts.Limit <= 0 || opts.Limit > 100 {
opts.Limit = 100
}
cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor)
var query string
if db.dbType == DBTypeSQLite {
query = baseQuery + `
WHERE (? = '' OR (datetime(created_at) || id) < ?)
ORDER BY created_at DESC, id DESC
LIMIT ?`
} else {
query = baseQuery + `
WHERE ($1 = '' OR (created_at::text || id) < $2)
ORDER BY created_at DESC, id DESC
LIMIT $3`
}
fetchLimit := opts.Limit + 1
var rows *sql.Rows
var err error
if db.dbType == DBTypeSQLite {
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
} else {
rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit)
}
if err != nil {
return nil, "", fmt.Errorf("failed to list tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var jobs []*Job
for rows.Next() {
job := &Job{}
var createdAt, updatedAt sql.NullTime
var startedAt, endedAt sql.NullTime
var datasetsJSON, metadataJSON []byte
err := rows.Scan(
&job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority,
&datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error,
&createdAt, &updatedAt, &startedAt, &endedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to scan job: %w", err)
}
if createdAt.Valid {
job.CreatedAt = createdAt.Time
}
if updatedAt.Valid {
job.UpdatedAt = updatedAt.Time
}
if startedAt.Valid {
job.StartedAt = &startedAt.Time
}
if endedAt.Valid {
job.EndedAt = &endedAt.Time
}
if len(datasetsJSON) > 0 {
_ = json.Unmarshal(datasetsJSON, &job.Datasets)
}
if len(metadataJSON) > 0 {
_ = json.Unmarshal(metadataJSON, &job.Metadata)
}
jobs = append(jobs, job)
}
if err = rows.Err(); err != nil {
return nil, "", fmt.Errorf("error iterating jobs: %w", err)
}
nextCursor := ""
if len(jobs) > opts.Limit {
lastJob := jobs[opts.Limit-1]
nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID)
jobs = jobs[:opts.Limit]
}
return jobs, nextCursor, nil
}
// encodeCursor creates an opaque cursor string from created_at and id.
// Format: base64url(created_at_RFC3339 + "|" + task_id)
func encodeCursor(createdAt time.Time, id string) string {
payload := createdAt.UTC().Format(time.RFC3339) + "|" + id
return base64.RawURLEncoding.EncodeToString([]byte(payload))
}
// decodeCursor extracts created_at and id from an opaque cursor.
// Returns empty strings if cursor is invalid or empty.
func decodeCursor(cursor string) (createdAt, id string, err error) {
if cursor == "" {
return "", "", nil
}
payload, err := base64.RawURLEncoding.DecodeString(cursor)
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(payload), "|", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid cursor format")
}
return parts[0], parts[1], nil
}