]> Cypherpunks repositories - gostls13.git/commitdiff
bufio: add Reader.Discard
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 2 Jan 2015 18:02:15 +0000 (10:02 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 7 Jan 2015 06:37:57 +0000 (06:37 +0000)
Reader.Discard is the complement to Peek. It discards the next n bytes
of input.

We already have Reader.Buffered to see how many bytes of data are
sitting available in memory, and Reader.Peek to get that that buffer
directly. But once you're done with the Peek'd data, you can't get rid
of it, other than Reading it.
Both Read and io.CopyN(ioutil.Discard, bufReader, N) are relatively
slow. People instead resort to multiple blind ReadByte calls, just to
advance the internal b.r variable.

I've wanted this previously, several people have asked for it in the
past on golang-nuts/dev, and somebody just asked me for it again in a
private email. There are a few places in the standard library we'd use
it too.

Change-Id: I85dfad47704a58bd42f6867adbc9e4e1792bc3b0
Reviewed-on: https://go-review.googlesource.com/2260
Reviewed-by: Russ Cox <rsc@golang.org>
src/bufio/bufio.go
src/bufio/bufio_test.go
src/net/http/transfer.go

index d3c68fe6fe539e975f668c6d7e6af66c7348af79..dbbe80e4c2aed376c270e13c1b96d45d13f8d17b 100644 (file)
@@ -144,6 +144,39 @@ func (b *Reader) Peek(n int) ([]byte, error) {
        return b.buf[b.r : b.r+n], err
 }
 
+// Discard skips the next n bytes, returning the number of bytes discarded.
+//
+// If Discard skips fewer than n bytes, it also returns an error.
+// If 0 <= n <= b.Buffered(), Discard is guaranteed to succeed without
+// reading from the underlying io.Reader.
+func (b *Reader) Discard(n int) (discarded int, err error) {
+       if n < 0 {
+               return 0, ErrNegativeCount
+       }
+       if n == 0 {
+               return
+       }
+       remain := n
+       for {
+               skip := b.Buffered()
+               if skip == 0 {
+                       b.fill()
+                       skip = b.Buffered()
+               }
+               if skip > remain {
+                       skip = remain
+               }
+               b.r += skip
+               remain -= skip
+               if remain == 0 {
+                       return n, nil
+               }
+               if b.err != nil {
+                       return n - remain, b.readErr()
+               }
+       }
+}
+
 // Read reads data into p.
 // It returns the number of bytes read into p.
 // It calls Read at most once on the underlying Reader,
index 550dac9173f0083828974c47d8fe3d1acd61eace..666c44e15a938be90bb2f9095fea82a18e9401e7 100644 (file)
@@ -1268,6 +1268,135 @@ func TestWriterReset(t *testing.T) {
        }
 }
 
+func TestReaderDiscard(t *testing.T) {
+       tests := []struct {
+               name     string
+               r        io.Reader
+               bufSize  int // 0 means 16
+               peekSize int
+
+               n int // input to Discard
+
+               want    int   // from Discard
+               wantErr error // from Discard
+
+               wantBuffered int
+       }{
+               {
+                       name:         "normal case",
+                       r:            strings.NewReader("abcdefghijklmnopqrstuvwxyz"),
+                       peekSize:     16,
+                       n:            6,
+                       want:         6,
+                       wantBuffered: 10,
+               },
+               {
+                       name:         "discard causing read",
+                       r:            strings.NewReader("abcdefghijklmnopqrstuvwxyz"),
+                       n:            6,
+                       want:         6,
+                       wantBuffered: 10,
+               },
+               {
+                       name:         "discard all without peek",
+                       r:            strings.NewReader("abcdefghijklmnopqrstuvwxyz"),
+                       n:            26,
+                       want:         26,
+                       wantBuffered: 0,
+               },
+               {
+                       name:         "discard more than end",
+                       r:            strings.NewReader("abcdefghijklmnopqrstuvwxyz"),
+                       n:            27,
+                       want:         26,
+                       wantErr:      io.EOF,
+                       wantBuffered: 0,
+               },
+               // Any error from filling shouldn't show up until we
+               // get past the valid bytes. Here we return we return 5 valid bytes at the same time
+               // as an error, but test that we don't see the error from Discard.
+               {
+                       name: "fill error, discard less",
+                       r: newScriptedReader(func(p []byte) (n int, err error) {
+                               if len(p) < 5 {
+                                       panic("unexpected small read")
+                               }
+                               return 5, errors.New("5-then-error")
+                       }),
+                       n:            4,
+                       want:         4,
+                       wantErr:      nil,
+                       wantBuffered: 1,
+               },
+               {
+                       name: "fill error, discard equal",
+                       r: newScriptedReader(func(p []byte) (n int, err error) {
+                               if len(p) < 5 {
+                                       panic("unexpected small read")
+                               }
+                               return 5, errors.New("5-then-error")
+                       }),
+                       n:            5,
+                       want:         5,
+                       wantErr:      nil,
+                       wantBuffered: 0,
+               },
+               {
+                       name: "fill error, discard more",
+                       r: newScriptedReader(func(p []byte) (n int, err error) {
+                               if len(p) < 5 {
+                                       panic("unexpected small read")
+                               }
+                               return 5, errors.New("5-then-error")
+                       }),
+                       n:            6,
+                       want:         5,
+                       wantErr:      errors.New("5-then-error"),
+                       wantBuffered: 0,
+               },
+               // Discard of 0 shouldn't cause a read:
+               {
+                       name:         "discard zero",
+                       r:            newScriptedReader(), // will panic on Read
+                       n:            0,
+                       want:         0,
+                       wantErr:      nil,
+                       wantBuffered: 0,
+               },
+               {
+                       name:         "discard negative",
+                       r:            newScriptedReader(), // will panic on Read
+                       n:            -1,
+                       want:         0,
+                       wantErr:      ErrNegativeCount,
+                       wantBuffered: 0,
+               },
+       }
+       for _, tt := range tests {
+               br := NewReaderSize(tt.r, tt.bufSize)
+               if tt.peekSize > 0 {
+                       peekBuf, err := br.Peek(tt.peekSize)
+                       if err != nil {
+                               t.Errorf("%s: Peek(%d): %v", tt.name, tt.peekSize, err)
+                               continue
+                       }
+                       if len(peekBuf) != tt.peekSize {
+                               t.Errorf("%s: len(Peek(%d)) = %v; want %v", tt.name, tt.peekSize, len(peekBuf), tt.peekSize)
+                               continue
+                       }
+               }
+               discarded, err := br.Discard(tt.n)
+               if ge, we := fmt.Sprint(err), fmt.Sprint(tt.wantErr); discarded != tt.want || ge != we {
+                       t.Errorf("%s: Discard(%d) = (%v, %v); want (%v, %v)", tt.name, tt.n, discarded, ge, tt.want, we)
+                       continue
+               }
+               if bn := br.Buffered(); bn != tt.wantBuffered {
+                       t.Errorf("%s: after Discard, Buffered = %d; want %d", tt.name, bn, tt.wantBuffered)
+               }
+       }
+
+}
+
 // An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have.
 type onlyReader struct {
        io.Reader
@@ -1278,6 +1407,23 @@ type onlyWriter struct {
        io.Writer
 }
 
+// A scriptedReader is an io.Reader that executes its steps sequentially.
+type scriptedReader []func(p []byte) (n int, err error)
+
+func (sr *scriptedReader) Read(p []byte) (n int, err error) {
+       if len(*sr) == 0 {
+               panic("too many Read calls on scripted Reader. No steps remain.")
+       }
+       step := (*sr)[0]
+       *sr = (*sr)[1:]
+       return step(p)
+}
+
+func newScriptedReader(steps ...func(p []byte) (n int, err error)) io.Reader {
+       sr := scriptedReader(steps)
+       return &sr
+}
+
 func BenchmarkReaderCopyOptimal(b *testing.B) {
        // Optimal case is where the underlying reader implements io.WriterTo
        srcBuf := bytes.NewBuffer(make([]byte, 8192))
index fd9389adf06600aa0e8efc14609f44ae8219e944..c39d6cff6788b8a3af349bb207b8d5e1d06c66a7 100644 (file)
@@ -638,8 +638,7 @@ func (b *body) readTrailer() error {
        // The common case, since nobody uses trailers.
        buf, err := b.r.Peek(2)
        if bytes.Equal(buf, singleCRLF) {
-               b.r.ReadByte()
-               b.r.ReadByte()
+               b.r.Discard(2)
                return nil
        }
        if len(buf) < 2 {