]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: close connection if OnProxyConnectResponse returns an error
authorDamien Neil <dneil@google.com>
Sat, 3 Feb 2024 00:19:37 +0000 (16:19 -0800)
committerDamien Neil <dneil@google.com>
Tue, 13 Feb 2024 18:54:50 +0000 (18:54 +0000)
Fixes #64804

Change-Id: Ibe56ab8d114b8826e477b0718470d0b9fbfef9b0
Reviewed-on: https://go-review.googlesource.com/c/go/+/560856
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
src/net/http/transport.go
src/net/http/transport_test.go

index 57c70e72f9522c93a643f48c649d2288c7e62640..2a549a9576b9cfc4a7a5e4b5b4835b6fe4f911d3 100644 (file)
@@ -1761,6 +1761,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                if t.OnProxyConnectResponse != nil {
                        err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
                        if err != nil {
+                               conn.Close()
                                return nil, err
                        }
                }
index 3057024b764baa04791e6a68ad14ad1480b0f905..698a43530ad6b4e6a2093e09e37042fa2357eb37 100644 (file)
@@ -1523,6 +1523,24 @@ func TestOnProxyConnectResponse(t *testing.T) {
 
                c := proxy.Client()
 
+               var (
+                       dials  atomic.Int32
+                       closes atomic.Int32
+               )
+               c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+                       conn, err := net.Dial(network, addr)
+                       if err != nil {
+                               return nil, err
+                       }
+                       dials.Add(1)
+                       return noteCloseConn{
+                               Conn: conn,
+                               closeFunc: func() {
+                                       closes.Add(1)
+                               },
+                       }, nil
+               }
+
                c.Transport.(*Transport).Proxy = ProxyURL(pu)
                c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
                        if proxyURL.String() != pu.String() {
@@ -1534,10 +1552,23 @@ func TestOnProxyConnectResponse(t *testing.T) {
                        }
                        return tcase.err
                }
+               wantCloses := int32(0)
                if _, err := c.Head(ts.URL); err != nil {
+                       wantCloses = 1
                        if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
                                t.Errorf("got %v, want %v", err, tcase.err)
                        }
+               } else {
+                       if tcase.err != nil {
+                               t.Errorf("got %v, want nil", err)
+                       }
+               }
+               if got, want := dials.Load(), int32(1); got != want {
+                       t.Errorf("got %v dials, want %v", got, want)
+               }
+               // #64804: If OnProxyConnectResponse returns an error, we should close the conn.
+               if got, want := closes.Load(), wantCloses; got != want {
+                       t.Errorf("got %v closes, want %v", got, want)
                }
        }
 }