fetch_ml/internal/api/ws/handler_test.go
Jeremie Fraeys a4e2ecdbe6
refactor: co-locate api, audit, auth tests with source code
Move unit tests from tests/unit/ to internal/ following Go conventions:
- tests/unit/api/* -> internal/api/* (WebSocket handlers, helpers, duplicate detection)
- tests/unit/audit/* -> internal/audit/* (alert, sealed, verifier tests)
- tests/unit/auth/* -> internal/auth/* (API key, keychain, user manager)
- tests/unit/crypto/kms/* -> internal/auth/kms/* (cache, protocol tests)

Update import paths in test files to reflect new locations.

Benefits:
- Tests live alongside the code they test
- Easier navigation and maintenance
- Clearer package boundaries
- Follows standard Go project layout
2026-03-12 16:34:54 -04:00

229 lines
4.9 KiB
Go

package ws_test
import (
"testing"
ws "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
)
func TestHandler_Authenticate(t *testing.T) {
h := &ws.Handler{}
tests := []struct {
name string
payload []byte
wantErr bool
}{
{
name: "valid payload",
payload: make([]byte, 16),
wantErr: false,
},
{
name: "short payload",
payload: make([]byte, 10),
wantErr: true,
},
{
name: "empty payload",
payload: []byte{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := h.Authenticate(tt.payload)
if tt.wantErr && err == nil {
t.Errorf("Authenticate() expected error, got nil")
}
if !tt.wantErr && err != nil {
t.Errorf("Authenticate() unexpected error: %v", err)
}
if !tt.wantErr && user == nil {
t.Errorf("Authenticate() expected user, got nil")
}
})
}
}
func TestHandler_RequirePermission(t *testing.T) {
h := &ws.Handler{}
tests := []struct {
name string
user *auth.User
permission string
want bool
}{
{
name: "nil user",
user: nil,
permission: "jobs:read",
want: false,
},
{
name: "admin user",
user: &auth.User{
Name: "admin",
Admin: true,
},
permission: "any:permission",
want: true,
},
{
name: "user with permission",
user: &auth.User{
Name: "user",
Admin: false,
Permissions: map[string]bool{"jobs:read": true},
},
permission: "jobs:read",
want: true,
},
{
name: "user without permission",
user: &auth.User{
Name: "user",
Admin: false,
Permissions: map[string]bool{"jobs:read": true},
},
permission: "jobs:write",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := h.RequirePermission(tt.user, tt.permission)
if got != tt.want {
t.Errorf("RequirePermission() = %v, want %v", got, tt.want)
}
})
}
}
func TestHandler_handleLogMetric_PayloadParsing(t *testing.T) {
tests := []struct {
name string
payload []byte
}{
{
name: "valid payload",
payload: buildLogMetricPayload("accuracy", 0.95),
},
{
name: "payload too short",
payload: make([]byte, 20),
},
{
name: "invalid name length",
payload: buildLogMetricPayloadInvalid(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if len(tt.payload) < 16+1+8 {
return
}
offset := 16
nameLen := int(tt.payload[offset])
offset++
if nameLen <= 0 || len(tt.payload) < offset+nameLen+8 {
return
}
name := string(tt.payload[offset : offset+nameLen])
if tt.name == "valid payload" && name != "accuracy" {
t.Errorf("Expected name 'accuracy', got '%s'", name)
}
})
}
}
func TestHandler_handleGetExperiment_PayloadParsing(t *testing.T) {
tests := []struct {
name string
payload []byte
}{
{
name: "valid payload",
payload: buildGetExperimentPayload("abc123"),
},
{
name: "payload too short",
payload: make([]byte, 16),
},
{
name: "invalid commit ID length",
payload: buildGetExperimentPayloadInvalid(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if len(tt.payload) < 16+1 {
return
}
offset := 16
commitIDLen := int(tt.payload[offset])
offset++
if commitIDLen <= 0 || len(tt.payload) < offset+commitIDLen {
return
}
commitID := string(tt.payload[offset : offset+commitIDLen])
if tt.name == "valid payload" && commitID != "abc123" {
t.Errorf("Expected commitID 'abc123', got '%s'", commitID)
}
})
}
}
func TestHandler_handleStatusRequest(t *testing.T) {
h := &ws.Handler{}
payload := make([]byte, 16)
user, err := h.Authenticate(payload)
if err != nil {
t.Errorf("Authenticate() unexpected error: %v", err)
}
if user == nil {
t.Error("Expected user, got nil")
}
}
func buildLogMetricPayload(name string, value float64) []byte {
payload := make([]byte, 16+1+len(name)+8)
payload[16] = byte(len(name))
copy(payload[17:17+len(name)], []byte(name))
val := uint64(value * 100)
payload[17+len(name)] = byte(val >> 56)
payload[18+len(name)] = byte(val >> 48)
payload[19+len(name)] = byte(val >> 40)
payload[20+len(name)] = byte(val >> 32)
payload[21+len(name)] = byte(val >> 24)
payload[22+len(name)] = byte(val >> 16)
payload[23+len(name)] = byte(val >> 8)
payload[24+len(name)] = byte(val)
return payload
}
func buildLogMetricPayloadInvalid() []byte {
payload := make([]byte, 16+1+8)
payload[16] = 200
return payload
}
func buildGetExperimentPayload(commitID string) []byte {
payload := make([]byte, 16+1+len(commitID))
payload[16] = byte(len(commitID))
copy(payload[17:], []byte(commitID))
return payload
}
func buildGetExperimentPayloadInvalid() []byte {
payload := make([]byte, 16+1+5)
payload[16] = 100
return payload
}