fetch_ml/tests/e2e/wss_reverse_proxy_e2e_test.go
Jeremie Fraeys d8cc2a4efa
refactor: Migrate all test imports from api to api/ws package
Updated 6 test files to use proper api/ws package imports:

1. tests/e2e/websocket_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

2. tests/e2e/wss_reverse_proxy_e2e_test.go
   - api.NewWSHandler → ws.NewHandler

3. tests/integration/ws_handler_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

4. tests/integration/websocket_queue_integration_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

5. tests/unit/api/ws_test.go
   - api.NewWSHandler → wspkg.NewHandler
   - api.Opcode* → wspkg.Opcode*

6. tests/unit/api/ws_jobs_args_test.go
   - api.Opcode* → wspkg.Opcode*

Removed api/ws_compat.go shim as all tests now use proper imports.

Build status: Compiles successfully
2026-02-17 13:52:20 -05:00

97 lines
2.6 KiB
Go

package tests
import (
"crypto/tls"
"log/slog"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/api/ws"
"github.com/jfraeys/fetch_ml/internal/auth"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/logging"
)
type wsUpgradeProxy struct {
proxy *httputil.ReverseProxy
}
func (p *wsUpgradeProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// ReverseProxy will forward Upgrade requests; we additionally ensure hop-by-hop
// headers used for WS are preserved.
if r.Header.Get("Upgrade") != "" {
r.Header.Del("Connection")
r.Header.Add("Connection", "upgrade")
}
p.proxy.ServeHTTP(w, r)
}
func startWSBackendServer(t *testing.T) *httptest.Server {
t.Helper()
logger := logging.NewLogger(slog.LevelInfo, false)
authConfig := &auth.Config{Enabled: false}
expManager := experiment.NewManager(t.TempDir())
h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
srv := httptest.NewServer(h)
t.Cleanup(srv.Close)
return srv
}
func startTLSReverseProxy(t *testing.T, target *url.URL) *httptest.Server {
t.Helper()
rp := httputil.NewSingleHostReverseProxy(target)
proxyHandler := &wsUpgradeProxy{proxy: rp}
proxySrv := httptest.NewTLSServer(proxyHandler)
return proxySrv
}
func TestWSS_UpgradeThroughTLSReverseProxy(t *testing.T) {
backendSrv := startWSBackendServer(t)
backendURL, err := url.Parse(backendSrv.URL)
if err != nil {
t.Fatalf("failed to parse backend url: %v", err)
}
proxySrv := startTLSReverseProxy(t, backendURL)
defer proxySrv.Close()
proxyURL, err := url.Parse(proxySrv.URL)
if err != nil {
t.Fatalf("failed to parse proxy url: %v", err)
}
wssURL := url.URL{Scheme: "wss", Host: proxyURL.Host, Path: "/ws"}
dialer := websocket.Dialer{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true, // test-only (self-signed cert from httptest)
},
}
conn, resp, err := dialer.Dial(wssURL.String(), nil)
if resp != nil && resp.Body != nil {
defer func() { _ = resp.Body.Close() }()
}
if err != nil {
t.Fatalf("failed to connect via wss through proxy: %v", err)
}
defer func() { _ = conn.Close() }()
// Basic write to ensure upgraded channel is usable.
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
statusMsg := []byte{0x02}
statusMsg = append(statusMsg, make([]byte, 16)...)
if err := conn.WriteMessage(websocket.BinaryMessage, statusMsg); err != nil {
t.Fatalf("failed to write websocket message: %v", err)
}
}