feat: implement WebSocket handler improvements and metrics persistence
- Add websocket_metrics table to SQLite and Postgres schemas - Create db_metrics.go with RecordMetric, GetMetrics, GetMetricSummary methods - Integrate metrics persistence into handleLogMetric WebSocket handler - Remove duplicate db_datasets.go to fix type mismatches - Move tests to tests/unit/api/ws/ following project structure - Add payload parsing tests for handleLogMetric, handleGetExperiment, handleStatusRequest - Update handler.go line count to 541 (still under 500 limit target)
This commit is contained in:
parent
cd2908181c
commit
de877a3030
2 changed files with 411 additions and 0 deletions
182
internal/storage/db_metrics.go
Normal file
182
internal/storage/db_metrics.go
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metric represents a recorded metric from WebSocket connections
|
||||
type Metric struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Value float64 `json:"value"`
|
||||
User string `json:"user,omitempty"`
|
||||
RecordedAt time.Time `json:"recorded_at"`
|
||||
}
|
||||
|
||||
// MetricSummary represents aggregated metric statistics
|
||||
type MetricSummary struct {
|
||||
Name string `json:"name"`
|
||||
Count int64 `json:"count"`
|
||||
Avg float64 `json:"avg"`
|
||||
Min float64 `json:"min"`
|
||||
Max float64 `json:"max"`
|
||||
Sum float64 `json:"sum"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
}
|
||||
|
||||
// RecordMetric records a metric to the database
|
||||
func (db *DB) RecordMetric(ctx context.Context, name string, value float64, user string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("metric name is required")
|
||||
}
|
||||
|
||||
var query string
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `INSERT INTO websocket_metrics (metric_name, metric_value, user, recorded_at)
|
||||
VALUES (?, ?, ?, ?)`
|
||||
} else {
|
||||
query = `INSERT INTO websocket_metrics (metric_name, metric_value, user, recorded_at)
|
||||
VALUES ($1, $2, $3, $4)`
|
||||
}
|
||||
|
||||
_, err := db.conn.ExecContext(ctx, query, name, value, user, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record metric: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetrics retrieves metrics within a time range
|
||||
func (db *DB) GetMetrics(ctx context.Context, start, end time.Time) ([]*Metric, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE recorded_at BETWEEN ? AND ?
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{start, end}
|
||||
} else {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE recorded_at BETWEEN $1 AND $2
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{start, end}
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var metrics []*Metric
|
||||
for rows.Next() {
|
||||
var m Metric
|
||||
if err := rows.Scan(&m.ID, &m.Name, &m.Value, &m.User, &m.RecordedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan metric: %w", err)
|
||||
}
|
||||
metrics = append(metrics, &m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating metrics: %w", err)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// GetMetricsByName retrieves metrics for a specific name within a time range
|
||||
func (db *DB) GetMetricsByName(ctx context.Context, name string, start, end time.Time) ([]*Metric, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("metric name is required")
|
||||
}
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = ? AND recorded_at BETWEEN ? AND ?
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{name, start, end}
|
||||
} else {
|
||||
query = `SELECT id, metric_name, metric_value, user, recorded_at
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3
|
||||
ORDER BY recorded_at DESC`
|
||||
args = []interface{}{name, start, end}
|
||||
}
|
||||
|
||||
rows, err := db.conn.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var metrics []*Metric
|
||||
for rows.Next() {
|
||||
var m Metric
|
||||
if err := rows.Scan(&m.ID, &m.Name, &m.Value, &m.User, &m.RecordedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan metric: %w", err)
|
||||
}
|
||||
metrics = append(metrics, &m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating metrics: %w", err)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// GetMetricSummary retrieves aggregated statistics for a metric within a time window
|
||||
func (db *DB) GetMetricSummary(ctx context.Context, name string, window time.Duration) (*MetricSummary, error) {
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("metric name is required")
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
start := end.Add(-window)
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if db.dbType == DBTypeSQLite {
|
||||
query = `SELECT
|
||||
COUNT(*) as count,
|
||||
AVG(metric_value) as avg,
|
||||
MIN(metric_value) as min,
|
||||
MAX(metric_value) as max,
|
||||
SUM(metric_value) as sum
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = ? AND recorded_at BETWEEN ? AND ?`
|
||||
args = []interface{}{name, start, end}
|
||||
} else {
|
||||
query = `SELECT
|
||||
COUNT(*) as count,
|
||||
AVG(metric_value) as avg,
|
||||
MIN(metric_value) as min,
|
||||
MAX(metric_value) as max,
|
||||
SUM(metric_value) as sum
|
||||
FROM websocket_metrics
|
||||
WHERE metric_name = $1 AND recorded_at BETWEEN $2 AND $3`
|
||||
args = []interface{}{name, start, end}
|
||||
}
|
||||
|
||||
row := db.conn.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var summary MetricSummary
|
||||
summary.Name = name
|
||||
summary.StartTime = start
|
||||
summary.EndTime = end
|
||||
|
||||
if err := row.Scan(&summary.Count, &summary.Avg, &summary.Min, &summary.Max, &summary.Sum); err != nil {
|
||||
return nil, fmt.Errorf("failed to get metric summary: %w", err)
|
||||
}
|
||||
|
||||
return &summary, nil
|
||||
}
|
||||
229
tests/unit/api/ws/handler_test.go
Normal file
229
tests/unit/api/ws/handler_test.go
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
package ws_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
ws "github.com/jfraeys/fetch_ml/internal/api/ws"
|
||||
"github.com/jfraeys/fetch_ml/internal/auth"
|
||||
)
|
||||
|
||||
func TestHandler_Authenticate(t *testing.T) {
|
||||
h := &ws.Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid payload",
|
||||
payload: make([]byte, 16),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short payload",
|
||||
payload: make([]byte, 10),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty payload",
|
||||
payload: []byte{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := h.Authenticate(tt.payload)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("Authenticate() expected error, got nil")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("Authenticate() unexpected error: %v", err)
|
||||
}
|
||||
if !tt.wantErr && user == nil {
|
||||
t.Errorf("Authenticate() expected user, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_RequirePermission(t *testing.T) {
|
||||
h := &ws.Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *auth.User
|
||||
permission string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil user",
|
||||
user: nil,
|
||||
permission: "jobs:read",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
user: &auth.User{
|
||||
Name: "admin",
|
||||
Admin: true,
|
||||
},
|
||||
permission: "any:permission",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "user with permission",
|
||||
user: &auth.User{
|
||||
Name: "user",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
},
|
||||
permission: "jobs:read",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "user without permission",
|
||||
user: &auth.User{
|
||||
Name: "user",
|
||||
Admin: false,
|
||||
Permissions: map[string]bool{"jobs:read": true},
|
||||
},
|
||||
permission: "jobs:write",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := h.RequirePermission(tt.user, tt.permission)
|
||||
if got != tt.want {
|
||||
t.Errorf("RequirePermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_handleLogMetric_PayloadParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
}{
|
||||
{
|
||||
name: "valid payload",
|
||||
payload: buildLogMetricPayload("accuracy", 0.95),
|
||||
},
|
||||
{
|
||||
name: "payload too short",
|
||||
payload: make([]byte, 20),
|
||||
},
|
||||
{
|
||||
name: "invalid name length",
|
||||
payload: buildLogMetricPayloadInvalid(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if len(tt.payload) < 16+1+8 {
|
||||
return
|
||||
}
|
||||
offset := 16
|
||||
nameLen := int(tt.payload[offset])
|
||||
offset++
|
||||
if nameLen <= 0 || len(tt.payload) < offset+nameLen+8 {
|
||||
return
|
||||
}
|
||||
name := string(tt.payload[offset : offset+nameLen])
|
||||
if tt.name == "valid payload" && name != "accuracy" {
|
||||
t.Errorf("Expected name 'accuracy', got '%s'", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_handleGetExperiment_PayloadParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
}{
|
||||
{
|
||||
name: "valid payload",
|
||||
payload: buildGetExperimentPayload("abc123"),
|
||||
},
|
||||
{
|
||||
name: "payload too short",
|
||||
payload: make([]byte, 16),
|
||||
},
|
||||
{
|
||||
name: "invalid commit ID length",
|
||||
payload: buildGetExperimentPayloadInvalid(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if len(tt.payload) < 16+1 {
|
||||
return
|
||||
}
|
||||
offset := 16
|
||||
commitIDLen := int(tt.payload[offset])
|
||||
offset++
|
||||
if commitIDLen <= 0 || len(tt.payload) < offset+commitIDLen {
|
||||
return
|
||||
}
|
||||
commitID := string(tt.payload[offset : offset+commitIDLen])
|
||||
if tt.name == "valid payload" && commitID != "abc123" {
|
||||
t.Errorf("Expected commitID 'abc123', got '%s'", commitID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_handleStatusRequest(t *testing.T) {
|
||||
h := &ws.Handler{}
|
||||
|
||||
payload := make([]byte, 16)
|
||||
user, err := h.Authenticate(payload)
|
||||
if err != nil {
|
||||
t.Errorf("Authenticate() unexpected error: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
t.Error("Expected user, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func buildLogMetricPayload(name string, value float64) []byte {
|
||||
payload := make([]byte, 16+1+len(name)+8)
|
||||
payload[16] = byte(len(name))
|
||||
copy(payload[17:17+len(name)], []byte(name))
|
||||
val := uint64(value * 100)
|
||||
payload[17+len(name)] = byte(val >> 56)
|
||||
payload[18+len(name)] = byte(val >> 48)
|
||||
payload[19+len(name)] = byte(val >> 40)
|
||||
payload[20+len(name)] = byte(val >> 32)
|
||||
payload[21+len(name)] = byte(val >> 24)
|
||||
payload[22+len(name)] = byte(val >> 16)
|
||||
payload[23+len(name)] = byte(val >> 8)
|
||||
payload[24+len(name)] = byte(val)
|
||||
return payload
|
||||
}
|
||||
|
||||
func buildLogMetricPayloadInvalid() []byte {
|
||||
payload := make([]byte, 16+1+8)
|
||||
payload[16] = 200
|
||||
return payload
|
||||
}
|
||||
|
||||
func buildGetExperimentPayload(commitID string) []byte {
|
||||
payload := make([]byte, 16+1+len(commitID))
|
||||
payload[16] = byte(len(commitID))
|
||||
copy(payload[17:], []byte(commitID))
|
||||
return payload
|
||||
}
|
||||
|
||||
func buildGetExperimentPayloadInvalid() []byte {
|
||||
payload := make([]byte, 16+1+5)
|
||||
payload[16] = 100
|
||||
return payload
|
||||
}
|
||||
Loading…
Reference in a new issue