diff --git a/internal/scheduler/hub.go b/internal/scheduler/hub.go index c6629d7..1166064 100644 --- a/internal/scheduler/hub.go +++ b/internal/scheduler/hub.go @@ -63,6 +63,7 @@ type HubConfig struct { WorkerTokens map[string]string // token -> workerID PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration TestGracePeriods map[JobTier]time.Duration // For tests to inject fast grace periods + DisableTLSForTesting bool // Force HTTP (ws://) for tests } // WorkerConn represents a connected worker @@ -224,8 +225,8 @@ func (h *SchedulerHub) Start() error { h.config.KeyFile = keyFile } - // Start with TLS if certificates are configured - if h.config.CertFile != "" && h.config.KeyFile != "" { + // Start with TLS if certificates are configured and not disabled for testing + if !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != "" { go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile) } else { go h.server.Serve(listener) @@ -250,6 +251,11 @@ func (h *SchedulerHub) Addr() string { return h.listener.Addr().String() } +// IsUsingTLS returns true if the hub is using TLS (wss://) +func (h *SchedulerHub) IsUsingTLS() bool { + return !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != "" +} + // Stop gracefully shuts down the scheduler func (h *SchedulerHub) Stop() { h.cancel() diff --git a/tests/fixtures/scheduler_fixture.go b/tests/fixtures/scheduler_fixture.go index c16c3d1..35fc646 100644 --- a/tests/fixtures/scheduler_fixture.go +++ b/tests/fixtures/scheduler_fixture.go @@ -157,6 +157,7 @@ func DefaultHubConfig() scheduler.HubConfig { StarvationThresholdMins: 5, AcceptanceTimeoutSecs: 5, GangAllocTimeoutSecs: 10, + DisableTLSForTesting: true, // Use ws:// for tests to avoid TLS complexity // #nosec G101 -- These are test fixture tokens, not real credentials WorkerTokens: tokens, } diff --git a/tests/fixtures/scheduler_mock.go b/tests/fixtures/scheduler_mock.go index f38b057..47b09d4 100644 --- a/tests/fixtures/scheduler_mock.go +++ b/tests/fixtures/scheduler_mock.go @@ -21,6 +21,7 @@ type MockWorker struct { ID string RecvCh chan scheduler.Message SendCh chan scheduler.Message + Done chan struct{} // Closed when worker disconnects wg sync.WaitGroup mu sync.RWMutex closed bool @@ -32,7 +33,12 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) * addr := hub.Addr() require.NotEmpty(t, addr, "hub not started") - wsURL := "ws://" + addr + "/ws/worker" + // Use correct protocol based on hub TLS configuration + protocol := "ws" + if hub.IsUsingTLS() { + protocol = "wss" + } + wsURL := protocol + "://" + addr + "/ws/worker" // Add test token to headers header := http.Header{} @@ -46,13 +52,13 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) * ID: workerID, RecvCh: make(chan scheduler.Message, 100), SendCh: make(chan scheduler.Message, 100), + Done: make(chan struct{}), T: t, } // Start receive goroutine - mw.wg.Add(1) - go func() { - defer mw.wg.Done() + mw.wg.Go(func() { + defer close(mw.Done) // Signal disconnect when goroutine exits for { var msg scheduler.Message err := conn.ReadJSON(&msg) @@ -62,18 +68,16 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) * } mw.RecvCh <- msg } - }() + }) // Start send goroutine - mw.wg.Add(1) - go func() { - defer mw.wg.Done() + mw.wg.Go(func() { for msg := range mw.SendCh { if err := conn.WriteJSON(msg); err != nil { return } } - }() + }) return mw }