]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base32: support custom and disabled padding when decoding
authorGustav Westling <zegl@westling.xyz>
Sun, 2 Jul 2017 11:27:47 +0000 (13:27 +0200)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 6 Jul 2017 18:05:22 +0000 (18:05 +0000)
CL 38634 added support for custom (and disabled) padding characters
when encoding, but didn't update the decoding paths. This adds
decoding support.

Fixes #20854

Change-Id: I9fb1a0aaebb27f1204c9f726a780d5784eb71024
Reviewed-on: https://go-review.googlesource.com/47341
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/encoding/base32/base32.go
src/encoding/base32/base32_test.go

index 437b41d225fd51799c1ef1b91c8da8d68afd6441..0270e8f4d4031da95b70f061a47f47eba7df1d98 100644 (file)
@@ -279,19 +279,28 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
                dlen := 8
 
                for j := 0; j < 8; {
-                       if len(src) == 0 {
+
+                       // We have reached the end and are missing padding
+                       if len(src) == 0 && enc.padChar != NoPadding {
                                return n, false, CorruptInputError(olen - len(src) - j)
                        }
+
+                       // We have reached the end and are not expecing any padding
+                       if len(src) == 0 && enc.padChar == NoPadding {
+                               dlen, end = j, true
+                               break
+                       }
+
                        in := src[0]
                        src = src[1:]
-                       if in == '=' && j >= 2 && len(src) < 8 {
+                       if in == byte(enc.padChar) && j >= 2 && len(src) < 8 {
                                // We've reached the end and there's padding
                                if len(src)+j < 8-1 {
                                        // not enough padding
                                        return n, false, CorruptInputError(olen)
                                }
                                for k := 0; k < 8-1-j; k++ {
-                                       if len(src) > k && src[k] != '=' {
+                                       if len(src) > k && src[k] != byte(enc.padChar) {
                                                // incorrect padding
                                                return n, false, CorruptInputError(olen - len(src) + k - 1)
                                        }
@@ -484,4 +493,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
 
 // DecodedLen returns the maximum length in bytes of the decoded data
 // corresponding to n bytes of base32-encoded data.
-func (enc *Encoding) DecodedLen(n int) int { return n / 8 * 5 }
+func (enc *Encoding) DecodedLen(n int) int {
+       if enc.padChar == NoPadding {
+               // +6 represents the missing padding
+               return (n + 6) / 8 * 5
+       }
+
+       return n / 8 * 5
+}
index ee7525c997c11cc7977e94dbc88bfbb958e5f0d4..6fe292b476ebebb84d79581c1b4fccb86deeba64 100644 (file)
@@ -490,3 +490,91 @@ func TestWithoutPadding(t *testing.T) {
                }
        }
 }
+
+func TestDecodeWithPadding(t *testing.T) {
+       encodings := []*Encoding{
+               StdEncoding,
+               StdEncoding.WithPadding('-'),
+               StdEncoding.WithPadding(NoPadding),
+       }
+
+       for i, enc := range encodings {
+               for _, pair := range pairs {
+
+                       input := pair.decoded
+                       encoded := enc.EncodeToString([]byte(input))
+
+                       decoded, err := enc.DecodeString(encoded)
+                       if err != nil {
+                               t.Errorf("DecodeString Error for encoding %d (%q): %v", i, input, err)
+                       }
+
+                       if input != string(decoded) {
+                               t.Errorf("Unexpected result for encoding %d: got %q; want %q", i, decoded, input)
+                       }
+               }
+       }
+}
+
+func TestDecodeWithWrongPadding(t *testing.T) {
+       encoded := StdEncoding.EncodeToString([]byte("foobar"))
+
+       _, err := StdEncoding.WithPadding('-').DecodeString(encoded)
+       if err == nil {
+               t.Error("expected error")
+       }
+
+       _, err = StdEncoding.WithPadding(NoPadding).DecodeString(encoded)
+       if err == nil {
+               t.Error("expected error")
+       }
+}
+
+func TestEncodedDecodedLen(t *testing.T) {
+       type test struct {
+               in      int
+               wantEnc int
+               wantDec int
+       }
+       data := bytes.Repeat([]byte("x"), 100)
+       for _, test := range []struct {
+               name  string
+               enc   *Encoding
+               cases []test
+       }{
+               {"StdEncoding", StdEncoding, []test{
+                       {0, 0, 0},
+                       {1, 8, 5},
+                       {5, 8, 5},
+                       {6, 16, 10},
+                       {10, 16, 10},
+               }},
+               {"NoPadding", StdEncoding.WithPadding(NoPadding), []test{
+                       {0, 0, 0},
+                       {1, 2, 5},
+                       {2, 4, 5},
+                       {5, 8, 5},
+                       {6, 10, 10},
+                       {7, 12, 10},
+                       {10, 16, 10},
+                       {11, 18, 15},
+               }},
+       } {
+               t.Run(test.name, func(t *testing.T) {
+                       for _, tc := range test.cases {
+                               encLen := test.enc.EncodedLen(tc.in)
+                               decLen := test.enc.DecodedLen(encLen)
+                               enc := test.enc.EncodeToString(data[:tc.in])
+                               if len(enc) != encLen {
+                                       t.Fatalf("EncodedLen(%d) = %d but encoded to %q (%d)", tc.in, encLen, enc, len(enc))
+                               }
+                               if encLen != tc.wantEnc {
+                                       t.Fatalf("EncodedLen(%d) = %d; want %d", tc.in, encLen, tc.wantEnc)
+                               }
+                               if decLen != tc.wantDec {
+                                       t.Fatalf("DecodedLen(%d) = %d; want %d", encLen, decLen, tc.wantDec)
+                               }
+                       }
+               })
+       }
+}