feat(scheduler): add test mode config and TLS detection

- Add DisableTLSForTesting to HubConfig for test environments
- Add IsUsingTLS() method to detect scheduler TLS status
- Update MockWorker to auto-select ws:// vs wss:// protocol
- Set DisableTLSForTesting: true in DefaultHubConfig
This commit is contained in:
Jeremie Fraeys 2026-03-12 14:05:35 -04:00
parent c5524562e9
commit ca913e8878
No known key found for this signature in database
3 changed files with 22 additions and 11 deletions

View file

@ -63,6 +63,7 @@ type HubConfig struct {
WorkerTokens map[string]string // token -> workerID
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
TestGracePeriods map[JobTier]time.Duration // For tests to inject fast grace periods
DisableTLSForTesting bool // Force HTTP (ws://) for tests
}
// WorkerConn represents a connected worker
@ -224,8 +225,8 @@ func (h *SchedulerHub) Start() error {
h.config.KeyFile = keyFile
}
// Start with TLS if certificates are configured
if h.config.CertFile != "" && h.config.KeyFile != "" {
// Start with TLS if certificates are configured and not disabled for testing
if !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != "" {
go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile)
} else {
go h.server.Serve(listener)
@ -250,6 +251,11 @@ func (h *SchedulerHub) Addr() string {
return h.listener.Addr().String()
}
// IsUsingTLS returns true if the hub is using TLS (wss://)
func (h *SchedulerHub) IsUsingTLS() bool {
return !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != ""
}
// Stop gracefully shuts down the scheduler
func (h *SchedulerHub) Stop() {
h.cancel()

View file

@ -157,6 +157,7 @@ func DefaultHubConfig() scheduler.HubConfig {
StarvationThresholdMins: 5,
AcceptanceTimeoutSecs: 5,
GangAllocTimeoutSecs: 10,
DisableTLSForTesting: true, // Use ws:// for tests to avoid TLS complexity
// #nosec G101 -- These are test fixture tokens, not real credentials
WorkerTokens: tokens,
}

View file

@ -21,6 +21,7 @@ type MockWorker struct {
ID string
RecvCh chan scheduler.Message
SendCh chan scheduler.Message
Done chan struct{} // Closed when worker disconnects
wg sync.WaitGroup
mu sync.RWMutex
closed bool
@ -32,7 +33,12 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
addr := hub.Addr()
require.NotEmpty(t, addr, "hub not started")
wsURL := "ws://" + addr + "/ws/worker"
// Use correct protocol based on hub TLS configuration
protocol := "ws"
if hub.IsUsingTLS() {
protocol = "wss"
}
wsURL := protocol + "://" + addr + "/ws/worker"
// Add test token to headers
header := http.Header{}
@ -46,13 +52,13 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
ID: workerID,
RecvCh: make(chan scheduler.Message, 100),
SendCh: make(chan scheduler.Message, 100),
Done: make(chan struct{}),
T: t,
}
// Start receive goroutine
mw.wg.Add(1)
go func() {
defer mw.wg.Done()
mw.wg.Go(func() {
defer close(mw.Done) // Signal disconnect when goroutine exits
for {
var msg scheduler.Message
err := conn.ReadJSON(&msg)
@ -62,18 +68,16 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
}
mw.RecvCh <- msg
}
}()
})
// Start send goroutine
mw.wg.Add(1)
go func() {
defer mw.wg.Done()
mw.wg.Go(func() {
for msg := range mw.SendCh {
if err := conn.WriteJSON(msg); err != nil {
return
}
}
}()
})
return mw
}