- Add websocket_metrics table to SQLite and Postgres schemas - Create db_metrics.go with RecordMetric, GetMetrics, GetMetricSummary methods - Integrate metrics persistence into handleLogMetric WebSocket handler - Remove duplicate db_datasets.go to fix type mismatches - Move tests to tests/unit/api/ws/ following project structure - Add payload parsing tests for handleLogMetric, handleGetExperiment, handleStatusRequest - Update handler.go line count to 541 (still under 500 limit target)
229 lines
4.9 KiB
Go
229 lines
4.9 KiB
Go
package ws_test
|
|
|
|
import (
|
|
"testing"
|
|
|
|
ws "github.com/jfraeys/fetch_ml/internal/api/ws"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
)
|
|
|
|
func TestHandler_Authenticate(t *testing.T) {
|
|
h := &ws.Handler{}
|
|
|
|
tests := []struct {
|
|
name string
|
|
payload []byte
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid payload",
|
|
payload: make([]byte, 16),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "short payload",
|
|
payload: make([]byte, 10),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty payload",
|
|
payload: []byte{},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
user, err := h.Authenticate(tt.payload)
|
|
if tt.wantErr && err == nil {
|
|
t.Errorf("Authenticate() expected error, got nil")
|
|
}
|
|
if !tt.wantErr && err != nil {
|
|
t.Errorf("Authenticate() unexpected error: %v", err)
|
|
}
|
|
if !tt.wantErr && user == nil {
|
|
t.Errorf("Authenticate() expected user, got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandler_RequirePermission(t *testing.T) {
|
|
h := &ws.Handler{}
|
|
|
|
tests := []struct {
|
|
name string
|
|
user *auth.User
|
|
permission string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "nil user",
|
|
user: nil,
|
|
permission: "jobs:read",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "admin user",
|
|
user: &auth.User{
|
|
Name: "admin",
|
|
Admin: true,
|
|
},
|
|
permission: "any:permission",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "user with permission",
|
|
user: &auth.User{
|
|
Name: "user",
|
|
Admin: false,
|
|
Permissions: map[string]bool{"jobs:read": true},
|
|
},
|
|
permission: "jobs:read",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "user without permission",
|
|
user: &auth.User{
|
|
Name: "user",
|
|
Admin: false,
|
|
Permissions: map[string]bool{"jobs:read": true},
|
|
},
|
|
permission: "jobs:write",
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := h.RequirePermission(tt.user, tt.permission)
|
|
if got != tt.want {
|
|
t.Errorf("RequirePermission() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandler_handleLogMetric_PayloadParsing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
payload []byte
|
|
}{
|
|
{
|
|
name: "valid payload",
|
|
payload: buildLogMetricPayload("accuracy", 0.95),
|
|
},
|
|
{
|
|
name: "payload too short",
|
|
payload: make([]byte, 20),
|
|
},
|
|
{
|
|
name: "invalid name length",
|
|
payload: buildLogMetricPayloadInvalid(),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if len(tt.payload) < 16+1+8 {
|
|
return
|
|
}
|
|
offset := 16
|
|
nameLen := int(tt.payload[offset])
|
|
offset++
|
|
if nameLen <= 0 || len(tt.payload) < offset+nameLen+8 {
|
|
return
|
|
}
|
|
name := string(tt.payload[offset : offset+nameLen])
|
|
if tt.name == "valid payload" && name != "accuracy" {
|
|
t.Errorf("Expected name 'accuracy', got '%s'", name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandler_handleGetExperiment_PayloadParsing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
payload []byte
|
|
}{
|
|
{
|
|
name: "valid payload",
|
|
payload: buildGetExperimentPayload("abc123"),
|
|
},
|
|
{
|
|
name: "payload too short",
|
|
payload: make([]byte, 16),
|
|
},
|
|
{
|
|
name: "invalid commit ID length",
|
|
payload: buildGetExperimentPayloadInvalid(),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if len(tt.payload) < 16+1 {
|
|
return
|
|
}
|
|
offset := 16
|
|
commitIDLen := int(tt.payload[offset])
|
|
offset++
|
|
if commitIDLen <= 0 || len(tt.payload) < offset+commitIDLen {
|
|
return
|
|
}
|
|
commitID := string(tt.payload[offset : offset+commitIDLen])
|
|
if tt.name == "valid payload" && commitID != "abc123" {
|
|
t.Errorf("Expected commitID 'abc123', got '%s'", commitID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandler_handleStatusRequest(t *testing.T) {
|
|
h := &ws.Handler{}
|
|
|
|
payload := make([]byte, 16)
|
|
user, err := h.Authenticate(payload)
|
|
if err != nil {
|
|
t.Errorf("Authenticate() unexpected error: %v", err)
|
|
}
|
|
if user == nil {
|
|
t.Error("Expected user, got nil")
|
|
}
|
|
}
|
|
|
|
func buildLogMetricPayload(name string, value float64) []byte {
|
|
payload := make([]byte, 16+1+len(name)+8)
|
|
payload[16] = byte(len(name))
|
|
copy(payload[17:17+len(name)], []byte(name))
|
|
val := uint64(value * 100)
|
|
payload[17+len(name)] = byte(val >> 56)
|
|
payload[18+len(name)] = byte(val >> 48)
|
|
payload[19+len(name)] = byte(val >> 40)
|
|
payload[20+len(name)] = byte(val >> 32)
|
|
payload[21+len(name)] = byte(val >> 24)
|
|
payload[22+len(name)] = byte(val >> 16)
|
|
payload[23+len(name)] = byte(val >> 8)
|
|
payload[24+len(name)] = byte(val)
|
|
return payload
|
|
}
|
|
|
|
func buildLogMetricPayloadInvalid() []byte {
|
|
payload := make([]byte, 16+1+8)
|
|
payload[16] = 200
|
|
return payload
|
|
}
|
|
|
|
func buildGetExperimentPayload(commitID string) []byte {
|
|
payload := make([]byte, 16+1+len(commitID))
|
|
payload[16] = byte(len(commitID))
|
|
copy(payload[17:], []byte(commitID))
|
|
return payload
|
|
}
|
|
|
|
func buildGetExperimentPayloadInvalid() []byte {
|
|
payload := make([]byte, 16+1+5)
|
|
payload[16] = 100
|
|
return payload
|
|
}
|