// knownRoundTripperImpl reports whether rt is a RoundTripper that's
// maintained by the Go team and known to implement the latest
-// optional semantics (notably contexts).
-func knownRoundTripperImpl(rt RoundTripper) bool {
- switch rt.(type) {
- case *Transport, *http2Transport:
+// optional semantics (notably contexts). The Request is used
+// to check whether this particular request is using an alternate protocol,
+// in which case we need to check the RoundTripper for that protocol.
+func knownRoundTripperImpl(rt RoundTripper, req *Request) bool {
+ switch t := rt.(type) {
+ case *Transport:
+ if altRT := t.alternateRoundTripper(req); altRT != nil {
+ return knownRoundTripperImpl(altRT, req)
+ }
+ return true
+ case *http2Transport, http2noDialH2RoundTripper:
return true
}
// There's a very minor chance of a false positive with this.
if deadline.IsZero() {
return nop, alwaysFalse
}
- knownTransport := knownRoundTripperImpl(rt)
+ knownTransport := knownRoundTripperImpl(rt, req)
oldCtx := req.Context()
if req.Cancel == nil && knownTransport {
func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
+type http2noDialH2RoundTripper struct{}
+
+func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
+
type http2noDialClientConnPool struct {
http2clientConnPool http2clientConnPool
}
return true
}
+// alternateRoundTripper returns the alternate RoundTripper to use
+// for this request if the Request's URL scheme requires one,
+// or nil for the normal case of using the Transport.
+func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
+ if !t.useRegisteredProtocol(req) {
+ return nil
+ }
+ altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+ return altProto[req.URL.Scheme]
+}
+
// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
}
}
- if t.useRegisteredProtocol(req) {
- altProto, _ := t.altProto.Load().(map[string]RoundTripper)
- if altRT := altProto[scheme]; altRT != nil {
- if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
- return resp, err
- }
+ if altRT := t.alternateRoundTripper(req); altRT != nil {
+ if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
+ return resp, err
}
}
if !isHTTP {
t.Errorf("error occurred: %v", err)
}
}
+
+// Issue 36820
+// Test that we use the older backward compatible cancellation protocol
+// when a RoundTripper is registered via RegisterProtocol.
+func TestAltProtoCancellation(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{
+ Transport: tr,
+ Timeout: time.Millisecond,
+ }
+ tr.RegisterProtocol("timeout", timeoutProto{})
+ _, err := c.Get("timeout://bar.com/path")
+ if err == nil {
+ t.Error("request unexpectedly succeeded")
+ } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
+ t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
+ }
+}
+
+var timeoutProtoErr = errors.New("canceled as expected")
+
+type timeoutProto struct{}
+
+func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
+ select {
+ case <-req.Cancel:
+ return nil, timeoutProtoErr
+ case <-time.After(5 * time.Second):
+ return nil, errors.New("request was not canceled")
+ }
+}