fetch_ml/internal/api/ws_test.go
Jeremie Fraeys 803677be57 feat: implement Go backend with comprehensive API and internal packages
- Add API server with WebSocket support and REST endpoints
- Implement authentication system with API keys and permissions
- Add task queue system with Redis backend and error handling
- Include storage layer with database migrations and schemas
- Add comprehensive logging, metrics, and telemetry
- Implement security middleware and network utilities
- Add experiment management and container orchestration
- Include configuration management with smart defaults
2025-12-04 16:53:53 -05:00

335 lines
8.4 KiB
Go

package api
import (
"encoding/binary"
"math"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) {
// Setup miniredis
s, err := miniredis.Run()
require.NoError(t, err)
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
}
tq, err := queue.NewTaskQueue(queueCfg)
require.NoError(t, err)
// Setup dependencies
logger := logging.NewLogger(0, false)
expManager := experiment.NewManager(t.TempDir())
authCfg := &auth.AuthConfig{Enabled: false}
// Create handler
handler := NewWSHandler(authCfg, logger, expManager, tq)
// Setup test server
server := httptest.NewServer(handler)
return server, tq, expManager, s
}
func connectWS(t *testing.T, serverURL string) *websocket.Conn {
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
return ws
}
func TestWSHandler_QueueJob(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare queue_job message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
opcode := byte(OpcodeQueueJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := make([]byte, 64)
copy(commitID, []byte(strings.Repeat("a", 64)))
priority := byte(5)
jobName := "test-job"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, priority)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeSuccess = 0x00)
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
// Verify task in Redis
time.Sleep(100 * time.Millisecond)
task, err := tq.GetNextTask()
require.NoError(t, err)
require.NotNil(t, task)
assert.Equal(t, jobName, task.JobName)
}
func TestWSHandler_StatusRequest(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-1",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare status_request message
// Protocol: [opcode:1][api_key_hash:64]
opcode := byte(OpcodeStatusRequest)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData = 0x04 for status with payload)
assert.Equal(t, byte(PacketTypeData), resp[0])
}
func TestWSHandler_CancelJob(t *testing.T) {
server, tq, _, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-to-cancel",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user", // Auth disabled so this matches any user
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare cancel_job message
// Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var]
opcode := byte(OpcodeCancelJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
jobName := "job-to-cancel"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
// Verify task cancelled
updatedTask, err := tq.GetTask("task-1")
require.NoError(t, err)
assert.Equal(t, "cancelled", updatedTask.Status)
}
func TestWSHandler_Prune(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create some experiments
_ = expManager.CreateExperiment("commit-1")
_ = expManager.CreateExperiment("commit-2")
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare prune message
// Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4]
opcode := byte(OpcodePrune)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
pruneType := byte(0) // Keep N
value := uint32(1) // Keep 1
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, value)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, pruneType)
msg = append(msg, valueBytes...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
}
func TestWSHandler_LogMetric(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create experiment
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare log_metric message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
opcode := byte(OpcodeLogMetric)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
step := uint32(100)
value := 0.95
valueBits := math.Float64bits(value)
metricName := "accuracy"
nameLen := byte(len(metricName))
stepBytes := make([]byte, 4)
binary.BigEndian.PutUint32(stepBytes, step)
valueBytes := make([]byte, 8)
binary.BigEndian.PutUint64(valueBytes, valueBits)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, stepBytes...)
msg = append(msg, valueBytes...)
msg = append(msg, nameLen)
msg = append(msg, []byte(metricName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(PacketTypeSuccess), resp[0])
}
func TestWSHandler_GetExperiment(t *testing.T) {
server, tq, expManager, s := setupTestServer(t)
defer server.Close()
defer tq.Close()
defer s.Close()
// Create experiment and metadata
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: "test-job",
}
err = expManager.WriteMetadata(meta)
require.NoError(t, err)
ws := connectWS(t, server.URL)
defer ws.Close()
// Prepare get_experiment message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64]
opcode := byte(OpcodeGetExperiment)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData)
assert.Equal(t, byte(PacketTypeData), resp[0])
}