]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.OnProxyConnectResponse
authorcuiweixie <cuiweixie@gmail.com>
Wed, 2 Nov 2022 13:07:46 +0000 (21:07 +0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 9 Nov 2022 17:41:34 +0000 (17:41 +0000)
Fixes #54299

Change-Id: I3a29527bde7ac71f3824e771982db4257234e9ef
Reviewed-on: https://go-review.googlesource.com/c/go/+/447216
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: xie cui <523516579@qq.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
api/next/54299.txt [new file with mode: 0644]
src/net/http/transport.go
src/net/http/transport_test.go

diff --git a/api/next/54299.txt b/api/next/54299.txt
new file mode 100644 (file)
index 0000000..19bac0c
--- /dev/null
@@ -0,0 +1 @@
+pkg net/http, type Transport struct, OnProxyConnectResponse func(context.Context, *url.URL, *Request, *Response) error #54299
\ No newline at end of file
index 671d9959ea63a7317970e17755267f91df2216c7..2a508ec41b1e0a91f06b94f5e4ee59b16b2a0005 100644 (file)
@@ -120,6 +120,11 @@ type Transport struct {
        // If Proxy is nil or returns a nil *URL, no proxy is used.
        Proxy func(*Request) (*url.URL, error)
 
+       // OnProxyConnectResponse is called when the Transport gets an HTTP response from
+       // a proxy for a CONNECT request. It's called before the check for a 200 OK response.
+       // If it returns an error, the request fails with that error.
+       OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error
+
        // DialContext specifies the dial function for creating unencrypted TCP connections.
        // If DialContext is nil (and the deprecated Dial below is also nil),
        // then the transport dials using package net.
@@ -309,6 +314,7 @@ func (t *Transport) Clone() *Transport {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
        t2 := &Transport{
                Proxy:                  t.Proxy,
+               OnProxyConnectResponse: t.OnProxyConnectResponse,
                DialContext:            t.DialContext,
                Dial:                   t.Dial,
                DialTLS:                t.DialTLS,
@@ -1716,6 +1722,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                        conn.Close()
                        return nil, err
                }
+
+               if t.OnProxyConnectResponse != nil {
+                       err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
+                       if err != nil {
+                               return nil, err
+                       }
+               }
+
                if resp.StatusCode != 200 {
                        _, text, ok := strings.Cut(resp.Status, " ")
                        conn.Close()
index a581845516090190c47e3873c15e7e64eadb19bb..b637e40cb429b7e3420117d03335e3ffc0b21049 100644 (file)
@@ -1465,6 +1465,98 @@ func TestTransportProxy(t *testing.T) {
        }
 }
 
+func TestOnProxyConnectResponse(t *testing.T) {
+
+       var tcases = []struct {
+               proxyStatusCode int
+               err             error
+       }{
+               {
+                       StatusOK,
+                       nil,
+               },
+               {
+                       StatusForbidden,
+                       errors.New("403"),
+               },
+       }
+       for _, tcase := range tcases {
+               h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+
+               })
+
+               h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+                       // Implement an entire CONNECT proxy
+                       if r.Method == "CONNECT" {
+                               if tcase.proxyStatusCode != StatusOK {
+                                       w.WriteHeader(tcase.proxyStatusCode)
+                                       return
+                               }
+                               hijacker, ok := w.(Hijacker)
+                               if !ok {
+                                       t.Errorf("hijack not allowed")
+                                       return
+                               }
+                               clientConn, _, err := hijacker.Hijack()
+                               if err != nil {
+                                       t.Errorf("hijacking failed")
+                                       return
+                               }
+                               res := &Response{
+                                       StatusCode: StatusOK,
+                                       Proto:      "HTTP/1.1",
+                                       ProtoMajor: 1,
+                                       ProtoMinor: 1,
+                                       Header:     make(Header),
+                               }
+
+                               targetConn, err := net.Dial("tcp", r.URL.Host)
+                               if err != nil {
+                                       t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
+                                       return
+                               }
+
+                               if err := res.Write(clientConn); err != nil {
+                                       t.Errorf("Writing 200 OK failed: %v", err)
+                                       return
+                               }
+
+                               go io.Copy(targetConn, clientConn)
+                               go func() {
+                                       io.Copy(clientConn, targetConn)
+                                       targetConn.Close()
+                               }()
+                       }
+               })
+               ts := newClientServerTest(t, https1Mode, h1).ts
+               proxy := newClientServerTest(t, https1Mode, h2).ts
+
+               pu, err := url.Parse(proxy.URL)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               c := proxy.Client()
+
+               c.Transport.(*Transport).Proxy = ProxyURL(pu)
+               c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+                       if proxyURL.String() != pu.String() {
+                               t.Errorf("proxy url got %s, want %s", proxyURL, pu)
+                       }
+
+                       if "https://"+connectReq.URL.String() != ts.URL {
+                               t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
+                       }
+                       return tcase.err
+               }
+               if _, err := c.Head(ts.URL); err != nil {
+                       if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
+                               t.Errorf("got %v, want %v", err, tcase.err)
+                       }
+               }
+       }
+}
+
 // Issue 28012: verify that the Transport closes its TCP connection to http proxies
 // when they're slow to reply to HTTPS CONNECT responses.
 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
@@ -5906,7 +5998,10 @@ func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
 
 func TestTransportClone(t *testing.T) {
        tr := &Transport{
-               Proxy:                  func(*Request) (*url.URL, error) { panic("") },
+               Proxy: func(*Request) (*url.URL, error) { panic("") },
+               OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+                       return nil
+               },
                DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
                Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
                DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },