From f2db0dca0b0399c08319d22cbcbfa83be2bb781a Mon Sep 17 00:00:00 2001 From: Sean Liao Date: Sat, 19 Apr 2025 13:50:35 +0100 Subject: [PATCH] net/http/httptest: redirect example.com requests to server The default server cert used by NewServer already includes example.com in its DNSNames, and by default, the client's RootCA configuration means it won't trust a response from the real example.com. Fixes #31054 Change-Id: I0686977e5ffe2c2f22f3fc09a47ee8ecc44765db Reviewed-on: https://go-review.googlesource.com/c/go/+/666855 Reviewed-by: Damien Neil Reviewed-by: Carlos Amedee LUCI-TryBot-Result: Go LUCI --- .../99-minor/net/http/httptest/31054.md | 2 + src/net/http/httptest/server.go | 34 ++++++++++++++++- src/net/http/httptest/server_test.go | 37 +++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 doc/next/6-stdlib/99-minor/net/http/httptest/31054.md diff --git a/doc/next/6-stdlib/99-minor/net/http/httptest/31054.md b/doc/next/6-stdlib/99-minor/net/http/httptest/31054.md new file mode 100644 index 0000000000..ef6a4898f2 --- /dev/null +++ b/doc/next/6-stdlib/99-minor/net/http/httptest/31054.md @@ -0,0 +1,2 @@ +The HTTP client returned by [Server.Client] will now redirect requests for +`example.com` and any subdomains to the server being tested. diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index fa54923179..7ae2561b71 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -7,6 +7,7 @@ package httptest import ( + "context" "crypto/tls" "crypto/x509" "flag" @@ -126,8 +127,24 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } + if s.client == nil { - s.client = &http.Client{Transport: &http.Transport{}} + tr := &http.Transport{} + dialer := net.Dialer{} + // User code may set either of Dial or DialContext, with DialContext taking precedence. + // We set DialContext here to preserve any context values that are passed in, + // but fall back to Dial if the user has set it. + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + if tr.Dial != nil { + return tr.Dial(network, addr) + } + if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") { + addr = s.Listener.Addr().String() + } + return dialer.DialContext(ctx, network, addr) + } + s.client = &http.Client{Transport: tr} + } s.URL = "http://" + s.Listener.Addr().String() s.wrap() @@ -173,12 +190,23 @@ func (s *Server) StartTLS() { } certpool := x509.NewCertPool() certpool.AddCert(s.certificate) - s.client.Transport = &http.Transport{ + tr := &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: certpool, }, ForceAttemptHTTP2: s.EnableHTTP2, } + dialer := net.Dialer{} + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + if tr.Dial != nil { + return tr.Dial(network, addr) + } + if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") { + addr = s.Listener.Addr().String() + } + return dialer.DialContext(ctx, network, addr) + } + s.client.Transport = tr s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() s.wrap() @@ -300,6 +328,8 @@ func (s *Server) Certificate() *x509.Certificate { // It is configured to trust the server's TLS test certificate and will // close its idle connections on [Server.Close]. // Use Server.URL as the base URL to send requests to the server. +// The returned client will also redirect any requests to "example.com" +// or its subdomains to the server. func (s *Server) Client() *http.Client { return s.client } diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index c96a0ff337..f3cfa7c2db 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -293,3 +293,40 @@ func TestTLSServerWithHTTP2(t *testing.T) { }) } } + +func TestClientExampleCom(t *testing.T) { + modes := []struct { + proto string + host string + }{ + {"http", "example.com"}, + {"http", "foo.example.com"}, + {"https", "example.com"}, + {"https", "foo.example.com"}, + } + + for _, tt := range modes { + t.Run(tt.proto+" "+tt.host, func(t *testing.T) { + cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("requested-hostname", r.Host) + })) + switch tt.proto { + case "https": + cst.EnableHTTP2 = true + cst.StartTLS() + default: + cst.Start() + } + + defer cst.Close() + + res, err := cst.Client().Get(tt.proto + "://" + tt.host) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + if got, want := res.Header.Get("requested-hostname"), tt.host; got != want { + t.Fatalf("Requested hostname mismatch\ngot: %q\nwant: %q", got, want) + } + }) + } +} -- 2.52.0