- Fix YAML tags in auth config struct (json -> yaml) - Update CLI configs to use pre-hashed API keys - Remove double hashing in WebSocket client - Fix port mapping (9102 -> 9103) in CLI commands - Update permission keys to use jobs:read, jobs:create, etc. - Clean up all debug logging from CLI and server - All user roles now authenticate correctly: * Admin: Can queue jobs and see all jobs * Researcher: Can queue jobs and see own jobs * Analyst: Can see status (read-only access) Multi-user authentication is now fully functional.
341 lines
9.1 KiB
Go
341 lines
9.1 KiB
Go
//nolint:revive // Package name 'tests' is appropriate for this integration test package
|
|
package tests
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"math"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
"github.com/jfraeys/fetch_ml/internal/queue"
|
|
)
|
|
|
|
func setupWSIntegrationServer(t *testing.T) (*httptest.Server, *queue.TaskQueue, *experiment.Manager, *miniredis.Miniredis) {
|
|
// Setup miniredis
|
|
s, err := miniredis.Run()
|
|
require.NoError(t, err)
|
|
|
|
// Setup TaskQueue
|
|
queueCfg := queue.Config{
|
|
RedisAddr: s.Addr(),
|
|
MetricsFlushInterval: 10 * time.Millisecond,
|
|
}
|
|
tq, err := queue.NewTaskQueue(queueCfg)
|
|
require.NoError(t, err)
|
|
|
|
// Setup dependencies
|
|
logger := logging.NewLogger(0, false)
|
|
expManager := experiment.NewManager(t.TempDir())
|
|
authCfg := &auth.Config{Enabled: false}
|
|
|
|
// Create handler
|
|
handler := api.NewWSHandler(authCfg, logger, expManager, tq)
|
|
|
|
// Setup test server
|
|
server := httptest.NewServer(handler)
|
|
|
|
return server, tq, expManager, s
|
|
}
|
|
|
|
func connectWSIntegration(t *testing.T, serverURL string) *websocket.Conn {
|
|
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
|
|
ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
require.NoError(t, err)
|
|
return ws
|
|
}
|
|
|
|
func TestWSHandler_QueueJob_Integration(t *testing.T) {
|
|
server, tq, _, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare queue_job message
|
|
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][priority:1][job_name_len:1][job_name:var]
|
|
opcode := byte(api.OpcodeQueueJob)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
commitID := make([]byte, 64)
|
|
copy(commitID, []byte(strings.Repeat("a", 64)))
|
|
priority := byte(5)
|
|
jobName := "test-job"
|
|
jobNameLen := byte(len(jobName))
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
msg = append(msg, commitID...)
|
|
msg = append(msg, priority)
|
|
msg = append(msg, jobNameLen)
|
|
msg = append(msg, []byte(jobName)...)
|
|
|
|
// Send message
|
|
err := ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response (PacketTypeSuccess = 0x00)
|
|
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
|
|
|
|
// Verify task in Redis
|
|
time.Sleep(100 * time.Millisecond)
|
|
task, err := tq.GetNextTask()
|
|
require.NoError(t, err)
|
|
require.NotNil(t, task)
|
|
assert.Equal(t, jobName, task.JobName)
|
|
}
|
|
|
|
func TestWSHandler_StatusRequest_Integration(t *testing.T) {
|
|
server, tq, _, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
// Add a task to queue
|
|
task := &queue.Task{
|
|
ID: "task-1",
|
|
JobName: "job-1",
|
|
Status: "queued",
|
|
Priority: 10,
|
|
CreatedAt: time.Now(),
|
|
UserID: "user",
|
|
CreatedBy: "user",
|
|
}
|
|
err := tq.AddTask(task)
|
|
require.NoError(t, err)
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare status_request message
|
|
// Protocol: [opcode:1][api_key_hash:64]
|
|
opcode := byte(api.OpcodeStatusRequest)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
|
|
// Send message
|
|
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response (PacketTypeData = 0x04 for status with payload)
|
|
assert.Equal(t, byte(api.PacketTypeData), resp[0])
|
|
}
|
|
|
|
func TestWSHandler_CancelJob_Integration(t *testing.T) {
|
|
server, tq, _, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
// Add a task to queue
|
|
task := &queue.Task{
|
|
ID: "task-1",
|
|
JobName: "job-to-cancel",
|
|
Status: "queued",
|
|
Priority: 10,
|
|
CreatedAt: time.Now(),
|
|
UserID: "user", // Auth disabled so this matches any user
|
|
CreatedBy: "user",
|
|
}
|
|
err := tq.AddTask(task)
|
|
require.NoError(t, err)
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare cancel_job message
|
|
// Protocol: [opcode:1][api_key_hash:64][job_name_len:1][job_name:var]
|
|
opcode := byte(api.OpcodeCancelJob)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
jobName := "job-to-cancel"
|
|
jobNameLen := byte(len(jobName))
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
msg = append(msg, jobNameLen)
|
|
msg = append(msg, []byte(jobName)...)
|
|
|
|
// Send message
|
|
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response
|
|
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
|
|
|
|
// Verify task cancelled
|
|
updatedTask, err := tq.GetTask("task-1")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "cancelled", updatedTask.Status)
|
|
}
|
|
|
|
func TestWSHandler_Prune_Integration(t *testing.T) {
|
|
server, tq, expManager, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
// Create some experiments
|
|
_ = expManager.CreateExperiment("commit-1")
|
|
_ = expManager.CreateExperiment("commit-2")
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare prune message
|
|
// Protocol: [opcode:1][api_key_hash:64][prune_type:1][value:4]
|
|
opcode := byte(api.OpcodePrune)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
pruneType := byte(0) // Keep N
|
|
value := uint32(1) // Keep 1
|
|
valueBytes := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(valueBytes, value)
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
msg = append(msg, pruneType)
|
|
msg = append(msg, valueBytes...)
|
|
|
|
// Send message
|
|
err := ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response
|
|
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
|
|
}
|
|
|
|
func TestWSHandler_LogMetric_Integration(t *testing.T) {
|
|
server, tq, expManager, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
// Create experiment
|
|
commitIDStr := strings.Repeat("a", 64)
|
|
err := expManager.CreateExperiment(commitIDStr)
|
|
require.NoError(t, err)
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare log_metric message
|
|
// Protocol: [opcode:1][api_key_hash:64][commit_id:64][step:4][value:8][name_len:1][name:var]
|
|
opcode := byte(api.OpcodeLogMetric)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
commitID := []byte(commitIDStr)
|
|
step := uint32(100)
|
|
value := 0.95
|
|
valueBits := math.Float64bits(value)
|
|
metricName := "accuracy"
|
|
nameLen := byte(len(metricName))
|
|
|
|
stepBytes := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(stepBytes, step)
|
|
valueBytes := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(valueBytes, valueBits)
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
msg = append(msg, commitID...)
|
|
msg = append(msg, stepBytes...)
|
|
msg = append(msg, valueBytes...)
|
|
msg = append(msg, nameLen)
|
|
msg = append(msg, []byte(metricName)...)
|
|
|
|
// Send message
|
|
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response
|
|
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
|
|
}
|
|
|
|
func TestWSHandler_GetExperiment_Integration(t *testing.T) {
|
|
server, tq, expManager, s := setupWSIntegrationServer(t)
|
|
defer server.Close()
|
|
defer func() { _ = tq.Close() }()
|
|
defer s.Close()
|
|
|
|
// Create experiment and metadata
|
|
commitIDStr := strings.Repeat("a", 64)
|
|
err := expManager.CreateExperiment(commitIDStr)
|
|
require.NoError(t, err)
|
|
|
|
meta := &experiment.Metadata{
|
|
CommitID: commitIDStr,
|
|
JobName: "test-job",
|
|
}
|
|
err = expManager.WriteMetadata(meta)
|
|
require.NoError(t, err)
|
|
|
|
ws := connectWSIntegration(t, server.URL)
|
|
defer func() { _ = ws.Close() }()
|
|
|
|
// Prepare get_experiment message
|
|
// Protocol: [opcode:1][api_key_hash:64][commit_id:64]
|
|
opcode := byte(api.OpcodeGetExperiment)
|
|
apiKeyHash := make([]byte, 64)
|
|
copy(apiKeyHash, []byte(strings.Repeat("0", 64)))
|
|
commitID := []byte(commitIDStr)
|
|
|
|
var msg []byte
|
|
msg = append(msg, opcode)
|
|
msg = append(msg, apiKeyHash...)
|
|
msg = append(msg, commitID...)
|
|
|
|
// Send message
|
|
err = ws.WriteMessage(websocket.BinaryMessage, msg)
|
|
require.NoError(t, err)
|
|
|
|
// Read response
|
|
_, resp, err := ws.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
// Verify success response (PacketTypeData)
|
|
assert.Equal(t, byte(api.PacketTypeData), resp[0])
|
|
}
|