fetch_ml/tests/integration/ws_handler_integration_test.go
Jeremie Fraeys cd5640ebd2 Slim and secure: move scripts, clean configs, remove secrets
- Move ci-test.sh and setup.sh to scripts/
- Trim docs/src/zig-cli.md to current structure
- Replace hardcoded secrets with placeholders in configs
- Update .gitignore to block .env*, secrets/, keys, build artifacts
- Slim README.md to reflect current CLI/TUI split
- Add cleanup trap to ci-test.sh
- Ensure no secrets are committed
2025-12-07 13:57:51 -05:00

346 lines
9.1 KiB
Go

//nolint:revive // Package name 'tests' is appropriate for this integration test package
package tests
import (
"encoding/binary"
"math"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/jfraeys/fetch_ml/internal/api"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/queue"
)
func setupWSIntegrationServer(t *testing.T) (
*httptest.Server,
*queue.TaskQueue,
*experiment.Manager,
*miniredis.Miniredis,
) {
// Setup miniredis
s, err := miniredis.Run()
require.NoError(t, err)
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
}
tq, err := queue.NewTaskQueue(queueCfg)
require.NoError(t, err)
// Setup dependencies
logger := logging.NewLogger(0, false)
expManager := experiment.NewManager(t.TempDir())
authCfg := &auth.Config{Enabled: false}
// Create handler
handler := api.NewWSHandler(authCfg, logger, expManager, tq)
// Setup test server
server := httptest.NewServer(handler)
return server, tq, expManager, s
}
func connectWSIntegration(t *testing.T, serverURL string) *websocket.Conn {
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
require.NoError(t, err)
return ws
}
func TestWSHandler_QueueJob_Integration(t *testing.T) {
server, tq, _, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare queue_job message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
opcode := byte(api.OpcodeQueueJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := make([]byte, 64)
copy(commitID, []byte(strings.Repeat("a", 64)))
priority := byte(5)
jobName := "test-job"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, priority)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeSuccess = 0x00)
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
// Verify task in Redis
time.Sleep(100 * time.Millisecond)
task, err := tq.GetNextTask()
require.NoError(t, err)
require.NotNil(t, task)
assert.Equal(t, jobName, task.JobName)
}
func TestWSHandler_StatusRequest_Integration(t *testing.T) {
server, tq, _, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-1",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare status_request message
// Protocol: [opcode:1][api_key_hash:64]
opcode := byte(api.OpcodeStatusRequest)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData = 0x04 for status with payload)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}
func TestWSHandler_CancelJob_Integration(t *testing.T) {
server, tq, _, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-to-cancel",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user", // Auth disabled so this matches any user
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare cancel_job message
// Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var]
opcode := byte(api.OpcodeCancelJob)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
jobName := "job-to-cancel"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
// Verify task cancelled
updatedTask, err := tq.GetTask("task-1")
require.NoError(t, err)
assert.Equal(t, "cancelled", updatedTask.Status)
}
func TestWSHandler_Prune_Integration(t *testing.T) {
server, tq, expManager, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
// Create some experiments
_ = expManager.CreateExperiment("commit-1")
_ = expManager.CreateExperiment("commit-2")
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare prune message
// Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4]
opcode := byte(api.OpcodePrune)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
pruneType := byte(0) // Keep N
value := uint32(1) // Keep 1
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, value)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, pruneType)
msg = append(msg, valueBytes...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
}
func TestWSHandler_LogMetric_Integration(t *testing.T) {
server, tq, expManager, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
// Create experiment
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare log_metric message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
opcode := byte(api.OpcodeLogMetric)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
step := uint32(100)
value := 0.95
valueBits := math.Float64bits(value)
metricName := "accuracy"
nameLen := byte(len(metricName))
stepBytes := make([]byte, 4)
binary.BigEndian.PutUint32(stepBytes, step)
valueBytes := make([]byte, 8)
binary.BigEndian.PutUint64(valueBytes, valueBits)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, stepBytes...)
msg = append(msg, valueBytes...)
msg = append(msg, nameLen)
msg = append(msg, []byte(metricName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
}
func TestWSHandler_GetExperiment_Integration(t *testing.T) {
server, tq, expManager, s := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
// Create experiment and metadata
commitIDStr := strings.Repeat("a", 64)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: "test-job",
}
err = expManager.WriteMetadata(meta)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare get_experiment message
// Protocol: [opcode:1][api_key_hash:64][commit_id:64]
opcode := byte(api.OpcodeGetExperiment)
apiKeyHash := make([]byte, 64)
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
commitID := []byte(commitIDStr)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}