package storage import ( "context" "database/sql" "fmt" "time" ) // ShareToken represents a signed token for unauthenticated access type ShareToken struct { Token string `json:"token"` TaskID *string `json:"task_id,omitempty"` ExperimentID *string `json:"experiment_id,omitempty"` CreatedBy string `json:"created_by"` CreatedAt time.Time `json:"created_at"` ExpiresAt *time.Time `json:"expires_at,omitempty"` AccessCount int `json:"access_count"` MaxAccesses *int `json:"max_accesses,omitempty"` } // CreateShareToken stores a new share token in the database. func (db *DB) CreateShareToken(token string, taskID, experimentID *string, createdBy string, expiresAt *time.Time, maxAccesses *int) error { var query string if db.dbType == DBTypeSQLite { query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses) VALUES (?, ?, ?, ?, ?, ?)` } else { query = `INSERT INTO share_tokens (token, task_id, experiment_id, created_by, expires_at, max_accesses) VALUES ($1, $2, $3, $4, $5, $6)` } _, err := db.conn.ExecContext(context.Background(), query, token, taskID, experimentID, createdBy, expiresAt, maxAccesses) if err != nil { return fmt.Errorf("failed to create share token: %w", err) } return nil } // GetShareToken retrieves a share token by its value. func (db *DB) GetShareToken(token string) (*ShareToken, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE token = ?` } else { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE token = $1` } var t ShareToken var taskID, experimentID sql.NullString var expiresAt sql.NullTime var maxAccesses sql.NullInt64 err := db.conn.QueryRowContext(context.Background(), query, token).Scan( &t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses, ) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, fmt.Errorf("failed to get share token: %w", err) } if taskID.Valid { t.TaskID = &taskID.String } if experimentID.Valid { t.ExperimentID = &experimentID.String } if expiresAt.Valid { t.ExpiresAt = &expiresAt.Time } if maxAccesses.Valid { m := int(maxAccesses.Int64) t.MaxAccesses = &m } return &t, nil } // IncrementTokenAccessCount increments the access count for a token. func (db *DB) IncrementTokenAccessCount(token string) error { var query string if db.dbType == DBTypeSQLite { query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = ?` } else { query = `UPDATE share_tokens SET access_count = access_count + 1 WHERE token = $1` } _, err := db.conn.ExecContext(context.Background(), query, token) if err != nil { return fmt.Errorf("failed to increment token access count: %w", err) } return nil } // DeleteShareToken removes a share token from the database. func (db *DB) DeleteShareToken(token string) error { var query string if db.dbType == DBTypeSQLite { query = `DELETE FROM share_tokens WHERE token = ?` } else { query = `DELETE FROM share_tokens WHERE token = $1` } _, err := db.conn.ExecContext(context.Background(), query, token) if err != nil { return fmt.Errorf("failed to delete share token: %w", err) } return nil } // ListShareTokensForTask returns all share tokens for a specific task. func (db *DB) ListShareTokensForTask(taskID string) ([]*ShareToken, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE task_id = ? ORDER BY created_at DESC` } else { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE task_id = $1 ORDER BY created_at DESC` } return db.queryShareTokens(query, taskID) } // ListShareTokensForExperiment returns all share tokens for a specific experiment. func (db *DB) ListShareTokensForExperiment(experimentID string) ([]*ShareToken, error) { var query string if db.dbType == DBTypeSQLite { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE experiment_id = ? ORDER BY created_at DESC` } else { query = `SELECT token, task_id, experiment_id, created_by, created_at, expires_at, access_count, max_accesses FROM share_tokens WHERE experiment_id = $1 ORDER BY created_at DESC` } return db.queryShareTokens(query, experimentID) } // queryShareTokens is a helper to execute share token queries. func (db *DB) queryShareTokens(query string, arg string) ([]*ShareToken, error) { rows, err := db.conn.QueryContext(context.Background(), query, arg) if err != nil { return nil, fmt.Errorf("failed to list share tokens: %w", err) } defer func() { _ = rows.Close() }() var tokens []*ShareToken for rows.Next() { var t ShareToken var taskID, experimentID sql.NullString var expiresAt sql.NullTime var maxAccesses sql.NullInt64 err := rows.Scan( &t.Token, &taskID, &experimentID, &t.CreatedBy, &t.CreatedAt, &expiresAt, &t.AccessCount, &maxAccesses, ) if err != nil { return nil, fmt.Errorf("failed to scan share token: %w", err) } if taskID.Valid { t.TaskID = &taskID.String } if experimentID.Valid { t.ExperimentID = &experimentID.String } if expiresAt.Valid { t.ExpiresAt = &expiresAt.Time } if maxAccesses.Valid { m := int(maxAccesses.Int64) t.MaxAccesses = &m } tokens = append(tokens, &t) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error iterating share tokens: %w", err) } return tokens, nil } // ValidateShareToken checks if a token is valid for accessing a task or experiment. // Returns the token details if valid, nil if invalid or expired. func (db *DB) ValidateShareToken(token string, taskID *string, experimentID *string) (*ShareToken, error) { t, err := db.GetShareToken(token) if err != nil { return nil, err } if t == nil { return nil, nil } // Check if token is for the correct resource if taskID != nil && (t.TaskID == nil || *t.TaskID != *taskID) { return nil, nil } if experimentID != nil && (t.ExperimentID == nil || *t.ExperimentID != *experimentID) { return nil, nil } // Check expiry if t.ExpiresAt != nil && time.Now().After(*t.ExpiresAt) { return nil, nil } // Check max accesses if t.MaxAccesses != nil && t.AccessCount >= *t.MaxAccesses { return nil, nil } return t, nil }