]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base64: Optimize EncodeToString and DecodeString.
authorEgon Elbre <egonelbre@gmail.com>
Mon, 13 Apr 2015 10:21:00 +0000 (13:21 +0300)
committerNigel Tao <nigeltao@golang.org>
Fri, 24 Apr 2015 01:45:43 +0000 (01:45 +0000)
benchmark                   old ns/op     new ns/op     delta
BenchmarkEncodeToString     31281         23821         -23.85%
BenchmarkDecodeString       156508        82254         -47.44%

benchmark                   old MB/s     new MB/s     speedup
BenchmarkEncodeToString     261.88       343.89       1.31x
BenchmarkDecodeString       69.80        132.81       1.90x

Change-Id: I115e0b18c3a6d5ef6bfdcb3f637644f02f290907
Reviewed-on: https://go-review.googlesource.com/8808
Reviewed-by: Nigel Tao <nigeltao@golang.org>
src/encoding/base64/base64.go

index db4a409cb3d1d36ee36531345d9b84bb27ee296b..3302fb4a7426924d2f9b6f529184193f5bb19df9 100644 (file)
@@ -6,10 +6,8 @@
 package base64
 
 import (
-       "bytes"
        "io"
        "strconv"
-       "strings"
 )
 
 /*
@@ -22,7 +20,7 @@ import (
 // (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
 // the standard encoding with - and _ substituted for + and /.
 type Encoding struct {
-       encode    string
+       encode    [64]byte
        decodeMap [256]byte
        padChar   rune
 }
@@ -40,9 +38,14 @@ const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz012345678
 // The resulting Encoding uses the default padding character ('='),
 // which may be changed or disabled via WithPadding.
 func NewEncoding(encoder string) *Encoding {
+       if len(encoder) != 64 {
+               panic("encoding alphabet is not 64-bytes long")
+       }
+
        e := new(Encoding)
-       e.encode = encoder
        e.padChar = StdPadding
+       copy(e.encode[:], encoder)
+
        for i := 0; i < len(e.decodeMap); i++ {
                e.decodeMap[i] = 0xFF
        }
@@ -77,13 +80,6 @@ var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
 // This is the same as URLEncoding but omits padding characters.
 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
 
-var removeNewlinesMapper = func(r rune) rune {
-       if r == '\r' || r == '\n' {
-               return -1
-       }
-       return r
-}
-
 /*
  * Encoder
  */
@@ -99,46 +95,45 @@ func (enc *Encoding) Encode(dst, src []byte) {
                return
        }
 
-       for len(src) > 0 {
-               var b0, b1, b2, b3 byte
+       di, si := 0, 0
+       n := (len(src) / 3) * 3
+       for si < n {
+               // Convert 3x 8bit source bytes into 4 bytes
+               val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
 
-               // Unpack 4x 6-bit source blocks into a 4 byte
-               // destination quantum
-               switch len(src) {
-               default:
-                       b3 = src[2] & 0x3F
-                       b2 = src[2] >> 6
-                       fallthrough
-               case 2:
-                       b2 |= (src[1] << 2) & 0x3F
-                       b1 = src[1] >> 4
-                       fallthrough
-               case 1:
-                       b1 |= (src[0] << 4) & 0x3F
-                       b0 = src[0] >> 2
-               }
+               dst[di+0] = enc.encode[val>>18&0x3F]
+               dst[di+1] = enc.encode[val>>12&0x3F]
+               dst[di+2] = enc.encode[val>>6&0x3F]
+               dst[di+3] = enc.encode[val&0x3F]
 
-               // Encode 6-bit blocks using the base64 alphabet
-               dst[0] = enc.encode[b0]
-               dst[1] = enc.encode[b1]
-               if len(src) >= 3 {
-                       dst[2] = enc.encode[b2]
-                       dst[3] = enc.encode[b3]
-               } else { // Final incomplete quantum
-                       if len(src) >= 2 {
-                               dst[2] = enc.encode[b2]
-                       }
-                       if enc.padChar != NoPadding {
-                               if len(src) < 2 {
-                                       dst[2] = byte(enc.padChar)
-                               }
-                               dst[3] = byte(enc.padChar)
-                       }
-                       break
-               }
+               si += 3
+               di += 4
+       }
 
-               src = src[3:]
-               dst = dst[4:]
+       remain := len(src) - si
+       if remain == 0 {
+               return
+       }
+       // Add the remaining small block
+       val := uint(src[si+0]) << 16
+       if remain == 2 {
+               val |= uint(src[si+1]) << 8
+       }
+
+       dst[di+0] = enc.encode[val>>18&0x3F]
+       dst[di+1] = enc.encode[val>>12&0x3F]
+
+       switch remain {
+       case 2:
+               dst[di+2] = enc.encode[val>>6&0x3F]
+               if enc.padChar != NoPadding {
+                       dst[di+3] = byte(enc.padChar)
+               }
+       case 1:
+               if enc.padChar != NoPadding {
+                       dst[di+2] = byte(enc.padChar)
+                       dst[di+3] = byte(enc.padChar)
+               }
        }
 }
 
@@ -248,67 +243,83 @@ func (e CorruptInputError) Error() string {
 
 // decode is like Decode but returns an additional 'end' value, which
 // indicates if end-of-message padding or a partial quantum was encountered
-// and thus any additional data is an error. This method assumes that src has been
-// stripped of all supported whitespace ('\r' and '\n').
+// and thus any additional data is an error.
 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
-       olen := len(src)
-       for len(src) > 0 && !end {
+       si := 0
+
+       // skip over newlines
+       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+               si++
+       }
+
+       for si < len(src) && !end {
                // Decode quantum using the base64 alphabet
                var dbuf [4]byte
                dinc, dlen := 3, 4
 
                for j := range dbuf {
-                       if len(src) == 0 {
+                       if len(src) == si {
                                if enc.padChar != NoPadding || j < 2 {
-                                       return n, false, CorruptInputError(olen - len(src) - j)
+                                       return n, false, CorruptInputError(si - j)
                                }
                                dinc, dlen, end = j-1, j, true
                                break
                        }
-                       in := src[0]
-                       src = src[1:]
+                       in := src[si]
+
+                       si++
+                       // skip over newlines
+                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                               si++
+                       }
+
                        if rune(in) == enc.padChar {
                                // We've reached the end and there's padding
                                switch j {
                                case 0, 1:
                                        // incorrect padding
-                                       return n, false, CorruptInputError(olen - len(src) - 1)
+                                       return n, false, CorruptInputError(si - 1)
                                case 2:
                                        // "==" is expected, the first "=" is already consumed.
-                                       if len(src) == 0 {
+                                       if si == len(src) {
                                                // not enough padding
-                                               return n, false, CorruptInputError(olen)
+                                               return n, false, CorruptInputError(len(src))
                                        }
-                                       if rune(src[0]) != enc.padChar {
+                                       if rune(src[si]) != enc.padChar {
                                                // incorrect padding
-                                               return n, false, CorruptInputError(olen - len(src) - 1)
+                                               return n, false, CorruptInputError(si - 1)
+                                       }
+
+                                       si++
+                                       // skip over newlines
+                                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                                               si++
                                        }
-                                       src = src[1:]
                                }
-                               if len(src) > 0 {
+                               if si < len(src) {
                                        // trailing garbage
-                                       err = CorruptInputError(olen - len(src))
+                                       err = CorruptInputError(si)
                                }
                                dinc, dlen, end = 3, j, true
                                break
                        }
                        dbuf[j] = enc.decodeMap[in]
                        if dbuf[j] == 0xFF {
-                               return n, false, CorruptInputError(olen - len(src) - 1)
+                               return n, false, CorruptInputError(si - 1)
                        }
                }
 
-               // Pack 4x 6-bit source blocks into 3 byte destination
-               // quantum
+               // Convert 4x 6bit source bytes into 3 bytes
+               val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
                switch dlen {
                case 4:
-                       dst[2] = dbuf[2]<<6 | dbuf[3]
+                       dst[2] = byte(val >> 0)
                        fallthrough
                case 3:
-                       dst[1] = dbuf[1]<<4 | dbuf[2]>>2
+                       dst[1] = byte(val >> 8)
                        fallthrough
                case 2:
-                       dst[0] = dbuf[0]<<2 | dbuf[1]>>4
+                       dst[0] = byte(val >> 16)
                }
                dst = dst[dinc:]
                n += dlen - 1
@@ -323,14 +334,12 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
 // number of bytes successfully written and CorruptInputError.
 // New line characters (\r and \n) are ignored.
 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
-       src = bytes.Map(removeNewlinesMapper, src)
        n, _, err = enc.decode(dst, src)
        return
 }
 
 // DecodeString returns the bytes represented by the base64 string s.
 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
-       s = strings.Map(removeNewlinesMapper, s)
        dbuf := make([]byte, enc.DecodedLen(len(s)))
        n, _, err := enc.decode(dbuf, []byte(s))
        return dbuf[:n], err
@@ -359,6 +368,8 @@ func (d *decoder) Read(p []byte) (n int, err error) {
                return n, nil
        }
 
+       // This code assumes that d.r strips supported whitespace ('\r' and '\n').
+
        // Read a chunk.
        nn := len(p) / 3 * 4
        if nn < 4 {