c := proxy.Client()
+ var (
+ dials atomic.Int32
+ closes atomic.Int32
+ )
+ c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ conn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ dials.Add(1)
+ return noteCloseConn{
+ Conn: conn,
+ closeFunc: func() {
+ closes.Add(1)
+ },
+ }, nil
+ }
+
c.Transport.(*Transport).Proxy = ProxyURL(pu)
c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
if proxyURL.String() != pu.String() {
}
return tcase.err
}
+ wantCloses := int32(0)
if _, err := c.Head(ts.URL); err != nil {
+ wantCloses = 1
if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
t.Errorf("got %v, want %v", err, tcase.err)
}
+ } else {
+ if tcase.err != nil {
+ t.Errorf("got %v, want nil", err)
+ }
+ }
+ if got, want := dials.Load(), int32(1); got != want {
+ t.Errorf("got %v dials, want %v", got, want)
+ }
+ // #64804: If OnProxyConnectResponse returns an error, we should close the conn.
+ if got, want := closes.Load(), wantCloses; got != want {
+ t.Errorf("got %v closes, want %v", got, want)
}
}
}