fetch_ml/tests/integration/ws_handler_integration_test.go
Jeremie Fraeys d8cc2a4efa
refactor: Migrate all test imports from api to api/ws package
Updated 6 test files to use proper api/ws package imports:

1. tests/e2e/websocket_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

2. tests/e2e/wss_reverse_proxy_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

3. tests/integration/ws_handler_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

4. tests/integration/websocket_queue_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

5. tests/unit/api/ws_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

6. tests/unit/api/ws_jobs_args_test.go
   - api.Opcode* → wspkg.Opcode*

Removed api/ws_compat.go shim as all tests now use proper imports.

Build status: Compiles successfully
2026-02-17 13:52:20 -05:00

1063 lines
32 KiB
Go

//nolint:revive // Package name 'tests' is appropriate for this integration test package
package tests
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"math"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/jfraeys/fetch_ml/internal/api"
wspkg "github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
"github.com/jfraeys/fetch_ml/internal/manifest"
"github.com/jfraeys/fetch_ml/internal/queue"
"github.com/jfraeys/fetch_ml/internal/storage"
"github.com/jfraeys/fetch_ml/internal/worker"
)
func setupWSIntegrationServerWithDataDir(t *testing.T, dataDir string) (
*httptest.Server,
*queue.TaskQueue,
*experiment.Manager,
*miniredis.Miniredis,
*storage.DB,
) {
s, err := miniredis.Run()
require.NoError(t, err)
queueCfg := queue.Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
}
tq, err := queue.NewTaskQueue(queueCfg)
require.NoError(t, err)
logger := logging.NewLogger(0, false)
expManager := experiment.NewManager(t.TempDir())
authConfig := &auth.Config{Enabled: false}
dbPath := filepath.Join(t.TempDir(), "test.db")
db, err := storage.NewDBFromPath(dbPath)
require.NoError(t, err)
schema, err := storage.SchemaForDBType(storage.DBTypeSQLite)
require.NoError(t, err)
require.NoError(t, db.Initialize(schema))
handler := wspkg.NewHandler(
authConfig,
logger,
expManager,
dataDir,
tq,
db,
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
)
server := httptest.NewServer(handler)
return server, tq, expManager, s, db
}
func decodeDataPacket(t *testing.T, resp []byte) (string, []byte) {
t.Helper()
require.GreaterOrEqual(t, len(resp), 1+8)
if resp[0] != byte(api.PacketTypeData) {
t.Fatalf("expected PacketTypeData=%d, got %d", api.PacketTypeData, resp[0])
}
idx := 1 + 8
dataTypeLen, n := binary.Uvarint(resp[idx:])
require.Greater(t, n, 0)
idx += n
require.GreaterOrEqual(t, len(resp), idx+int(dataTypeLen))
dataType := string(resp[idx : idx+int(dataTypeLen)])
idx += int(dataTypeLen)
payloadLen, n := binary.Uvarint(resp[idx:])
require.Greater(t, n, 0)
idx += n
require.GreaterOrEqual(t, len(resp), idx+int(payloadLen))
return dataType, resp[idx : idx+int(payloadLen)]
}
func TestWSHandler_ValidateRequest_TaskID_RunManifestMissingForRunning_Fails(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
taskID := "task-run-manifest-missing"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "running",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
},
}
require.NoError(t, tq.AddTask(task))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
rm := checks["run_manifest"].(map[string]any)
require.Equal(t, false, rm["ok"].(bool))
}
func TestWSHandler_ValidateRequest_TaskID_RunManifestCommitMismatch_Fails(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
taskID := "task-run-manifest-commit-mismatch"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "completed",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
},
}
require.NoError(t, tq.AddTask(task))
jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName)
require.NoError(t, os.MkdirAll(jobDir, 0750))
rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt)
rm.CommitID = strings.Repeat("62", 20)
rm.DepsManifestName = "requirements.txt"
rm.DepsManifestSHA = depSha
rm.MarkStarted(time.Now().UTC().Add(-2 * time.Second))
exitCode := 0
rm.MarkFinished(time.Now().UTC(), &exitCode, nil)
require.NoError(t, rm.WriteToDir(jobDir))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
commitCheck := checks["run_manifest_commit_id"].(map[string]any)
require.Equal(t, false, commitCheck["ok"].(bool))
require.Equal(t, commitIDStr, commitCheck["expected"].(string))
}
func TestWSHandler_ValidateRequest_TaskID_RunManifestLocationMismatch_Fails(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
taskID := "task-run-manifest-location-mismatch"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "running",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
},
}
require.NoError(t, tq.AddTask(task))
// Intentionally write manifest to the wrong bucket.
jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName)
require.NoError(t, os.MkdirAll(jobDir, 0750))
rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt)
rm.CommitID = commitIDStr
rm.DepsManifestName = "requirements.txt"
rm.DepsManifestSHA = depSha
rm.MarkStarted(time.Now().UTC().Add(-2 * time.Second))
require.NoError(t, rm.WriteToDir(jobDir))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
loc := checks["run_manifest_location"].(map[string]any)
require.Equal(t, false, loc["ok"].(bool))
require.Equal(t, "running", loc["expected"].(string))
require.Equal(t, "finished", loc["actual"].(string))
}
func TestWSHandler_ValidateRequest_TaskID_RunManifestLifecycleOrdering_Fails(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
taskID := "task-run-manifest-lifecycle-ordering"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "completed",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
},
}
require.NoError(t, tq.AddTask(task))
jobDir := filepath.Join(expMgr.BasePath(), "finished", task.JobName)
require.NoError(t, os.MkdirAll(jobDir, 0750))
rm := manifest.NewRunManifest("run-test", task.ID, task.JobName, task.CreatedAt)
rm.CommitID = commitIDStr
rm.DepsManifestName = "requirements.txt"
rm.DepsManifestSHA = depSha
start := time.Now().UTC()
end := start.Add(-1 * time.Second)
rm.MarkStarted(start)
exitCode := 0
rm.MarkFinished(end, &exitCode, nil)
require.NoError(t, rm.WriteToDir(jobDir))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
lifecycle := checks["run_manifest_lifecycle"].(map[string]any)
require.Equal(t, false, lifecycle["ok"].(bool))
}
func TestWSHandler_ValidateRequest_TaskID_InvalidResources(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
taskID := "task-invalid-resources"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "queued",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
},
CPU: -1,
}
require.NoError(t, tq.AddTask(task))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
res := checks["resources"].(map[string]any)
require.Equal(t, false, res["ok"].(bool))
}
func TestWSHandler_ValidateRequest_TaskID_SnapshotMismatch(t *testing.T) {
dataDir := t.TempDir()
server, tq, expMgr, s, db := setupWSIntegrationServerWithDataDir(t, dataDir)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
reqBytes := []byte("numpy==1.0.0\n")
reqSum := sha256.Sum256(reqBytes)
depSha := hex.EncodeToString(reqSum[:])
snapshotID := "snap-1"
snapPath := filepath.Join(dataDir, "snapshots", snapshotID)
require.NoError(t, os.MkdirAll(snapPath, 0750))
require.NoError(t, os.WriteFile(filepath.Join(snapPath, "hello.txt"), []byte("hello"), 0600))
actualSnap, err := worker.DirOverallSHA256Hex(snapPath)
require.NoError(t, err)
taskID := "task-snap-mismatch"
task := &queue.Task{
ID: taskID,
JobName: "job",
Status: "queued",
Priority: 1,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
SnapshotID: snapshotID,
Metadata: map[string]string{
"commit_id": commitIDStr,
"experiment_manifest_overall_sha": man.OverallSHA,
"deps_manifest_name": "requirements.txt",
"deps_manifest_sha256": depSha,
"snapshot_sha256": strings.Repeat("0", 64),
},
}
require.NoError(t, tq.AddTask(task))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+len(taskID))
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(1))
msg = append(msg, byte(len(taskID)))
msg = append(msg, []byte(taskID)...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report map[string]any
require.NoError(t, json.Unmarshal(payload, &report))
require.Equal(t, false, report["ok"].(bool))
checks := report["checks"].(map[string]any)
snap := checks["snapshot"].(map[string]any)
require.Equal(t, false, snap["ok"].(bool))
require.Equal(t, actualSnap, snap["actual"].(string))
}
func setupWSIntegrationServer(t *testing.T) (
*httptest.Server,
*queue.TaskQueue,
*experiment.Manager,
*miniredis.Miniredis,
*storage.DB,
) {
// Setup miniredis
s, err := miniredis.Run()
require.NoError(t, err)
// Setup TaskQueue
queueCfg := queue.Config{
RedisAddr: s.Addr(),
MetricsFlushInterval: 10 * time.Millisecond,
}
tq, err := queue.NewTaskQueue(queueCfg)
require.NoError(t, err)
// Setup dependencies
logger := logging.NewLogger(0, false)
expManager := experiment.NewManager(t.TempDir())
authConfig := &auth.Config{Enabled: false} // Renamed from authCfg
dbPath := filepath.Join(t.TempDir(), "test.db")
db, err := storage.NewDBFromPath(dbPath)
require.NoError(t, err)
schema, err := storage.SchemaForDBType(storage.DBTypeSQLite)
require.NoError(t, err)
require.NoError(t, db.Initialize(schema))
// Create handler
handler := wspkg.NewHandler(
authConfig,
logger,
expManager,
"",
tq, // Renamed from taskQueue
db, // db
nil, // jupyterServiceMgr
nil, // securityConfig
nil, // auditLogger
)
// Setup test server
server := httptest.NewServer(handler)
return server, tq, expManager, s, db
}
func connectWSIntegration(t *testing.T, serverURL string) *websocket.Conn {
wsURL := "ws" + strings.TrimPrefix(serverURL, "http")
ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
require.NoError(t, err)
return ws
}
func TestWSHandler_QueueJob_Integration(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare queue_job message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20][priority:1][job_name_len:1][job_name:var]
opcode := byte(wspkg.OpcodeQueueJob)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := make([]byte, 20)
copy(commitID, []byte(strings.Repeat("a", 20)))
commitIDStr := strings.Repeat("61", 20)
// Pre-create experiment files so enqueue can compute expected provenance (deps manifest + manifest overall sha).
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
priority := byte(5)
jobName := "test-job"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, priority)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Optional resource request tail: [cpu:1][memory_gb:1][gpu:1][gpu_mem_len:1][gpu_mem:var]
msg = append(msg, byte(4)) // cpu
msg = append(msg, byte(16)) // memory_gb
msg = append(msg, byte(1)) // gpu
gpuMem := "8GB"
msg = append(msg, byte(len(gpuMem)))
msg = append(msg, []byte(gpuMem)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeSuccess = 0x00)
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
// Verify task in Redis
time.Sleep(100 * time.Millisecond)
task, err := tq.GetNextTask()
require.NoError(t, err)
require.NotNil(t, task)
assert.Equal(t, jobName, task.JobName)
assert.Equal(t, 4, task.CPU)
assert.Equal(t, 16, task.MemoryGB)
assert.Equal(t, 1, task.GPU)
assert.Equal(t, gpuMem, task.GPUMemory)
}
func TestWSHandler_StatusRequest_Integration(t *testing.T) {
server, tq, _, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-1",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user",
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare status_request message
// Protocol: [opcode:1][api_key_hash:16]
opcode := byte(wspkg.OpcodeStatusRequest)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response (PacketTypeData = 0x04 for status with payload)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}
func TestWSHandler_ValidateRequest_Integration(t *testing.T) {
server, tq, expMgr, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
commitIDBytes := make([]byte, 20)
copy(commitIDBytes, []byte(strings.Repeat("a", 20)))
commitIDStr := strings.Repeat("61", 20)
require.NoError(t, expMgr.CreateExperiment(commitIDStr))
filesPath := expMgr.GetFilesPath(commitIDStr)
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "train.py"), []byte("print('ok')\n"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(filesPath, "requirements.txt"), []byte("numpy==1.0.0\n"), 0600))
man, err := expMgr.GenerateManifest(commitIDStr)
require.NoError(t, err)
require.NoError(t, expMgr.WriteManifest(man))
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
msg := make([]byte, 0, 1+16+1+1+20)
msg = append(msg, byte(wspkg.OpcodeValidateRequest))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(0))
msg = append(msg, byte(20))
msg = append(msg, commitIDBytes...)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
dataType, payload := decodeDataPacket(t, resp)
require.Equal(t, "validate", dataType)
var report struct {
OK bool `json:"ok"`
CommitID string `json:"commit_id"`
}
require.NoError(t, json.Unmarshal(payload, &report))
require.True(t, report.OK)
require.Equal(t, commitIDStr, report.CommitID)
}
func TestWSHandler_CancelJob_Integration(t *testing.T) {
server, tq, _, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
// Add a task to queue
task := &queue.Task{
ID: "task-1",
JobName: "job-to-cancel",
Status: "queued",
Priority: 10,
CreatedAt: time.Now(),
UserID: "user", // Auth disabled so this matches any user
CreatedBy: "user",
}
err := tq.AddTask(task)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare cancel_job message
// Protocol: [opcode:1][api_key_hash:16][job_name_len:1][job_name:var]
opcode := byte(wspkg.OpcodeCancelJob)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
jobName := "job-to-cancel"
jobNameLen := byte(len(jobName))
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, jobNameLen)
msg = append(msg, []byte(jobName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
// Verify task cancelled
updatedTask, err := tq.GetTask("task-1")
require.NoError(t, err)
assert.Equal(t, "cancelled", updatedTask.Status)
}
func TestWSHandler_Prune_Integration(t *testing.T) {
server, tq, expManager, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
// Create some experiments
_ = expManager.CreateExperiment("commit-1")
_ = expManager.CreateExperiment("commit-2")
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare prune message
// Protocol: [opcode:1][api_key_hash:16][prune_type:1][value:4]
opcode := byte(wspkg.OpcodePrune)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
pruneType := byte(0) // Keep N
value := uint32(1) // Keep 1
valueBytes := make([]byte, 4)
binary.BigEndian.PutUint32(valueBytes, value)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, pruneType)
msg = append(msg, valueBytes...)
// Send message
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
}
func TestWSHandler_LogMetric_Integration(t *testing.T) {
server, tq, expManager, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
// Create experiment
commitIDStr := strings.Repeat("a", 20)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
// Write metadata to ensure proper initialization
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: "test-job",
}
err = expManager.WriteMetadata(meta)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare log_metric message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20][step:4][value:8][name_len:1][name:var]
opcode := byte(wspkg.OpcodeLogMetric)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := []byte(commitIDStr)
step := uint32(100)
value := 0.95
valueBits := math.Float64bits(value)
metricName := "accuracy"
nameLen := byte(len(metricName))
stepBytes := make([]byte, 4)
binary.BigEndian.PutUint32(stepBytes, step)
valueBytes := make([]byte, 8)
binary.BigEndian.PutUint64(valueBytes, valueBits)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
msg = append(msg, stepBytes...)
msg = append(msg, valueBytes...)
msg = append(msg, nameLen)
msg = append(msg, []byte(metricName)...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify success response
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
}
func TestWSHandler_GetExperiment_Integration(t *testing.T) {
server, tq, expManager, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
// Create experiment and metadata
commitIDStr := strings.Repeat("a", 20)
err := expManager.CreateExperiment(commitIDStr)
require.NoError(t, err)
meta := &experiment.Metadata{
CommitID: commitIDStr,
JobName: "test-job",
}
err = expManager.WriteMetadata(meta)
require.NoError(t, err)
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
// Prepare get_experiment message
// Protocol: [opcode:1][api_key_hash:16][commit_id:20]
opcode := byte(wspkg.OpcodeGetExperiment)
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
commitID := []byte(commitIDStr)
var msg []byte
msg = append(msg, opcode)
msg = append(msg, apiKeyHash...)
msg = append(msg, commitID...)
// Send message
err = ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
// Read response
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
// Verify error response (PacketTypeError)
assert.Equal(t, byte(api.PacketTypeError), resp[0])
}
func TestWSHandler_DatasetListRegisterInfoSearch_Integration(t *testing.T) {
server, tq, _, s, db := setupWSIntegrationServer(t)
defer server.Close()
defer func() { _ = tq.Close() }()
defer s.Close()
defer func() { _ = db.Close() }()
ws := connectWSIntegration(t, server.URL)
defer func() { _ = ws.Close() }()
apiKeyHash := make([]byte, 16)
copy(apiKeyHash, []byte(strings.Repeat("0", 16)))
// 1) List should return empty array
{
msg := make([]byte, 0, 1+16)
msg = append(msg, byte(wspkg.OpcodeDatasetList))
msg = append(msg, apiKeyHash...)
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}
// 2) Register dataset
name := "mnist"
urlStr := "https://example.com/mnist.tar.gz"
{
msg := make([]byte, 0, 1+16+1+len(name)+2+len(urlStr))
msg = append(msg, byte(wspkg.OpcodeDatasetRegister))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(name)))
msg = append(msg, []byte(name)...)
urlLen := make([]byte, 2)
binary.BigEndian.PutUint16(urlLen, uint16(len(urlStr)))
msg = append(msg, urlLen...)
msg = append(msg, []byte(urlStr)...)
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
assert.Equal(t, byte(api.PacketTypeSuccess), resp[0])
}
// 3) Info should return PacketTypeData
{
msg := make([]byte, 0, 1+16+1+len(name))
msg = append(msg, byte(wspkg.OpcodeDatasetInfo))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(name)))
msg = append(msg, []byte(name)...)
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}
// 4) Search should return PacketTypeData
term := "mn"
{
msg := make([]byte, 0, 1+16+1+len(term))
msg = append(msg, byte(wspkg.OpcodeDatasetSearch))
msg = append(msg, apiKeyHash...)
msg = append(msg, byte(len(term)))
msg = append(msg, []byte(term)...)
err := ws.WriteMessage(websocket.BinaryMessage, msg)
require.NoError(t, err)
_, resp, err := ws.ReadMessage()
require.NoError(t, err)
assert.Equal(t, byte(api.PacketTypeData), resp[0])
}
}