]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/hex: make Decode, Decoder, DecodeString agree about partial results and...
authorRuss Cox <rsc@golang.org>
Tue, 14 Nov 2017 18:43:02 +0000 (13:43 -0500)
committerRuss Cox <rsc@golang.org>
Thu, 16 Nov 2017 01:08:20 +0000 (01:08 +0000)
CL 70210 added Decoder for #21590, and in doing so it changed
the existing func Decode to return partial results for decoding errors.
That seems like a good change to make to Decode, but it was
untested (except as used by Decoder), inconsistent with DecodeString
in all error cases, and inconsistent with Decoder in not returning
partial results for odd-length input strings.

This CL makes Decode, DecodeString, and Decoder all agree about
the handling of partial results (they are returned) and error
precedence (the error earliest in the input is reported),
and it documents and tests this.

Change-Id: Ifb7d1e100ecb66fe2ed5ba34a621084d480f16db
Reviewed-on: https://go-review.googlesource.com/78120
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/encoding/hex/hex.go
src/encoding/hex/hex_test.go

index f47b7fa34ea1f5a6893c0e4e955203e2699dd3f8..e4df6cbd4d11ce10503fea96272d3affb057c31d 100644 (file)
@@ -31,7 +31,9 @@ func Encode(dst, src []byte) int {
        return len(src) * 2
 }
 
-// ErrLength results from decoding an odd length slice.
+// ErrLength reports an attempt to decode an odd-length input
+// using Decode or DecodeString.
+// The stream-based Decoder returns io.ErrUnexpectedEOF instead of ErrLength.
 var ErrLength = errors.New("encoding/hex: odd length hex string")
 
 // InvalidByteError values describe errors resulting from an invalid byte in a hex string.
@@ -50,12 +52,11 @@ func DecodedLen(x int) int { return x / 2 }
 //
 // Decode expects that src contain only hexadecimal
 // characters and that src should have an even length.
+// If the input is malformed, Decode returns the number
+// of bytes decoded before the error.
 func Decode(dst, src []byte) (int, error) {
-       if len(src)%2 == 1 {
-               return 0, ErrLength
-       }
-
-       for i := 0; i < len(src)/2; i++ {
+       var i int
+       for i = 0; i < len(src)/2; i++ {
                a, ok := fromHexChar(src[i*2])
                if !ok {
                        return i, InvalidByteError(src[i*2])
@@ -66,8 +67,15 @@ func Decode(dst, src []byte) (int, error) {
                }
                dst[i] = (a << 4) | b
        }
-
-       return len(src) / 2, nil
+       if len(src)%2 == 1 {
+               // Check for invalid char before reporting bad length,
+               // since the invalid char (if present) is an earlier problem.
+               if _, ok := fromHexChar(src[i*2]); !ok {
+                       return i, InvalidByteError(src[i*2])
+               }
+               return i, ErrLength
+       }
+       return i, nil
 }
 
 // fromHexChar converts a hex character into its value and a success flag.
@@ -92,15 +100,17 @@ func EncodeToString(src []byte) string {
 }
 
 // DecodeString returns the bytes represented by the hexadecimal string s.
+//
+// DecodeString expects that src contain only hexadecimal
+// characters and that src should have an even length.
+// If the input is malformed, DecodeString returns a string
+// containing the bytes decoded before the error.
 func DecodeString(s string) ([]byte, error) {
        src := []byte(s)
        // We can use the source slice itself as the destination
        // because the decode loop increments by one and then the 'seen' byte is not used anymore.
-       len, err := Decode(src, src)
-       if err != nil {
-               return nil, err
-       }
-       return src[:len], nil
+       n, err := Decode(src, src)
+       return src[:n], err
 }
 
 // Dump returns a string that contains a hex dump of the given data. The format
@@ -164,7 +174,11 @@ func (d *decoder) Read(p []byte) (n int, err error) {
                numRead, d.err = d.r.Read(d.arr[numCopy:])
                d.in = d.arr[:numCopy+numRead]
                if d.err == io.EOF && len(d.in)%2 != 0 {
-                       d.err = io.ErrUnexpectedEOF
+                       if _, ok := fromHexChar(d.in[len(d.in)-1]); !ok {
+                               d.err = InvalidByteError(d.in[len(d.in)-1])
+                       } else {
+                               d.err = io.ErrUnexpectedEOF
+                       }
                }
        }
 
index d874b39e9597a9deb028d298e4290e45b5e21bc8..b6bab21c4801d87c75a67b7e9a1518bdc7ce02ec 100644 (file)
@@ -78,38 +78,37 @@ func TestDecodeString(t *testing.T) {
        }
 }
 
-type errTest struct {
+var errTests = []struct {
        in  string
+       out string
        err error
+}{
+       {"", "", nil},
+       {"0", "", ErrLength},
+       {"zd4aa", "", InvalidByteError('z')},
+       {"d4aaz", "\xd4\xaa", InvalidByteError('z')},
+       {"30313", "01", ErrLength},
+       {"0g", "", InvalidByteError('g')},
+       {"00gg", "\x00", InvalidByteError('g')},
+       {"0\x01", "", InvalidByteError('\x01')},
+       {"ffeed", "\xff\xee", ErrLength},
 }
 
-var errTests = []errTest{
-       {"0", ErrLength},
-       {"zd4aa", ErrLength},
-       {"0g", InvalidByteError('g')},
-       {"00gg", InvalidByteError('g')},
-       {"0\x01", InvalidByteError('\x01')},
-}
-
-func TestInvalidErr(t *testing.T) {
-       for i, test := range errTests {
-               dst := make([]byte, DecodedLen(len(test.in)))
-               _, err := Decode(dst, []byte(test.in))
-               if err == nil {
-                       t.Errorf("#%d: expected %v; got none", i, test.err)
-               } else if err != test.err {
-                       t.Errorf("#%d: got: %v want: %v", i, err, test.err)
+func TestDecodeErr(t *testing.T) {
+       for _, tt := range errTests {
+               out := make([]byte, len(tt.in)+10)
+               n, err := Decode(out, []byte(tt.in))
+               if string(out[:n]) != tt.out || err != tt.err {
+                       t.Errorf("Decode(%q) = %q, %v, want %q, %v", tt.in, string(out[:n]), err, tt.out, tt.err)
                }
        }
 }
 
-func TestInvalidStringErr(t *testing.T) {
-       for i, test := range errTests {
-               _, err := DecodeString(test.in)
-               if err == nil {
-                       t.Errorf("#%d: expected %v; got none", i, test.err)
-               } else if err != test.err {
-                       t.Errorf("#%d: got: %v want: %v", i, err, test.err)
+func TestDecodeStringErr(t *testing.T) {
+       for _, tt := range errTests {
+               out, err := DecodeString(tt.in)
+               if string(out) != tt.out || err != tt.err {
+                       t.Errorf("DecodeString(%q) = %q, %v, want %q, %v", tt.in, out, err, tt.out, tt.err)
                }
        }
 }
@@ -148,25 +147,17 @@ func TestEncoderDecoder(t *testing.T) {
        }
 }
 
-func TestDecodeErr(t *testing.T) {
-       tests := []struct {
-               in      string
-               wantOut string
-               wantErr error
-       }{
-               {"", "", nil},
-               {"0", "", io.ErrUnexpectedEOF},
-               {"0g", "", InvalidByteError('g')},
-               {"00gg", "\x00", InvalidByteError('g')},
-               {"0\x01", "", InvalidByteError('\x01')},
-               {"ffeed", "\xff\xee", io.ErrUnexpectedEOF},
-       }
-
-       for _, tt := range tests {
+func TestDecoderErr(t *testing.T) {
+       for _, tt := range errTests {
                dec := NewDecoder(strings.NewReader(tt.in))
-               got, err := ioutil.ReadAll(dec)
-               if string(got) != tt.wantOut || err != tt.wantErr {
-                       t.Errorf("NewDecoder(%q) = (%q, %v), want (%q, %v)", tt.in, got, err, tt.wantOut, tt.wantErr)
+               out, err := ioutil.ReadAll(dec)
+               wantErr := tt.err
+               // Decoder is reading from stream, so it reports io.ErrUnexpectedEOF instead of ErrLength.
+               if wantErr == ErrLength {
+                       wantErr = io.ErrUnexpectedEOF
+               }
+               if string(out) != tt.out || err != wantErr {
+                       t.Errorf("NewDecoder(%q) = %q, %v, want %q, %v", tt.in, out, err, tt.out, wantErr)
                }
        }
 }