feat(scheduler): add test mode config and TLS detection
- Add DisableTLSForTesting to HubConfig for test environments - Add IsUsingTLS() method to detect scheduler TLS status - Update MockWorker to auto-select ws:// vs wss:// protocol - Set DisableTLSForTesting: true in DefaultHubConfig
This commit is contained in:
parent
c5524562e9
commit
ca913e8878
3 changed files with 22 additions and 11 deletions
|
|
@ -63,6 +63,7 @@ type HubConfig struct {
|
|||
WorkerTokens map[string]string // token -> workerID
|
||||
PluginQuota PluginQuotaConfig // NEW: plugin GPU quota configuration
|
||||
TestGracePeriods map[JobTier]time.Duration // For tests to inject fast grace periods
|
||||
DisableTLSForTesting bool // Force HTTP (ws://) for tests
|
||||
}
|
||||
|
||||
// WorkerConn represents a connected worker
|
||||
|
|
@ -224,8 +225,8 @@ func (h *SchedulerHub) Start() error {
|
|||
h.config.KeyFile = keyFile
|
||||
}
|
||||
|
||||
// Start with TLS if certificates are configured
|
||||
if h.config.CertFile != "" && h.config.KeyFile != "" {
|
||||
// Start with TLS if certificates are configured and not disabled for testing
|
||||
if !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != "" {
|
||||
go h.server.ServeTLS(listener, h.config.CertFile, h.config.KeyFile)
|
||||
} else {
|
||||
go h.server.Serve(listener)
|
||||
|
|
@ -250,6 +251,11 @@ func (h *SchedulerHub) Addr() string {
|
|||
return h.listener.Addr().String()
|
||||
}
|
||||
|
||||
// IsUsingTLS returns true if the hub is using TLS (wss://)
|
||||
func (h *SchedulerHub) IsUsingTLS() bool {
|
||||
return !h.config.DisableTLSForTesting && h.config.CertFile != "" && h.config.KeyFile != ""
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the scheduler
|
||||
func (h *SchedulerHub) Stop() {
|
||||
h.cancel()
|
||||
|
|
|
|||
1
tests/fixtures/scheduler_fixture.go
vendored
1
tests/fixtures/scheduler_fixture.go
vendored
|
|
@ -157,6 +157,7 @@ func DefaultHubConfig() scheduler.HubConfig {
|
|||
StarvationThresholdMins: 5,
|
||||
AcceptanceTimeoutSecs: 5,
|
||||
GangAllocTimeoutSecs: 10,
|
||||
DisableTLSForTesting: true, // Use ws:// for tests to avoid TLS complexity
|
||||
// #nosec G101 -- These are test fixture tokens, not real credentials
|
||||
WorkerTokens: tokens,
|
||||
}
|
||||
|
|
|
|||
22
tests/fixtures/scheduler_mock.go
vendored
22
tests/fixtures/scheduler_mock.go
vendored
|
|
@ -21,6 +21,7 @@ type MockWorker struct {
|
|||
ID string
|
||||
RecvCh chan scheduler.Message
|
||||
SendCh chan scheduler.Message
|
||||
Done chan struct{} // Closed when worker disconnects
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
|
|
@ -32,7 +33,12 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
|
|||
addr := hub.Addr()
|
||||
require.NotEmpty(t, addr, "hub not started")
|
||||
|
||||
wsURL := "ws://" + addr + "/ws/worker"
|
||||
// Use correct protocol based on hub TLS configuration
|
||||
protocol := "ws"
|
||||
if hub.IsUsingTLS() {
|
||||
protocol = "wss"
|
||||
}
|
||||
wsURL := protocol + "://" + addr + "/ws/worker"
|
||||
|
||||
// Add test token to headers
|
||||
header := http.Header{}
|
||||
|
|
@ -46,13 +52,13 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
|
|||
ID: workerID,
|
||||
RecvCh: make(chan scheduler.Message, 100),
|
||||
SendCh: make(chan scheduler.Message, 100),
|
||||
Done: make(chan struct{}),
|
||||
T: t,
|
||||
}
|
||||
|
||||
// Start receive goroutine
|
||||
mw.wg.Add(1)
|
||||
go func() {
|
||||
defer mw.wg.Done()
|
||||
mw.wg.Go(func() {
|
||||
defer close(mw.Done) // Signal disconnect when goroutine exits
|
||||
for {
|
||||
var msg scheduler.Message
|
||||
err := conn.ReadJSON(&msg)
|
||||
|
|
@ -62,18 +68,16 @@ func NewMockWorker(t testing.TB, hub *scheduler.SchedulerHub, workerID string) *
|
|||
}
|
||||
mw.RecvCh <- msg
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Start send goroutine
|
||||
mw.wg.Add(1)
|
||||
go func() {
|
||||
defer mw.wg.Done()
|
||||
mw.wg.Go(func() {
|
||||
for msg := range mw.SendCh {
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return mw
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue