// indicates if end-of-message padding or a partial quantum was encountered
// and thus any additional data is an error.
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
- var inIdx int
si := 0
- // skip over newlines
- for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
- si++
- }
-
for si < len(src) && !end {
// Decode quantum using the base64 alphabet
var dbuf [4]byte
dinc, dlen := 3, 4
- for j := range dbuf {
+ for j := 0; j < len(dbuf); j++ {
if len(src) == si {
- if enc.padChar != NoPadding || j < 2 {
+ switch {
+ case j == 0:
+ return n, false, nil
+ case j == 1, enc.padChar != NoPadding:
return n, false, CorruptInputError(si - j)
}
dinc, dlen, end = j-1, j, true
break
}
in := src[si]
- inIdx = si
si++
- // skip over newlines
- for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
- si++
+
+ out := enc.decodeMap[in]
+ if out != 0xFF {
+ dbuf[j] = out
+ continue
}
+ if in == '\n' || in == '\r' {
+ j--
+ continue
+ }
if rune(in) == enc.padChar {
// We've reached the end and there's padding
switch j {
case 0, 1:
// incorrect padding
- return n, false, CorruptInputError(inIdx)
+ return n, false, CorruptInputError(si - 1)
case 2:
// "==" is expected, the first "=" is already consumed.
+ // skip over newlines
+ for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+ si++
+ }
if si == len(src) {
// not enough padding
return n, false, CorruptInputError(len(src))
}
si++
- // skip over newlines
- for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
- si++
- }
+ }
+ // skip over newlines
+ for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+ si++
}
if si < len(src) {
// trailing garbage
dinc, dlen, end = 3, j, true
break
}
- dbuf[j] = enc.decodeMap[in]
- if dbuf[j] == 0xFF {
- return n, false, CorruptInputError(inIdx)
- }
+ return n, false, CorruptInputError(si - 1)
}
// Convert 4x 6bit source bytes into 3 bytes
import (
"bytes"
"errors"
+ "fmt"
"io"
"io/ioutil"
"reflect"
offset int // -1 means no corruption.
}{
{"", -1},
+ {"\n", -1},
+ {"AAA=\n", -1},
+ {"AAAA\n", -1},
{"!!!!", 0},
{"====", 0},
{"x===", 1},
}
func BenchmarkDecodeString(b *testing.B) {
- data := StdEncoding.EncodeToString(make([]byte, 8192))
- b.SetBytes(int64(len(data)))
- for i := 0; i < b.N; i++ {
- StdEncoding.DecodeString(data)
+ sizes := []int{2, 4, 8, 64, 8192}
+ benchFunc := func(b *testing.B, benchSize int) {
+ data := StdEncoding.EncodeToString(make([]byte, benchSize))
+ b.SetBytes(int64(len(data)))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ StdEncoding.DecodeString(data)
+ }
+ }
+ for _, size := range sizes {
+ b.Run(fmt.Sprintf("%d", size), func(b *testing.B) {
+ benchFunc(b, size)
+ })
}
}