]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base64: optimize DecodeString
authorJosselin Costanzi <josselin@costanzi.fr>
Wed, 22 Mar 2017 20:19:15 +0000 (21:19 +0100)
committerIan Lance Taylor <iant@golang.org>
Mon, 9 Oct 2017 15:39:51 +0000 (15:39 +0000)
Optimize base64 decoding speed by adding 32-bits and 64-bits specialized
methods that don't perform any error checking and fall back to the more
complex decodeQuantum method when a non-base64 character is present.

On a 64-bits cpu:

name                 old time/op    new time/op     delta
DecodeString/2-4       70.0ns ± 6%     69.2ns ± 0%     ~     (p=0.169 n=5+8)
DecodeString/4-4       91.3ns ± 2%     80.4ns ± 0%  -11.89%  (p=0.001 n=5+10)
DecodeString/8-4        126ns ± 5%      106ns ± 0%  -16.14%  (p=0.000 n=5+7)
DecodeString/64-4       652ns ±21%      361ns ± 0%  -44.57%  (p=0.000 n=5+7)
DecodeString/8192-4    61.0µs ±13%     31.5µs ± 1%  -48.38%  (p=0.001 n=5+9)

name                 old speed      new speed       delta
DecodeString/2-4     57.2MB/s ± 6%   57.7MB/s ± 2%     ~     (p=0.419 n=5+9)
DecodeString/4-4     87.7MB/s ± 2%   99.5MB/s ± 0%  +13.45%  (p=0.001 n=5+10)
DecodeString/8-4     94.8MB/s ± 5%  112.6MB/s ± 1%  +18.82%  (p=0.001 n=5+9)
DecodeString/64-4     136MB/s ±19%    243MB/s ± 0%  +78.17%  (p=0.003 n=5+7)
DecodeString/8192-4   180MB/s ±11%    347MB/s ± 1%  +92.94%  (p=0.001 n=5+9)

Improves #19636

Change-Id: Ic10a454851093a7e1d46ca0c140deed73535d990
Reviewed-on: https://go-review.googlesource.com/38632
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/cmd/dist/deps.go
src/encoding/base64/base64.go
src/encoding/base64/base64_test.go
src/go/build/deps_test.go

index cd7eaaea0d95c37a18748864f34ecb438c2125f9..660db750000395e1e62dfecdb217b319258e6151 100644 (file)
@@ -461,8 +461,9 @@ var builddeps = map[string][]string{
        },
 
        "encoding/base64": {
-               "io",      // encoding/base64
-               "strconv", // encoding/base64
+               "encoding/binary", // encoding/base64
+               "io",              // encoding/base64
+               "strconv",         // encoding/base64
        },
 
        "encoding/binary": {
index b208f9e4d8dff46f37d866a8333dfa2ee006892d..9a99370f1e53082f3efbde3b888deadb0e721228 100644 (file)
@@ -6,6 +6,7 @@
 package base64
 
 import (
+       "encoding/binary"
        "io"
        "strconv"
 )
@@ -269,121 +270,110 @@ func (e CorruptInputError) Error() string {
        return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
 }
 
-// decode is like Decode but returns an additional 'end' value, which
-// indicates if end-of-message padding or a partial quantum was encountered
-// and thus any additional data is an error.
-func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
-       si := 0
-
-       for si < len(src) && !end {
-               // Decode quantum using the base64 alphabet
-               var dbuf [4]byte
-               dinc, dlen := 3, 4
-
-               for j := 0; j < len(dbuf); j++ {
-                       if len(src) == si {
-                               switch {
-                               case j == 0:
-                                       return n, false, nil
-                               case j == 1, enc.padChar != NoPadding:
-                                       return n, false, CorruptInputError(si - j)
-                               }
-                               dinc, dlen, end = j-1, j, true
-                               break
+// decodeQuantum decodes up to 4 base64 bytes. It takes for parameters
+// the destination buffer dst, the source buffer src and an index in the
+// source buffer si.
+// It returns the number of bytes read from src, the number of bytes written
+// to dst, and an error, if any.
+func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
+       // Decode quantum using the base64 alphabet
+       var dbuf [4]byte
+       dinc, dlen := 3, 4
+
+       for j := 0; j < len(dbuf); j++ {
+               if len(src) == si {
+                       switch {
+                       case j == 0:
+                               return si, 0, nil
+                       case j == 1, enc.padChar != NoPadding:
+                               return si, 0, CorruptInputError(si - j)
                        }
-                       in := src[si]
+                       dinc, dlen = j-1, j
+                       break
+               }
+               in := src[si]
+               si++
 
-                       si++
+               out := enc.decodeMap[in]
+               if out != 0xff {
+                       dbuf[j] = out
+                       continue
+               }
 
-                       out := enc.decodeMap[in]
-                       if out != 0xFF {
-                               dbuf[j] = out
-                               continue
-                       }
+               if in == '\n' || in == '\r' {
+                       j--
+                       continue
+               }
 
-                       if in == '\n' || in == '\r' {
-                               j--
-                               continue
-                       }
-                       if rune(in) == enc.padChar {
-                               // We've reached the end and there's padding
-                               switch j {
-                               case 0, 1:
-                                       // incorrect padding
-                                       return n, false, CorruptInputError(si - 1)
-                               case 2:
-                                       // "==" is expected, the first "=" is already consumed.
-                                       // skip over newlines
-                                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
-                                               si++
-                                       }
-                                       if si == len(src) {
-                                               // not enough padding
-                                               return n, false, CorruptInputError(len(src))
-                                       }
-                                       if rune(src[si]) != enc.padChar {
-                                               // incorrect padding
-                                               return n, false, CorruptInputError(si - 1)
-                                       }
-
-                                       si++
-                               }
-                               // skip over newlines
-                               for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
-                                       si++
-                               }
-                               if si < len(src) {
-                                       // trailing garbage
-                                       err = CorruptInputError(si)
-                               }
-                               dinc, dlen, end = 3, j, true
-                               break
-                       }
-                       return n, false, CorruptInputError(si - 1)
+               if rune(in) != enc.padChar {
+                       return si, 0, CorruptInputError(si - 1)
                }
 
-               // Convert 4x 6bit source bytes into 3 bytes
-               val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
-               dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
-               switch dlen {
-               case 4:
-                       dst[2] = dbuf[2]
-                       dbuf[2] = 0
-                       fallthrough
-               case 3:
-                       dst[1] = dbuf[1]
-                       if enc.strict && dbuf[2] != 0 {
-                               return n, end, CorruptInputError(si - 1)
-                       }
-                       dbuf[1] = 0
-                       fallthrough
+               // We've reached the end and there's padding
+               switch j {
+               case 0, 1:
+                       // incorrect padding
+                       return si, 0, CorruptInputError(si - 1)
                case 2:
-                       dst[0] = dbuf[0]
-                       if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
-                               return n, end, CorruptInputError(si - 2)
+                       // "==" is expected, the first "=" is already consumed.
+                       // skip over newlines
+                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                               si++
+                       }
+                       if si == len(src) {
+                               // not enough padding
+                               return si, 0, CorruptInputError(len(src))
                        }
+                       if rune(src[si]) != enc.padChar {
+                               // incorrect padding
+                               return si, 0, CorruptInputError(si - 1)
+                       }
+
+                       si++
+               }
+
+               // skip over newlines
+               for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                       si++
+               }
+               if si < len(src) {
+                       // trailing garbage
+                       err = CorruptInputError(si)
                }
-               dst = dst[dinc:]
-               n += dlen - 1
+               dinc, dlen = 3, j
+               break
        }
 
-       return n, end, err
-}
+       // Convert 4x 6bit source bytes into 3 bytes
+       val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
+       dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
+       switch dlen {
+       case 4:
+               dst[2] = dbuf[2]
+               dbuf[2] = 0
+               fallthrough
+       case 3:
+               dst[1] = dbuf[1]
+               if enc.strict && dbuf[2] != 0 {
+                       return si, 0, CorruptInputError(si - 1)
+               }
+               dbuf[1] = 0
+               fallthrough
+       case 2:
+               dst[0] = dbuf[0]
+               if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
+                       return si, 0, CorruptInputError(si - 2)
+               }
+       }
+       dst = dst[dinc:]
 
-// Decode decodes src using the encoding enc. It writes at most
-// DecodedLen(len(src)) bytes to dst and returns the number of bytes
-// written. If src contains invalid base64 data, it will return the
-// number of bytes successfully written and CorruptInputError.
-// New line characters (\r and \n) are ignored.
-func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
-       n, _, err = enc.decode(dst, src)
-       return
+       return si, dlen - 1, err
 }
 
 // DecodeString returns the bytes represented by the base64 string s.
 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
        dbuf := make([]byte, enc.DecodedLen(len(s)))
-       n, _, err := enc.decode(dbuf, []byte(s))
+       n, err := enc.Decode(dbuf, []byte(s))
        return dbuf[:n], err
 }
 
@@ -392,7 +382,6 @@ type decoder struct {
        readErr error // error from r.Read
        enc     *Encoding
        r       io.Reader
-       end     bool       // saw end of message
        buf     [1024]byte // leftover input
        nbuf    int
        out     []byte // leftover decoded output
@@ -430,9 +419,8 @@ func (d *decoder) Read(p []byte) (n int, err error) {
                if d.enc.padChar == NoPadding && d.nbuf > 0 {
                        // Decode final fragment, without padding.
                        var nw int
-                       nw, _, d.err = d.enc.decode(d.outbuf[:], d.buf[:d.nbuf])
+                       nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
                        d.nbuf = 0
-                       d.end = true
                        d.out = d.outbuf[:nw]
                        n = copy(p, d.out)
                        d.out = d.out[n:]
@@ -454,18 +442,138 @@ func (d *decoder) Read(p []byte) (n int, err error) {
        nr := d.nbuf / 4 * 4
        nw := d.nbuf / 4 * 3
        if nw > len(p) {
-               nw, d.end, d.err = d.enc.decode(d.outbuf[:], d.buf[:nr])
+               nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
                d.out = d.outbuf[:nw]
                n = copy(p, d.out)
                d.out = d.out[n:]
        } else {
-               n, d.end, d.err = d.enc.decode(p, d.buf[:nr])
+               n, d.err = d.enc.Decode(p, d.buf[:nr])
        }
        d.nbuf -= nr
        copy(d.buf[:d.nbuf], d.buf[nr:])
        return n, d.err
 }
 
+// Decode decodes src using the encoding enc. It writes at most
+// DecodedLen(len(src)) bytes to dst and returns the number of bytes
+// written. If src contains invalid base64 data, it will return the
+// number of bytes successfully written and CorruptInputError.
+// New line characters (\r and \n) are ignored.
+func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
+       if len(src) == 0 {
+               return 0, nil
+       }
+
+       si := 0
+       ilen := len(src)
+       olen := len(dst)
+       for strconv.IntSize >= 64 && ilen-si >= 8 && olen-n >= 8 {
+               if ok := enc.decode64(dst[n:], src[si:]); ok {
+                       n += 6
+                       si += 8
+               } else {
+                       var ninc int
+                       si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+                       n += ninc
+                       if err != nil {
+                               return n, err
+                       }
+               }
+       }
+
+       for ilen-si >= 4 && olen-n >= 4 {
+               if ok := enc.decode32(dst[n:], src[si:]); ok {
+                       n += 3
+                       si += 4
+               } else {
+                       var ninc int
+                       si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+                       n += ninc
+                       if err != nil {
+                               return n, err
+                       }
+               }
+       }
+
+       for si < len(src) {
+               var ninc int
+               si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+               n += ninc
+               if err != nil {
+                       return n, err
+               }
+       }
+       return n, err
+}
+
+// decode32 tries to decode 4 base64 char into 3 bytes.
+// len(dst) and len(src) must both be >= 4.
+// Returns true if decode succeeded.
+func (enc *Encoding) decode32(dst, src []byte) bool {
+       var dn, n uint32
+       if n = uint32(enc.decodeMap[src[0]]); n == 0xff {
+               return false
+       }
+       dn |= n << 26
+       if n = uint32(enc.decodeMap[src[1]]); n == 0xff {
+               return false
+       }
+       dn |= n << 20
+       if n = uint32(enc.decodeMap[src[2]]); n == 0xff {
+               return false
+       }
+       dn |= n << 14
+       if n = uint32(enc.decodeMap[src[3]]); n == 0xff {
+               return false
+       }
+       dn |= n << 8
+
+       binary.BigEndian.PutUint32(dst, dn)
+       return true
+}
+
+// decode64 tries to decode 8 base64 char into 6 bytes.
+// len(dst) and len(src) must both be >= 8.
+// Returns true if decode succeeded.
+func (enc *Encoding) decode64(dst, src []byte) bool {
+       var dn, n uint64
+       if n = uint64(enc.decodeMap[src[0]]); n == 0xff {
+               return false
+       }
+       dn |= n << 58
+       if n = uint64(enc.decodeMap[src[1]]); n == 0xff {
+               return false
+       }
+       dn |= n << 52
+       if n = uint64(enc.decodeMap[src[2]]); n == 0xff {
+               return false
+       }
+       dn |= n << 46
+       if n = uint64(enc.decodeMap[src[3]]); n == 0xff {
+               return false
+       }
+       dn |= n << 40
+       if n = uint64(enc.decodeMap[src[4]]); n == 0xff {
+               return false
+       }
+       dn |= n << 34
+       if n = uint64(enc.decodeMap[src[5]]); n == 0xff {
+               return false
+       }
+       dn |= n << 28
+       if n = uint64(enc.decodeMap[src[6]]); n == 0xff {
+               return false
+       }
+       dn |= n << 22
+       if n = uint64(enc.decodeMap[src[7]]); n == 0xff {
+               return false
+       }
+       dn |= n << 16
+
+       binary.BigEndian.PutUint64(dst, dn)
+       return true
+}
+
 type newlineFilteringReader struct {
        wrapped io.Reader
 }
index 05011fbdf38d8fd48728c3d4f41201d7c53f447b..9f5c493dbfe8485f8248ffeadb23d398c6b27d3b 100644 (file)
@@ -152,12 +152,9 @@ func TestDecode(t *testing.T) {
                for _, tt := range encodingTests {
                        encoded := tt.conv(p.encoded)
                        dbuf := make([]byte, tt.enc.DecodedLen(len(encoded)))
-                       count, end, err := tt.enc.decode(dbuf, []byte(encoded))
+                       count, err := tt.enc.Decode(dbuf, []byte(encoded))
                        testEqual(t, "Decode(%q) = error %v, want %v", encoded, err, error(nil))
                        testEqual(t, "Decode(%q) = length %v, want %v", encoded, count, len(p.decoded))
-                       if len(encoded) > 0 {
-                               testEqual(t, "Decode(%q) = end %v, want %v", encoded, end, len(p.decoded)%3 != 0)
-                       }
                        testEqual(t, "Decode(%q) = %q, want %q", encoded, string(dbuf[0:count]), p.decoded)
 
                        dbuf, err = tt.enc.DecodeString(encoded)
index 8f485f1632b3ecdb3a8ba9a5c5ef095127683b7c..0494a155ef1ea2b14ecd2d64072ae9b6293b885a 100644 (file)
@@ -101,7 +101,7 @@ var pkgDeps = map[string][]string{
        "crypto/cipher":       {"L2", "crypto/subtle"},
        "crypto/subtle":       {},
        "encoding/base32":     {"L2"},
-       "encoding/base64":     {"L2"},
+       "encoding/base64":     {"L2", "encoding/binary"},
        "encoding/binary":     {"L2", "reflect"},
        "hash":                {"L2"}, // interfaces
        "hash/adler32":        {"L2", "hash"},