package storage import ( "context" "database/sql" "encoding/base64" "encoding/json" "fmt" "strings" "time" ) // TaskIsSharedWithUser returns true if an active, non-expired explicit share exists. func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM task_shares WHERE task_id = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)` } else { query = `SELECT COUNT(*) FROM task_shares WHERE task_id = $1 AND user_id = $2 AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)` } var count int err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count) if err != nil { return false } return count > 0 } // UserSharesGroupWithTask returns true if the user is a live member of any group // associated with the task via task_group_access. func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM group_members gm JOIN task_group_access tga ON tga.group_id = gm.group_id WHERE gm.user_id = ? AND tga.task_id = ?` } else { query = `SELECT COUNT(*) FROM group_members gm JOIN task_group_access tga ON tga.group_id = gm.group_id WHERE gm.user_id = $1 AND tga.task_id = $2` } var count int err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count) if err != nil { return false } return count > 0 } // TaskAllowsPublicClone returns true if the task has allow_public_clone enabled. // Checks if the task visibility is 'open' which allows public cloning. func (db *DB) TaskAllowsPublicClone(taskID string) bool { var query string if db.dbType == DBTypeSQLite { query = `SELECT visibility FROM jobs WHERE id = ?` } else { query = `SELECT visibility FROM jobs WHERE id = $1` } var visibility string err := db.conn.QueryRowContext(context.Background(), query, taskID).Scan(&visibility) if err != nil { return false } // Only 'open' visibility allows public clone return visibility == "open" } // AssociateTaskWithGroup records that a task is shared with a specific group. // Called at submit time for lab-visibility tasks. func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error { var query string if db.dbType == DBTypeSQLite { query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)` } else { query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2) ON CONFLICT (task_id, group_id) DO NOTHING` } _, err := db.conn.ExecContext(context.Background(), query, taskID, groupID) if err != nil { return fmt.Errorf("failed to associate task with group: %w", err) } return nil } // CountOpenTasksForUserToday returns the number of open tasks created by user today. func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT COUNT(*) FROM jobs WHERE user_id = ? AND visibility = 'open' AND created_at >= date('now')` } else { query = `SELECT COUNT(*) FROM jobs WHERE user_id = $1 AND visibility = 'open' AND created_at >= CURRENT_DATE` } var count int err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to count open tasks: %w", err) } return count, nil } // ListTasksOptions provides pagination options for task listing. type ListTasksOptions struct { Limit int // required; enforced max 100 Cursor string // (created_at, id) pair encoded as opaque string; empty = first page } // ListTasksForUser returns tasks visible to a specific user. // Uses UNION-based SQL for efficient index usage with cursor pagination. func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) { // Enforce max limit if opts.Limit <= 0 || opts.Limit > 100 { opts.Limit = 100 } if isAdmin { return db.listTasksPaginated("SELECT * FROM jobs", opts) } // Parse cursor if provided cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) // UNION approach: one index-able branch per access reason. // Explicit column list instead of j.* avoids issues in SQLite strict mode // and makes schema expectations explicit across all branches. cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at` var query string if db.dbType == DBTypeSQLite { query = fmt.Sprintf(` -- Owner SELECT %s FROM jobs j WHERE j.user_id = ? AND (? = '' OR (datetime(j.created_at) || j.id) < ?) UNION -- Explicit share (expiry checked inline) SELECT %s FROM jobs j JOIN task_shares ts ON ts.task_id = j.id WHERE ts.user_id = ? AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP) AND (? = '' OR (datetime(j.created_at) || j.id) < ?) UNION -- Lab group membership (live, not snapshotted) SELECT %s FROM jobs j JOIN task_group_access tga ON tga.task_id = j.id JOIN group_members gm ON gm.group_id = tga.group_id WHERE gm.user_id = ? AND j.visibility = 'lab' AND (? = '' OR (datetime(j.created_at) || j.id) < ?) UNION -- Institution or open (all authenticated users) SELECT %s FROM jobs j WHERE j.visibility IN ('institution', 'open') AND (? = '' OR (datetime(j.created_at) || j.id) < ?) ORDER BY created_at DESC, id DESC LIMIT ? `, cols, cols, cols, cols) } else { // PostgreSQL version with $N placeholders query = fmt.Sprintf(` -- Owner SELECT %s FROM jobs j WHERE j.user_id = $1 AND ($2 = '' OR (j.created_at::text || j.id) < $3) UNION -- Explicit share (expiry checked inline) SELECT %s FROM jobs j JOIN task_shares ts ON ts.task_id = j.id WHERE ts.user_id = $4 AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP) AND ($5 = '' OR (j.created_at::text || j.id) < $6) UNION -- Lab group membership (live, not snapshotted) SELECT %s FROM jobs j JOIN task_group_access tga ON tga.task_id = j.id JOIN group_members gm ON gm.group_id = tga.group_id WHERE gm.user_id = $7 AND j.visibility = 'lab' AND ($8 = '' OR (j.created_at::text || j.id) < $9) UNION -- Institution or open (all authenticated users) SELECT %s FROM jobs j WHERE j.visibility IN ('institution', 'open') AND ($10 = '' OR (j.created_at::text || j.id) < $11) ORDER BY created_at DESC, id DESC LIMIT $12 `, cols, cols, cols, cols) } // Fetch limit+1 rows; if len(rows) > limit, there is a next page fetchLimit := opts.Limit + 1 var rows *sql.Rows var err error if db.dbType == DBTypeSQLite { rows, err = db.conn.QueryContext(context.Background(), query, userID, cursorCreatedAt, cursorCreatedAt+cursorID, userID, cursorCreatedAt, cursorCreatedAt+cursorID, userID, cursorCreatedAt, cursorCreatedAt+cursorID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit, ) } else { rows, err = db.conn.QueryContext(context.Background(), query, userID, cursorCreatedAt, cursorCreatedAt+cursorID, userID, cursorCreatedAt, cursorCreatedAt+cursorID, userID, cursorCreatedAt, cursorCreatedAt+cursorID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit, ) } if err != nil { return nil, "", fmt.Errorf("failed to list tasks: %w", err) } defer func() { _ = rows.Close() }() var jobs []*Job for rows.Next() { job := &Job{} var createdAt, updatedAt sql.NullTime var startedAt, endedAt sql.NullTime var datasetsJSON, metadataJSON []byte err := rows.Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, &createdAt, &updatedAt, &startedAt, &endedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to scan job: %w", err) } if createdAt.Valid { job.CreatedAt = createdAt.Time } if updatedAt.Valid { job.UpdatedAt = updatedAt.Time } if startedAt.Valid { job.StartedAt = &startedAt.Time } if endedAt.Valid { job.EndedAt = &endedAt.Time } if len(datasetsJSON) > 0 { _ = json.Unmarshal(datasetsJSON, &job.Datasets) } if len(metadataJSON) > 0 { _ = json.Unmarshal(metadataJSON, &job.Metadata) } jobs = append(jobs, job) } if err = rows.Err(); err != nil { return nil, "", fmt.Errorf("error iterating jobs: %w", err) } // Determine next cursor nextCursor := "" if len(jobs) > opts.Limit { // More rows available, encode cursor from the last row we actually want to return lastJob := jobs[opts.Limit-1] nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) // Trim to requested limit jobs = jobs[:opts.Limit] } return jobs, nextCursor, nil } // ListTasksForGroup returns tasks associated with a group func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) { if opts.Limit <= 0 || opts.Limit > 100 { opts.Limit = 100 } cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) var query string if db.dbType == DBTypeSQLite { query = ` SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at FROM jobs j JOIN task_group_access tga ON tga.task_id = j.id WHERE tga.group_id = ? AND (? = '' OR (datetime(j.created_at) || j.id) < ?) ORDER BY j.created_at DESC, j.id DESC LIMIT ?` } else { query = ` SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at FROM jobs j JOIN task_group_access tga ON tga.task_id = j.id WHERE tga.group_id = $1 AND ($2 = '' OR (j.created_at::text || j.id) < $3) ORDER BY j.created_at DESC, j.id DESC LIMIT $4` } fetchLimit := opts.Limit + 1 var rows *sql.Rows var err error if db.dbType == DBTypeSQLite { rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) } else { rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) } if err != nil { return nil, "", fmt.Errorf("failed to list group tasks: %w", err) } defer func() { _ = rows.Close() }() var jobs []*Job for rows.Next() { job := &Job{} var createdAt, updatedAt sql.NullTime var startedAt, endedAt sql.NullTime var datasetsJSON, metadataJSON []byte err := rows.Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, &createdAt, &updatedAt, &startedAt, &endedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to scan job: %w", err) } if createdAt.Valid { job.CreatedAt = createdAt.Time } if updatedAt.Valid { job.UpdatedAt = updatedAt.Time } if startedAt.Valid { job.StartedAt = &startedAt.Time } if endedAt.Valid { job.EndedAt = &endedAt.Time } if len(datasetsJSON) > 0 { _ = json.Unmarshal(datasetsJSON, &job.Datasets) } if len(metadataJSON) > 0 { _ = json.Unmarshal(metadataJSON, &job.Metadata) } jobs = append(jobs, job) } if err = rows.Err(); err != nil { return nil, "", fmt.Errorf("error iterating jobs: %w", err) } nextCursor := "" if len(jobs) > opts.Limit { lastJob := jobs[opts.Limit-1] nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) jobs = jobs[:opts.Limit] } return jobs, nextCursor, nil } // listTasksPaginated is a helper for admin queries that lists all tasks. func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) { if opts.Limit <= 0 || opts.Limit > 100 { opts.Limit = 100 } cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) var query string if db.dbType == DBTypeSQLite { query = baseQuery + ` WHERE (? = '' OR (datetime(created_at) || id) < ?) ORDER BY created_at DESC, id DESC LIMIT ?` } else { query = baseQuery + ` WHERE ($1 = '' OR (created_at::text || id) < $2) ORDER BY created_at DESC, id DESC LIMIT $3` } fetchLimit := opts.Limit + 1 var rows *sql.Rows var err error if db.dbType == DBTypeSQLite { rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) } else { rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) } if err != nil { return nil, "", fmt.Errorf("failed to list tasks: %w", err) } defer func() { _ = rows.Close() }() var jobs []*Job for rows.Next() { job := &Job{} var createdAt, updatedAt sql.NullTime var startedAt, endedAt sql.NullTime var datasetsJSON, metadataJSON []byte err := rows.Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, &createdAt, &updatedAt, &startedAt, &endedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to scan job: %w", err) } if createdAt.Valid { job.CreatedAt = createdAt.Time } if updatedAt.Valid { job.UpdatedAt = updatedAt.Time } if startedAt.Valid { job.StartedAt = &startedAt.Time } if endedAt.Valid { job.EndedAt = &endedAt.Time } if len(datasetsJSON) > 0 { _ = json.Unmarshal(datasetsJSON, &job.Datasets) } if len(metadataJSON) > 0 { _ = json.Unmarshal(metadataJSON, &job.Metadata) } jobs = append(jobs, job) } if err = rows.Err(); err != nil { return nil, "", fmt.Errorf("error iterating jobs: %w", err) } nextCursor := "" if len(jobs) > opts.Limit { lastJob := jobs[opts.Limit-1] nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) jobs = jobs[:opts.Limit] } return jobs, nextCursor, nil } // encodeCursor creates an opaque cursor string from created_at and id. // Format: base64url(created_at_RFC3339 + "|" + task_id) func encodeCursor(createdAt time.Time, id string) string { payload := createdAt.UTC().Format(time.RFC3339) + "|" + id return base64.RawURLEncoding.EncodeToString([]byte(payload)) } // decodeCursor extracts created_at and id from an opaque cursor. // Returns empty strings if cursor is invalid or empty. func decodeCursor(cursor string) (createdAt, id string, err error) { if cursor == "" { return "", "", nil } payload, err := base64.RawURLEncoding.DecodeString(cursor) if err != nil { return "", "", err } parts := strings.SplitN(string(payload), "|", 2) if len(parts) != 2 { return "", "", fmt.Errorf("invalid cursor format") } return parts[0], parts[1], nil }