97 lines
2.6 KiB
Go
97 lines
2.6 KiB
Go
package tests
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jfraeys/fetch_ml/internal/api"
|
|
"github.com/jfraeys/fetch_ml/internal/auth"
|
|
"github.com/jfraeys/fetch_ml/internal/experiment"
|
|
"github.com/jfraeys/fetch_ml/internal/logging"
|
|
)
|
|
|
|
type wsUpgradeProxy struct {
|
|
proxy *httputil.ReverseProxy
|
|
}
|
|
|
|
func (p *wsUpgradeProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// ReverseProxy will forward Upgrade requests; we additionally ensure hop-by-hop
|
|
// headers used for WS are preserved.
|
|
if r.Header.Get("Upgrade") != "" {
|
|
r.Header.Del("Connection")
|
|
r.Header.Add("Connection", "upgrade")
|
|
}
|
|
p.proxy.ServeHTTP(w, r)
|
|
}
|
|
|
|
func startWSBackendServer(t *testing.T) *httptest.Server {
|
|
t.Helper()
|
|
|
|
logger := logging.NewLogger(slog.LevelInfo, false)
|
|
authConfig := &auth.Config{Enabled: false}
|
|
expManager := experiment.NewManager(t.TempDir())
|
|
h := api.NewWSHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil)
|
|
|
|
srv := httptest.NewServer(h)
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
func startTLSReverseProxy(t *testing.T, target *url.URL) *httptest.Server {
|
|
t.Helper()
|
|
|
|
rp := httputil.NewSingleHostReverseProxy(target)
|
|
proxyHandler := &wsUpgradeProxy{proxy: rp}
|
|
|
|
proxySrv := httptest.NewTLSServer(proxyHandler)
|
|
return proxySrv
|
|
}
|
|
|
|
func TestWSS_UpgradeThroughTLSReverseProxy(t *testing.T) {
|
|
backendSrv := startWSBackendServer(t)
|
|
backendURL, err := url.Parse(backendSrv.URL)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse backend url: %v", err)
|
|
}
|
|
|
|
proxySrv := startTLSReverseProxy(t, backendURL)
|
|
defer proxySrv.Close()
|
|
|
|
proxyURL, err := url.Parse(proxySrv.URL)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse proxy url: %v", err)
|
|
}
|
|
|
|
wssURL := url.URL{Scheme: "wss", Host: proxyURL.Host, Path: "/ws"}
|
|
|
|
dialer := websocket.Dialer{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: true, // test-only (self-signed cert from httptest)
|
|
},
|
|
}
|
|
|
|
conn, resp, err := dialer.Dial(wssURL.String(), nil)
|
|
if resp != nil && resp.Body != nil {
|
|
defer func() { _ = resp.Body.Close() }()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("failed to connect via wss through proxy: %v", err)
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
// Basic write to ensure upgraded channel is usable.
|
|
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
statusMsg := []byte{0x02}
|
|
statusMsg = append(statusMsg, make([]byte, 16)...)
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, statusMsg); err != nil {
|
|
t.Fatalf("failed to write websocket message: %v", err)
|
|
}
|
|
}
|