package scheduler_test import ( "encoding/json" "net/http" "net/url" "testing" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/scheduler" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestDistributedRoundTrip validates full job lifecycle through scheduler func TestDistributedRoundTrip(t *testing.T) { // Create scheduler hub with token auth configured testToken := "test-token-123" hub, err := scheduler.NewHub(scheduler.HubConfig{ BindAddr: "localhost:0", StateDir: t.TempDir(), DefaultBatchSlots: 4, AcceptanceTimeoutSecs: 5, WorkerTokens: map[string]string{ testToken: "test-worker", }, }, nil) require.NoError(t, err) defer hub.Stop() // Start scheduler err = hub.Start() require.NoError(t, err) // Get scheduler address - use the actual listening address u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} wsURL := u.String() // Create mock worker connection with auth token workerID := "test-worker" headers := http.Header{} headers.Set("Authorization", "Bearer "+testToken) conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) require.NoError(t, err) defer conn.Close() // Start receive goroutine recvCh := make(chan scheduler.Message, 10) go func() { for { var msg scheduler.Message err := conn.ReadJSON(&msg) if err != nil { close(recvCh) return } recvCh <- msg } }() // Register worker err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgRegister, Payload: mustMarshal(scheduler.WorkerRegistration{ ID: workerID, Capabilities: scheduler.WorkerCapabilities{ GPUCount: 0, GPUType: "", }, }), }) require.NoError(t, err) // Wait for ack select { case msg := <-recvCh: require.Equal(t, scheduler.MsgAck, msg.Type) case <-time.After(2 * time.Second): t.Fatal("timeout waiting for registration ack") } // Send heartbeat to show we're alive err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgHeartbeat, Payload: mustMarshal(scheduler.HeartbeatPayload{ WorkerID: workerID, Slots: scheduler.SlotStatus{ BatchTotal: 4, BatchInUse: 0, }, }), }) require.NoError(t, err) // Signal ready for work err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgReadyForWork, Payload: mustMarshal(scheduler.ReadyPayload{ WorkerID: workerID, Slots: scheduler.SlotStatus{ BatchTotal: 4, BatchInUse: 0, }, Reason: "polling", }), }) require.NoError(t, err) // Wait a bit and verify connection is still alive time.Sleep(200 * time.Millisecond) } // TestWorkerRegistration validates worker registration flow func TestWorkerRegistration(t *testing.T) { testToken := "reg-test-token" hub, err := scheduler.NewHub(scheduler.HubConfig{ BindAddr: "localhost:0", StateDir: t.TempDir(), DefaultBatchSlots: 4, WorkerTokens: map[string]string{ testToken: "reg-test-worker", }, }, nil) require.NoError(t, err) defer hub.Stop() // Start scheduler err = hub.Start() require.NoError(t, err) u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} wsURL := u.String() // Connect worker with auth token headers := http.Header{} headers.Set("Authorization", "Bearer "+testToken) conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) require.NoError(t, err) defer conn.Close() // Receive channel recvCh := make(chan scheduler.Message, 10) go func() { for { var msg scheduler.Message err := conn.ReadJSON(&msg) if err != nil { close(recvCh) return } recvCh <- msg } }() // Register with capabilities workerID := "reg-test-worker" err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgRegister, Payload: mustMarshal(scheduler.WorkerRegistration{ ID: workerID, Capabilities: scheduler.WorkerCapabilities{ GPUCount: 2, GPUType: "nvidia", CPUCount: 8, MemoryGB: 32.0, }, }), }) require.NoError(t, err) // Expect ack select { case msg := <-recvCh: assert.Equal(t, scheduler.MsgAck, msg.Type) case <-time.After(2 * time.Second): t.Fatal("timeout waiting for ack") } } // TestHeartbeat validates heartbeat slot reporting func TestHeartbeat(t *testing.T) { testToken := "hb-test-token" hub, err := scheduler.NewHub(scheduler.HubConfig{ BindAddr: "localhost:0", StateDir: t.TempDir(), DefaultBatchSlots: 4, WorkerTokens: map[string]string{ testToken: "hb-test-worker", }, }, nil) require.NoError(t, err) defer hub.Stop() // Start scheduler err = hub.Start() require.NoError(t, err) u := &url.URL{Scheme: "ws", Host: hub.Addr(), Path: "/ws/worker"} wsURL := u.String() headers := http.Header{} headers.Set("Authorization", "Bearer "+testToken) conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) require.NoError(t, err) defer conn.Close() workerID := "hb-test-worker" // Register first err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgRegister, Payload: mustMarshal(scheduler.WorkerRegistration{ ID: workerID, Capabilities: scheduler.WorkerCapabilities{ GPUCount: 0, }, }), }) require.NoError(t, err) // Send multiple heartbeats slots := []scheduler.SlotStatus{ {BatchTotal: 4, BatchInUse: 0}, {BatchTotal: 4, BatchInUse: 1}, {BatchTotal: 4, BatchInUse: 2}, } for _, slot := range slots { err = conn.WriteJSON(scheduler.Message{ Type: scheduler.MsgHeartbeat, Payload: mustMarshal(scheduler.HeartbeatPayload{ WorkerID: workerID, Slots: slot, }), }) require.NoError(t, err) time.Sleep(50 * time.Millisecond) } // Connection should remain healthy time.Sleep(100 * time.Millisecond) } func mustMarshal(v any) []byte { b, _ := json.Marshal(v) return b }