From c436eadbc36704012be727457f464d8fbf950638 Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Mon, 27 Jan 2020 16:35:28 -0800 Subject: [PATCH] net/http: don't treat an alternate protocol as a known round tripper As of CL 175857, the client code checks for known round tripper implementations, and uses simpler cancellation code when it finds one. However, this code was not considering the case of a request that uses a user-defined protocol, where the user-defined protocol was registered with the transport to use a different round tripper. The effect was that round trippers that worked with earlier releases would not see the expected cancellation semantics with tip. Fixes #36820 Change-Id: I60e75b5d0badcfb9fde9d73a966ba1d3f7aa42b1 Reviewed-on: https://go-review.googlesource.com/c/go/+/216618 Run-TryBot: Ian Lance Taylor TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/net/http/client.go | 17 ++++++++++++----- src/net/http/omithttp2.go | 4 ++++ src/net/http/transport.go | 20 ++++++++++++++------ src/net/http/transport_test.go | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/net/http/client.go b/src/net/http/client.go index 6a8c59a670..a496f1c0c7 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -288,10 +288,17 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool { // knownRoundTripperImpl reports whether rt is a RoundTripper that's // maintained by the Go team and known to implement the latest -// optional semantics (notably contexts). -func knownRoundTripperImpl(rt RoundTripper) bool { - switch rt.(type) { - case *Transport, *http2Transport: +// optional semantics (notably contexts). The Request is used +// to check whether this particular request is using an alternate protocol, +// in which case we need to check the RoundTripper for that protocol. +func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { + switch t := rt.(type) { + case *Transport: + if altRT := t.alternateRoundTripper(req); altRT != nil { + return knownRoundTripperImpl(altRT, req) + } + return true + case *http2Transport, http2noDialH2RoundTripper: return true } // There's a very minor chance of a false positive with this. @@ -319,7 +326,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi if deadline.IsZero() { return nop, alwaysFalse } - knownTransport := knownRoundTripperImpl(rt) + knownTransport := knownRoundTripperImpl(rt, req) oldCtx := req.Context() if req.Cancel == nil && knownTransport { diff --git a/src/net/http/omithttp2.go b/src/net/http/omithttp2.go index a0b33e9aad..307d93a3b1 100644 --- a/src/net/http/omithttp2.go +++ b/src/net/http/omithttp2.go @@ -36,6 +36,10 @@ type http2erringRoundTripper struct{} func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } +type http2noDialH2RoundTripper struct{} + +func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } + type http2noDialClientConnPool struct { http2clientConnPool http2clientConnPool } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index fa2303ca30..d0bfdb412c 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -469,6 +469,17 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool { return true } +// alternateRoundTripper returns the alternate RoundTripper to use +// for this request if the Request's URL scheme requires one, +// or nil for the normal case of using the Transport. +func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { + if !t.useRegisteredProtocol(req) { + return nil + } + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + return altProto[req.URL.Scheme] +} + // roundTrip implements a RoundTripper over HTTP. func (t *Transport) roundTrip(req *Request) (*Response, error) { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) @@ -500,12 +511,9 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } } - if t.useRegisteredProtocol(req) { - altProto, _ := t.altProto.Load().(map[string]RoundTripper) - if altRT := altProto[scheme]; altRT != nil { - if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { - return resp, err - } + if altRT := t.alternateRoundTripper(req); altRT != nil { + if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { + return resp, err } } if !isHTTP { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 5fc60e1842..3ca7ce93b2 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -6143,3 +6143,35 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { t.Errorf("error occurred: %v", err) } } + +// Issue 36820 +// Test that we use the older backward compatible cancellation protocol +// when a RoundTripper is registered via RegisterProtocol. +func TestAltProtoCancellation(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &Client{ + Transport: tr, + Timeout: time.Millisecond, + } + tr.RegisterProtocol("timeout", timeoutProto{}) + _, err := c.Get("timeout://bar.com/path") + if err == nil { + t.Error("request unexpectedly succeeded") + } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { + t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) + } +} + +var timeoutProtoErr = errors.New("canceled as expected") + +type timeoutProto struct{} + +func (timeoutProto) RoundTrip(req *Request) (*Response, error) { + select { + case <-req.Cancel: + return nil, timeoutProtoErr + case <-time.After(5 * time.Second): + return nil, errors.New("request was not canceled") + } +} -- 2.48.1