]> Cypherpunks repositories - gostls13.git/commitdiff
base64: cut out some middle layers
authorRuss Cox <rsc@golang.org>
Wed, 24 Jun 2009 22:52:31 +0000 (15:52 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 24 Jun 2009 22:52:31 +0000 (15:52 -0700)
R=austin
DELTA=352  (67 added, 196 deleted, 89 changed)
OCL=30694
CL=30713

src/pkg/Make.deps
src/pkg/base64/base64.go
src/pkg/base64/base64_test.go

index 44a4bdaa178a2d269132f468fb81981ca7d0f551..9df9db4165331c9cb34a52bdffd108618fbb7277 100644 (file)
@@ -1,5 +1,5 @@
 archive/tar.install: bufio.install bytes.install io.install os.install strconv.install
-base64.install: io.install os.install strconv.install
+base64.install: bytes.install io.install os.install strconv.install
 bignum.install: fmt.install
 bufio.install: io.install os.install utf8.install
 bytes.install: utf8.install
index b680cf79fc1f12a931e05efe15feb3f1f9dde94e..3db0a83da9338377755d8960d4c0d27eb896e456 100644 (file)
@@ -6,6 +6,7 @@
 package base64
 
 import (
+       "bytes";
        "io";
        "os";
        "strconv";
@@ -56,7 +57,7 @@ var URLEncoding = NewEncoding(encodeURL);
 
 // Encode encodes src using the encoding enc, writing
 // EncodedLen(len(input)) bytes to dst.
-// 
+//
 // The encoding pads the output to a multiple of 4 bytes,
 // so Encode is not appropriate for use on individual blocks
 // of a large data stream.  Use NewEncoder() instead.
@@ -66,6 +67,11 @@ func (enc *Encoding) Encode(src, dst []byte) {
        }
 
        for len(src) > 0 {
+               dst[0] = 0;
+               dst[1] = 0;
+               dst[2] = 0;
+               dst[3] = 0;
+
                // Unpack 4x 6-bit source blocks into a 4 byte
                // destination quantum
                switch len(src) {
@@ -101,120 +107,76 @@ func (enc *Encoding) Encode(src, dst []byte) {
        }
 }
 
-// encodeBlocker is a restricted FIFO for byte data that always
-// returns byte arrays whose lengths are some multiple of 3.
-type encodeBlocker struct {
-       // The overflow buffer contains data that should be returned
-       // before any data in nextbuf.
-       buffer [3]byte;
-       bufpos int;
-       nextbuf []byte;
-}
-
-// put appends the data contained in buf to the encode blocker's
-// buffer.  In general, you have to get everything out before you can
-// put another array.
-func (eb *encodeBlocker) put(buf []byte) {
-       if eb.nextbuf != nil {
-               panic("there is already a nextbuf");
-       }
-
-       // If we have anything in the overflow buffer, fill it up the
-       // rest of the way so we can return the overflow buffer.
-       bpos := 0;
-       if eb.bufpos != 0 {
-               for ; eb.bufpos < 3 && bpos < len(buf); eb.bufpos++ {
-                       eb.buffer[eb.bufpos] = buf[bpos];
-                       bpos++;
-               }
-       }
-
-       if bpos < len(buf) {
-               eb.nextbuf = buf[bpos:len(buf)];
-       }
-}
-
-// get retrieves an input quantum aligned byte array from the encode
-// blocker.
-func (eb *encodeBlocker) get() []byte {
-       // If there is data in the overflow buffer, return it first
-       if eb.bufpos > 0 {
-               if eb.bufpos < 3 {
-                       // We don't have a full quantum
-                       return nil;
-               }
-               eb.bufpos = 0;
-               return &eb.buffer;
-       }
-
-       // No overflow buffer, so return nextbuf.  However, it has to
-       // be quantum-aligned, so copy the tail of the data into the
-       // overflow buffer for next time.
-       end := len(eb.nextbuf)/3*3;
-       for i := end; i < len(eb.nextbuf); i++ {
-               eb.buffer[eb.bufpos] = eb.nextbuf[i];
-               eb.bufpos++;
-       }
-       b := eb.nextbuf[0:end];
-       eb.nextbuf = nil;
-       if end == 0 {
-               return nil;
-       }
-       return b;
-}
-
-// size returns the number of bytes remaining in the encode blocker's
-// buffer.
-func (eb *encodeBlocker) size() int {
-       return (eb.bufpos + len(eb.nextbuf))/3*3;
-}
-
 type encoder struct {
-       w io.Writer;
-       enc *Encoding;
        err os.Error;
-       eb encodeBlocker;
+       enc *Encoding;
+       w io.Writer;
+       buf [3]byte;            // buffered data waiting to be encoded
+       nbuf int;                       // number of bytes in buf
+       out [1024]byte;         // output buffer
 }
 
-func (e *encoder) Write(b []byte) (int, os.Error) {
+func (e *encoder) Write(p []byte) (n int, err os.Error) {
        if e.err != nil {
                return 0, e.err;
        }
 
-       e.eb.put(b);
-
-       output := make([]byte, e.eb.size()/3*4);
-       opos := 0;
+       // Leading fringe.
+       if e.nbuf > 0 {
+               var i int;
+               for i = 0; i < len(p) && e.nbuf < 3; i++ {
+                       e.buf[e.nbuf] = p[i];
+                       e.nbuf++;
+               }
+               n += i;
+               p = p[i:len(p)];
+               if e.nbuf < 3 {
+                       return;
+               }
+               e.enc.Encode(&e.buf, &e.out);
+               var _ int;
+               if _, e.err = e.w.Write(e.out[0:4]); e.err != nil {
+                       return n, e.err;
+               }
+               e.nbuf = 0;
+       }
 
-       for {
-               block := e.eb.get();
-               if block == nil {
-                       break;
+       // Large interior chunks.
+       for len(p) > 3 {
+               nn := len(e.out)/4 * 3;
+               if nn > len(p) {
+                       nn = len(p);
+               }
+               nn -= nn % 3;
+               if nn > 0 {
+                       e.enc.Encode(p[0:nn], &e.out);
+                       var _ int;
+                       if _, e.err = e.w.Write(e.out[0:nn/3*4]); e.err != nil {
+                               return n, e.err;
+                       }
                }
-               e.enc.Encode(block, output[opos:len(output)]);
-               opos += len(block)/3*4;
+               n += nn;
+               p = p[nn:len(p)];
        }
 
-       n, err := e.w.Write(output);
-       if err != nil {
-               e.err = io.ErrShortWrite;
-               return n/4*3, e.err;
+       // Trailing fringe.
+       for i := 0; i < len(p); i++ {
+               e.buf[i] = p[i];
        }
-       return len(b), nil;
+       e.nbuf = len(p);
+       n += len(p);
+       return;
 }
 
-// Close flushes any pending output from the encoder.  It is an error
-// to call Write after calling Close.
+// Close flushes any pending output from the encoder.
+// It is an error to call Write after calling Close.
 func (e *encoder) Close() os.Error {
        // If there's anything left in the buffer, flush it out
-       if e.err == nil && e.eb.bufpos > 0 {
-               var output [4]byte;
-               e.enc.Encode(e.eb.buffer[0:e.eb.bufpos], &output);
-               e.eb.bufpos = 0;
-               n, err := e.w.Write(&output);
-               if err != nil {
-                       e.err = io.ErrShortWrite;
-               }
+       if e.err == nil && e.nbuf > 0 {
+               e.enc.Encode(e.buf[0:e.nbuf], &e.out);
+               e.nbuf = 0;
+               var _ int;
+               _, e.err = e.w.Write(e.out[0:4]);
        }
        return e.err;
 }
@@ -225,7 +187,7 @@ func (e *encoder) Close() os.Error {
 // writing, the caller must Close the returned encoder to flush any
 // partially written blocks.
 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
-       return &encoder{w: w, enc: enc};
+       return &encoder{enc: enc, w: w};
 }
 
 // EncodedLen returns the length in bytes of the base64 encoding
@@ -254,7 +216,7 @@ func (enc *Encoding) decode(src, dst []byte) (n int, end bool, err os.Error) {
                var dbuf [4]byte;
                dlen := 4;
 
-dbufloop:
+       dbufloop:
                for j := 0; j < 4; j++ {
                        in := src[i*4+j];
                        if in == '=' && j >= 2 && i == len(src)/4 - 1 {
@@ -305,193 +267,68 @@ func (enc *Encoding) Decode(src, dst []byte) (n int, err os.Error) {
        return;
 }
 
-// quantumReader wraps a regular reader and ensures that each read
-// will return a slice whose length is a multiple of 4-bytes.
-type quantumReader struct {
+type decoder struct {
+       err os.Error;
+       enc *Encoding;
        r io.Reader;
-       buf [4]byte;
-       buflen int;
+       end bool;               // saw end of message
+       buf [1024]byte; // leftover input
+       nbuf int;
+       out []byte;             // leftover decoded output
+       outbuf [1024/4*3]byte;
 }
 
-func (q *quantumReader) Read(p []byte) (int, os.Error) {
-       // Copy buffered data into the output
-       for i := 0; i < q.buflen; i++ {
-               p[i] = q.buf[i];
+func (d *decoder) Read(p []byte) (n int, err os.Error) {
+       if d.err != nil {
+               return 0, d.err;
        }
 
-       // Read more data into the output
-       n, err := q.r.Read(p[q.buflen:len(p)]);
-
-       // Buffer tail data that does not fit into the quanta
-       end := (q.buflen+n)/4*4;
-       for i := end; i < q.buflen+n; i++ {
-               q.buf[i-end] = p[i];
+       // Use leftover decoded output from last read.
+       if len(d.out) > 0 {
+               n = bytes.Copy(p, d.out);
+               d.out = d.out[n:len(d.out)];
+               return n, nil;
        }
 
-       // Is EOF misaligned?
-       if err == os.EOF && q.buflen > 0 {
-               err = io.ErrUnexpectedEOF;
+       // Read a chunk.
+       nn := len(p)/3*4;
+       if nn < 4 {
+               nn = 4;
        }
-
-       return end, err;
-}
-
-// decodeBlocker takes a sequence of arbitrary size output byte slices
-// and makes them available as a stream of byte slices whose lengths
-// are always a multiple of 3.
-type decodeBlocker struct {
-       output []byte;
-       noutput int;
-       overflow [3]byte;
-       overflowstart int;
-}
-
-// flush flushes as much data from the overflow buffer as possible in
-// to the current output buffer, reseting the output buffer to nil if
-// it fills it up.  It returns the number of bytes written to the
-// output buffer.
-func (db *decodeBlocker) flush() int {
-       // Copy overflow into the beginning of this buffer
-       i := 0;
-       for ; i < len(db.output) && db.overflowstart < 3; i++ {
-               db.output[i] = db.overflow[db.overflowstart];
-               db.overflowstart++;
+       if nn > len(d.buf) {
+               nn = len(d.buf);
        }
-       if i == len(db.output) {
-               db.output = nil;
-       } else {
-               db.output = db.output[i:len(db.output)];
-       }
-       return i;
-}
-
-// use begins using a new output buffer.  Any data that did not fit in
-// the previous output buffer will be placed at the beginning of this
-// buffer.
-func (db *decodeBlocker) use(buf []byte) {
-       db.output = buf;
-       db.noutput = 0;
-       // Copy left-over overflow from the previous buffer into this
-       // buffer
-       db.noutput += db.flush();
-}
-
-// checkout retrieve the next slice to fill with data.  The length of
-// the returned slice will always be a multiple of 3.  It returns nil
-// if there is no more buffer space.
-func (db *decodeBlocker) checkout() []byte {
-       // If we can use the output buffer, do so
-       if len(db.output) >= 3 {
-               end := len(db.output)/3*3;
-               return db.output[0:end];
-       } else if db.overflowstart == 3 {
-               // Fill the overflow buffer
-               db.overflowstart = 0;
-               return &db.overflow;
+       nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 4-d.nbuf);
+       d.nbuf += nn;
+       if d.nbuf < 4 {
+               return 0, d.err;
        }
-       // We're out of space
-       return nil;
-}
 
-// checking indicates that we're done with the checked-out slice and
-// that we wrote count bytes to it.
-func (db *decodeBlocker) checkin(count int) {
-       if db.overflowstart == 3 {
-               // Wrote to the output buffer
-               db.noutput += count;
-               db.output = db.output[count:len(db.output)];
+       // Decode chunk into p, or d.out and then p if p is too small.
+       nr := d.nbuf/4 * 4;
+       nw := d.nbuf/4 * 3;
+       if nw > len(p) {
+               nw, d.end, d.err = d.enc.decode(d.buf[0:nr], &d.outbuf);
+               d.out = d.outbuf[0:nw];
+               n = bytes.Copy(p, d.out);
+               d.out = d.out[n:len(d.out)];
        } else {
-               // Wrote to the overflow buffer.  Flush what we can to
-               // the output buffer.
-               n := db.flush();
-               if n > count {
-                       n = count;
-               }
-               db.noutput += n;
+               n, d.end, d.err = d.enc.decode(d.buf[0:nr], p);
        }
-}
-
-// remaining returns the number of bytes remaining in the decode
-// blocker's buffer.  This will always be a multiple of 3.
-func (db *decodeBlocker) remaining() int {
-       return (len(db.output)+2)/3*3;
-}
-
-// outlen returns the number of bytes written to the output buffer.
-func (db *decodeBlocker) outlen() int {
-       return db.noutput;
-}
-
-type decoder struct {
-       r quantumReader;
-       enc *Encoding;
-       db decodeBlocker;
-       err os.Error;
-       // Have we definitely reached the end of the message?
-       end bool;
-}
-
-func min(a int, b int) int {
-       if a < b {
-               return a;
-       }
-       return b;
-}
-
-func (d *decoder) Read(output []byte) (int, os.Error) {
-       if d.err != nil {
-               return 0, d.err;
-       }
-
-       d.db.use(output);
-
-       var inbuf [512]byte;
-
-       // Read enough data to fill either our input buffer or our
-       // output buffer.
-       maxin := min(d.db.remaining()/3*4, len(inbuf));
-       n, err := d.r.Read(inbuf[0:maxin]);
-
-       // Decode into output buffer.
-       ipos := 0;
-       for ipos < n {
-               outbuf := d.db.checkout();
-               if outbuf == nil {
-                       // Out of output buffer space
-                       break;
-               }
-
-               inlen := min(len(outbuf)/3*4, n - ipos);
-               if d.end {
-                       // We've seen end-of-message padding, but
-                       // there's more data.  The RFC says this is an
-                       // error.
-                       // XXX Should shift character count
-                       d.err = CorruptInputError(0);
-                       break;
-               }
-               count := 0;
-               count, d.end, d.err = d.enc.decode(inbuf[ipos:ipos+inlen], outbuf);
-               d.db.checkin(count);
-               if d.err != nil {
-                       // XXX Should shift character count
-                       break;
-               }
-               ipos += inlen;
+       d.nbuf -= nr;
+       for i := 0; i < d.nbuf; i++ {
+               d.buf[i] = d.buf[i+nr];
        }
 
-       if err != nil && d.err == nil {
+       if d.err == nil {
                d.err = err;
        }
-
-       return d.db.outlen(), d.err;
+       return n, d.err;
 }
 
 // NewDecoder constructs a new base64 stream decoder.
 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
-       return &decoder{r: quantumReader{r:r},
-                       enc: enc,
-                       db: decodeBlocker{overflowstart: 3}};
+       return &decoder{enc: enc, r: r};
 }
 
 // DecodeLen returns the maximum length in bytes of the decoded data
index 1071200a0df4e62b2bdc05c4d77e9e2738d0349e..d11d99a88199998bafd615810a36c162b8529335 100644 (file)
@@ -6,6 +6,7 @@ package base64
 
 import (
        "base64";
+       "bytes";
        "io";
        "os";
        "reflect";
@@ -167,3 +168,36 @@ func TestDecodeCorrupt(t *testing.T) {
                }
        }
 }
+
+func TestBig(t *testing.T) {
+       n := 3*1000+1;
+       raw := make([]byte, n);
+       const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
+       for i := 0; i < n; i++ {
+               raw[i] = alpha[i%len(alpha)];
+       }
+       encoded := new(io.ByteBuffer);
+       w := NewEncoder(StdEncoding, encoded);
+       nn, err := w.Write(raw);
+       if nn != n || err != nil {
+               t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n);
+       }
+       err = w.Close();
+       if err != nil {
+               t.Fatalf("Encoder.Close() = %v want nil", err);
+       }
+       decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded));
+       if err != nil {
+               t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err);
+       }
+       
+       if !bytes.Equal(raw, decoded) {
+               var i int;
+               for i = 0; i < len(decoded) && i < len(raw); i++ {
+                       if decoded[i] != raw[i] {
+                               break;
+                       }
+               }
+               t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i);
+       }
+}