package base64
import (
+ "encoding/binary"
"io"
"strconv"
)
return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
}
-// decode is like Decode but returns an additional 'end' value, which
-// indicates if end-of-message padding or a partial quantum was encountered
-// and thus any additional data is an error.
-func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
- si := 0
-
- for si < len(src) && !end {
- // Decode quantum using the base64 alphabet
- var dbuf [4]byte
- dinc, dlen := 3, 4
-
- for j := 0; j < len(dbuf); j++ {
- if len(src) == si {
- switch {
- case j == 0:
- return n, false, nil
- case j == 1, enc.padChar != NoPadding:
- return n, false, CorruptInputError(si - j)
- }
- dinc, dlen, end = j-1, j, true
- break
+// decodeQuantum decodes up to 4 base64 bytes. It takes for parameters
+// the destination buffer dst, the source buffer src and an index in the
+// source buffer si.
+// It returns the number of bytes read from src, the number of bytes written
+// to dst, and an error, if any.
+func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
+ // Decode quantum using the base64 alphabet
+ var dbuf [4]byte
+ dinc, dlen := 3, 4
+
+ for j := 0; j < len(dbuf); j++ {
+ if len(src) == si {
+ switch {
+ case j == 0:
+ return si, 0, nil
+ case j == 1, enc.padChar != NoPadding:
+ return si, 0, CorruptInputError(si - j)
}
- in := src[si]
+ dinc, dlen = j-1, j
+ break
+ }
+ in := src[si]
+ si++
- si++
+ out := enc.decodeMap[in]
+ if out != 0xff {
+ dbuf[j] = out
+ continue
+ }
- out := enc.decodeMap[in]
- if out != 0xFF {
- dbuf[j] = out
- continue
- }
+ if in == '\n' || in == '\r' {
+ j--
+ continue
+ }
- if in == '\n' || in == '\r' {
- j--
- continue
- }
- if rune(in) == enc.padChar {
- // We've reached the end and there's padding
- switch j {
- case 0, 1:
- // incorrect padding
- return n, false, CorruptInputError(si - 1)
- case 2:
- // "==" is expected, the first "=" is already consumed.
- // skip over newlines
- for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
- si++
- }
- if si == len(src) {
- // not enough padding
- return n, false, CorruptInputError(len(src))
- }
- if rune(src[si]) != enc.padChar {
- // incorrect padding
- return n, false, CorruptInputError(si - 1)
- }
-
- si++
- }
- // skip over newlines
- for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
- si++
- }
- if si < len(src) {
- // trailing garbage
- err = CorruptInputError(si)
- }
- dinc, dlen, end = 3, j, true
- break
- }
- return n, false, CorruptInputError(si - 1)
+ if rune(in) != enc.padChar {
+ return si, 0, CorruptInputError(si - 1)
}
- // Convert 4x 6bit source bytes into 3 bytes
- val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
- dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
- switch dlen {
- case 4:
- dst[2] = dbuf[2]
- dbuf[2] = 0
- fallthrough
- case 3:
- dst[1] = dbuf[1]
- if enc.strict && dbuf[2] != 0 {
- return n, end, CorruptInputError(si - 1)
- }
- dbuf[1] = 0
- fallthrough
+ // We've reached the end and there's padding
+ switch j {
+ case 0, 1:
+ // incorrect padding
+ return si, 0, CorruptInputError(si - 1)
case 2:
- dst[0] = dbuf[0]
- if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
- return n, end, CorruptInputError(si - 2)
+ // "==" is expected, the first "=" is already consumed.
+ // skip over newlines
+ for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+ si++
+ }
+ if si == len(src) {
+ // not enough padding
+ return si, 0, CorruptInputError(len(src))
}
+ if rune(src[si]) != enc.padChar {
+ // incorrect padding
+ return si, 0, CorruptInputError(si - 1)
+ }
+
+ si++
+ }
+
+ // skip over newlines
+ for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
+ si++
+ }
+ if si < len(src) {
+ // trailing garbage
+ err = CorruptInputError(si)
}
- dst = dst[dinc:]
- n += dlen - 1
+ dinc, dlen = 3, j
+ break
}
- return n, end, err
-}
+ // Convert 4x 6bit source bytes into 3 bytes
+ val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
+ dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
+ switch dlen {
+ case 4:
+ dst[2] = dbuf[2]
+ dbuf[2] = 0
+ fallthrough
+ case 3:
+ dst[1] = dbuf[1]
+ if enc.strict && dbuf[2] != 0 {
+ return si, 0, CorruptInputError(si - 1)
+ }
+ dbuf[1] = 0
+ fallthrough
+ case 2:
+ dst[0] = dbuf[0]
+ if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
+ return si, 0, CorruptInputError(si - 2)
+ }
+ }
+ dst = dst[dinc:]
-// Decode decodes src using the encoding enc. It writes at most
-// DecodedLen(len(src)) bytes to dst and returns the number of bytes
-// written. If src contains invalid base64 data, it will return the
-// number of bytes successfully written and CorruptInputError.
-// New line characters (\r and \n) are ignored.
-func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
- n, _, err = enc.decode(dst, src)
- return
+ return si, dlen - 1, err
}
// DecodeString returns the bytes represented by the base64 string s.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
dbuf := make([]byte, enc.DecodedLen(len(s)))
- n, _, err := enc.decode(dbuf, []byte(s))
+ n, err := enc.Decode(dbuf, []byte(s))
return dbuf[:n], err
}
readErr error // error from r.Read
enc *Encoding
r io.Reader
- end bool // saw end of message
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
if d.enc.padChar == NoPadding && d.nbuf > 0 {
// Decode final fragment, without padding.
var nw int
- nw, _, d.err = d.enc.decode(d.outbuf[:], d.buf[:d.nbuf])
+ nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
d.nbuf = 0
- d.end = true
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
nr := d.nbuf / 4 * 4
nw := d.nbuf / 4 * 3
if nw > len(p) {
- nw, d.end, d.err = d.enc.decode(d.outbuf[:], d.buf[:nr])
+ nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
- n, d.end, d.err = d.enc.decode(p, d.buf[:nr])
+ n, d.err = d.enc.Decode(p, d.buf[:nr])
}
d.nbuf -= nr
copy(d.buf[:d.nbuf], d.buf[nr:])
return n, d.err
}
+// Decode decodes src using the encoding enc. It writes at most
+// DecodedLen(len(src)) bytes to dst and returns the number of bytes
+// written. If src contains invalid base64 data, it will return the
+// number of bytes successfully written and CorruptInputError.
+// New line characters (\r and \n) are ignored.
+func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
+ if len(src) == 0 {
+ return 0, nil
+ }
+
+ si := 0
+ ilen := len(src)
+ olen := len(dst)
+ for strconv.IntSize >= 64 && ilen-si >= 8 && olen-n >= 8 {
+ if ok := enc.decode64(dst[n:], src[si:]); ok {
+ n += 6
+ si += 8
+ } else {
+ var ninc int
+ si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+ n += ninc
+ if err != nil {
+ return n, err
+ }
+ }
+ }
+
+ for ilen-si >= 4 && olen-n >= 4 {
+ if ok := enc.decode32(dst[n:], src[si:]); ok {
+ n += 3
+ si += 4
+ } else {
+ var ninc int
+ si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+ n += ninc
+ if err != nil {
+ return n, err
+ }
+ }
+ }
+
+ for si < len(src) {
+ var ninc int
+ si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
+ n += ninc
+ if err != nil {
+ return n, err
+ }
+ }
+ return n, err
+}
+
+// decode32 tries to decode 4 base64 char into 3 bytes.
+// len(dst) and len(src) must both be >= 4.
+// Returns true if decode succeeded.
+func (enc *Encoding) decode32(dst, src []byte) bool {
+ var dn, n uint32
+ if n = uint32(enc.decodeMap[src[0]]); n == 0xff {
+ return false
+ }
+ dn |= n << 26
+ if n = uint32(enc.decodeMap[src[1]]); n == 0xff {
+ return false
+ }
+ dn |= n << 20
+ if n = uint32(enc.decodeMap[src[2]]); n == 0xff {
+ return false
+ }
+ dn |= n << 14
+ if n = uint32(enc.decodeMap[src[3]]); n == 0xff {
+ return false
+ }
+ dn |= n << 8
+
+ binary.BigEndian.PutUint32(dst, dn)
+ return true
+}
+
+// decode64 tries to decode 8 base64 char into 6 bytes.
+// len(dst) and len(src) must both be >= 8.
+// Returns true if decode succeeded.
+func (enc *Encoding) decode64(dst, src []byte) bool {
+ var dn, n uint64
+ if n = uint64(enc.decodeMap[src[0]]); n == 0xff {
+ return false
+ }
+ dn |= n << 58
+ if n = uint64(enc.decodeMap[src[1]]); n == 0xff {
+ return false
+ }
+ dn |= n << 52
+ if n = uint64(enc.decodeMap[src[2]]); n == 0xff {
+ return false
+ }
+ dn |= n << 46
+ if n = uint64(enc.decodeMap[src[3]]); n == 0xff {
+ return false
+ }
+ dn |= n << 40
+ if n = uint64(enc.decodeMap[src[4]]); n == 0xff {
+ return false
+ }
+ dn |= n << 34
+ if n = uint64(enc.decodeMap[src[5]]); n == 0xff {
+ return false
+ }
+ dn |= n << 28
+ if n = uint64(enc.decodeMap[src[6]]); n == 0xff {
+ return false
+ }
+ dn |= n << 22
+ if n = uint64(enc.decodeMap[src[7]]); n == 0xff {
+ return false
+ }
+ dn |= n << 16
+
+ binary.BigEndian.PutUint64(dst, dn)
+ return true
+}
+
type newlineFilteringReader struct {
wrapped io.Reader
}