package tests import ( "crypto/tls" "log/slog" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "testing" "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/api/datasets" "github.com/jfraeys/fetch_ml/internal/api/jobs" jupyterj "github.com/jfraeys/fetch_ml/internal/api/jupyter" "github.com/jfraeys/fetch_ml/internal/api/ws" "github.com/jfraeys/fetch_ml/internal/auth" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/logging" ) type wsUpgradeProxy struct { proxy *httputil.ReverseProxy } func (p *wsUpgradeProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // ReverseProxy will forward Upgrade requests; we additionally ensure hop-by-hop // headers used for WS are preserved. if r.Header.Get("Upgrade") != "" { r.Header.Del("Connection") r.Header.Add("Connection", "upgrade") } p.proxy.ServeHTTP(w, r) } func startWSBackendServer(t *testing.T) *httptest.Server { t.Helper() logger := logging.NewLogger(slog.LevelInfo, false) authConfig := &auth.Config{Enabled: false} expManager := experiment.NewManager(t.TempDir()) jobsHandler := jobs.NewHandler(expManager, logger, nil, nil, authConfig, nil) jupyterHandler := jupyterj.NewHandler(logger, nil, authConfig) datasetsHandler := datasets.NewHandler(logger, nil, "") h := ws.NewHandler(authConfig, logger, expManager, "", nil, nil, nil, nil, nil, jobsHandler, jupyterHandler, datasetsHandler) srv := httptest.NewServer(h) t.Cleanup(srv.Close) return srv } func startTLSReverseProxy(t *testing.T, target *url.URL) *httptest.Server { t.Helper() rp := httputil.NewSingleHostReverseProxy(target) proxyHandler := &wsUpgradeProxy{proxy: rp} proxySrv := httptest.NewTLSServer(proxyHandler) return proxySrv } func TestWSS_UpgradeThroughTLSReverseProxy(t *testing.T) { backendSrv := startWSBackendServer(t) backendURL, err := url.Parse(backendSrv.URL) if err != nil { t.Fatalf("failed to parse backend url: %v", err) } proxySrv := startTLSReverseProxy(t, backendURL) defer proxySrv.Close() proxyURL, err := url.Parse(proxySrv.URL) if err != nil { t.Fatalf("failed to parse proxy url: %v", err) } wssURL := url.URL{Scheme: "wss", Host: proxyURL.Host, Path: "/ws"} dialer := websocket.Dialer{ TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, InsecureSkipVerify: true, // test-only (self-signed cert from httptest) }, } conn, resp, err := dialer.Dial(wssURL.String(), nil) if resp != nil && resp.Body != nil { defer func() { _ = resp.Body.Close() }() } if err != nil { t.Fatalf("failed to connect via wss through proxy: %v", err) } defer func() { _ = conn.Close() }() // Basic write to ensure upgraded channel is usable. _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) statusMsg := []byte{0x02} statusMsg = append(statusMsg, make([]byte, 16)...) if err := conn.WriteMessage(websocket.BinaryMessage, statusMsg); err != nil { t.Fatalf("failed to write websocket message: %v", err) } }