]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/base32, encoding/base64: fix issues with decoder whitespace handling
authorPhilip K. Warren <pkwarren@gmail.com>
Tue, 12 Mar 2013 05:50:10 +0000 (01:50 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 12 Mar 2013 05:50:10 +0000 (01:50 -0400)
Adds a new reader to filter newlines, which fixes errors seen in the
decoder chunking code. Found additional issues with whitespace handling
after the first padding character.
Fixes #4779.

R=minux.ma, rsc, bradfitz
CC=golang-dev
https://golang.org/cl/7311069

src/pkg/encoding/base32/base32.go
src/pkg/encoding/base32/base32_test.go
src/pkg/encoding/base64/base64.go
src/pkg/encoding/base64/base64_test.go

index 6c5d8d3a9e9c50ae7526e7a3180750c0670e06a7..fe17b732207c82219d97bae3a39b0be83aab0536 100644 (file)
@@ -6,8 +6,10 @@
 package base32
 
 import (
+       "bytes"
        "io"
        "strconv"
+       "strings"
 )
 
 /*
@@ -48,6 +50,13 @@ var StdEncoding = NewEncoding(encodeStd)
 // It is typically used in DNS.
 var HexEncoding = NewEncoding(encodeHex)
 
+var removeNewlinesMapper = func(r rune) rune {
+       if r == '\r' || r == '\n' {
+               return -1
+       }
+       return r
+}
+
 /*
  * Encoder
  */
@@ -228,7 +237,8 @@ func (e CorruptInputError) Error() string {
 
 // decode is like Decode but returns an additional 'end' value, which
 // indicates if end-of-message padding was encountered and thus any
-// additional data is an error.
+// additional data is an error. This method assumes that src has been
+// stripped of all supported whitespace ('\r' and '\n').
 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
        olen := len(src)
        for len(src) > 0 && !end {
@@ -242,10 +252,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
                        }
                        in := src[0]
                        src = src[1:]
-                       if in == '\r' || in == '\n' {
-                               // Ignore this character.
-                               continue
-                       }
                        if in == '=' && j >= 2 && len(src) < 8 {
                                // We've reached the end and there's padding
                                if len(src)+j < 8-1 {
@@ -317,12 +323,14 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
 // number of bytes successfully written and CorruptInputError.
 // New line characters (\r and \n) are ignored.
 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
+       src = bytes.Map(removeNewlinesMapper, src)
        n, _, err = enc.decode(dst, src)
        return
 }
 
 // DecodeString returns the bytes represented by the base32 string s.
 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
+       s = strings.Map(removeNewlinesMapper, s)
        dbuf := make([]byte, enc.DecodedLen(len(s)))
        n, err := enc.Decode(dbuf, []byte(s))
        return dbuf[:n], err
@@ -387,9 +395,34 @@ func (d *decoder) Read(p []byte) (n int, err error) {
        return n, d.err
 }
 
+type newlineFilteringReader struct {
+       wrapped io.Reader
+}
+
+func (r *newlineFilteringReader) Read(p []byte) (int, error) {
+       n, err := r.wrapped.Read(p)
+       for n > 0 {
+               offset := 0
+               for i, b := range p[0:n] {
+                       if b != '\r' && b != '\n' {
+                               if i != offset {
+                                       p[offset] = b
+                               }
+                               offset++
+                       }
+               }
+               if offset > 0 {
+                       return offset, err
+               }
+               // Previous buffer entirely whitespace, read again
+               n, err = r.wrapped.Read(p)
+       }
+       return n, err
+}
+
 // NewDecoder constructs a new base32 stream decoder.
 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
-       return &decoder{enc: enc, r: r}
+       return &decoder{enc: enc, r: &newlineFilteringReader{r}}
 }
 
 // DecodedLen returns the maximum length in bytes of the decoded data
index b62bfeebf62738c55566e28b86f0f2d1654c0afd..63298d1c94c180d34797b797768b80e7976f4229 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes"
        "io"
        "io/ioutil"
+       "strings"
        "testing"
 )
 
@@ -216,9 +217,21 @@ func TestBig(t *testing.T) {
        }
 }
 
+func testStringEncoding(t *testing.T, expected string, examples []string) {
+       for _, e := range examples {
+               buf, err := StdEncoding.DecodeString(e)
+               if err != nil {
+                       t.Errorf("Decode(%q) failed: %v", e, err)
+                       continue
+               }
+               if s := string(buf); s != expected {
+                       t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
+               }
+       }
+}
+
 func TestNewLineCharacters(t *testing.T) {
        // Each of these should decode to the string "sure", without errors.
-       const expected = "sure"
        examples := []string{
                "ON2XEZI=",
                "ON2XEZI=\r",
@@ -230,14 +243,44 @@ func TestNewLineCharacters(t *testing.T) {
                "ON2XEZ\nI=",
                "ON2XEZI\n=",
        }
-       for _, e := range examples {
-               buf, err := StdEncoding.DecodeString(e)
-               if err != nil {
-                       t.Errorf("Decode(%q) failed: %v", e, err)
-                       continue
-               }
-               if s := string(buf); s != expected {
-                       t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
-               }
+       testStringEncoding(t, "sure", examples)
+
+       // Each of these should decode to the string "foobar", without errors.
+       examples = []string{
+               "MZXW6YTBOI======",
+               "MZXW6YTBOI=\r\n=====",
+       }
+       testStringEncoding(t, "foobar", examples)
+}
+
+func TestDecoderIssue4779(t *testing.T) {
+       encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4
+RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH
+K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB
+WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM
+MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR
+DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX
+IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF
+2HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF
+NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42
+LNEBUWIIDFON2CA3DBMJXXE5LNFY==
+====`
+       encodedShort := strings.Replace(encoded, "\n", "", -1)
+
+       dec := NewDecoder(StdEncoding, bytes.NewBufferString(encoded))
+       res1, err := ioutil.ReadAll(dec)
+       if err != nil {
+               t.Errorf("ReadAll failed: %v", err)
+       }
+
+       dec = NewDecoder(StdEncoding, bytes.NewBufferString(encodedShort))
+       var res2 []byte
+       res2, err = ioutil.ReadAll(dec)
+       if err != nil {
+               t.Errorf("ReadAll failed: %v", err)
+       }
+
+       if !bytes.Equal(res1, res2) {
+               t.Error("Decoded results not equal")
        }
 }
index 26dd7f7b99f687952e7931d6318f5380a5dd14d4..85e398fd0b7be3b59e85df6f8a565cb53d0f8e91 100644 (file)
@@ -6,8 +6,10 @@
 package base64
 
 import (
+       "bytes"
        "io"
        "strconv"
+       "strings"
 )
 
 /*
@@ -49,6 +51,13 @@ var StdEncoding = NewEncoding(encodeStd)
 // It is typically used in URLs and file names.
 var URLEncoding = NewEncoding(encodeURL)
 
+var removeNewlinesMapper = func(r rune) rune {
+       if r == '\r' || r == '\n' {
+               return -1
+       }
+       return r
+}
+
 /*
  * Encoder
  */
@@ -208,7 +217,8 @@ func (e CorruptInputError) Error() string {
 
 // decode is like Decode but returns an additional 'end' value, which
 // indicates if end-of-message padding was encountered and thus any
-// additional data is an error.
+// additional data is an error. This method assumes that src has been
+// stripped of all supported whitespace ('\r' and '\n').
 func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
        olen := len(src)
        for len(src) > 0 && !end {
@@ -222,10 +232,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
                        }
                        in := src[0]
                        src = src[1:]
-                       if in == '\r' || in == '\n' {
-                               // Ignore this character.
-                               continue
-                       }
                        if in == '=' && j >= 2 && len(src) < 4 {
                                // We've reached the end and there's padding
                                if len(src)+j < 4-1 {
@@ -271,12 +277,14 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
 // number of bytes successfully written and CorruptInputError.
 // New line characters (\r and \n) are ignored.
 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
+       src = bytes.Map(removeNewlinesMapper, src)
        n, _, err = enc.decode(dst, src)
        return
 }
 
 // DecodeString returns the bytes represented by the base64 string s.
 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
+       s = strings.Map(removeNewlinesMapper, s)
        dbuf := make([]byte, enc.DecodedLen(len(s)))
        n, err := enc.Decode(dbuf, []byte(s))
        return dbuf[:n], err
@@ -341,9 +349,34 @@ func (d *decoder) Read(p []byte) (n int, err error) {
        return n, d.err
 }
 
+type newlineFilteringReader struct {
+       wrapped io.Reader
+}
+
+func (r *newlineFilteringReader) Read(p []byte) (int, error) {
+       n, err := r.wrapped.Read(p)
+       for n > 0 {
+               offset := 0
+               for i, b := range p[0:n] {
+                       if b != '\r' && b != '\n' {
+                               if i != offset {
+                                       p[offset] = b
+                               }
+                               offset++
+                       }
+               }
+               if offset > 0 {
+                       return offset, err
+               }
+               // Previous buffer entirely whitespace, read again
+               n, err = r.wrapped.Read(p)
+       }
+       return n, err
+}
+
 // NewDecoder constructs a new base64 stream decoder.
 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
-       return &decoder{enc: enc, r: r}
+       return &decoder{enc: enc, r: &newlineFilteringReader{r}}
 }
 
 // DecodedLen returns the maximum length in bytes of the decoded data
index 71c2bfce7f3aedd0ff9eb829f70d9bb451365771..579591a88d7a0b44f33eef97c1beb53f7eba34fe 100644 (file)
@@ -9,6 +9,7 @@ import (
        "errors"
        "io"
        "io/ioutil"
+       "strings"
        "testing"
        "time"
 )
@@ -225,6 +226,8 @@ func TestNewLineCharacters(t *testing.T) {
                "c3V\nyZ\rQ==",
                "c3VyZ\nQ==",
                "c3VyZQ\n==",
+               "c3VyZQ=\n=",
+               "c3VyZQ=\r\n\r\n=",
        }
        for _, e := range examples {
                buf, err := StdEncoding.DecodeString(e)
@@ -285,3 +288,40 @@ func TestDecoderIssue3577(t *testing.T) {
                t.Errorf("timeout; Decoder blocked without returning an error")
        }
 }
+
+func TestDecoderIssue4779(t *testing.T) {
+       encoded := `CP/EAT8AAAEF
+AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB
+BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx
+Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm
+9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS
+0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0
+pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj
+1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N
+ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf
++gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM
+NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn
+EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy
+j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv
+2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2
+bqbPb06551Y4
+`
+       encodedShort := strings.Replace(encoded, "\n", "", -1)
+
+       dec := NewDecoder(StdEncoding, bytes.NewBufferString(encoded))
+       res1, err := ioutil.ReadAll(dec)
+       if err != nil {
+               t.Errorf("ReadAll failed: %v", err)
+       }
+
+       dec = NewDecoder(StdEncoding, bytes.NewBufferString(encodedShort))
+       var res2 []byte
+       res2, err = ioutil.ReadAll(dec)
+       if err != nil {
+               t.Errorf("ReadAll failed: %v", err)
+       }
+
+       if !bytes.Equal(res1, res2) {
+               t.Error("Decoded results not equal")
+       }
+}