]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base64: This change modifies Go to take strict option when decoding base64
authorXuyang Kang <xuyangkang@gmail.com>
Sun, 17 Jul 2016 07:23:56 +0000 (00:23 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 12 Oct 2016 03:56:18 +0000 (03:56 +0000)
If strict option is enabled, when decoding, instead of skip the padding
bits, it will do strict check to enforce they are set to zero.

Fixes #15656

Change-Id: I869fb725a39cc9dde44dbc4ff0046446e7abc642
Reviewed-on: https://go-review.googlesource.com/24964
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/encoding/base64/base64.go
src/encoding/base64/base64_test.go

index c2116d8a3436180bdc66ac8d8061d690d946ba42..d2efad4518b5f9d4a77f7c8c8b257cbcd70049d8 100644 (file)
@@ -23,6 +23,7 @@ type Encoding struct {
        encode    [64]byte
        decodeMap [256]byte
        padChar   rune
+       strict    bool
 }
 
 const (
@@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding {
        return &enc
 }
 
+// Strict creates a new encoding identical to enc except with
+// strict decoding enabled. In this mode, the decoder requires that
+// trailing padding bits are zero, as described in RFC 4648 section 3.5.
+func (enc Encoding) Strict() *Encoding {
+       enc.strict = true
+       return &enc
+}
+
 // StdEncoding is the standard base64 encoding, as defined in
 // RFC 4648.
 var StdEncoding = NewEncoding(encodeStd)
@@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
 
                // Convert 4x 6bit source bytes into 3 bytes
                val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
+               dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
                switch dlen {
                case 4:
-                       dst[2] = byte(val >> 0)
+                       dst[2] = dbuf[2]
+                       dbuf[2] = 0
                        fallthrough
                case 3:
-                       dst[1] = byte(val >> 8)
+                       dst[1] = dbuf[1]
+                       if enc.strict && dbuf[2] != 0 {
+                               return n, end, CorruptInputError(si - 1)
+                       }
+                       dbuf[1] = 0
                        fallthrough
                case 2:
-                       dst[0] = byte(val >> 16)
+                       dst[0] = dbuf[0]
+                       if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
+                               return n, end, CorruptInputError(si - 2)
+                       }
                }
                dst = dst[dinc:]
                n += dlen - 1
index 19ddb92f644b23c118a926e0cd34dff28085509d..e2e1d59f3c0d5c547e7332d32f40ae7e15a2fc5b 100644 (file)
@@ -85,6 +85,11 @@ var encodingTests = []encodingTest{
        {RawStdEncoding, rawRef},
        {RawURLEncoding, rawUrlRef},
        {funnyEncoding, funnyRef},
+       {StdEncoding.Strict(), stdRef},
+       {URLEncoding.Strict(), urlRef},
+       {RawStdEncoding.Strict(), rawRef},
+       {RawURLEncoding.Strict(), rawUrlRef},
+       {funnyEncoding.Strict(), funnyRef},
 }
 
 var bigtest = testpair{
@@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) {
        }
 }
 
+func TestDecoderIssue15656(t *testing.T) {
+       _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
+       want := CorruptInputError(22)
+       if !reflect.DeepEqual(want, err) {
+               t.Errorf("Error = %v; want CorruptInputError(22)", err)
+       }
+       _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
+       if err != nil {
+               t.Errorf("Error = %v; want nil", err)
+       }
+       _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
+       if err != nil {
+               t.Errorf("Error = %v; want nil", err)
+       }
+}
+
 func BenchmarkEncodeToString(b *testing.B) {
        data := make([]byte, 8192)
        b.SetBytes(int64(len(data)))