]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.GetProxyConnectHeader
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 6 Oct 2020 17:53:11 +0000 (10:53 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 6 Oct 2020 22:02:30 +0000 (22:02 +0000)
Fixes golang/go#41048

Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261
Reviewed-on: https://go-review.googlesource.com/c/go/+/259917
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
src/net/http/transport.go
src/net/http/transport_test.go

index b97c4268b57e71638ab96502188e9c2d116bafb5..454616643054bd8204e64aba4005808eeaa81fd2 100644 (file)
@@ -240,8 +240,18 @@ type Transport struct {
 
        // ProxyConnectHeader optionally specifies headers to send to
        // proxies during CONNECT requests.
+       // To set the header dynamically, see GetProxyConnectHeader.
        ProxyConnectHeader Header
 
+       // GetProxyConnectHeader optionally specifies a func to return
+       // headers to send to proxyURL during a CONNECT request to the
+       // ip:port target.
+       // If it returns an error, the Transport's RoundTrip fails with
+       // that error. It can return (nil, nil) to not add headers.
+       // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
+       // ignored.
+       GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
+
        // MaxResponseHeaderBytes specifies a limit on how many
        // response bytes are allowed in the server's response
        // header.
@@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport {
                ResponseHeaderTimeout:  t.ResponseHeaderTimeout,
                ExpectContinueTimeout:  t.ExpectContinueTimeout,
                ProxyConnectHeader:     t.ProxyConnectHeader.Clone(),
+               GetProxyConnectHeader:  t.GetProxyConnectHeader,
                MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
                ForceAttemptHTTP2:      t.ForceAttemptHTTP2,
                WriteBufferSize:        t.WriteBufferSize,
@@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                }
        case cm.targetScheme == "https":
                conn := pconn.conn
-               hdr := t.ProxyConnectHeader
+               var hdr Header
+               if t.GetProxyConnectHeader != nil {
+                       var err error
+                       hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
+                       if err != nil {
+                               conn.Close()
+                               return nil, err
+                       }
+               } else {
+                       hdr = t.ProxyConnectHeader
+               }
                if hdr == nil {
                        hdr = make(Header)
                }
index f4b76236308fadaa01083a303dcb43fcf639d0b3..a1c9e822b472a4950fda761863cb5cf770c94cb2 100644 (file)
@@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
        }
 }
 
+func TestTransportProxyGetConnectHeader(t *testing.T) {
+       defer afterTest(t)
+       reqc := make(chan *Request, 1)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.Method != "CONNECT" {
+                       t.Errorf("method = %q; want CONNECT", r.Method)
+               }
+               reqc <- r
+               c, _, err := w.(Hijacker).Hijack()
+               if err != nil {
+                       t.Errorf("Hijack: %v", err)
+                       return
+               }
+               c.Close()
+       }))
+       defer ts.Close()
+
+       c := ts.Client()
+       c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+               return url.Parse(ts.URL)
+       }
+       // These should be ignored:
+       c.Transport.(*Transport).ProxyConnectHeader = Header{
+               "User-Agent": {"foo"},
+               "Other":      {"bar"},
+       }
+       c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
+               return Header{
+                       "User-Agent": {"foo2"},
+                       "Other":      {"bar2"},
+               }, nil
+       }
+
+       res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+       if err == nil {
+               res.Body.Close()
+               t.Errorf("unexpected success")
+       }
+       select {
+       case <-time.After(3 * time.Second):
+               t.Fatal("timeout")
+       case r := <-reqc:
+               if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
+                       t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+               }
+               if got, want := r.Header.Get("Other"), "bar2"; got != want {
+                       t.Errorf("CONNECT request Other = %q; want %q", got, want)
+               }
+       }
+}
+
 var errFakeRoundTrip = errors.New("fake roundtrip")
 
 type funcRoundTripper func()
@@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) {
                ResponseHeaderTimeout:  time.Second,
                ExpectContinueTimeout:  time.Second,
                ProxyConnectHeader:     Header{},
+               GetProxyConnectHeader:  func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
                MaxResponseHeaderBytes: 1,
                ForceAttemptHTTP2:      true,
                TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{