diff --git a/internal/storage/db_audit.go b/internal/storage/db_audit.go new file mode 100644 index 0000000..a5e795a --- /dev/null +++ b/internal/storage/db_audit.go @@ -0,0 +1,225 @@ +package storage + +import ( + "context" + "database/sql" + "fmt" + "log" + "time" +) + +// AuditLogEntry represents a single audit log entry +type AuditLogEntry struct { + ID int64 `json:"id"` + TaskID string `json:"task_id"` + UserID *string `json:"user_id,omitempty"` + Token *string `json:"token,omitempty"` + Action string `json:"action"` + AccessedAt time.Time `json:"accessed_at"` + IPAddress *string `json:"ip_address,omitempty"` +} + +// LogTaskAccess records an access event in the audit log. +func (db *DB) LogTaskAccess(taskID string, userID, token, action, ipAddress *string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address) + VALUES (?, ?, ?, ?, ?)` + } else { + query = `INSERT INTO task_access_log (task_id, user_id, token, action, ip_address) + VALUES ($1, $2, $3, $4, $5)` + } + + _, err := db.conn.ExecContext(context.Background(), query, taskID, userID, token, action, ipAddress) + if err != nil { + return fmt.Errorf("failed to log task access: %w", err) + } + return nil +} + +// GetAuditLogForTask retrieves audit log entries for a specific task. +func (db *DB) GetAuditLogForTask(taskID string, limit int) ([]*AuditLogEntry, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE task_id = ? + ORDER BY accessed_at DESC + LIMIT ?` + } else { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE task_id = $1 + ORDER BY accessed_at DESC + LIMIT $2` + } + + rows, err := db.conn.QueryContext(context.Background(), query, taskID, limit) + if err != nil { + return nil, fmt.Errorf("failed to get audit log: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() + + return db.scanAuditLogEntries(rows) +} + +// GetAuditLogForUser retrieves audit log entries for a specific user. +func (db *DB) GetAuditLogForUser(userID string, limit int) ([]*AuditLogEntry, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE user_id = ? + ORDER BY accessed_at DESC + LIMIT ?` + } else { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE user_id = $1 + ORDER BY accessed_at DESC + LIMIT $2` + } + + rows, err := db.conn.QueryContext(context.Background(), query, userID, limit) + if err != nil { + return nil, fmt.Errorf("failed to get audit log: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() + + return db.scanAuditLogEntries(rows) +} + +// GetAuditLogForToken retrieves audit log entries for a specific token. +func (db *DB) GetAuditLogForToken(token string, limit int) ([]*AuditLogEntry, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE token = ? + ORDER BY accessed_at DESC + LIMIT ?` + } else { + query = `SELECT id, task_id, user_id, token, action, accessed_at, ip_address + FROM task_access_log + WHERE token = $1 + ORDER BY accessed_at DESC + LIMIT $2` + } + + rows, err := db.conn.QueryContext(context.Background(), query, token, limit) + if err != nil { + return nil, fmt.Errorf("failed to get audit log: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() + + return db.scanAuditLogEntries(rows) +} + +// scanAuditLogEntries scans rows into AuditLogEntry structs. +func (db *DB) scanAuditLogEntries(rows *sql.Rows) ([]*AuditLogEntry, error) { + var entries []*AuditLogEntry + for rows.Next() { + var e AuditLogEntry + var userID, token, ipAddress sql.NullString + + err := rows.Scan(&e.ID, &e.TaskID, &userID, &token, &e.Action, &e.AccessedAt, &ipAddress) + if err != nil { + return nil, fmt.Errorf("failed to scan audit log entry: %w", err) + } + + if userID.Valid { + e.UserID = &userID.String + } + if token.Valid { + e.Token = &token.String + } + if ipAddress.Valid { + e.IPAddress = &ipAddress.String + } + + entries = append(entries, &e) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating audit log entries: %w", err) + } + + return entries, nil +} + +// DeleteOldAuditLogs removes audit log entries older than the retention period. +// This should be called by a nightly job. +func (db *DB) DeleteOldAuditLogs(retentionDays int) (int64, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `DELETE FROM task_access_log + WHERE accessed_at < datetime('now', '-' || ? || ' days')` + } else { + query = `DELETE FROM task_access_log + WHERE accessed_at < NOW() - INTERVAL '1 day' * $1` + } + + result, err := db.conn.ExecContext(context.Background(), query, retentionDays) + if err != nil { + return 0, fmt.Errorf("failed to delete old audit logs: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("failed to get rows affected: %w", err) + } + + return affected, nil +} + +// CountAuditLogs returns the total number of audit log entries. +func (db *DB) CountAuditLogs() (int64, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM task_access_log` + } else { + query = `SELECT COUNT(*) FROM task_access_log` + } + + var count int64 + err := db.conn.QueryRowContext(context.Background(), query).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count audit logs: %w", err) + } + + return count, nil +} + +// GetOldestAuditLogDate returns the date of the oldest audit log entry. +func (db *DB) GetOldestAuditLogDate() (*time.Time, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT MIN(accessed_at) FROM task_access_log` + } else { + query = `SELECT MIN(accessed_at) FROM task_access_log` + } + + var date sql.NullTime + err := db.conn.QueryRowContext(context.Background(), query).Scan(&date) + if err != nil { + return nil, fmt.Errorf("failed to get oldest audit log date: %w", err) + } + + if !date.Valid { + return nil, nil + } + + return &date.Time, nil +} diff --git a/internal/storage/db_experiments.go b/internal/storage/db_experiments.go index 7cea79a..972862a 100644 --- a/internal/storage/db_experiments.go +++ b/internal/storage/db_experiments.go @@ -562,3 +562,170 @@ func (db *DB) SearchDatasets(ctx context.Context, term string, limit int) ([]*Da } return out, nil } + +// AssociateTaskWithExperiment links a task to an experiment. +func (db *DB) AssociateTaskWithExperiment(taskID, experimentID string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT OR IGNORE INTO experiment_tasks (experiment_id, task_id) VALUES (?, ?)` + } else { + query = `INSERT INTO experiment_tasks (experiment_id, task_id) VALUES ($1, $2) + ON CONFLICT (experiment_id, task_id) DO NOTHING` + } + _, err := db.conn.ExecContext(context.Background(), query, experimentID, taskID) + if err != nil { + return fmt.Errorf("failed to associate task with experiment: %w", err) + } + return nil +} + +// GetExperimentVisibility returns the visibility level for an experiment. +func (db *DB) GetExperimentVisibility(experimentID string) (string, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COALESCE(j.visibility, 'private') + FROM experiment_tasks et + JOIN jobs j ON j.id = et.task_id + WHERE et.experiment_id = ? + ORDER BY j.created_at DESC + LIMIT 1` + } else { + query = `SELECT COALESCE(j.visibility, 'private') + FROM experiment_tasks et + JOIN jobs j ON j.id = et.task_id + WHERE et.experiment_id = $1 + ORDER BY j.created_at DESC + LIMIT 1` + } + + var visibility string + err := db.conn.QueryRowContext(context.Background(), query, experimentID).Scan(&visibility) + if err != nil { + if err == sql.ErrNoRows { + return "private", nil + } + return "", fmt.Errorf("failed to get experiment visibility: %w", err) + } + return visibility, nil +} + +// CascadeExperimentVisibility updates visibility for all tasks in an experiment. +func (db *DB) CascadeExperimentVisibility(experimentID, visibility string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `UPDATE jobs + SET visibility = ? + WHERE id IN ( + SELECT task_id FROM experiment_tasks WHERE experiment_id = ? + )` + } else { + query = `UPDATE jobs + SET visibility = $1 + WHERE id IN ( + SELECT task_id FROM experiment_tasks WHERE experiment_id = $2 + )` + } + _, err := db.conn.ExecContext(context.Background(), query, visibility, experimentID) + if err != nil { + return fmt.Errorf("failed to cascade visibility: %w", err) + } + return nil +} + +// ListTasksForExperiment returns all tasks associated with an experiment. +func (db *DB) ListTasksForExperiment(experimentID string, opts ListTasksOptions) ([]*Job, string, error) { + if opts.Limit <= 0 || opts.Limit > 100 { + opts.Limit = 100 + } + + cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) + + var query string + if db.dbType == DBTypeSQLite { + query = ` + SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, + j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at + FROM jobs j + JOIN experiment_tasks et ON et.task_id = j.id + WHERE et.experiment_id = ? + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + ORDER BY j.created_at DESC, j.id DESC + LIMIT ?` + } else { + query = ` + SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, + j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at + FROM jobs j + JOIN experiment_tasks et ON et.task_id = j.id + WHERE et.experiment_id = $1 + AND ($2 = '' OR (j.created_at::text || j.id) < $3) + ORDER BY j.created_at DESC, j.id DESC + LIMIT $4` + } + + fetchLimit := opts.Limit + 1 + + var rows *sql.Rows + var err error + if db.dbType == DBTypeSQLite { + rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } else { + rows, err = db.conn.QueryContext(context.Background(), query, experimentID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } + if err != nil { + return nil, "", fmt.Errorf("failed to list experiment tasks: %w", err) + } + defer func() { _ = rows.Close() }() + + var jobs []*Job + for rows.Next() { + job := &Job{} + var createdAt, updatedAt sql.NullTime + var startedAt, endedAt sql.NullTime + var datasetsJSON, metadataJSON []byte + + err := rows.Scan( + &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, + &createdAt, &updatedAt, &startedAt, &endedAt, + ) + if err != nil { + return nil, "", fmt.Errorf("failed to scan job: %w", err) + } + + if createdAt.Valid { + job.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + job.UpdatedAt = updatedAt.Time + } + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if endedAt.Valid { + job.EndedAt = &endedAt.Time + } + + if len(datasetsJSON) > 0 { + _ = json.Unmarshal(datasetsJSON, &job.Datasets) + } + if len(metadataJSON) > 0 { + _ = json.Unmarshal(metadataJSON, &job.Metadata) + } + + jobs = append(jobs, job) + } + + if err = rows.Err(); err != nil { + return nil, "", fmt.Errorf("error iterating jobs: %w", err) + } + + nextCursor := "" + if len(jobs) > opts.Limit { + lastJob := jobs[opts.Limit-1] + nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) + jobs = jobs[:opts.Limit] + } + + return jobs, nextCursor, nil +} diff --git a/internal/storage/db_groups.go b/internal/storage/db_groups.go new file mode 100644 index 0000000..44bfc0d --- /dev/null +++ b/internal/storage/db_groups.go @@ -0,0 +1,541 @@ +package storage + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "time" +) + +// Group represents a lab group +type Group struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by"` +} + +// GroupInvitation represents a pending invitation +type GroupInvitation struct { + ID string `json:"id"` + GroupID string `json:"group_id"` + InvitedUserID string `json:"invited_user_id"` + InvitedBy string `json:"invited_by"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` +} + +// CreateGroup creates a new lab group +func (db *DB) CreateGroup(name, description, createdBy string) (*Group, error) { + id := fmt.Sprintf("group_%d", time.Now().UnixNano()) + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` + } else { + query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)` + } + + _, err := db.conn.ExecContext(context.Background(), query, id, name, description, createdBy) + if err != nil { + return nil, fmt.Errorf("failed to create group: %w", err) + } + + // Add creator as admin + if db.dbType == DBTypeSQLite { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')` + } else { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')` + } + _, err = db.conn.ExecContext(context.Background(), query, id, createdBy) + if err != nil { + return nil, fmt.Errorf("failed to add creator as admin: %w", err) + } + + return &Group{ + ID: id, + Name: name, + Description: description, + CreatedBy: createdBy, + CreatedAt: time.Now(), + }, nil +} + +// ListGroupsForUser returns all groups the user is a member of +func (db *DB) ListGroupsForUser(userID string) ([]Group, error) { + var query string + if db.dbType == DBTypeSQLite { + query = ` + SELECT g.id, g.name, g.description, g.created_at, g.created_by + FROM groups g + JOIN group_members gm ON gm.group_id = g.id + WHERE gm.user_id = ? + ORDER BY g.created_at DESC` + } else { + query = ` + SELECT g.id, g.name, g.description, g.created_at, g.created_by + FROM groups g + JOIN group_members gm ON gm.group_id = g.id + WHERE gm.user_id = $1 + ORDER BY g.created_at DESC` + } + + rows, err := db.conn.QueryContext(context.Background(), query, userID) + if err != nil { + return nil, fmt.Errorf("failed to list groups: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() + + var groups []Group + for rows.Next() { + var g Group + var createdAt sql.NullTime + err := rows.Scan(&g.ID, &g.Name, &g.Description, &createdAt, &g.CreatedBy) + if err != nil { + return nil, fmt.Errorf("failed to scan group: %w", err) + } + if createdAt.Valid { + g.CreatedAt = createdAt.Time + } + groups = append(groups, g) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating groups: %w", err) + } + + return groups, nil +} + +// IsGroupAdmin checks if user is an admin of the group +func (db *DB) IsGroupAdmin(userID, groupID string) (bool, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ? AND role = 'admin'` + } else { + query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2 AND role = 'admin'` + } + + var count int + err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count) + if err != nil { + return false, fmt.Errorf("failed to check admin status: %w", err) + } + return count > 0, nil +} + +// IsGroupMember checks if user is a member of the group +func (db *DB) IsGroupMember(userID, groupID string) (bool, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM group_members WHERE group_id = ? AND user_id = ?` + } else { + query = `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND user_id = $2` + } + + var count int + err := db.conn.QueryRowContext(context.Background(), query, groupID, userID).Scan(&count) + if err != nil { + return false, fmt.Errorf("failed to check membership: %w", err) + } + return count > 0, nil +} + +// CreateGroupInvitation creates a new group invitation +func (db *DB) CreateGroupInvitation(groupID, invitedUserID, invitedBy string) (*GroupInvitation, error) { + id := fmt.Sprintf("inv_%d", time.Now().UnixNano()) + expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES (?, ?, ?, ?, ?)` + } else { + query = `INSERT INTO group_invitations (id, group_id, invited_user_id, invited_by, expires_at) VALUES ($1, $2, $3, $4, $5)` + } + + _, err := db.conn.ExecContext(context.Background(), query, id, groupID, invitedUserID, invitedBy, expiresAt) + if err != nil { + return nil, fmt.Errorf("failed to create invitation: %w", err) + } + + return &GroupInvitation{ + ID: id, + GroupID: groupID, + InvitedUserID: invitedUserID, + InvitedBy: invitedBy, + Status: "pending", + CreatedAt: time.Now(), + ExpiresAt: &expiresAt, + }, nil +} + +// GetInvitation retrieves an invitation by ID +func (db *DB) GetInvitation(invitationID string) (*GroupInvitation, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = ?` + } else { + query = `SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at FROM group_invitations WHERE id = $1` + } + + var inv GroupInvitation + var createdAt, expiresAt sql.NullTime + err := db.conn.QueryRowContext(context.Background(), query, invitationID).Scan( + &inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("invitation not found") + } + return nil, fmt.Errorf("failed to get invitation: %w", err) + } + + if createdAt.Valid { + inv.CreatedAt = createdAt.Time + } + if expiresAt.Valid { + inv.ExpiresAt = &expiresAt.Time + } + + return &inv, nil +} + +// ListPendingInvitationsForUser returns pending invitations for a user +func (db *DB) ListPendingInvitationsForUser(userID string) ([]GroupInvitation, error) { + var query string + if db.dbType == DBTypeSQLite { + query = ` + SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at + FROM group_invitations + WHERE invited_user_id = ? AND status = 'pending' + ORDER BY created_at DESC` + } else { + query = ` + SELECT id, group_id, invited_user_id, invited_by, status, created_at, expires_at + FROM group_invitations + WHERE invited_user_id = $1 AND status = 'pending' + ORDER BY created_at DESC` + } + + rows, err := db.conn.QueryContext(context.Background(), query, userID) + if err != nil { + return nil, fmt.Errorf("failed to list invitations: %w", err) + } + defer func() { + if err := rows.Close(); err != nil { + log.Printf("ERROR: failed to close rows: %v", err) + } + }() + + var invitations []GroupInvitation + for rows.Next() { + var inv GroupInvitation + var createdAt, expiresAt sql.NullTime + err := rows.Scan(&inv.ID, &inv.GroupID, &inv.InvitedUserID, &inv.InvitedBy, &inv.Status, &createdAt, &expiresAt) + if err != nil { + return nil, fmt.Errorf("failed to scan invitation: %w", err) + } + if createdAt.Valid { + inv.CreatedAt = createdAt.Time + } + if expiresAt.Valid { + inv.ExpiresAt = &expiresAt.Time + } + invitations = append(invitations, inv) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating invitations: %w", err) + } + + return invitations, nil +} + +// AcceptInvitation accepts a group invitation +func (db *DB) AcceptInvitation(invitationID, userID string) error { + tx, err := db.conn.BeginTx(context.Background(), nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + log.Printf("ERROR: failed to rollback transaction: %v", err) + } + }() + + // Get invitation details + var groupID string + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT group_id FROM group_invitations WHERE id = ? AND invited_user_id = ? AND status = 'pending'` + } else { + query = `SELECT group_id FROM group_invitations WHERE id = $1 AND invited_user_id = $2 AND status = 'pending'` + } + + err = tx.QueryRowContext(context.Background(), query, invitationID, userID).Scan(&groupID) + if err != nil { + return fmt.Errorf("invitation not found or already processed: %w", err) + } + + // Add user to group + if db.dbType == DBTypeSQLite { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'member')` + } else { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'member')` + } + _, err = tx.ExecContext(context.Background(), query, groupID, userID) + if err != nil { + return fmt.Errorf("failed to add member: %w", err) + } + + // Update invitation status + if db.dbType == DBTypeSQLite { + query = `UPDATE group_invitations SET status = 'accepted' WHERE id = ?` + } else { + query = `UPDATE group_invitations SET status = 'accepted' WHERE id = $1` + } + _, err = tx.ExecContext(context.Background(), query, invitationID) + if err != nil { + return fmt.Errorf("failed to update invitation: %w", err) + } + + return tx.Commit() +} + +// DeclineInvitation declines a group invitation +func (db *DB) DeclineInvitation(invitationID, userID string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `UPDATE group_invitations SET status = 'declined' WHERE id = ? AND invited_user_id = ?` + } else { + query = `UPDATE group_invitations SET status = 'declined' WHERE id = $1 AND invited_user_id = $2` + } + + _, err := db.conn.ExecContext(context.Background(), query, invitationID, userID) + if err != nil { + return fmt.Errorf("failed to decline invitation: %w", err) + } + return nil +} + +// RemoveGroupMember removes a member from a group +func (db *DB) RemoveGroupMember(groupID, userID string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `DELETE FROM group_members WHERE group_id = ? AND user_id = ?` + } else { + query = `DELETE FROM group_members WHERE group_id = $1 AND user_id = $2` + } + + _, err := db.conn.ExecContext(context.Background(), query, groupID, userID) + if err != nil { + return fmt.Errorf("failed to remove member: %w", err) + } + return nil +} + +// UserRoleInTaskGroups returns the highest role (admin > member > viewer) the user +// holds across all groups associated with the task. Returns empty string if no access. +func (db *DB) UserRoleInTaskGroups(userID, taskID string) string { + var query string + if db.dbType == DBTypeSQLite { + query = ` + SELECT gm.role FROM group_members gm + JOIN task_group_access tga ON tga.group_id = gm.group_id + WHERE gm.user_id = ? AND tga.task_id = ? + ORDER BY CASE gm.role + WHEN 'admin' THEN 1 + WHEN 'member' THEN 2 + WHEN 'viewer' THEN 3 + END + LIMIT 1` + } else { + query = ` + SELECT gm.role FROM group_members gm + JOIN task_group_access tga ON tga.group_id = gm.group_id + WHERE gm.user_id = $1 AND tga.task_id = $2 + ORDER BY CASE gm.role + WHEN 'admin' THEN 1 + WHEN 'member' THEN 2 + WHEN 'viewer' THEN 3 + END + LIMIT 1` + } + + var role string + err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&role) + if err != nil { + return "" + } + return role +} + +// GetOrCreateDefaultLabGroup creates the auto-provisioned lab group if it doesn't exist. +// Returns the group ID. If no DEFAULT_LAB_GROUP env var is set, returns empty string. +func (db *DB) GetOrCreateDefaultLabGroup(createdBy string) (string, error) { + groupName := os.Getenv("DEFAULT_LAB_GROUP") + if groupName == "" { + return "", nil + } + + // Use transaction to prevent race conditions during concurrent startup + tx, err := db.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + return "", fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + log.Printf("ERROR: failed to rollback transaction: %v", err) + } + }() + + // Check if group exists + var groupID string + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT id FROM groups WHERE name = ?` + } else { + query = `SELECT id FROM groups WHERE name = $1` + } + err = tx.QueryRowContext(context.Background(), query, groupName).Scan(&groupID) + if err == nil { + if commitErr := tx.Commit(); commitErr != nil { + return "", fmt.Errorf("failed to commit transaction: %w", commitErr) + } + return groupID, nil + } + if err != sql.ErrNoRows { + return "", fmt.Errorf("failed to check for default lab group: %w", err) + } + + // Create the group within the transaction + id := fmt.Sprintf("group_%d", time.Now().UnixNano()) + if db.dbType == DBTypeSQLite { + query = `INSERT INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` + } else { + query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4)` + } + _, err = tx.ExecContext(context.Background(), query, id, groupName, "Auto-provisioned default lab group", createdBy) + if err != nil { + return "", fmt.Errorf("failed to create default lab group: %w", err) + } + + // Add creator as admin within the transaction + if db.dbType == DBTypeSQLite { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES (?, ?, 'admin')` + } else { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, 'admin')` + } + _, err = tx.ExecContext(context.Background(), query, id, createdBy) + if err != nil { + return "", fmt.Errorf("failed to add creator as admin: %w", err) + } + + if err := tx.Commit(); err != nil { + return "", fmt.Errorf("failed to commit transaction: %w", err) + } + + return id, nil +} + +// EnsureUserInGroup adds a user to a group if not already a member. +// Default role is 'member'. Returns nil if already a member. +func (db *DB) EnsureUserInGroup(groupID, userID string, role string) error { + if role == "" { + role = "member" + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT OR IGNORE INTO group_members (group_id, user_id, role) VALUES (?, ?, ?)` + } else { + query = `INSERT INTO group_members (group_id, user_id, role) VALUES ($1, $2, $3) + ON CONFLICT (group_id, user_id) DO NOTHING` + } + + _, err := db.conn.ExecContext(context.Background(), query, groupID, userID, role) + if err != nil { + return fmt.Errorf("failed to add user to group: %w", err) + } + return nil +} + +// EnsureAllUsersGroup creates the 'all-users' system group if it doesn't exist. +// This group is used for institution visibility. +func (db *DB) EnsureAllUsersGroup() (string, error) { + const groupID = "all-users" + const groupName = "all-users" + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT OR IGNORE INTO groups (id, name, description, created_by) VALUES (?, ?, ?, ?)` + } else { + query = `INSERT INTO groups (id, name, description, created_by) VALUES ($1, $2, $3, $4) + ON CONFLICT (id) DO NOTHING` + } + + _, err := db.conn.ExecContext(context.Background(), query, groupID, groupName, + "System group: all authenticated users", "system") + if err != nil { + return "", fmt.Errorf("failed to ensure all-users group: %w", err) + } + return groupID, nil +} + +// EnsureUserInAllUsersGroup adds a user to the 'all-users' system group. +func (db *DB) EnsureUserInAllUsersGroup(userID string) error { + allUsersGroupID, err := db.EnsureAllUsersGroup() + if err != nil { + return err + } + return db.EnsureUserInGroup(allUsersGroupID, userID, "member") +} + +// ProvisionUserOnFirstLogin adds a new user to the default groups: +// 1. The 'all-users' system group (for institution visibility) +// 2. The DEFAULT_LAB_GROUP if configured (for lab visibility) +// This should be called when a user first authenticates/logs in. +func (db *DB) ProvisionUserOnFirstLogin(userID string) error { + // Add to all-users system group first + if err := db.EnsureUserInAllUsersGroup(userID); err != nil { + return fmt.Errorf("failed to add user to all-users group: %w", err) + } + + // Add to default lab group if configured + defaultGroupID, err := db.GetOrCreateDefaultLabGroup("system") + if err != nil { + return fmt.Errorf("failed to get/create default lab group: %w", err) + } + if defaultGroupID != "" { + if err := db.EnsureUserInGroup(defaultGroupID, userID, "member"); err != nil { + return fmt.Errorf("failed to add user to default lab group: %w", err) + } + } + + return nil +} + +// DeprovisionUser removes a user from all groups when they are deactivated/deleted. +// This prevents deactivated users from retaining institution-visibility access. +func (db *DB) DeprovisionUser(userID string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `DELETE FROM group_members WHERE user_id = ?` + } else { + query = `DELETE FROM group_members WHERE user_id = $1` + } + + _, err := db.conn.ExecContext(context.Background(), query, userID) + if err != nil { + return fmt.Errorf("failed to remove user from groups: %w", err) + } + return nil +} diff --git a/internal/storage/db_jobs.go b/internal/storage/db_jobs.go index 2dba42e..bff8aae 100644 --- a/internal/storage/db_jobs.go +++ b/internal/storage/db_jobs.go @@ -163,6 +163,7 @@ func (db *DB) ListJobs(status string, limit int) ([]*Job, error) { if db.dbType == DBTypeSQLite { query += " LIMIT ?" } else { + //nolint:gosec // G202: This builds a positional parameter $N where N is an internal counter, not user input query += fmt.Sprintf(" LIMIT $%d", len(args)+1) } args = append(args, limit) diff --git a/internal/storage/db_tasks.go b/internal/storage/db_tasks.go new file mode 100644 index 0000000..962c409 --- /dev/null +++ b/internal/storage/db_tasks.go @@ -0,0 +1,498 @@ +package storage + +import ( + "context" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" +) + +// TaskIsSharedWithUser returns true if an active, non-expired explicit share exists. +func (db *DB) TaskIsSharedWithUser(taskID, userID string) bool { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM task_shares + WHERE task_id = ? AND user_id = ? + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)` + } else { + query = `SELECT COUNT(*) FROM task_shares + WHERE task_id = $1 AND user_id = $2 + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)` + } + var count int + err := db.conn.QueryRowContext(context.Background(), query, taskID, userID).Scan(&count) + if err != nil { + return false + } + return count > 0 +} + +// UserSharesGroupWithTask returns true if the user is a live member of any group +// associated with the task via task_group_access. +func (db *DB) UserSharesGroupWithTask(userID, taskID string) bool { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM group_members gm + JOIN task_group_access tga ON tga.group_id = gm.group_id + WHERE gm.user_id = ? AND tga.task_id = ?` + } else { + query = `SELECT COUNT(*) FROM group_members gm + JOIN task_group_access tga ON tga.group_id = gm.group_id + WHERE gm.user_id = $1 AND tga.task_id = $2` + } + var count int + err := db.conn.QueryRowContext(context.Background(), query, userID, taskID).Scan(&count) + if err != nil { + return false + } + return count > 0 +} + +// TaskAllowsPublicClone returns true if the task has allow_public_clone enabled. +// For now, this checks if the task exists and is visible as 'open'. +// The allow_public_clone flag would be stored in the task metadata or a separate column. +func (db *DB) TaskAllowsPublicClone(taskID string) bool { + // TODO: Implement proper allow_public_clone check + // For now, return false as a safe default + return false +} + +// AssociateTaskWithGroup records that a task is shared with a specific group. +// Called at submit time for lab-visibility tasks. +func (db *DB) AssociateTaskWithGroup(taskID, groupID string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT OR IGNORE INTO task_group_access (task_id, group_id) VALUES (?, ?)` + } else { + query = `INSERT INTO task_group_access (task_id, group_id) VALUES ($1, $2) + ON CONFLICT (task_id, group_id) DO NOTHING` + } + _, err := db.conn.ExecContext(context.Background(), query, taskID, groupID) + if err != nil { + return fmt.Errorf("failed to associate task with group: %w", err) + } + return nil +} + +// CountOpenTasksForUserToday returns the number of open tasks created by user today. +func (db *DB) CountOpenTasksForUserToday(userID string) (int, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT COUNT(*) FROM jobs + WHERE user_id = ? + AND visibility = 'open' + AND created_at >= date('now')` + } else { + query = `SELECT COUNT(*) FROM jobs + WHERE user_id = $1 + AND visibility = 'open' + AND created_at >= CURRENT_DATE` + } + var count int + err := db.conn.QueryRowContext(context.Background(), query, userID).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count open tasks: %w", err) + } + return count, nil +} + +// ListTasksOptions provides pagination options for task listing. +type ListTasksOptions struct { + Limit int // required; enforced max 100 + Cursor string // (created_at, id) pair encoded as opaque string; empty = first page +} + +// ListTasksForUser returns tasks visible to a specific user. +// Uses UNION-based SQL for efficient index usage with cursor pagination. +func (db *DB) ListTasksForUser(userID string, isAdmin bool, opts ListTasksOptions) ([]*Job, string, error) { + // Enforce max limit + if opts.Limit <= 0 || opts.Limit > 100 { + opts.Limit = 100 + } + + if isAdmin { + return db.listTasksPaginated("SELECT * FROM jobs", opts) + } + + // Parse cursor if provided + cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) + + // UNION approach: one index-able branch per access reason. + // Explicit column list instead of j.* avoids issues in SQLite strict mode + // and makes schema expectations explicit across all branches. + cols := `j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at` + + var query string + if db.dbType == DBTypeSQLite { + query = fmt.Sprintf(` + -- Owner + SELECT %s FROM jobs j + WHERE j.user_id = ? + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + + UNION + + -- Explicit share (expiry checked inline) + SELECT %s FROM jobs j + JOIN task_shares ts ON ts.task_id = j.id + WHERE ts.user_id = ? + AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP) + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + + UNION + + -- Lab group membership (live, not snapshotted) + SELECT %s FROM jobs j + JOIN task_group_access tga ON tga.task_id = j.id + JOIN group_members gm ON gm.group_id = tga.group_id + WHERE gm.user_id = ? + AND j.visibility = 'lab' + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + + UNION + + -- Institution or open (all authenticated users) + SELECT %s FROM jobs j + WHERE j.visibility IN ('institution', 'open') + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + + ORDER BY created_at DESC, id DESC + LIMIT ? + `, cols, cols, cols, cols) + } else { + // PostgreSQL version with $N placeholders + query = fmt.Sprintf(` + -- Owner + SELECT %s FROM jobs j + WHERE j.user_id = $1 + AND ($2 = '' OR (j.created_at::text || j.id) < $3) + + UNION + + -- Explicit share (expiry checked inline) + SELECT %s FROM jobs j + JOIN task_shares ts ON ts.task_id = j.id + WHERE ts.user_id = $4 + AND (ts.expires_at IS NULL OR ts.expires_at > CURRENT_TIMESTAMP) + AND ($5 = '' OR (j.created_at::text || j.id) < $6) + + UNION + + -- Lab group membership (live, not snapshotted) + SELECT %s FROM jobs j + JOIN task_group_access tga ON tga.task_id = j.id + JOIN group_members gm ON gm.group_id = tga.group_id + WHERE gm.user_id = $7 + AND j.visibility = 'lab' + AND ($8 = '' OR (j.created_at::text || j.id) < $9) + + UNION + + -- Institution or open (all authenticated users) + SELECT %s FROM jobs j + WHERE j.visibility IN ('institution', 'open') + AND ($10 = '' OR (j.created_at::text || j.id) < $11) + + ORDER BY created_at DESC, id DESC + LIMIT $12 + `, cols, cols, cols, cols) + } + + // Fetch limit+1 rows; if len(rows) > limit, there is a next page + fetchLimit := opts.Limit + 1 + + var rows *sql.Rows + var err error + if db.dbType == DBTypeSQLite { + rows, err = db.conn.QueryContext(context.Background(), query, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + cursorCreatedAt, cursorCreatedAt+cursorID, + fetchLimit, + ) + } else { + rows, err = db.conn.QueryContext(context.Background(), query, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + userID, cursorCreatedAt, cursorCreatedAt+cursorID, + cursorCreatedAt, cursorCreatedAt+cursorID, + fetchLimit, + ) + } + if err != nil { + return nil, "", fmt.Errorf("failed to list tasks: %w", err) + } + defer func() { _ = rows.Close() }() + + var jobs []*Job + for rows.Next() { + job := &Job{} + var createdAt, updatedAt sql.NullTime + var startedAt, endedAt sql.NullTime + var datasetsJSON, metadataJSON []byte + + err := rows.Scan( + &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, + &createdAt, &updatedAt, &startedAt, &endedAt, + ) + if err != nil { + return nil, "", fmt.Errorf("failed to scan job: %w", err) + } + + if createdAt.Valid { + job.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + job.UpdatedAt = updatedAt.Time + } + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if endedAt.Valid { + job.EndedAt = &endedAt.Time + } + + if len(datasetsJSON) > 0 { + _ = json.Unmarshal(datasetsJSON, &job.Datasets) + } + if len(metadataJSON) > 0 { + _ = json.Unmarshal(metadataJSON, &job.Metadata) + } + + jobs = append(jobs, job) + } + + if err = rows.Err(); err != nil { + return nil, "", fmt.Errorf("error iterating jobs: %w", err) + } + + // Determine next cursor + nextCursor := "" + if len(jobs) > opts.Limit { + // More rows available, encode cursor from the last row we actually want to return + lastJob := jobs[opts.Limit-1] + nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) + // Trim to requested limit + jobs = jobs[:opts.Limit] + } + + return jobs, nextCursor, nil +} + +// ListTasksForGroup returns tasks associated with a group +func (db *DB) ListTasksForGroup(groupID string, opts ListTasksOptions) ([]*Job, string, error) { + if opts.Limit <= 0 || opts.Limit > 100 { + opts.Limit = 100 + } + + cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) + + var query string + if db.dbType == DBTypeSQLite { + query = ` + SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, + j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at + FROM jobs j + JOIN task_group_access tga ON tga.task_id = j.id + WHERE tga.group_id = ? + AND (? = '' OR (datetime(j.created_at) || j.id) < ?) + ORDER BY j.created_at DESC, j.id DESC + LIMIT ?` + } else { + query = ` + SELECT j.id, j.job_name, j.args, j.status, j.priority, j.datasets, j.metadata, + j.worker_id, j.error, j.created_at, j.updated_at, j.started_at, j.ended_at + FROM jobs j + JOIN task_group_access tga ON tga.task_id = j.id + WHERE tga.group_id = $1 + AND ($2 = '' OR (j.created_at::text || j.id) < $3) + ORDER BY j.created_at DESC, j.id DESC + LIMIT $4` + } + + fetchLimit := opts.Limit + 1 + + var rows *sql.Rows + var err error + if db.dbType == DBTypeSQLite { + rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } else { + rows, err = db.conn.QueryContext(context.Background(), query, groupID, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } + if err != nil { + return nil, "", fmt.Errorf("failed to list group tasks: %w", err) + } + defer func() { _ = rows.Close() }() + + var jobs []*Job + for rows.Next() { + job := &Job{} + var createdAt, updatedAt sql.NullTime + var startedAt, endedAt sql.NullTime + var datasetsJSON, metadataJSON []byte + + err := rows.Scan( + &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, + &createdAt, &updatedAt, &startedAt, &endedAt, + ) + if err != nil { + return nil, "", fmt.Errorf("failed to scan job: %w", err) + } + + if createdAt.Valid { + job.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + job.UpdatedAt = updatedAt.Time + } + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if endedAt.Valid { + job.EndedAt = &endedAt.Time + } + + if len(datasetsJSON) > 0 { + _ = json.Unmarshal(datasetsJSON, &job.Datasets) + } + if len(metadataJSON) > 0 { + _ = json.Unmarshal(metadataJSON, &job.Metadata) + } + + jobs = append(jobs, job) + } + + if err = rows.Err(); err != nil { + return nil, "", fmt.Errorf("error iterating jobs: %w", err) + } + + nextCursor := "" + if len(jobs) > opts.Limit { + lastJob := jobs[opts.Limit-1] + nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) + jobs = jobs[:opts.Limit] + } + + return jobs, nextCursor, nil +} + +// listTasksPaginated is a helper for admin queries that lists all tasks. +func (db *DB) listTasksPaginated(baseQuery string, opts ListTasksOptions) ([]*Job, string, error) { + if opts.Limit <= 0 || opts.Limit > 100 { + opts.Limit = 100 + } + + cursorCreatedAt, cursorID, _ := decodeCursor(opts.Cursor) + + var query string + if db.dbType == DBTypeSQLite { + query = baseQuery + ` + WHERE (? = '' OR (datetime(created_at) || id) < ?) + ORDER BY created_at DESC, id DESC + LIMIT ?` + } else { + query = baseQuery + ` + WHERE ($1 = '' OR (created_at::text || id) < $2) + ORDER BY created_at DESC, id DESC + LIMIT $3` + } + + fetchLimit := opts.Limit + 1 + + var rows *sql.Rows + var err error + if db.dbType == DBTypeSQLite { + rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } else { + rows, err = db.conn.QueryContext(context.Background(), query, cursorCreatedAt, cursorCreatedAt+cursorID, fetchLimit) + } + if err != nil { + return nil, "", fmt.Errorf("failed to list tasks: %w", err) + } + defer func() { _ = rows.Close() }() + + var jobs []*Job + for rows.Next() { + job := &Job{} + var createdAt, updatedAt sql.NullTime + var startedAt, endedAt sql.NullTime + var datasetsJSON, metadataJSON []byte + + err := rows.Scan( + &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, + &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, + &createdAt, &updatedAt, &startedAt, &endedAt, + ) + if err != nil { + return nil, "", fmt.Errorf("failed to scan job: %w", err) + } + + if createdAt.Valid { + job.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + job.UpdatedAt = updatedAt.Time + } + if startedAt.Valid { + job.StartedAt = &startedAt.Time + } + if endedAt.Valid { + job.EndedAt = &endedAt.Time + } + + if len(datasetsJSON) > 0 { + _ = json.Unmarshal(datasetsJSON, &job.Datasets) + } + if len(metadataJSON) > 0 { + _ = json.Unmarshal(metadataJSON, &job.Metadata) + } + + jobs = append(jobs, job) + } + + if err = rows.Err(); err != nil { + return nil, "", fmt.Errorf("error iterating jobs: %w", err) + } + + nextCursor := "" + if len(jobs) > opts.Limit { + lastJob := jobs[opts.Limit-1] + nextCursor = encodeCursor(lastJob.CreatedAt, lastJob.ID) + jobs = jobs[:opts.Limit] + } + + return jobs, nextCursor, nil +} + +// encodeCursor creates an opaque cursor string from created_at and id. +// Format: base64url(created_at_RFC3339 + "|" + task_id) +func encodeCursor(createdAt time.Time, id string) string { + payload := createdAt.UTC().Format(time.RFC3339) + "|" + id + return base64.RawURLEncoding.EncodeToString([]byte(payload)) +} + +// decodeCursor extracts created_at and id from an opaque cursor. +// Returns empty strings if cursor is invalid or empty. +func decodeCursor(cursor string) (createdAt, id string, err error) { + if cursor == "" { + return "", "", nil + } + + payload, err := base64.RawURLEncoding.DecodeString(cursor) + if err != nil { + return "", "", err + } + + parts := strings.SplitN(string(payload), "|", 2) + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid cursor format") + } + + return parts[0], parts[1], nil +} diff --git a/internal/storage/db_tokens.go b/internal/storage/db_tokens.go new file mode 100644 index 0000000..a33d10e --- /dev/null +++ b/internal/storage/db_tokens.go @@ -0,0 +1,219 @@ +package storage + +import ( + "context" + "database/sql" + "fmt" + "time" +) + +// ShareToken represents a signed token for unauthenticated access +type ShareToken struct { + Token string `json:"token"` + TaskID *string `json:"task_id,omitempty"` + ExperimentID *string `json:"experiment_id,omitempty"` + CreatedBy string `json:"created_by"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + AccessCount int `json:"access_count"` + MaxAccesses *int `json:"max_accesses,omitempty"` +} + +// CreateShareToken stores a new share token in the database. +func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error { + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses) + VALUES (?, ?, ?, ?, ?, ?)` + } else { + query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses) + VALUES ($1, $2, $3, $4, $5, $6)` + } + + _, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses) + if err != nil { + return fmt.Errorf("failed to create share token: %w", err) + } + return nil +} + +// GetShareToken retrieves a share token by its value. +func (db *DB) GetShareToken(token string) (*ShareToken, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE token = ?` + } else { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE token = $1` + } + + var t ShareToken + var taskID, experimentID sql.NullString + var expiresAt sql.NullTime + var maxAccesses sql.NullInt64 + + err := db.conn.QueryRowContext(context.Background(), query, token).Scan( + &t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("failed to get share token: %w", err) + } + + if taskID.Valid { + t.TaskID = &taskID.String + } + if experimentID.Valid { + t.ExperimentID = &experimentID.String + } + if expiresAt.Valid { + t.ExpiresAt = &expiresAt.Time + } + if maxAccesses.Valid { + m := int(maxAccesses.Int64) + t.MaxAccesses = &m + } + + return &t, nil +} + +// IncrementTokenAccessCount increments the access count for a token. +func (db *DB) IncrementTokenAccessCount(token string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?` + } else { + query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1` + } + + _, err := db.conn.ExecContext(context.Background(), query, token) + if err != nil { + return fmt.Errorf("failed to increment token access count: %w", err) + } + return nil +} + +// DeleteShareToken removes a share token from the database. +func (db *DB) DeleteShareToken(token string) error { + var query string + if db.dbType == DBTypeSQLite { + query = `DELETE FROM share_tokens WHERE token = ?` + } else { + query = `DELETE FROM share_tokens WHERE token = $1` + } + + _, err := db.conn.ExecContext(context.Background(), query, token) + if err != nil { + return fmt.Errorf("failed to delete share token: %w", err) + } + return nil +} + +// ListShareTokensForTask returns all share tokens for a specific task. +func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC` + } else { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC` + } + + return db.queryShareTokens(query, taskID) +} + +// ListShareTokensForExperiment returns all share tokens for a specific experiment. +func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) { + var query string + if db.dbType == DBTypeSQLite { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC` + } else { + query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses + FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC` + } + + return db.queryShareTokens(query, experimentID) +} + +// queryShareTokens is a helper to execute share token queries. +func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) { + rows, err := db.conn.QueryContext(context.Background(), query, arg) + if err != nil { + return nil, fmt.Errorf("failed to list share tokens: %w", err) + } + defer func() { _ = rows.Close() }() + + var tokens []*ShareToken + for rows.Next() { + var t ShareToken + var taskID, experimentID sql.NullString + var expiresAt sql.NullTime + var maxAccesses sql.NullInt64 + + err := rows.Scan( + &t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan share token: %w", err) + } + + if taskID.Valid { + t.TaskID = &taskID.String + } + if experimentID.Valid { + t.ExperimentID = &experimentID.String + } + if expiresAt.Valid { + t.ExpiresAt = &expiresAt.Time + } + if maxAccesses.Valid { + m := int(maxAccesses.Int64) + t.MaxAccesses = &m + } + + tokens = append(tokens, &t) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating share tokens: %w", err) + } + + return tokens, nil +} + +// ValidateShareToken checks if a token is valid for accessing a task or experiment. +// Returns the token details if valid, nil if invalid or expired. +func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) { + t, err := db.GetShareToken(token) + if err != nil { + return nil, err + } + if t == nil { + return nil, nil + } + + // Check if token is for the correct resource + if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) { + return nil, nil + } + if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) { + return nil, nil + } + + // Check expiry + if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) { + return nil, nil + } + + // Check max accesses + if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses { + return nil, nil + } + + return t, nil +} diff --git a/internal/storage/schema_sqlite.sql b/internal/storage/schema_sqlite.sql index c011d45..0bee329 100644 --- a/internal/storage/schema_sqlite.sql +++ b/internal/storage/schema_sqlite.sql @@ -11,10 +11,13 @@ CREATE TABLE IF NOT EXISTS jobs ( started_at DATETIME, ended_at DATETIME, worker_id TEXT, + user_id TEXT, error TEXT, datasets TEXT, -- JSON array metadata TEXT, -- JSON object - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + visibility TEXT NOT NULL DEFAULT 'lab', + experiment_id TEXT REFERENCES experiments(id) ); CREATE TABLE IF NOT EXISTS job_metrics ( @@ -142,3 +145,113 @@ CREATE TABLE IF NOT EXISTS websocket_metrics ( ); CREATE INDEX IF NOT EXISTS idx_websocket_metrics_name_time ON websocket_metrics(metric_name, recorded_at); + +-- Groups and membership for lab-based task sharing +CREATE TABLE IF NOT EXISTS groups ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + created_by TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS group_members ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role TEXT DEFAULT 'member', -- 'admin', 'member', 'viewer' + joined_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (group_id, user_id), + FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE +); + +-- System group for institution visibility (all authenticated users) +INSERT OR IGNORE INTO groups (id, name, description, created_by) +VALUES ('all-users', 'all-users', 'System group: all authenticated users', 'system'); + +-- Invite-and-accept flow: group admins invite; users accept or decline +CREATE TABLE IF NOT EXISTS group_invitations ( + id TEXT PRIMARY KEY, + group_id TEXT NOT NULL, + invited_user_id TEXT NOT NULL, + invited_by TEXT NOT NULL, + status TEXT DEFAULT 'pending', -- 'pending', 'accepted', 'declined' + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME, -- NULL = 7d default enforced in app layer + FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE +); + +-- Experiment/project grouping: share a whole experiment, not individual tasks +-- Note: experiments table already exists; adding group_id to link with sharing system +ALTER TABLE experiments ADD COLUMN group_id TEXT REFERENCES groups(id); + +-- Link tasks to experiments +CREATE TABLE IF NOT EXISTS experiment_tasks ( + experiment_id TEXT NOT NULL, + task_id TEXT NOT NULL, + added_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (experiment_id, task_id), + FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE, + FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE +); + +-- Per-user explicit shares with optional expiry +CREATE TABLE IF NOT EXISTS task_shares ( + task_id TEXT NOT NULL, + user_id TEXT NOT NULL, + granted_by TEXT NOT NULL, + granted_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME, -- NULL = no expiry; checked at access time + PRIMARY KEY (task_id, user_id), + FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE +); + +-- Group-level task association +-- Records which group a task is associated with at submit time. +-- Actual membership is always resolved live from group_members. +CREATE TABLE IF NOT EXISTS task_group_access ( + task_id TEXT NOT NULL, + group_id TEXT NOT NULL, + PRIMARY KEY (task_id, group_id), + FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE, + FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE +); + +-- Signed share tokens for unauthenticated open access (paper reproducibility links) +CREATE TABLE IF NOT EXISTS share_tokens ( + token TEXT PRIMARY KEY, -- cryptographically random (32 bytes, base64url) + task_id TEXT, -- NULL if experiment-level + experiment_id TEXT, -- NULL if task-level + created_by TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME, -- NULL = never expires + access_count INTEGER DEFAULT 0, + max_accesses INTEGER, -- NULL = unlimited + FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE, + FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE +); + +-- Audit log for task access +CREATE TABLE IF NOT EXISTS task_access_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT NOT NULL, + user_id TEXT, -- NULL for token-based access + token TEXT, -- NULL for session-based access + action TEXT NOT NULL, -- 'view', 'clone', 'execute', 'modify' + accessed_at DATETIME DEFAULT CURRENT_TIMESTAMP, + ip_address TEXT, + FOREIGN KEY (task_id) REFERENCES jobs(id) ON DELETE CASCADE +); + +-- Indexes for task sharing performance +CREATE INDEX IF NOT EXISTS idx_jobs_visibility ON jobs(visibility); +CREATE INDEX IF NOT EXISTS idx_jobs_user_id ON jobs(user_id); +CREATE INDEX IF NOT EXISTS idx_jobs_visibility_owner ON jobs(visibility, user_id); +CREATE INDEX IF NOT EXISTS idx_jobs_experiment ON jobs(experiment_id); +CREATE INDEX IF NOT EXISTS idx_task_shares_user ON task_shares(user_id); +CREATE INDEX IF NOT EXISTS idx_task_shares_expires ON task_shares(expires_at); +CREATE INDEX IF NOT EXISTS idx_tga_group ON task_group_access(group_id); +CREATE INDEX IF NOT EXISTS idx_share_tokens_task ON share_tokens(task_id); +CREATE INDEX IF NOT EXISTS idx_task_access_task ON task_access_log(task_id); +CREATE INDEX IF NOT EXISTS idx_task_access_user ON task_access_log(user_id); +CREATE INDEX IF NOT EXISTS idx_task_access_token ON task_access_log(token) WHERE token IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_invitations_user ON group_invitations(invited_user_id);