]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Transport.RoundTrip check context.Done earlier
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 13 Jun 2018 15:12:28 +0000 (15:12 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 13 Jun 2018 17:12:57 +0000 (17:12 +0000)
Fixes #25852

Change-Id: I35c630367c8f1934dcffc0b0e08891d55a903518
Reviewed-on: https://go-review.googlesource.com/118560
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Bonventre <andybons@golang.org>
src/net/http/transport.go
src/net/http/transport_test.go

index 9b5ea52c9b4c51eebaa78e0a78a25fadecf67aea..a298e2ef03c34e31f5b121a856cd4741a17b8bbc 100644 (file)
@@ -370,6 +370,13 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
        }
 
        for {
+               select {
+               case <-ctx.Done():
+                       req.closeBody()
+                       return nil, ctx.Err()
+               default:
+               }
+
                // treq gets modified by roundTrip, so we need to recreate for each retry.
                treq := &transportRequest{Request: req, trace: trace}
                cm, err := t.connectMethodForRequest(treq)
index 01a209c6330c6251f3faf62a772ea06cbb0ca2c2..a02867a2d0d442145c9936f5747126fc35e2ecd1 100644 (file)
@@ -4544,3 +4544,28 @@ func TestNoBodyOnChunked304Response(t *testing.T) {
 type funcWriter func([]byte) (int, error)
 
 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
+
+type doneContext struct {
+       context.Context
+       err error
+}
+
+func (doneContext) Done() <-chan struct{} {
+       c := make(chan struct{})
+       close(c)
+       return c
+}
+
+func (d doneContext) Err() error { return d.err }
+
+// Issue 25852: Transport should check whether Context is done early.
+func TestTransportCheckContextDoneEarly(t *testing.T) {
+       tr := &Transport{}
+       req, _ := NewRequest("GET", "http://fake.example/", nil)
+       wantErr := errors.New("some error")
+       req = req.WithContext(doneContext{context.Background(), wantErr})
+       _, err := tr.RoundTrip(req)
+       if err != wantErr {
+               t.Errorf("error = %v; want %v", err, wantErr)
+       }
+}