]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: ensure Request.Body.Close is called once and only once
authorRoss Light <ross@zombiezen.com>
Sat, 26 Sep 2020 15:49:56 +0000 (08:49 -0700)
committerRuss Cox <rsc@golang.org>
Fri, 16 Oct 2020 16:53:27 +0000 (16:53 +0000)
Makes *Request.write always close the body, so that callers no longer
have to close the body on returned errors, which was the trigger for
double-close behavior.

Fixes #40382

Change-Id: I128f7ec70415f240d82154cfca134b3f692191e3
Reviewed-on: https://go-review.googlesource.com/c/go/+/257819
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Trust: Damien Neil <dneil@google.com>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>

src/net/http/client_test.go
src/net/http/request.go
src/net/http/transfer.go
src/net/http/transport.go

index 80807fae7a42d64c485179abedb664258cda679d..4bd62735e840c7371f9a2b0507b9997e153985ad 100644 (file)
@@ -2026,3 +2026,60 @@ func TestClientPopulatesNilResponseBody(t *testing.T) {
                t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b)
        }
 }
+
+// Issue 40382: Client calls Close multiple times on Request.Body.
+func TestClientCallsCloseOnlyOnce(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.WriteHeader(StatusNoContent)
+       }))
+       defer cst.close()
+
+       // Issue occurred non-deterministically: needed to occur after a successful
+       // write (into TCP buffer) but before end of body.
+       for i := 0; i < 50 && !t.Failed(); i++ {
+               body := &issue40382Body{t: t, n: 300000}
+               req, err := NewRequest(MethodPost, cst.ts.URL, body)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               resp, err := cst.tr.RoundTrip(req)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               resp.Body.Close()
+       }
+}
+
+// issue40382Body is an io.ReadCloser for TestClientCallsCloseOnlyOnce.
+// Its Read reads n bytes before returning io.EOF.
+// Its Close returns nil but fails the test if called more than once.
+type issue40382Body struct {
+       t                *testing.T
+       n                int
+       closeCallsAtomic int32
+}
+
+func (b *issue40382Body) Read(p []byte) (int, error) {
+       switch {
+       case b.n == 0:
+               return 0, io.EOF
+       case b.n < len(p):
+               p = p[:b.n]
+               fallthrough
+       default:
+               for i := range p {
+                       p[i] = 'x'
+               }
+               b.n -= len(p)
+               return len(p), nil
+       }
+}
+
+func (b *issue40382Body) Close() error {
+       if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 {
+               b.t.Error("Body closed more than once")
+       }
+       return nil
+}
index 183606d0fffe194861b4443af4fa314d73817cf1..df73d5f62d4f08db4f1eac518d27e00332181a6f 100644 (file)
@@ -544,6 +544,7 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or
 
 // extraHeaders may be nil
 // waitForContinue may be nil
+// always closes body
 func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) {
        trace := httptrace.ContextClientTrace(r.Context())
        if trace != nil && trace.WroteRequest != nil {
@@ -553,6 +554,15 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
                        })
                }()
        }
+       closed := false
+       defer func() {
+               if closed {
+                       return
+               }
+               if closeErr := r.closeBody(); closeErr != nil && err == nil {
+                       err = closeErr
+               }
+       }()
 
        // Find the target host. Prefer the Host: header, but if that
        // is not given, use the host from the request URL.
@@ -671,6 +681,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
                        trace.Wait100Continue()
                }
                if !waitForContinue() {
+                       closed = true
                        r.closeBody()
                        return nil
                }
@@ -683,6 +694,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
        }
 
        // Write body and trailer
+       closed = true
        err = tw.writeBody(w)
        if err != nil {
                if tw.bodyReadError == err {
@@ -1387,10 +1399,11 @@ func (r *Request) wantsClose() bool {
        return hasToken(r.Header.get("Connection"), "close")
 }
 
-func (r *Request) closeBody() {
-       if r.Body != nil {
-               r.Body.Close()
+func (r *Request) closeBody() error {
+       if r.Body == nil {
+               return nil
        }
+       return r.Body.Close()
 }
 
 func (r *Request) isReplayable() bool {
index ab009177bc7bb3f0f735af365887f5b0eb1cfd62..c3234f30cc31063a769036527bd4d146b2e2b0f0 100644 (file)
@@ -330,9 +330,18 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace)
        return nil
 }
 
-func (t *transferWriter) writeBody(w io.Writer) error {
-       var err error
+// always closes t.BodyCloser
+func (t *transferWriter) writeBody(w io.Writer) (err error) {
        var ncopy int64
+       closed := false
+       defer func() {
+               if closed || t.BodyCloser == nil {
+                       return
+               }
+               if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil {
+                       err = closeErr
+               }
+       }()
 
        // Write body. We "unwrap" the body first if it was wrapped in a
        // nopCloser or readTrackingBody. This is to ensure that we can take advantage of
@@ -369,6 +378,7 @@ func (t *transferWriter) writeBody(w io.Writer) error {
                }
        }
        if t.BodyCloser != nil {
+               closed = true
                if err := t.BodyCloser.Close(); err != nil {
                        return err
                }
index d5ee5645fba7a0d3022fc200b7f568de077c2faf..29d7434f2a88997fae16ea3f9fc291f595f54aef 100644 (file)
@@ -623,7 +623,8 @@ var errCannotRewind = errors.New("net/http: cannot rewind body after connection
 
 type readTrackingBody struct {
        io.ReadCloser
-       didRead bool
+       didRead  bool
+       didClose bool
 }
 
 func (r *readTrackingBody) Read(data []byte) (int, error) {
@@ -631,6 +632,11 @@ func (r *readTrackingBody) Read(data []byte) (int, error) {
        return r.ReadCloser.Read(data)
 }
 
+func (r *readTrackingBody) Close() error {
+       r.didClose = true
+       return r.ReadCloser.Close()
+}
+
 // setupRewindBody returns a new request with a custom body wrapper
 // that can report whether the body needs rewinding.
 // This lets rewindBody avoid an error result when the request
@@ -649,10 +655,12 @@ func setupRewindBody(req *Request) *Request {
 // rewindBody takes care of closing req.Body when appropriate
 // (in all cases except when rewindBody returns req unmodified).
 func rewindBody(req *Request) (rewound *Request, err error) {
-       if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead {
+       if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) {
                return req, nil // nothing to rewind
        }
-       req.closeBody()
+       if !req.Body.(*readTrackingBody).didClose {
+               req.closeBody()
+       }
        if req.GetBody == nil {
                return nil, errCannotRewind
        }
@@ -2379,7 +2387,7 @@ func (pc *persistConn) writeLoop() {
                                // Request.Body are high priority.
                                // Set it here before sending on the
                                // channels below or calling
-                               // pc.close() which tears town
+                               // pc.close() which tears down
                                // connections and causes other
                                // errors.
                                wr.req.setError(err)
@@ -2388,7 +2396,6 @@ func (pc *persistConn) writeLoop() {
                                err = pc.bw.Flush()
                        }
                        if err != nil {
-                               wr.req.Request.closeBody()
                                if pc.nwrite == startBytesWritten {
                                        err = nothingWrittenError{err}
                                }