fetch_ml/tests/unit/api/ws_jupyter_test.go

265 lines
6.1 KiB
Go

package api_test
import (
"bytes"
"encoding/binary"
"testing"
)
// Test payload building and validation
func TestBuildStartJupyterPayload(t *testing.T) {
tests := []struct {
name string
apiKeyHash []byte
svcName string
workspace string
password string
wantLen int
}{
{
name: "valid payload",
apiKeyHash: make([]byte, 16),
svcName: "test-service",
workspace: "/tmp/workspace",
password: "mypass",
wantLen: 16 + 1 + 12 + 2 + 14 + 1 + 6,
},
{
name: "empty password",
apiKeyHash: make([]byte, 16),
svcName: "test",
workspace: "/tmp",
password: "",
wantLen: 16 + 1 + 4 + 2 + 4 + 1 + 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
payload := buildStartJupyterPayload(t, &startJupyterParams{
apiKeyHash: tt.apiKeyHash,
name: tt.svcName,
workspace: tt.workspace,
password: tt.password,
})
if len(payload) != tt.wantLen {
t.Errorf("expected payload length %d, got %d", tt.wantLen, len(payload))
}
// Verify API key hash
if !bytes.Equal(payload[0:16], tt.apiKeyHash) {
t.Error("API key hash mismatch")
}
// Verify name length and content
nameLen := int(payload[16])
if nameLen != len(tt.svcName) {
t.Errorf("expected name length %d, got %d", len(tt.svcName), nameLen)
}
})
}
}
func TestBuildStopJupyterPayload(t *testing.T) {
apiKeyHash := make([]byte, 16)
serviceID := "test-service-123"
payload := buildStopJupyterPayload(t, apiKeyHash, serviceID)
// Verify length
expectedLen := 16 + 1 + len(serviceID)
if len(payload) != expectedLen {
t.Errorf("expected payload length %d, got %d", expectedLen, len(payload))
}
// Verify API key hash
if !bytes.Equal(payload[0:16], apiKeyHash) {
t.Error("API key hash mismatch")
}
// Verify service ID length
idLen := int(payload[16])
if idLen != len(serviceID) {
t.Errorf("expected service ID length %d, got %d", len(serviceID), idLen)
}
// Verify service ID content
actualID := string(payload[17:])
if actualID != serviceID {
t.Errorf("expected service ID %s, got %s", serviceID, actualID)
}
}
func TestBuildListJupyterPayload(t *testing.T) {
apiKeyHash := make([]byte, 16)
for i := range apiKeyHash {
apiKeyHash[i] = byte(i)
}
payload := buildListJupyterPayload(t, apiKeyHash)
// List payload should only be API key hash
if len(payload) != 16 {
t.Errorf("expected payload length 16, got %d", len(payload))
}
if !bytes.Equal(payload, apiKeyHash) {
t.Error("API key hash mismatch in list payload")
}
}
func TestJupyterPayloadValidation(t *testing.T) {
tests := []struct {
name string
payload []byte
shouldErr bool
}{
{
name: "start payload too short",
payload: make([]byte, 10),
shouldErr: true,
},
{
name: "valid minimum start payload",
payload: buildStartJupyterPayload(t, &startJupyterParams{
apiKeyHash: make([]byte, 16),
name: "a",
workspace: "/",
password: "",
}),
shouldErr: false,
},
{
name: "stop payload too short",
payload: make([]byte, 5),
shouldErr: true,
},
{
name: "valid stop payload",
payload: buildStopJupyterPayload(t, make([]byte, 16), "test-id"),
shouldErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Just validate payload structure
if tt.shouldErr {
// Check basic length requirements
if len(tt.payload) >= 16 {
t.Error("expected short payload but got sufficient length")
}
} else {
// Check minimum length satisfied
if len(tt.payload) < 16 {
t.Error("payload too short for API key hash")
}
}
})
}
}
// Test that payload parsing logic is correct
func TestJupyterPayloadParsing(t *testing.T) {
t.Run("parse start jupyter payload", func(t *testing.T) {
params := &startJupyterParams{
apiKeyHash: make([]byte, 16),
name: "test-notebook",
workspace: "/home/user/notebooks",
password: "secret123",
}
payload := buildStartJupyterPayload(t, params)
// Parse it back
offset := 0
// API key hash
apiKeyHash := payload[offset : offset+16]
offset += 16
if !bytes.Equal(apiKeyHash, params.apiKeyHash) {
t.Error("API key hash mismatch after parsing")
}
// Name
nameLen := int(payload[offset])
offset++
name := string(payload[offset : offset+nameLen])
offset += nameLen
if name != params.name {
t.Errorf("expected name %s, got %s", params.name, name)
}
// Workspace
workspaceLen := int(binary.BigEndian.Uint16(payload[offset : offset+2]))
offset += 2
workspace := string(payload[offset : offset+workspaceLen])
offset += workspaceLen
if workspace != params.workspace {
t.Errorf("expected workspace %s, got %s", params.workspace, workspace)
}
// Password
passwordLen := int(payload[offset])
offset++
password := string(payload[offset : offset+passwordLen])
if password != params.password {
t.Errorf("expected password %s, got %s", params.password, password)
}
})
}
// Helper functions to build test payloads
type startJupyterParams struct {
apiKeyHash []byte
name string
workspace string
password string
}
func buildStartJupyterPayload(t *testing.T, params *startJupyterParams) []byte {
t.Helper()
buf := new(bytes.Buffer)
// API key hash (16 bytes)
buf.Write(params.apiKeyHash)
// Name length + name
buf.WriteByte(byte(len(params.name)))
buf.WriteString(params.name)
// Workspace length (2 bytes) + workspace
binary.Write(buf, binary.BigEndian, uint16(len(params.workspace)))
buf.WriteString(params.workspace)
// Password length + password
buf.WriteByte(byte(len(params.password)))
buf.WriteString(params.password)
return buf.Bytes()
}
func buildStopJupyterPayload(t *testing.T, apiKeyHash []byte, serviceID string) []byte {
t.Helper()
buf := new(bytes.Buffer)
// API key hash (16 bytes)
buf.Write(apiKeyHash)
// Service ID length + service ID
buf.WriteByte(byte(len(serviceID)))
buf.WriteString(serviceID)
return buf.Bytes()
}
func buildListJupyterPayload(t *testing.T, apiKeyHash []byte) []byte {
t.Helper()
// Only API key hash needed
return apiKeyHash
}