]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base64: Optimize DecodeString
authorJosselin Costanzi <josselin@costanzi.fr>
Sat, 7 Jan 2017 13:12:57 +0000 (14:12 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 24 Apr 2017 22:40:23 +0000 (22:40 +0000)
Optimize DecodeString for the common case where most of the input isn't
a newline or a padding character.
Also add some testcases found when fuzzing this implementation against
upstream.
Change Decode benchmark to run with different input sizes.

name                 old time/op    new time/op    delta
DecodeString/2-4       71.5ns ± 4%    70.0ns ± 6%     ~     (p=0.246 n=5+5)
DecodeString/4-4        112ns ±25%      91ns ± 2%     ~     (p=0.056 n=5+5)
DecodeString/8-4        136ns ± 5%     126ns ± 5%   -7.33%  (p=0.016 n=5+5)
DecodeString/64-4       872ns ±29%     652ns ±21%  -25.23%  (p=0.032 n=5+5)
DecodeString/8192-4    90.9µs ±21%    61.0µs ±13%  -32.87%  (p=0.008 n=5+5)

name                 old speed      new speed      delta
DecodeString/2-4     56.0MB/s ± 4%  57.2MB/s ± 6%     ~     (p=0.310 n=5+5)
DecodeString/4-4     73.4MB/s ±23%  87.7MB/s ± 2%     ~     (p=0.056 n=5+5)
DecodeString/8-4     87.8MB/s ± 5%  94.8MB/s ± 5%   +7.98%  (p=0.016 n=5+5)
DecodeString/64-4     103MB/s ±24%   136MB/s ±19%  +32.63%  (p=0.032 n=5+5)
DecodeString/8192-4   122MB/s ±19%   180MB/s ±11%  +47.75%  (p=0.008 n=5+5)

Improves #19636

Change-Id: I39667f4fb682a12b3137946d017ad999553c5780
Reviewed-on: https://go-review.googlesource.com/34950
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/encoding/base64/base64.go
src/encoding/base64/base64_test.go

index 5a384315f9665a5d9a62fc89c54c9043449439ee..b208f9e4d8dff46f37d866a8333dfa2ee006892d 100644 (file)
@@ -273,44 +273,50 @@ func (e CorruptInputError) Error() string {
 // indicates if end-of-message padding or a partial quantum was encountered
 // and thus any additional data is an error.
 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
-       var inIdx int
        si := 0
 
-       // skip over newlines
-       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
-               si++
-       }
-
        for si < len(src) && !end {
                // Decode quantum using the base64 alphabet
                var dbuf [4]byte
                dinc, dlen := 3, 4
 
-               for j := range dbuf {
+               for j := 0; j < len(dbuf); j++ {
                        if len(src) == si {
-                               if enc.padChar != NoPadding || j < 2 {
+                               switch {
+                               case j == 0:
+                                       return n, false, nil
+                               case j == 1, enc.padChar != NoPadding:
                                        return n, false, CorruptInputError(si - j)
                                }
                                dinc, dlen, end = j-1, j, true
                                break
                        }
                        in := src[si]
-                       inIdx = si
 
                        si++
-                       // skip over newlines
-                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
-                               si++
+
+                       out := enc.decodeMap[in]
+                       if out != 0xFF {
+                               dbuf[j] = out
+                               continue
                        }
 
+                       if in == '\n' || in == '\r' {
+                               j--
+                               continue
+                       }
                        if rune(in) == enc.padChar {
                                // We've reached the end and there's padding
                                switch j {
                                case 0, 1:
                                        // incorrect padding
-                                       return n, false, CorruptInputError(inIdx)
+                                       return n, false, CorruptInputError(si - 1)
                                case 2:
                                        // "==" is expected, the first "=" is already consumed.
+                                       // skip over newlines
+                                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                                               si++
+                                       }
                                        if si == len(src) {
                                                // not enough padding
                                                return n, false, CorruptInputError(len(src))
@@ -321,10 +327,10 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
                                        }
 
                                        si++
-                                       // skip over newlines
-                                       for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
-                                               si++
-                                       }
+                               }
+                               // skip over newlines
+                               for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+                                       si++
                                }
                                if si < len(src) {
                                        // trailing garbage
@@ -333,10 +339,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
                                dinc, dlen, end = 3, j, true
                                break
                        }
-                       dbuf[j] = enc.decodeMap[in]
-                       if dbuf[j] == 0xFF {
-                               return n, false, CorruptInputError(inIdx)
-                       }
+                       return n, false, CorruptInputError(si - 1)
                }
 
                // Convert 4x 6bit source bytes into 3 bytes
index ce52202dd8aa3a6c2f08d2458068866f08a738a9..8ebf2b155339223e2dd8ef228e6487def4b7ec9a 100644 (file)
@@ -7,6 +7,7 @@ package base64
 import (
        "bytes"
        "errors"
+       "fmt"
        "io"
        "io/ioutil"
        "reflect"
@@ -202,6 +203,9 @@ func TestDecodeCorrupt(t *testing.T) {
                offset int // -1 means no corruption.
        }{
                {"", -1},
+               {"\n", -1},
+               {"AAA=\n", -1},
+               {"AAAA\n", -1},
                {"!!!!", 0},
                {"====", 0},
                {"x===", 1},
@@ -468,10 +472,19 @@ func BenchmarkEncodeToString(b *testing.B) {
 }
 
 func BenchmarkDecodeString(b *testing.B) {
-       data := StdEncoding.EncodeToString(make([]byte, 8192))
-       b.SetBytes(int64(len(data)))
-       for i := 0; i < b.N; i++ {
-               StdEncoding.DecodeString(data)
+       sizes := []int{2, 4, 8, 64, 8192}
+       benchFunc := func(b *testing.B, benchSize int) {
+               data := StdEncoding.EncodeToString(make([]byte, benchSize))
+               b.SetBytes(int64(len(data)))
+               b.ResetTimer()
+               for i := 0; i < b.N; i++ {
+                       StdEncoding.DecodeString(data)
+               }
+       }
+       for _, size := range sizes {
+               b.Run(fmt.Sprintf("%d", size), func(b *testing.B) {
+                       benchFunc(b, size)
+               })
        }
 }