From 4da027868de6b78fe44bdff31df8a9f4a58a5f57 Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Fri, 13 Mar 2026 23:27:35 -0400 Subject: [PATCH] fix(storage): handle NULL values and state tracking in database operations Fixes to support proper test coverage: - db_jobs.go: UpdateJobStatus now checks RowsAffected and returns error for nonexistent jobs instead of silently succeeding - db_audit.go: GetOldestAuditLogDate uses sql.NullString to parse SQLite datetime strings in YYYY-MM-DD HH:MM:SS format with RFC3339 fallback - db_experiments.go: ListTasksForExperiment uses sql.NullString for nullable worker_id and error fields to prevent scan errors - db_connect.go: DB struct adds isClosed state tracking with mutex; Close() now returns error on double close to match test expectations --- internal/storage/db_audit.go | 18 ++++++++++++++---- internal/storage/db_connect.go | 15 +++++++++++++-- internal/storage/db_experiments.go | 10 +++++++++- internal/storage/db_jobs.go | 12 +++++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/internal/storage/db_audit.go b/internal/storage/db_audit.go index a5e795a..186b427 100644 --- a/internal/storage/db_audit.go +++ b/internal/storage/db_audit.go @@ -211,15 +211,25 @@ func (db *DB) GetOldestAuditLogDate() (*time.Time, error) { query = `SELECT MIN(accessed_at) FROM task_access_log` } - var date sql.NullTime - err := db.conn.QueryRowContext(context.Background(), query).Scan(&date) + var dateStr sql.NullString + err := db.conn.QueryRowContext(context.Background(), query).Scan(&dateStr) if err != nil { return nil, fmt.Errorf("failed to get oldest audit log date: %w", err) } - if !date.Valid { + if !dateStr.Valid || dateStr.String == "" { return nil, nil } - return &date.Time, nil + // Parse SQLite datetime string + t, err := time.Parse("2006-01-02 15:04:05", dateStr.String) + if err != nil { + // Try RFC3339 format as fallback + t, err = time.Parse(time.RFC3339, dateStr.String) + if err != nil { + return nil, fmt.Errorf("failed to parse date: %w", err) + } + } + + return &t, nil } diff --git a/internal/storage/db_connect.go b/internal/storage/db_connect.go index bfa2d19..19fe454 100644 --- a/internal/storage/db_connect.go +++ b/internal/storage/db_connect.go @@ -7,6 +7,7 @@ import ( "fmt" "regexp" "strings" + "sync" "time" _ "github.com/lib/pq" // PostgreSQL driver @@ -43,8 +44,10 @@ type DBConfig struct { // DB wraps a database connection with type information. type DB struct { - conn *sql.DB - dbType string + conn *sql.DB + dbType string + isClosed bool + closeMu sync.Mutex } // DBTypeSQLite is the constant for SQLite database type @@ -160,5 +163,13 @@ func (db *DB) Initialize(schema string) error { // Close closes the database connection. func (db *DB) Close() error { + db.closeMu.Lock() + defer db.closeMu.Unlock() + + if db.isClosed { + return fmt.Errorf("database connection already closed") + } + + db.isClosed = true return db.conn.Close() } diff --git a/internal/storage/db_experiments.go b/internal/storage/db_experiments.go index f12df69..47abee0 100644 --- a/internal/storage/db_experiments.go +++ b/internal/storage/db_experiments.go @@ -683,16 +683,24 @@ func (db *DB) ListTasksForExperiment(experimentID string, opts ListTasksOptions) var createdAt, updatedAt sql.NullTime var startedAt, endedAt sql.NullTime var datasetsJSON, metadataJSON []byte + var workerID, errorMsg sql.NullString err := rows.Scan( &job.ID, &job.JobName, &job.Args, &job.Status, &job.Priority, - &datasetsJSON, &metadataJSON, &job.WorkerID, &job.Error, + &datasetsJSON, &metadataJSON, &workerID, &errorMsg, &createdAt, &updatedAt, &startedAt, &endedAt, ) if err != nil { return nil, "", fmt.Errorf("failed to scan job: %w", err) } + if workerID.Valid { + job.WorkerID = workerID.String + } + if errorMsg.Valid { + job.Error = errorMsg.String + } + if createdAt.Valid { job.CreatedAt = createdAt.Time } diff --git a/internal/storage/db_jobs.go b/internal/storage/db_jobs.go index 0a31df0..936112c 100644 --- a/internal/storage/db_jobs.go +++ b/internal/storage/db_jobs.go @@ -127,7 +127,7 @@ func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { WHERE id = $6` } - _, err := db.conn.ExecContext( + result, err := db.conn.ExecContext( context.Background(), query, status, @@ -140,6 +140,16 @@ func (db *DB) UpdateJobStatus(id, status, workerID, errorMsg string) error { if err != nil { return fmt.Errorf("failed to update job status: %w", err) } + + // Check if any rows were affected + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if affected == 0 { + return fmt.Errorf("job not found: %s", id) + } + return nil }