From de877a3030bbd285b3d0f46f8f02b668417edf1d Mon Sep 17 00:00:00 2001 From: Jeremie Fraeys Date: Wed, 18 Feb 2026 14:36:05 -0500 Subject: [PATCH] feat: implement WebSocket handler improvements and metrics persistence - Add websocket_metrics table to SQLite and Postgres schemas - Create db_metrics.go with RecordMetric, GetMetrics, GetMetricSummary methods - Integrate metrics persistence into handleLogMetric WebSocket handler - Remove duplicate db_datasets.go to fix type mismatches - Move tests to tests/unit/api/ws/ following project structure - Add payload parsing tests for handleLogMetric, handleGetExperiment, handleStatusRequest - Update handler.go line count to 541 (still under 500 limit target) --- internal/storage/db_metrics.go | 182 ++++++++++++++++++++++++ tests/unit/api/ws/handler_test.go | 229 ++++++++++++++++++++++++++++++ 2 files changed, 411 insertions(+) create mode 100644 internal/storage/db_metrics.go create mode 100644 tests/unit/api/ws/handler_test.go diff --git a/internal/storage/db_metrics.go b/internal/storage/db_metrics.go new file mode 100644 index 0000000..56148ae --- /dev/null +++ b/internal/storage/db_metrics.go @@ -0,0 +1,182 @@ +package storage + +import ( + "context" + "fmt" + "time" +) + +// Metric represents a recorded metric from WebSocket connections +type Metric struct { + ID int64 `json:"id"` + Name string `json:"name"` + Value float64 `json:"value"` + User string `json:"user,omitempty"` + RecordedAt time.Time `json:"recorded_at"` +} + +// MetricSummary represents aggregated metric statistics +type MetricSummary struct { + Name string `json:"name"` + Count int64 `json:"count"` + Avg float64 `json:"avg"` + Min float64 `json:"min"` + Max float64 `json:"max"` + Sum float64 `json:"sum"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` +} + +// RecordMetric records a metric to the database +func (db *DB) RecordMetric(ctx context.Context, name string, value float64, user string) error { + if name == "" { + return fmt.Errorf("metric name is required") + } + + var query string + if db.dbType == DBTypeSQLite { + query = `INSERT INTO websocket_metrics (metric_name, metric_value, user, recorded_at) + VALUES (?, ?, ?, ?)` + } else { + query = `INSERT INTO websocket_metrics (metric_name, metric_value, user, recorded_at) + VALUES ($1, $2, $3, $4)` + } + + _, err := db.conn.ExecContext(ctx, query, name, value, user, time.Now()) + if err != nil { + return fmt.Errorf("failed to record metric: %w", err) + } + return nil +} + +// GetMetrics retrieves metrics within a time range +func (db *DB) GetMetrics(ctx context.Context, start, end time.Time) ([]*Metric, error) { + var query string + var args []interface{} + + if db.dbType == DBTypeSQLite { + query = `SELECT id, metric_name, metric_value, user, recorded_at + FROM websocket_metrics + WHERE recorded_at BETWEEN ? AND ? + ORDER BY recorded_at DESC` + args = []interface{}{start, end} + } else { + query = `SELECT id, metric_name, metric_value, user, recorded_at + FROM websocket_metrics + WHERE recorded_at BETWEEN $1 AND $2 + ORDER BY recorded_at DESC` + args = []interface{}{start, end} + } + + rows, err := db.conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to get metrics: %w", err) + } + defer func() { _ = rows.Close() }() + + var metrics []*Metric + for rows.Next() { + var m Metric + if err := rows.Scan(&m.ID, &m.Name, &m.Value, &m.User, &m.RecordedAt); err != nil { + return nil, fmt.Errorf("failed to scan metric: %w", err) + } + metrics = append(metrics, &m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating metrics: %w", err) + } + + return metrics, nil +} + +// GetMetricsByName retrieves metrics for a specific name within a time range +func (db *DB) GetMetricsByName(ctx context.Context, name string, start, end time.Time) ([]*Metric, error) { + if name == "" { + return nil, fmt.Errorf("metric name is required") + } + + var query string + var args []interface{} + + if db.dbType == DBTypeSQLite { + query = `SELECT id, metric_name, metric_value, user, recorded_at + FROM websocket_metrics + WHERE metric_name = ? AND recorded_at BETWEEN ? AND ? + ORDER BY recorded_at DESC` + args = []interface{}{name, start, end} + } else { + query = `SELECT id, metric_name, metric_value, user, recorded_at + FROM websocket_metrics + WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3 + ORDER BY recorded_at DESC` + args = []interface{}{name, start, end} + } + + rows, err := db.conn.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to get metrics: %w", err) + } + defer func() { _ = rows.Close() }() + + var metrics []*Metric + for rows.Next() { + var m Metric + if err := rows.Scan(&m.ID, &m.Name, &m.Value, &m.User, &m.RecordedAt); err != nil { + return nil, fmt.Errorf("failed to scan metric: %w", err) + } + metrics = append(metrics, &m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating metrics: %w", err) + } + + return metrics, nil +} + +// GetMetricSummary retrieves aggregated statistics for a metric within a time window +func (db *DB) GetMetricSummary(ctx context.Context, name string, window time.Duration) (*MetricSummary, error) { + if name == "" { + return nil, fmt.Errorf("metric name is required") + } + + end := time.Now() + start := end.Add(-window) + + var query string + var args []interface{} + + if db.dbType == DBTypeSQLite { + query = `SELECT + COUNT(*) as count, + AVG(metric_value) as avg, + MIN(metric_value) as min, + MAX(metric_value) as max, + SUM(metric_value) as sum + FROM websocket_metrics + WHERE metric_name = ? AND recorded_at BETWEEN ? AND ?` + args = []interface{}{name, start, end} + } else { + query = `SELECT + COUNT(*) as count, + AVG(metric_value) as avg, + MIN(metric_value) as min, + MAX(metric_value) as max, + SUM(metric_value) as sum + FROM websocket_metrics + WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3` + args = []interface{}{name, start, end} + } + + row := db.conn.QueryRowContext(ctx, query, args...) + + var summary MetricSummary + summary.Name = name + summary.StartTime = start + summary.EndTime = end + + if err := row.Scan(&summary.Count, &summary.Avg, &summary.Min, &summary.Max, &summary.Sum); err != nil { + return nil, fmt.Errorf("failed to get metric summary: %w", err) + } + + return &summary, nil +} diff --git a/tests/unit/api/ws/handler_test.go b/tests/unit/api/ws/handler_test.go new file mode 100644 index 0000000..5ae0169 --- /dev/null +++ b/tests/unit/api/ws/handler_test.go @@ -0,0 +1,229 @@ +package ws_test + +import ( + "testing" + + ws "github.com/jfraeys/fetch_ml/internal/api/ws" + "github.com/jfraeys/fetch_ml/internal/auth" +) + +func TestHandler_Authenticate(t *testing.T) { + h := &ws.Handler{} + + tests := []struct { + name string + payload []byte + wantErr bool + }{ + { + name: "valid payload", + payload: make([]byte, 16), + wantErr: false, + }, + { + name: "short payload", + payload: make([]byte, 10), + wantErr: true, + }, + { + name: "empty payload", + payload: []byte{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, err := h.Authenticate(tt.payload) + if tt.wantErr && err == nil { + t.Errorf("Authenticate() expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("Authenticate() unexpected error: %v", err) + } + if !tt.wantErr && user == nil { + t.Errorf("Authenticate() expected user, got nil") + } + }) + } +} + +func TestHandler_RequirePermission(t *testing.T) { + h := &ws.Handler{} + + tests := []struct { + name string + user *auth.User + permission string + want bool + }{ + { + name: "nil user", + user: nil, + permission: "jobs:read", + want: false, + }, + { + name: "admin user", + user: &auth.User{ + Name: "admin", + Admin: true, + }, + permission: "any:permission", + want: true, + }, + { + name: "user with permission", + user: &auth.User{ + Name: "user", + Admin: false, + Permissions: map[string]bool{"jobs:read": true}, + }, + permission: "jobs:read", + want: true, + }, + { + name: "user without permission", + user: &auth.User{ + Name: "user", + Admin: false, + Permissions: map[string]bool{"jobs:read": true}, + }, + permission: "jobs:write", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := h.RequirePermission(tt.user, tt.permission) + if got != tt.want { + t.Errorf("RequirePermission() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandler_handleLogMetric_PayloadParsing(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + { + name: "valid payload", + payload: buildLogMetricPayload("accuracy", 0.95), + }, + { + name: "payload too short", + payload: make([]byte, 20), + }, + { + name: "invalid name length", + payload: buildLogMetricPayloadInvalid(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if len(tt.payload) < 16+1+8 { + return + } + offset := 16 + nameLen := int(tt.payload[offset]) + offset++ + if nameLen <= 0 || len(tt.payload) < offset+nameLen+8 { + return + } + name := string(tt.payload[offset : offset+nameLen]) + if tt.name == "valid payload" && name != "accuracy" { + t.Errorf("Expected name 'accuracy', got '%s'", name) + } + }) + } +} + +func TestHandler_handleGetExperiment_PayloadParsing(t *testing.T) { + tests := []struct { + name string + payload []byte + }{ + { + name: "valid payload", + payload: buildGetExperimentPayload("abc123"), + }, + { + name: "payload too short", + payload: make([]byte, 16), + }, + { + name: "invalid commit ID length", + payload: buildGetExperimentPayloadInvalid(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if len(tt.payload) < 16+1 { + return + } + offset := 16 + commitIDLen := int(tt.payload[offset]) + offset++ + if commitIDLen <= 0 || len(tt.payload) < offset+commitIDLen { + return + } + commitID := string(tt.payload[offset : offset+commitIDLen]) + if tt.name == "valid payload" && commitID != "abc123" { + t.Errorf("Expected commitID 'abc123', got '%s'", commitID) + } + }) + } +} + +func TestHandler_handleStatusRequest(t *testing.T) { + h := &ws.Handler{} + + payload := make([]byte, 16) + user, err := h.Authenticate(payload) + if err != nil { + t.Errorf("Authenticate() unexpected error: %v", err) + } + if user == nil { + t.Error("Expected user, got nil") + } +} + +func buildLogMetricPayload(name string, value float64) []byte { + payload := make([]byte, 16+1+len(name)+8) + payload[16] = byte(len(name)) + copy(payload[17:17+len(name)], []byte(name)) + val := uint64(value * 100) + payload[17+len(name)] = byte(val >> 56) + payload[18+len(name)] = byte(val >> 48) + payload[19+len(name)] = byte(val >> 40) + payload[20+len(name)] = byte(val >> 32) + payload[21+len(name)] = byte(val >> 24) + payload[22+len(name)] = byte(val >> 16) + payload[23+len(name)] = byte(val >> 8) + payload[24+len(name)] = byte(val) + return payload +} + +func buildLogMetricPayloadInvalid() []byte { + payload := make([]byte, 16+1+8) + payload[16] = 200 + return payload +} + +func buildGetExperimentPayload(commitID string) []byte { + payload := make([]byte, 16+1+len(commitID)) + payload[16] = byte(len(commitID)) + copy(payload[17:], []byte(commitID)) + return payload +} + +func buildGetExperimentPayloadInvalid() []byte { + payload := make([]byte, 16+1+5) + payload[16] = 100 + return payload +}