]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't treat an alternate protocol as a known round tripper
authorIan Lance Taylor <iant@golang.org>
Tue, 28 Jan 2020 00:35:28 +0000 (16:35 -0800)
committerIan Lance Taylor <iant@golang.org>
Wed, 29 Jan 2020 04:04:52 +0000 (04:04 +0000)
As of CL 175857, the client code checks for known round tripper
implementations, and uses simpler cancellation code when it finds one.
However, this code was not considering the case of a request that uses
a user-defined protocol, where the user-defined protocol was
registered with the transport to use a different round tripper.
The effect was that round trippers that worked with earlier
releases would not see the expected cancellation semantics with tip.

Fixes #36820

Change-Id: I60e75b5d0badcfb9fde9d73a966ba1d3f7aa42b1
Reviewed-on: https://go-review.googlesource.com/c/go/+/216618
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/client.go
src/net/http/omithttp2.go
src/net/http/transport.go
src/net/http/transport_test.go

index 6a8c59a6702f343ae4f391babb6d0a5432299202..a496f1c0c7534ee45ccd86765dd558fa22c5780c 100644 (file)
@@ -288,10 +288,17 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool {
 
 // knownRoundTripperImpl reports whether rt is a RoundTripper that's
 // maintained by the Go team and known to implement the latest
-// optional semantics (notably contexts).
-func knownRoundTripperImpl(rt RoundTripper) bool {
-       switch rt.(type) {
-       case *Transport, *http2Transport:
+// optional semantics (notably contexts). The Request is used
+// to check whether this particular request is using an alternate protocol,
+// in which case we need to check the RoundTripper for that protocol.
+func knownRoundTripperImpl(rt RoundTripper, req *Request) bool {
+       switch t := rt.(type) {
+       case *Transport:
+               if altRT := t.alternateRoundTripper(req); altRT != nil {
+                       return knownRoundTripperImpl(altRT, req)
+               }
+               return true
+       case *http2Transport, http2noDialH2RoundTripper:
                return true
        }
        // There's a very minor chance of a false positive with this.
@@ -319,7 +326,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
        if deadline.IsZero() {
                return nop, alwaysFalse
        }
-       knownTransport := knownRoundTripperImpl(rt)
+       knownTransport := knownRoundTripperImpl(rt, req)
        oldCtx := req.Context()
 
        if req.Cancel == nil && knownTransport {
index a0b33e9aad1e9b09c13cc9af43c8e59faa041862..307d93a3b14b3c7464ac1dce9924fbc5cc9c640f 100644 (file)
@@ -36,6 +36,10 @@ type http2erringRoundTripper struct{}
 
 func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
 
+type http2noDialH2RoundTripper struct{}
+
+func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
+
 type http2noDialClientConnPool struct {
        http2clientConnPool http2clientConnPool
 }
index fa2303ca30b1acf16673b20ca71b543c1e164bf4..d0bfdb412cb834db70238008e1a58d3a5449c6b9 100644 (file)
@@ -469,6 +469,17 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool {
        return true
 }
 
+// alternateRoundTripper returns the alternate RoundTripper to use
+// for this request if the Request's URL scheme requires one,
+// or nil for the normal case of using the Transport.
+func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
+       if !t.useRegisteredProtocol(req) {
+               return nil
+       }
+       altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+       return altProto[req.URL.Scheme]
+}
+
 // roundTrip implements a RoundTripper over HTTP.
 func (t *Transport) roundTrip(req *Request) (*Response, error) {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
@@ -500,12 +511,9 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
                }
        }
 
-       if t.useRegisteredProtocol(req) {
-               altProto, _ := t.altProto.Load().(map[string]RoundTripper)
-               if altRT := altProto[scheme]; altRT != nil {
-                       if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
-                               return resp, err
-                       }
+       if altRT := t.alternateRoundTripper(req); altRT != nil {
+               if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
+                       return resp, err
                }
        }
        if !isHTTP {
index 5fc60e1842f46e69ee767587196088e122d57d24..3ca7ce93b29a5be4ae32c36564a9b765847aee05 100644 (file)
@@ -6143,3 +6143,35 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
                t.Errorf("error occurred: %v", err)
        }
 }
+
+// Issue 36820
+// Test that we use the older backward compatible cancellation protocol
+// when a RoundTripper is registered via RegisterProtocol.
+func TestAltProtoCancellation(t *testing.T) {
+       defer afterTest(t)
+       tr := &Transport{}
+       c := &Client{
+               Transport: tr,
+               Timeout:   time.Millisecond,
+       }
+       tr.RegisterProtocol("timeout", timeoutProto{})
+       _, err := c.Get("timeout://bar.com/path")
+       if err == nil {
+               t.Error("request unexpectedly succeeded")
+       } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
+               t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
+       }
+}
+
+var timeoutProtoErr = errors.New("canceled as expected")
+
+type timeoutProto struct{}
+
+func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
+       select {
+       case <-req.Cancel:
+               return nil, timeoutProtoErr
+       case <-time.After(5 * time.Second):
+               return nil, errors.New("request was not canceled")
+       }
+}