Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
228 lines
5.1 KiB
Go
228 lines
5.1 KiB
Go
// Package fixtures provides shared test utilities for all tests
|
|
package tests
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// MockWorker simulates a worker connection for testing
|
|
type MockWorker struct {
|
|
Conn *websocket.Conn
|
|
ID string
|
|
RecvCh chan scheduler.Message
|
|
SendCh chan scheduler.Message
|
|
wg sync.WaitGroup
|
|
mu sync.RWMutex
|
|
closed bool
|
|
T testing.TB
|
|
}
|
|
|
|
// NewMockWorker creates a new mock worker connected to the scheduler
|
|
func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *MockWorker {
|
|
addr := hub.Addr()
|
|
require.NotEmpty(t, addr, "hub not started")
|
|
|
|
wsURL := "ws://" + addr + "/ws/worker"
|
|
|
|
// Add test token to headers
|
|
header := http.Header{}
|
|
header.Set("Authorization", "Bearer test-token-"+workerID)
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
|
|
require.NoError(t, err)
|
|
|
|
mw := &MockWorker{
|
|
Conn: conn,
|
|
ID: workerID,
|
|
RecvCh: make(chan scheduler.Message, 100),
|
|
SendCh: make(chan scheduler.Message, 100),
|
|
T: t,
|
|
}
|
|
|
|
// Start receive goroutine
|
|
mw.wg.Add(1)
|
|
go func() {
|
|
defer mw.wg.Done()
|
|
for {
|
|
var msg scheduler.Message
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
close(mw.RecvCh)
|
|
return
|
|
}
|
|
mw.RecvCh <- msg
|
|
}
|
|
}()
|
|
|
|
// Start send goroutine
|
|
mw.wg.Add(1)
|
|
go func() {
|
|
defer mw.wg.Done()
|
|
for msg := range mw.SendCh {
|
|
if err := conn.WriteJSON(msg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return mw
|
|
}
|
|
|
|
// Register sends worker registration message and waits for ack
|
|
func (mw *MockWorker) Register(capabilities scheduler.WorkerCapabilities) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: MustMarshal(scheduler.WorkerRegistration{
|
|
ID: mw.ID,
|
|
Capabilities: capabilities,
|
|
}),
|
|
})
|
|
|
|
msg := mw.RecvTimeout(2 * time.Second)
|
|
require.Equal(mw.T, scheduler.MsgAck, msg.Type, "expected registration ack")
|
|
}
|
|
|
|
// Send sends a message to the scheduler
|
|
func (mw *MockWorker) Send(msg scheduler.Message) {
|
|
select {
|
|
case mw.SendCh <- msg:
|
|
case <-time.After(time.Second):
|
|
mw.T.Fatal("timeout sending message")
|
|
}
|
|
}
|
|
|
|
// Recv receives a message from the scheduler (blocks)
|
|
func (mw *MockWorker) Recv() scheduler.Message {
|
|
select {
|
|
case msg := <-mw.RecvCh:
|
|
return msg
|
|
case <-time.After(5 * time.Second):
|
|
require.Fail(mw.T, "timeout waiting for message")
|
|
return scheduler.Message{Type: "timeout"}
|
|
}
|
|
}
|
|
|
|
// RecvTimeout receives a message with a custom timeout
|
|
func (mw *MockWorker) RecvTimeout(timeout time.Duration) scheduler.Message {
|
|
select {
|
|
case msg := <-mw.RecvCh:
|
|
return msg
|
|
case <-time.After(timeout):
|
|
require.Fail(mw.T, "timeout waiting for message")
|
|
return scheduler.Message{Type: "timeout"}
|
|
}
|
|
}
|
|
|
|
// RecvNonBlock tries to receive without blocking
|
|
func (mw *MockWorker) RecvNonBlock() (scheduler.Message, bool) {
|
|
select {
|
|
case msg := <-mw.RecvCh:
|
|
return msg, true
|
|
default:
|
|
return scheduler.Message{}, false
|
|
}
|
|
}
|
|
|
|
// SignalReady sends ready for work message
|
|
func (mw *MockWorker) SignalReady(slots scheduler.SlotStatus, reason string) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: MustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: mw.ID,
|
|
Slots: slots,
|
|
Reason: reason,
|
|
}),
|
|
})
|
|
}
|
|
|
|
// SendHeartbeat sends a heartbeat message
|
|
func (mw *MockWorker) SendHeartbeat(slots scheduler.SlotStatus) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgHeartbeat,
|
|
Payload: MustMarshal(scheduler.HeartbeatPayload{
|
|
WorkerID: mw.ID,
|
|
Slots: slots,
|
|
}),
|
|
})
|
|
}
|
|
|
|
// AcceptJob accepts a job assignment
|
|
func (mw *MockWorker) AcceptJob(taskID string) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgJobAccepted,
|
|
Payload: MustMarshal(scheduler.JobResultPayload{
|
|
TaskID: taskID,
|
|
State: "accepted",
|
|
}),
|
|
})
|
|
}
|
|
|
|
// CompleteJob sends job completion
|
|
func (mw *MockWorker) CompleteJob(taskID string, exitCode int, output string) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgJobResult,
|
|
Payload: MustMarshal(scheduler.JobResultPayload{
|
|
TaskID: taskID,
|
|
State: "completed",
|
|
ExitCode: exitCode,
|
|
Error: output,
|
|
}),
|
|
})
|
|
}
|
|
|
|
// SendHealth sends service health update
|
|
func (mw *MockWorker) SendHealth(taskID string, healthy bool, message string) {
|
|
mw.Send(scheduler.Message{
|
|
Type: scheduler.MsgServiceHealth,
|
|
Payload: MustMarshal(scheduler.ServiceHealthPayload{
|
|
TaskID: taskID,
|
|
Healthy: healthy,
|
|
Message: message,
|
|
}),
|
|
})
|
|
}
|
|
|
|
// Close closes the worker connection
|
|
func (mw *MockWorker) Close() {
|
|
mw.mu.Lock()
|
|
if mw.closed {
|
|
mw.mu.Unlock()
|
|
return
|
|
}
|
|
mw.closed = true
|
|
mw.mu.Unlock()
|
|
|
|
close(mw.SendCh)
|
|
mw.Conn.Close()
|
|
mw.wg.Wait()
|
|
}
|
|
|
|
// WaitForDisconnect waits for the connection to close
|
|
func (mw *MockWorker) WaitForDisconnect(timeout time.Duration) bool {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
mw.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
return true
|
|
case <-time.After(timeout):
|
|
return false
|
|
}
|
|
}
|
|
|
|
// MustMarshal marshals a value to JSON, panicking on error
|
|
func MustMarshal(v any) []byte {
|
|
b, _ := json.Marshal(v)
|
|
return b
|
|
}
|