]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make the MaxBytesReader.Read error sticky
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 10 May 2016 22:09:23 +0000 (15:09 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 11 May 2016 17:10:58 +0000 (17:10 +0000)
Fixes #14981

Change-Id: I39b906d119ca96815801a0fbef2dbe524a3246ff
Reviewed-on: https://go-review.googlesource.com/23009
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/request.go
src/net/http/request_test.go

index 1bde114909c38eaff04cead872059f9b575d19a3..45507d23d14beb2b33e55f98f5a9e7a756188146 100644 (file)
@@ -885,68 +885,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
 }
 
 type maxBytesReader struct {
-       w       ResponseWriter
-       r       io.ReadCloser // underlying reader
-       n       int64         // max bytes remaining
-       stopped bool
-       sawEOF  bool
+       w   ResponseWriter
+       r   io.ReadCloser // underlying reader
+       n   int64         // max bytes remaining
+       err error         // sticky error
 }
 
 func (l *maxBytesReader) tooLarge() (n int, err error) {
-       if !l.stopped {
-               l.stopped = true
-
-               // The server code and client code both use
-               // maxBytesReader. This "requestTooLarge" check is
-               // only used by the server code. To prevent binaries
-               // which only using the HTTP Client code (such as
-               // cmd/go) from also linking in the HTTP server, don't
-               // use a static type assertion to the server
-               // "*response" type. Check this interface instead:
-               type requestTooLarger interface {
-                       requestTooLarge()
-               }
-               if res, ok := l.w.(requestTooLarger); ok {
-                       res.requestTooLarge()
-               }
-       }
-       return 0, errors.New("http: request body too large")
+       l.err = errors.New("http: request body too large")
+       return 0, l.err
 }
 
 func (l *maxBytesReader) Read(p []byte) (n int, err error) {
-       toRead := l.n
-       if l.n == 0 {
-               if l.sawEOF {
-                       return l.tooLarge()
-               }
-               // The underlying io.Reader may not return (0, io.EOF)
-               // at EOF if the requested size is 0, so read 1 byte
-               // instead. The io.Reader docs are a bit ambiguous
-               // about the return value of Read when 0 bytes are
-               // requested, and {bytes,strings}.Reader gets it wrong
-               // too (it returns (0, nil) even at EOF).
-               toRead = 1
+       if l.err != nil {
+               return 0, l.err
+       }
+       if len(p) == 0 {
+               return 0, nil
        }
-       if int64(len(p)) > toRead {
-               p = p[:toRead]
+       // If they asked for a 32KB byte read but only 5 bytes are
+       // remaining, no need to read 32KB. 6 bytes will answer the
+       // question of the whether we hit the limit or go past it.
+       if int64(len(p)) > l.n+1 {
+               p = p[:l.n+1]
        }
        n, err = l.r.Read(p)
-       if err == io.EOF {
-               l.sawEOF = true
-       }
-       if l.n == 0 {
-               // If we had zero bytes to read remaining (but hadn't seen EOF)
-               // and we get a byte here, that means we went over our limit.
-               if n > 0 {
-                       return l.tooLarge()
-               }
-               return 0, err
+
+       if int64(n) <= l.n {
+               l.n -= int64(n)
+               l.err = err
+               return n, err
        }
-       l.n -= int64(n)
-       if l.n < 0 {
-               l.n = 0
+
+       n = int(l.n)
+       l.n = 0
+
+       // The server code and client code both use
+       // maxBytesReader. This "requestTooLarge" check is
+       // only used by the server code. To prevent binaries
+       // which only using the HTTP Client code (such as
+       // cmd/go) from also linking in the HTTP server, don't
+       // use a static type assertion to the server
+       // "*response" type. Check this interface instead:
+       type requestTooLarger interface {
+               requestTooLarge()
        }
-       return
+       if res, ok := l.w.(requestTooLarger); ok {
+               res.requestTooLarge()
+       }
+       l.err = errors.New("http: request body too large")
+       return n, l.err
 }
 
 func (l *maxBytesReader) Close() error {
index 82c7af3cda8dbedeff0a0bc4f79a9115dba1b936..a4c88c02915ccdf3ada229bda2b6524b1f1623c0 100644 (file)
@@ -679,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) {
        }
 }
 
+// Issue 14981: MaxBytesReader's return error wasn't sticky. It
+// doesn't technically need to be, but people expected it to be.
+func TestMaxBytesReaderStickyError(t *testing.T) {
+       isSticky := func(r io.Reader) error {
+               var log bytes.Buffer
+               buf := make([]byte, 1000)
+               var firstErr error
+               for {
+                       n, err := r.Read(buf)
+                       fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err)
+                       if err == nil {
+                               continue
+                       }
+                       if firstErr == nil {
+                               firstErr = err
+                               continue
+                       }
+                       if !reflect.DeepEqual(err, firstErr) {
+                               return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes())
+                       }
+                       t.Logf("Got log: %s", log.Bytes())
+                       return nil
+               }
+       }
+       tests := [...]struct {
+               readable int
+               limit    int64
+       }{
+               0: {99, 100},
+               1: {100, 100},
+               2: {101, 100},
+       }
+       for i, tt := range tests {
+               rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit)
+               if err := isSticky(rc); err != nil {
+                       t.Errorf("%d. error: %v", i, err)
+               }
+       }
+}
+
 func testMissingFile(t *testing.T, req *Request) {
        f, fh, err := req.FormFile("missing")
        if f != nil {