]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.Clone
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 30 Apr 2019 20:03:57 +0000 (20:03 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 3 May 2019 15:17:54 +0000 (15:17 +0000)
Fixes #26013

Change-Id: I2c82bd90ea7ce6f7a8e5b6c460d3982dca681a93
Reviewed-on: https://go-review.googlesource.com/c/go/+/174597
Reviewed-by: Andrew Bonventre <andybons@golang.org>
src/net/http/transport.go
src/net/http/transport_test.go

index 88761909fd503b8443f22d3360c4ea41c98eb174..ca97489eeaa74222dfbad4853a574f9d5db48d8a 100644 (file)
@@ -261,13 +261,14 @@ type Transport struct {
 
        // ReadBufferSize specifies the size of the read buffer used
        // when reading from the transport.
-       //If zero, a default (currently 4KB) is used.
+       // If zero, a default (currently 4KB) is used.
        ReadBufferSize int
 
        // nextProtoOnce guards initialization of TLSNextProto and
        // h2transport (via onceSetNextProtoDefaults)
-       nextProtoOnce sync.Once
-       h2transport   h2Transport // non-nil if http2 wired up
+       nextProtoOnce      sync.Once
+       h2transport        h2Transport // non-nil if http2 wired up
+       tlsNextProtoWasNil bool        // whether TLSNextProto was nil when the Once fired
 
        // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
        // TLSClientConfig or Dial, DialTLS or DialContext func is provided. By default, use of any those fields conservatively
@@ -290,6 +291,40 @@ func (t *Transport) readBufferSize() int {
        return 4 << 10
 }
 
+// Clone returns a deep copy of t's exported fields.
+func (t *Transport) Clone() *Transport {
+       t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+       t2 := &Transport{
+               Proxy:                  t.Proxy,
+               DialContext:            t.DialContext,
+               Dial:                   t.Dial,
+               DialTLS:                t.DialTLS,
+               TLSClientConfig:        t.TLSClientConfig.Clone(),
+               TLSHandshakeTimeout:    t.TLSHandshakeTimeout,
+               DisableKeepAlives:      t.DisableKeepAlives,
+               DisableCompression:     t.DisableCompression,
+               MaxIdleConns:           t.MaxIdleConns,
+               MaxIdleConnsPerHost:    t.MaxIdleConnsPerHost,
+               MaxConnsPerHost:        t.MaxConnsPerHost,
+               IdleConnTimeout:        t.IdleConnTimeout,
+               ResponseHeaderTimeout:  t.ResponseHeaderTimeout,
+               ExpectContinueTimeout:  t.ExpectContinueTimeout,
+               ProxyConnectHeader:     t.ProxyConnectHeader.Clone(),
+               MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
+               ForceAttemptHTTP2:      t.ForceAttemptHTTP2,
+               WriteBufferSize:        t.WriteBufferSize,
+               ReadBufferSize:         t.ReadBufferSize,
+       }
+       if !t.tlsNextProtoWasNil {
+               npm := map[string]func(authority string, c *tls.Conn) RoundTripper{}
+               for k, v := range t.TLSNextProto {
+                       npm[k] = v
+               }
+               t2.TLSNextProto = npm
+       }
+       return t2
+}
+
 // h2Transport is the interface we expect to be able to call from
 // net/http against an *http2.Transport that's either bundled into
 // h2_bundle.go or supplied by the user via x/net/http2.
@@ -303,6 +338,7 @@ type h2Transport interface {
 // onceSetNextProtoDefaults initializes TLSNextProto.
 // It must be called via t.nextProtoOnce.Do.
 func (t *Transport) onceSetNextProtoDefaults() {
+       t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
        if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
                return
        }
index dbfbd5792d4a5b8dafeb5e03e2dbb20be8df3db5..cf2bbe1189cac58cee72bb7f9afe84c9c2ac5cda 100644 (file)
@@ -20,6 +20,7 @@ import (
        "encoding/binary"
        "errors"
        "fmt"
+       "go/token"
        "internal/nettrace"
        "internal/testenv"
        "io"
@@ -5320,3 +5321,53 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) {
                })
        }
 }
+
+func TestTransportClone(t *testing.T) {
+       tr := &Transport{
+               Proxy:                  func(*Request) (*url.URL, error) { panic("") },
+               DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+               Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
+               DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
+               TLSClientConfig:        new(tls.Config),
+               TLSHandshakeTimeout:    time.Second,
+               DisableKeepAlives:      true,
+               DisableCompression:     true,
+               MaxIdleConns:           1,
+               MaxIdleConnsPerHost:    1,
+               MaxConnsPerHost:        1,
+               IdleConnTimeout:        time.Second,
+               ResponseHeaderTimeout:  time.Second,
+               ExpectContinueTimeout:  time.Second,
+               ProxyConnectHeader:     Header{},
+               MaxResponseHeaderBytes: 1,
+               ForceAttemptHTTP2:      true,
+               TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
+                       "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
+               },
+               ReadBufferSize:  1,
+               WriteBufferSize: 1,
+       }
+       tr2 := tr.Clone()
+       rv := reflect.ValueOf(tr2).Elem()
+       rt := rv.Type()
+       for i := 0; i < rt.NumField(); i++ {
+               sf := rt.Field(i)
+               if !token.IsExported(sf.Name) {
+                       continue
+               }
+               if rv.Field(i).IsZero() {
+                       t.Errorf("cloned field t2.%s is zero", sf.Name)
+               }
+       }
+
+       if _, ok := tr2.TLSNextProto["foo"]; !ok {
+               t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
+       }
+
+       // But test that a nil TLSNextProto is kept nil:
+       tr = new(Transport)
+       tr2 = tr.Clone()
+       if tr2.TLSNextProto != nil {
+               t.Errorf("Transport.TLSNextProto unexpected non-nil")
+       }
+}