package httptest
import (
+ "context"
"crypto/tls"
"crypto/x509"
"flag"
if s.URL != "" {
panic("Server already started")
}
+
if s.client == nil {
- s.client = &http.Client{Transport: &http.Transport{}}
+ tr := &http.Transport{}
+ dialer := net.Dialer{}
+ // User code may set either of Dial or DialContext, with DialContext taking precedence.
+ // We set DialContext here to preserve any context values that are passed in,
+ // but fall back to Dial if the user has set it.
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if tr.Dial != nil {
+ return tr.Dial(network, addr)
+ }
+ if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
+ addr = s.Listener.Addr().String()
+ }
+ return dialer.DialContext(ctx, network, addr)
+ }
+ s.client = &http.Client{Transport: tr}
+
}
s.URL = "http://" + s.Listener.Addr().String()
s.wrap()
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
- s.client.Transport = &http.Transport{
+ tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
ForceAttemptHTTP2: s.EnableHTTP2,
}
+ dialer := net.Dialer{}
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ if tr.Dial != nil {
+ return tr.Dial(network, addr)
+ }
+ if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
+ addr = s.Listener.Addr().String()
+ }
+ return dialer.DialContext(ctx, network, addr)
+ }
+ s.client.Transport = tr
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrap()
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on [Server.Close].
// Use Server.URL as the base URL to send requests to the server.
+// The returned client will also redirect any requests to "example.com"
+// or its subdomains to the server.
func (s *Server) Client() *http.Client {
return s.client
}
})
}
}
+
+func TestClientExampleCom(t *testing.T) {
+ modes := []struct {
+ proto string
+ host string
+ }{
+ {"http", "example.com"},
+ {"http", "foo.example.com"},
+ {"https", "example.com"},
+ {"https", "foo.example.com"},
+ }
+
+ for _, tt := range modes {
+ t.Run(tt.proto+" "+tt.host, func(t *testing.T) {
+ cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("requested-hostname", r.Host)
+ }))
+ switch tt.proto {
+ case "https":
+ cst.EnableHTTP2 = true
+ cst.StartTLS()
+ default:
+ cst.Start()
+ }
+
+ defer cst.Close()
+
+ res, err := cst.Client().Get(tt.proto + "://" + tt.host)
+ if err != nil {
+ t.Fatalf("Failed to make request: %v", err)
+ }
+ if got, want := res.Header.Get("requested-hostname"), tt.host; got != want {
+ t.Fatalf("Requested hostname mismatch\ngot: %q\nwant: %q", got, want)
+ }
+ })
+ }
+}