]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: redirect example.com requests to server
authorSean Liao <sean@liao.dev>
Sat, 19 Apr 2025 12:50:35 +0000 (13:50 +0100)
committerSean Liao <sean@liao.dev>
Sat, 23 Aug 2025 19:28:10 +0000 (12:28 -0700)
The default server cert used by NewServer already includes example.com
in its DNSNames, and by default, the client's RootCA configuration
means it won't trust a response from the real example.com.

Fixes #31054

Change-Id: I0686977e5ffe2c2f22f3fc09a47ee8ecc44765db
Reviewed-on: https://go-review.googlesource.com/c/go/+/666855
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

doc/next/6-stdlib/99-minor/net/http/httptest/31054.md [new file with mode: 0644]
src/net/http/httptest/server.go
src/net/http/httptest/server_test.go

diff --git a/doc/next/6-stdlib/99-minor/net/http/httptest/31054.md b/doc/next/6-stdlib/99-minor/net/http/httptest/31054.md
new file mode 100644 (file)
index 0000000..ef6a489
--- /dev/null
@@ -0,0 +1,2 @@
+The HTTP client returned by [Server.Client] will now redirect requests for
+`example.com` and any subdomains to the server being tested.
index fa549231796925fc61b23d53d477280cfaafaece..7ae2561b71971dbdb76a558b9d378d5623a17e0b 100644 (file)
@@ -7,6 +7,7 @@
 package httptest
 
 import (
+       "context"
        "crypto/tls"
        "crypto/x509"
        "flag"
@@ -126,8 +127,24 @@ func (s *Server) Start() {
        if s.URL != "" {
                panic("Server already started")
        }
+
        if s.client == nil {
-               s.client = &http.Client{Transport: &http.Transport{}}
+               tr := &http.Transport{}
+               dialer := net.Dialer{}
+               // User code may set either of Dial or DialContext, with DialContext taking precedence.
+               // We set DialContext here to preserve any context values that are passed in,
+               // but fall back to Dial if the user has set it.
+               tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+                       if tr.Dial != nil {
+                               return tr.Dial(network, addr)
+                       }
+                       if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
+                               addr = s.Listener.Addr().String()
+                       }
+                       return dialer.DialContext(ctx, network, addr)
+               }
+               s.client = &http.Client{Transport: tr}
+
        }
        s.URL = "http://" + s.Listener.Addr().String()
        s.wrap()
@@ -173,12 +190,23 @@ func (s *Server) StartTLS() {
        }
        certpool := x509.NewCertPool()
        certpool.AddCert(s.certificate)
-       s.client.Transport = &http.Transport{
+       tr := &http.Transport{
                TLSClientConfig: &tls.Config{
                        RootCAs: certpool,
                },
                ForceAttemptHTTP2: s.EnableHTTP2,
        }
+       dialer := net.Dialer{}
+       tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+               if tr.Dial != nil {
+                       return tr.Dial(network, addr)
+               }
+               if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
+                       addr = s.Listener.Addr().String()
+               }
+               return dialer.DialContext(ctx, network, addr)
+       }
+       s.client.Transport = tr
        s.Listener = tls.NewListener(s.Listener, s.TLS)
        s.URL = "https://" + s.Listener.Addr().String()
        s.wrap()
@@ -300,6 +328,8 @@ func (s *Server) Certificate() *x509.Certificate {
 // It is configured to trust the server's TLS test certificate and will
 // close its idle connections on [Server.Close].
 // Use Server.URL as the base URL to send requests to the server.
+// The returned client will also redirect any requests to "example.com"
+// or its subdomains to the server.
 func (s *Server) Client() *http.Client {
        return s.client
 }
index c96a0ff3379b4dec79d2341d67c721c80bdf9d82..f3cfa7c2dbdb76d8ff52a4bd612439b5c450af69 100644 (file)
@@ -293,3 +293,40 @@ func TestTLSServerWithHTTP2(t *testing.T) {
                })
        }
 }
+
+func TestClientExampleCom(t *testing.T) {
+       modes := []struct {
+               proto string
+               host  string
+       }{
+               {"http", "example.com"},
+               {"http", "foo.example.com"},
+               {"https", "example.com"},
+               {"https", "foo.example.com"},
+       }
+
+       for _, tt := range modes {
+               t.Run(tt.proto+" "+tt.host, func(t *testing.T) {
+                       cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+                               w.Header().Set("requested-hostname", r.Host)
+                       }))
+                       switch tt.proto {
+                       case "https":
+                               cst.EnableHTTP2 = true
+                               cst.StartTLS()
+                       default:
+                               cst.Start()
+                       }
+
+                       defer cst.Close()
+
+                       res, err := cst.Client().Get(tt.proto + "://" + tt.host)
+                       if err != nil {
+                               t.Fatalf("Failed to make request: %v", err)
+                       }
+                       if got, want := res.Header.Get("requested-hostname"), tt.host; got != want {
+                               t.Fatalf("Requested hostname mismatch\ngot: %q\nwant: %q", got, want)
+                       }
+               })
+       }
+}