Add new scheduler component for distributed ML workload orchestration: - Hub-based coordination for multi-worker clusters - Pacing controller for rate limiting job submissions - Priority queue with preemption support - Port allocator for dynamic service discovery - Protocol handlers for worker-scheduler communication - Service manager with OS-specific implementations - Connection management and state persistence - Template system for service deployment Includes comprehensive test suite: - Unit tests for all core components - Integration tests for distributed scenarios - Benchmark tests for performance validation - Mock fixtures for isolated testing Refs: scheduler-architecture.md
248 lines
5.7 KiB
Go
248 lines
5.7 KiB
Go
package scheduler_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/scheduler"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestDistributedRoundTrip validates full job lifecycle through scheduler
|
|
func TestDistributedRoundTrip(t *testing.T) {
|
|
// Create scheduler hub with token auth configured
|
|
testToken := "test-token-123"
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 4,
|
|
AcceptanceTimeoutSecs: 5,
|
|
WorkerTokens: map[string]string{
|
|
testToken: "test-worker",
|
|
},
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
// Start scheduler
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
// Get scheduler address - use the actual listening address
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
// Create mock worker connection with auth token
|
|
workerID := "test-worker"
|
|
headers := http.Header{}
|
|
headers.Set("Authorization", "Bearer "+testToken)
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Start receive goroutine
|
|
recvCh := make(chan scheduler.Message, 10)
|
|
go func() {
|
|
for {
|
|
var msg scheduler.Message
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
close(recvCh)
|
|
return
|
|
}
|
|
recvCh <- msg
|
|
}
|
|
}()
|
|
|
|
// Register worker
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: workerID,
|
|
Capabilities: scheduler.WorkerCapabilities{
|
|
GPUCount: 0,
|
|
GPUType: "",
|
|
},
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for ack
|
|
select {
|
|
case msg := <-recvCh:
|
|
require.Equal(t, scheduler.MsgAck, msg.Type)
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for registration ack")
|
|
}
|
|
|
|
// Send heartbeat to show we're alive
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgHeartbeat,
|
|
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
|
WorkerID: workerID,
|
|
Slots: scheduler.SlotStatus{
|
|
BatchTotal: 4,
|
|
BatchInUse: 0,
|
|
},
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Signal ready for work
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgReadyForWork,
|
|
Payload: mustMarshal(scheduler.ReadyPayload{
|
|
WorkerID: workerID,
|
|
Slots: scheduler.SlotStatus{
|
|
BatchTotal: 4,
|
|
BatchInUse: 0,
|
|
},
|
|
Reason: "polling",
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait a bit and verify connection is still alive
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
// TestWorkerRegistration validates worker registration flow
|
|
func TestWorkerRegistration(t *testing.T) {
|
|
testToken := "reg-test-token"
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 4,
|
|
WorkerTokens: map[string]string{
|
|
testToken: "reg-test-worker",
|
|
},
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
// Start scheduler
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
// Connect worker with auth token
|
|
headers := http.Header{}
|
|
headers.Set("Authorization", "Bearer "+testToken)
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
// Receive channel
|
|
recvCh := make(chan scheduler.Message, 10)
|
|
go func() {
|
|
for {
|
|
var msg scheduler.Message
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
close(recvCh)
|
|
return
|
|
}
|
|
recvCh <- msg
|
|
}
|
|
}()
|
|
|
|
// Register with capabilities
|
|
workerID := "reg-test-worker"
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: workerID,
|
|
Capabilities: scheduler.WorkerCapabilities{
|
|
GPUCount: 2,
|
|
GPUType: "nvidia",
|
|
CPUCount: 8,
|
|
MemoryGB: 32.0,
|
|
},
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Expect ack
|
|
select {
|
|
case msg := <-recvCh:
|
|
assert.Equal(t, scheduler.MsgAck, msg.Type)
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for ack")
|
|
}
|
|
}
|
|
|
|
// TestHeartbeat validates heartbeat slot reporting
|
|
func TestHeartbeat(t *testing.T) {
|
|
testToken := "hb-test-token"
|
|
hub, err := scheduler.NewHub(scheduler.HubConfig{
|
|
BindAddr: "localhost:0",
|
|
StateDir: t.TempDir(),
|
|
DefaultBatchSlots: 4,
|
|
WorkerTokens: map[string]string{
|
|
testToken: "hb-test-worker",
|
|
},
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
defer hub.Stop()
|
|
|
|
// Start scheduler
|
|
err = hub.Start()
|
|
require.NoError(t, err)
|
|
|
|
u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"}
|
|
wsURL := u.String()
|
|
|
|
headers := http.Header{}
|
|
headers.Set("Authorization", "Bearer "+testToken)
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
workerID := "hb-test-worker"
|
|
|
|
// Register first
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgRegister,
|
|
Payload: mustMarshal(scheduler.WorkerRegistration{
|
|
ID: workerID,
|
|
Capabilities: scheduler.WorkerCapabilities{
|
|
GPUCount: 0,
|
|
},
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Send multiple heartbeats
|
|
slots := []scheduler.SlotStatus{
|
|
{BatchTotal: 4, BatchInUse: 0},
|
|
{BatchTotal: 4, BatchInUse: 1},
|
|
{BatchTotal: 4, BatchInUse: 2},
|
|
}
|
|
|
|
for _, slot := range slots {
|
|
err = conn.WriteJSON(scheduler.Message{
|
|
Type: scheduler.MsgHeartbeat,
|
|
Payload: mustMarshal(scheduler.HeartbeatPayload{
|
|
WorkerID: workerID,
|
|
Slots: slot,
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
// Connection should remain healthy
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
func mustMarshal(v any) []byte {
|
|
b, _ := json.Marshal(v)
|
|
return b
|
|
}
|