fetch_ml/tests/fixtures/scheduler_mock.go
Jeremie Fraeys 43e6446587
feat(scheduler): implement multi-tenant job scheduler with gang scheduling
Add new scheduler component for distributed ML workload orchestration:
- Hub-based coordination for multi-worker clusters
- Pacing controller for rate limiting job submissions
- Priority queue with preemption support
- Port allocator for dynamic service discovery
- Protocol handlers for worker-scheduler communication
- Service manager with OS-specific implementations
- Connection management and state persistence
- Template system for service deployment

Includes comprehensive test suite:
- Unit tests for all core components
- Integration tests for distributed scenarios
- Benchmark tests for performance validation
- Mock fixtures for isolated testing

Refs: scheduler-architecture.md
2026-02-26 12:03:23 -05:00

228 lines
5.1 KiB
Go

// Package fixtures provides shared test utilities for all tests
package tests
import (
"encoding/json"
"net/http"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/scheduler"
"github.com/stretchr/testify/require"
)
// MockWorker simulates a worker connection for testing
type MockWorker struct {
Conn *websocket.Conn
ID string
RecvCh chan scheduler.Message
SendCh chan scheduler.Message
wg sync.WaitGroup
mu sync.RWMutex
closed bool
T testing.TB
}
// NewMockWorker creates a new mock worker connected to the scheduler
func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *MockWorker {
addr := hub.Addr()
require.NotEmpty(t, addr, "hub not started")
wsURL := "ws://" + addr + "/ws/worker"
// Add test token to headers
header := http.Header{}
header.Set("Authorization", "Bearer test-token-"+workerID)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
require.NoError(t, err)
mw := &MockWorker{
Conn: conn,
ID: workerID,
RecvCh: make(chan scheduler.Message, 100),
SendCh: make(chan scheduler.Message, 100),
T: t,
}
// Start receive goroutine
mw.wg.Add(1)
go func() {
defer mw.wg.Done()
for {
var msg scheduler.Message
err := conn.ReadJSON(&msg)
if err != nil {
close(mw.RecvCh)
return
}
mw.RecvCh <- msg
}
}()
// Start send goroutine
mw.wg.Add(1)
go func() {
defer mw.wg.Done()
for msg := range mw.SendCh {
if err := conn.WriteJSON(msg); err != nil {
return
}
}
}()
return mw
}
// Register sends worker registration message and waits for ack
func (mw *MockWorker) Register(capabilities scheduler.WorkerCapabilities) {
mw.Send(scheduler.Message{
Type: scheduler.MsgRegister,
Payload: MustMarshal(scheduler.WorkerRegistration{
ID: mw.ID,
Capabilities: capabilities,
}),
})
msg := mw.RecvTimeout(2 * time.Second)
require.Equal(mw.T, scheduler.MsgAck, msg.Type, "expected registration ack")
}
// Send sends a message to the scheduler
func (mw *MockWorker) Send(msg scheduler.Message) {
select {
case mw.SendCh <- msg:
case <-time.After(time.Second):
mw.T.Fatal("timeout sending message")
}
}
// Recv receives a message from the scheduler (blocks)
func (mw *MockWorker) Recv() scheduler.Message {
select {
case msg := <-mw.RecvCh:
return msg
case <-time.After(5 * time.Second):
require.Fail(mw.T, "timeout waiting for message")
return scheduler.Message{Type: "timeout"}
}
}
// RecvTimeout receives a message with a custom timeout
func (mw *MockWorker) RecvTimeout(timeout time.Duration) scheduler.Message {
select {
case msg := <-mw.RecvCh:
return msg
case <-time.After(timeout):
require.Fail(mw.T, "timeout waiting for message")
return scheduler.Message{Type: "timeout"}
}
}
// RecvNonBlock tries to receive without blocking
func (mw *MockWorker) RecvNonBlock() (scheduler.Message, bool) {
select {
case msg := <-mw.RecvCh:
return msg, true
default:
return scheduler.Message{}, false
}
}
// SignalReady sends ready for work message
func (mw *MockWorker) SignalReady(slots scheduler.SlotStatus, reason string) {
mw.Send(scheduler.Message{
Type: scheduler.MsgReadyForWork,
Payload: MustMarshal(scheduler.ReadyPayload{
WorkerID: mw.ID,
Slots: slots,
Reason: reason,
}),
})
}
// SendHeartbeat sends a heartbeat message
func (mw *MockWorker) SendHeartbeat(slots scheduler.SlotStatus) {
mw.Send(scheduler.Message{
Type: scheduler.MsgHeartbeat,
Payload: MustMarshal(scheduler.HeartbeatPayload{
WorkerID: mw.ID,
Slots: slots,
}),
})
}
// AcceptJob accepts a job assignment
func (mw *MockWorker) AcceptJob(taskID string) {
mw.Send(scheduler.Message{
Type: scheduler.MsgJobAccepted,
Payload: MustMarshal(scheduler.JobResultPayload{
TaskID: taskID,
State: "accepted",
}),
})
}
// CompleteJob sends job completion
func (mw *MockWorker) CompleteJob(taskID string, exitCode int, output string) {
mw.Send(scheduler.Message{
Type: scheduler.MsgJobResult,
Payload: MustMarshal(scheduler.JobResultPayload{
TaskID: taskID,
State: "completed",
ExitCode: exitCode,
Error: output,
}),
})
}
// SendHealth sends service health update
func (mw *MockWorker) SendHealth(taskID string, healthy bool, message string) {
mw.Send(scheduler.Message{
Type: scheduler.MsgServiceHealth,
Payload: MustMarshal(scheduler.ServiceHealthPayload{
TaskID: taskID,
Healthy: healthy,
Message: message,
}),
})
}
// Close closes the worker connection
func (mw *MockWorker) Close() {
mw.mu.Lock()
if mw.closed {
mw.mu.Unlock()
return
}
mw.closed = true
mw.mu.Unlock()
close(mw.SendCh)
mw.Conn.Close()
mw.wg.Wait()
}
// WaitForDisconnect waits for the connection to close
func (mw *MockWorker) WaitForDisconnect(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
mw.wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
// MustMarshal marshals a value to JSON, panicking on error
func MustMarshal(v any) []byte {
b, _ := json.Marshal(v)
return b
}