]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix data race when sharing request body between client and server
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 7 Jan 2014 18:40:56 +0000 (10:40 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 7 Jan 2014 18:40:56 +0000 (10:40 -0800)
A server Handler (e.g. a proxy) can receive a Request, and
then turn around and give a copy of that Request.Body out to
the Transport. So then two goroutines own that Request.Body
(the server and the http client), and both think they can
close it on failure.  Therefore, all incoming server requests
bodies (always *http.body from transfer.go) need to be
thread-safe.

Fixes #6995

R=golang-codereviews, r
CC=golang-codereviews
https://golang.org/cl/46570043

src/pkg/net/http/response_test.go
src/pkg/net/http/serve_test.go
src/pkg/net/http/transfer.go
src/pkg/net/http/transfer_test.go

index 5044306a876cb1022a26c6051e29fd306f5434ce..f73172189e94985fbd9c815be98024fb376ce506 100644 (file)
@@ -14,6 +14,7 @@ import (
        "io/ioutil"
        "net/url"
        "reflect"
+       "regexp"
        "strings"
        "testing"
 )
@@ -406,8 +407,7 @@ func TestWriteResponse(t *testing.T) {
                        t.Errorf("#%d: %v", i, err)
                        continue
                }
-               bout := bytes.NewBuffer(nil)
-               err = resp.Write(bout)
+               err = resp.Write(ioutil.Discard)
                if err != nil {
                        t.Errorf("#%d: %v", i, err)
                        continue
@@ -506,6 +506,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
                rest, err := ioutil.ReadAll(bufr)
                checkErr(err, "ReadAll on remainder")
                if e, g := "Next Request Here", string(rest); e != g {
+                       g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
+                               return fmt.Sprintf("x(repeated x%d)", len(match))
+                       })
                        fatalf("remainder = %q, expected %q", g, e)
                }
        }
index e4d9b340bee1f15dbd0174fd5870bd1a37107770..8f382fa6ea11adc05944c4b5e621ec64dada3ef2 100644 (file)
@@ -2090,6 +2090,64 @@ func TestNoContentTypeOnNotModified(t *testing.T) {
        }
 }
 
+// Issue 6995
+// A server Handler can receive a Request, and then turn around and
+// give a copy of that Request.Body out to the Transport (e.g. any
+// proxy).  So then two people own that Request.Body (both the server
+// and the http client), and both think they can close it on failure.
+// Therefore, all incoming server requests Bodies need to be thread-safe.
+func TestTransportAndServerSharedBodyRace(t *testing.T) {
+       defer afterTest(t)
+
+       const bodySize = 1 << 20
+
+       unblockBackend := make(chan bool)
+       backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+               io.CopyN(rw, req.Body, bodySize/2)
+               <-unblockBackend
+       }))
+       defer backend.Close()
+
+       backendRespc := make(chan *Response, 1)
+       proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+               if req.RequestURI == "/foo" {
+                       rw.Write([]byte("bar"))
+                       return
+               }
+               req2, _ := NewRequest("POST", backend.URL, req.Body)
+               req2.ContentLength = bodySize
+
+               bresp, err := DefaultClient.Do(req2)
+               if err != nil {
+                       t.Errorf("Proxy outbound request: %v", err)
+                       return
+               }
+               _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4)
+               if err != nil {
+                       t.Errorf("Proxy copy error: %v", err)
+                       return
+               }
+               backendRespc <- bresp // to close later
+
+               // Try to cause a race: Both the DefaultTransport and the proxy handler's Server
+               // will try to read/close req.Body (aka req2.Body)
+               DefaultTransport.(*Transport).CancelRequest(req2)
+               rw.Write([]byte("OK"))
+       }))
+       defer proxy.Close()
+
+       req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
+       res, err := DefaultClient.Do(req)
+       if err != nil {
+               t.Fatalf("Original request: %v", err)
+       }
+
+       // Cleanup, so we don't leak goroutines.
+       res.Body.Close()
+       close(unblockBackend)
+       (<-backendRespc).Body.Close()
+}
+
 func TestResponseWriterWriteStringAllocs(t *testing.T) {
        ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
                if r.URL.Path == "/s" {
index bacd83732de57ff1fac89c7b664d934eb6db2a36..4a2bda19facf83d23bf79e51560c100ffc674327 100644 (file)
@@ -14,6 +14,7 @@ import (
        "net/textproto"
        "strconv"
        "strings"
+       "sync"
 )
 
 // transferWriter inspects the fields of a user-supplied Request or Response,
@@ -331,17 +332,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
                if noBodyExpected(t.RequestMethod) {
                        t.Body = eofReader
                } else {
-                       t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+                       t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
                }
        case realLength == 0:
                t.Body = eofReader
        case realLength > 0:
-               t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
+               t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
        default:
                // realLength < 0, i.e. "Content-Length" not mentioned in header
                if t.Close {
                        // Close semantics (i.e. HTTP/1.0)
-                       t.Body = &body{Reader: r, closing: t.Close}
+                       t.Body = &body{src: r, closing: t.Close}
                } else {
                        // Persistent connection (i.e. HTTP/1.1)
                        t.Body = eofReader
@@ -514,11 +515,13 @@ func fixTrailer(header Header, te []string) (Header, error) {
 // Close ensures that the body has been fully read
 // and then reads the trailer if necessary.
 type body struct {
-       io.Reader
+       src     io.Reader
        hdr     interface{}   // non-nil (Response or Request) value means read trailer
        r       *bufio.Reader // underlying wire-format reader for the trailer
        closing bool          // is the connection to be closed after reading body?
-       closed  bool
+
+       mu     sync.Mutex // guards closed, and calls to Read and Close
+       closed bool
 }
 
 // ErrBodyReadAfterClose is returned when reading a Request or Response
@@ -528,10 +531,17 @@ type body struct {
 var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
 
 func (b *body) Read(p []byte) (n int, err error) {
+       b.mu.Lock()
+       defer b.mu.Unlock()
        if b.closed {
                return 0, ErrBodyReadAfterClose
        }
-       n, err = b.Reader.Read(p)
+       return b.readLocked(p)
+}
+
+// Must hold b.mu.
+func (b *body) readLocked(p []byte) (n int, err error) {
+       n, err = b.src.Read(p)
 
        if err == io.EOF {
                // Chunked case. Read the trailer.
@@ -543,7 +553,7 @@ func (b *body) Read(p []byte) (n int, err error) {
                } else {
                        // If the server declared the Content-Length, our body is a LimitedReader
                        // and we need to check whether this EOF arrived early.
-                       if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 {
+                       if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
                                err = io.ErrUnexpectedEOF
                        }
                }
@@ -618,6 +628,8 @@ func (b *body) readTrailer() error {
 }
 
 func (b *body) Close() error {
+       b.mu.Lock()
+       defer b.mu.Unlock()
        if b.closed {
                return nil
        }
@@ -629,12 +641,25 @@ func (b *body) Close() error {
        default:
                // Fully consume the body, which will also lead to us reading
                // the trailer headers after the body, if present.
-               _, err = io.Copy(ioutil.Discard, b)
+               _, err = io.Copy(ioutil.Discard, bodyLocked{b})
        }
        b.closed = true
        return err
 }
 
+// bodyLocked is a io.Reader reading from a *body when its mutex is
+// already held.
+type bodyLocked struct {
+       b *body
+}
+
+func (bl bodyLocked) Read(p []byte) (n int, err error) {
+       if bl.b.closed {
+               return 0, ErrBodyReadAfterClose
+       }
+       return bl.b.readLocked(p)
+}
+
 // parseContentLength trims whitespace from s and returns -1 if no value
 // is set, or the value if it's >= 0.
 func parseContentLength(cl string) (int64, error) {
index 8627a374c8f6ba1cc73bc30c022e888d021139d8..fb5ef37a0f0e605f7e11d220130709f0a6a37606 100644 (file)
@@ -12,9 +12,9 @@ import (
 
 func TestBodyReadBadTrailer(t *testing.T) {
        b := &body{
-               Reader: strings.NewReader("foobar"),
-               hdr:    true, // force reading the trailer
-               r:      bufio.NewReader(strings.NewReader("")),
+               src: strings.NewReader("foobar"),
+               hdr: true, // force reading the trailer
+               r:   bufio.NewReader(strings.NewReader("")),
        }
        buf := make([]byte, 7)
        n, err := b.Read(buf[:3])