]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base32: add Encoding.WithPadding, StdPadding, NoPadding
authorGustav Westling <gustav@westling.xyz>
Fri, 24 Mar 2017 23:35:40 +0000 (00:35 +0100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 22 May 2017 22:50:17 +0000 (22:50 +0000)
Fixes #19478

Change-Id: I9fc186610d79fd003e7b5d88c0955286ebe7d3cf
Reviewed-on: https://go-review.googlesource.com/38634
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/encoding/base32/base32.go
src/encoding/base32/base32_test.go

index 788a06115ac575fa4f00fb455121199a4a2c3b8e..437b41d225fd51799c1ef1b91c8da8d68afd6441 100644 (file)
@@ -23,8 +23,14 @@ import (
 type Encoding struct {
        encode    string
        decodeMap [256]byte
+       padChar   rune
 }
 
+const (
+       StdPadding rune = '=' // Standard padding character
+       NoPadding  rune = -1  // No padding
+)
+
 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
 const encodeHex = "0123456789ABCDEFGHIJKLMNOPQRSTUV"
 
@@ -33,6 +39,8 @@ const encodeHex = "0123456789ABCDEFGHIJKLMNOPQRSTUV"
 func NewEncoding(encoder string) *Encoding {
        e := new(Encoding)
        e.encode = encoder
+       e.padChar = StdPadding
+
        for i := 0; i < len(e.decodeMap); i++ {
                e.decodeMap[i] = 0xFF
        }
@@ -57,6 +65,26 @@ var removeNewlinesMapper = func(r rune) rune {
        return r
 }
 
+// WithPadding creates a new encoding identical to enc except
+// with a specified padding character, or NoPadding to disable padding.
+// The padding character must not be '\r' or '\n', must not
+// be contained in the encoding's alphabet and must be a rune equal or
+// below '\xff'.
+func (enc Encoding) WithPadding(padding rune) *Encoding {
+       if padding == '\r' || padding == '\n' || padding > 0xff {
+               panic("invalid padding")
+       }
+
+       for i := 0; i < len(enc.encode); i++ {
+               if rune(enc.encode[i]) == padding {
+                       panic("padding contained in alphabet")
+               }
+       }
+
+       enc.padChar = padding
+       return &enc
+}
+
 /*
  * Encoder
  */
@@ -73,60 +101,63 @@ func (enc *Encoding) Encode(dst, src []byte) {
        }
 
        for len(src) > 0 {
-               var b0, b1, b2, b3, b4, b5, b6, b7 byte
+               var b [8]byte
 
                // Unpack 8x 5-bit source blocks into a 5 byte
                // destination quantum
                switch len(src) {
                default:
-                       b7 = src[4] & 0x1F
-                       b6 = src[4] >> 5
+                       b[7] = src[4] & 0x1F
+                       b[6] = src[4] >> 5
                        fallthrough
                case 4:
-                       b6 |= (src[3] << 3) & 0x1F
-                       b5 = (src[3] >> 2) & 0x1F
-                       b4 = src[3] >> 7
+                       b[6] |= (src[3] << 3) & 0x1F
+                       b[5] = (src[3] >> 2) & 0x1F
+                       b[4] = src[3] >> 7
                        fallthrough
                case 3:
-                       b4 |= (src[2] << 1) & 0x1F
-                       b3 = (src[2] >> 4) & 0x1F
+                       b[4] |= (src[2] << 1) & 0x1F
+                       b[3] = (src[2] >> 4) & 0x1F
                        fallthrough
                case 2:
-                       b3 |= (src[1] << 4) & 0x1F
-                       b2 = (src[1] >> 1) & 0x1F
-                       b1 = (src[1] >> 6) & 0x1F
+                       b[3] |= (src[1] << 4) & 0x1F
+                       b[2] = (src[1] >> 1) & 0x1F
+                       b[1] = (src[1] >> 6) & 0x1F
                        fallthrough
                case 1:
-                       b1 |= (src[0] << 2) & 0x1F
-                       b0 = src[0] >> 3
+                       b[1] |= (src[0] << 2) & 0x1F
+                       b[0] = src[0] >> 3
                }
 
                // Encode 5-bit blocks using the base32 alphabet
-               dst[0] = enc.encode[b0]
-               dst[1] = enc.encode[b1]
-               dst[2] = enc.encode[b2]
-               dst[3] = enc.encode[b3]
-               dst[4] = enc.encode[b4]
-               dst[5] = enc.encode[b5]
-               dst[6] = enc.encode[b6]
-               dst[7] = enc.encode[b7]
+               for i := 0; i < 8; i++ {
+                       if len(dst) > i {
+                               dst[i] = enc.encode[b[i]]
+                       }
+               }
 
                // Pad the final quantum
                if len(src) < 5 {
-                       dst[7] = '='
+                       if enc.padChar == NoPadding {
+                               break
+                       }
+
+                       dst[7] = byte(enc.padChar)
                        if len(src) < 4 {
-                               dst[6] = '='
-                               dst[5] = '='
+                               dst[6] = byte(enc.padChar)
+                               dst[5] = byte(enc.padChar)
                                if len(src) < 3 {
-                                       dst[4] = '='
+                                       dst[4] = byte(enc.padChar)
                                        if len(src) < 2 {
-                                               dst[3] = '='
-                                               dst[2] = '='
+                                               dst[3] = byte(enc.padChar)
+                                               dst[2] = byte(enc.padChar)
                                        }
                                }
                        }
+
                        break
                }
+
                src = src[5:]
                dst = dst[8:]
        }
@@ -219,7 +250,12 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
 
 // EncodedLen returns the length in bytes of the base32 encoding
 // of an input buffer of length n.
-func (enc *Encoding) EncodedLen(n int) int { return (n + 4) / 5 * 8 }
+func (enc *Encoding) EncodedLen(n int) int {
+       if enc.padChar == NoPadding {
+               return (n*8 + 4) / 5
+       }
+       return (n + 4) / 5 * 8
+}
 
 /*
  * Decoder
index 37db770b02de8a9e9b602f063da9050d90bda801..bd101b5b04f3d5892ba0dbe3613f1aea44956994 100644 (file)
@@ -455,3 +455,33 @@ func BenchmarkDecodeString(b *testing.B) {
                StdEncoding.DecodeString(data)
        }
 }
+
+func TestWithCustomPadding(t *testing.T) {
+       for _, testcase := range pairs {
+               defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded))
+               customPadding := StdEncoding.WithPadding('@').EncodeToString([]byte(testcase.decoded))
+               expected := strings.Replace(defaultPadding, "=", "@", -1)
+
+               if expected != customPadding {
+                       t.Errorf("Expected custom %s, got %s", expected, customPadding)
+               }
+               if testcase.encoded != defaultPadding {
+                       t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding)
+               }
+       }
+}
+
+func TestWithoutPadding(t *testing.T) {
+       for _, testcase := range pairs {
+               defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded))
+               customPadding := StdEncoding.WithPadding(NoPadding).EncodeToString([]byte(testcase.decoded))
+               expected := strings.TrimRight(defaultPadding, "=")
+
+               if expected != customPadding {
+                       t.Errorf("Expected custom %s, got %s", expected, customPadding)
+               }
+               if testcase.encoded != defaultPadding {
+                       t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding)
+               }
+       }
+}