package ws_test import ( "testing" ws "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" ) func TestHandler_Authenticate(t *testing.T) { h := &ws.Handler{} tests := []struct { name string payload []byte wantErr bool }{ { name: "valid payload", payload: make([]byte, 16), wantErr: false, }, { name: "short payload", payload: make([]byte, 10), wantErr: true, }, { name: "empty payload", payload: []byte{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { user, err := h.Authenticate(tt.payload) if tt.wantErr && err == nil { t.Errorf("Authenticate() expected error, got nil") } if !tt.wantErr && err != nil { t.Errorf("Authenticate() unexpected error: %v", err) } if !tt.wantErr && user == nil { t.Errorf("Authenticate() expected user, got nil") } }) } } func TestHandler_RequirePermission(t *testing.T) { h := &ws.Handler{} tests := []struct { name string user *auth.User permission string want bool }{ { name: "nil user", user: nil, permission: "jobs:read", want: false, }, { name: "admin user", user: &auth.User{ Name: "admin", Admin: true, }, permission: "any:permission", want: true, }, { name: "user with permission", user: &auth.User{ Name: "user", Admin: false, Permissions: map[string]bool{"jobs:read": true}, }, permission: "jobs:read", want: true, }, { name: "user without permission", user: &auth.User{ Name: "user", Admin: false, Permissions: map[string]bool{"jobs:read": true}, }, permission: "jobs:write", want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := h.RequirePermission(tt.user, tt.permission) if got != tt.want { t.Errorf("RequirePermission() = %v, want %v", got, tt.want) } }) } } func TestHandler_handleLogMetric_PayloadParsing(t *testing.T) { tests := []struct { name string payload []byte }{ { name: "valid payload", payload: buildLogMetricPayload("accuracy", 0.95), }, { name: "payload too short", payload: make([]byte, 20), }, { name: "invalid name length", payload: buildLogMetricPayloadInvalid(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if len(tt.payload) < 16+1+8 { return } offset := 16 nameLen := int(tt.payload[offset]) offset++ if nameLen <= 0 || len(tt.payload) < offset+nameLen+8 { return } name := string(tt.payload[offset : offset+nameLen]) if tt.name == "valid payload" && name != "accuracy" { t.Errorf("Expected name 'accuracy', got '%s'", name) } }) } } func TestHandler_handleGetExperiment_PayloadParsing(t *testing.T) { tests := []struct { name string payload []byte }{ { name: "valid payload", payload: buildGetExperimentPayload("abc123"), }, { name: "payload too short", payload: make([]byte, 16), }, { name: "invalid commit ID length", payload: buildGetExperimentPayloadInvalid(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if len(tt.payload) < 16+1 { return } offset := 16 commitIDLen := int(tt.payload[offset]) offset++ if commitIDLen <= 0 || len(tt.payload) < offset+commitIDLen { return } commitID := string(tt.payload[offset : offset+commitIDLen]) if tt.name == "valid payload" && commitID != "abc123" { t.Errorf("Expected commitID 'abc123', got '%s'", commitID) } }) } } func TestHandler_handleStatusRequest(t *testing.T) { h := &ws.Handler{} payload := make([]byte, 16) user, err := h.Authenticate(payload) if err != nil { t.Errorf("Authenticate() unexpected error: %v", err) } if user == nil { t.Error("Expected user, got nil") } } func buildLogMetricPayload(name string, value float64) []byte { payload := make([]byte, 16+1+len(name)+8) payload[16] = byte(len(name)) copy(payload[17:17+len(name)], []byte(name)) val := uint64(value * 100) payload[17+len(name)] = byte(val >> 56) payload[18+len(name)] = byte(val >> 48) payload[19+len(name)] = byte(val >> 40) payload[20+len(name)] = byte(val >> 32) payload[21+len(name)] = byte(val >> 24) payload[22+len(name)] = byte(val >> 16) payload[23+len(name)] = byte(val >> 8) payload[24+len(name)] = byte(val) return payload } func buildLogMetricPayloadInvalid() []byte { payload := make([]byte, 16+1+8) payload[16] = 200 return payload } func buildGetExperimentPayload(commitID string) []byte { payload := make([]byte, 16+1+len(commitID)) payload[16] = byte(len(commitID)) copy(payload[17:], []byte(commitID)) return payload } func buildGetExperimentPayloadInvalid() []byte { payload := make([]byte, 16+1+5) payload[16] = 100 return payload }