]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/hex: add NewEncoder, NewDecoder
authorTim Cooper <tim.cooper@layeh.com>
Thu, 12 Oct 2017 01:05:03 +0000 (22:05 -0300)
committerJoe Tsai <thebrokentoaster@gmail.com>
Fri, 20 Oct 2017 23:47:07 +0000 (23:47 +0000)
NewEncoder returns an io.Writer that writes all incoming bytes as
hexadecimal characters to the underlying io.Writer. NewDecoder returns an
io.Reader that does the inverse.

Fixes #21590

Change-Id: Iebe0813faf365b42598f19a9aa41768f571dc0a8
Reviewed-on: https://go-review.googlesource.com/70210
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/encoding/hex/hex.go
src/encoding/hex/hex_test.go

index 18e0c09ef3889771b0a35e423b60cd16a9de4938..f47b7fa34ea1f5a6893c0e4e955203e2699dd3f8 100644 (file)
@@ -58,11 +58,11 @@ func Decode(dst, src []byte) (int, error) {
        for i := 0; i < len(src)/2; i++ {
                a, ok := fromHexChar(src[i*2])
                if !ok {
-                       return 0, InvalidByteError(src[i*2])
+                       return i, InvalidByteError(src[i*2])
                }
                b, ok := fromHexChar(src[i*2+1])
                if !ok {
-                       return 0, InvalidByteError(src[i*2+1])
+                       return i, InvalidByteError(src[i*2+1])
                }
                dst[i] = (a << 4) | b
        }
@@ -113,6 +113,77 @@ func Dump(data []byte) string {
        return buf.String()
 }
 
+// bufferSize is the number of hexadecimal characters to buffer in encoder and decoder.
+const bufferSize = 1024
+
+type encoder struct {
+       w   io.Writer
+       err error
+       out [bufferSize]byte // output buffer
+}
+
+// NewEncoder returns an io.Writer that writes lowercase hexadecimal characters to w.
+func NewEncoder(w io.Writer) io.Writer {
+       return &encoder{w: w}
+}
+
+func (e *encoder) Write(p []byte) (n int, err error) {
+       for len(p) > 0 && e.err == nil {
+               chunkSize := bufferSize / 2
+               if len(p) < chunkSize {
+                       chunkSize = len(p)
+               }
+
+               var written int
+               encoded := Encode(e.out[:], p[:chunkSize])
+               written, e.err = e.w.Write(e.out[:encoded])
+               n += written / 2
+               p = p[chunkSize:]
+       }
+       return n, e.err
+}
+
+type decoder struct {
+       r   io.Reader
+       err error
+       in  []byte           // input buffer (encoded form)
+       arr [bufferSize]byte // backing array for in
+}
+
+// NewDecoder returns an io.Reader that decodes hexadecimal characters from r.
+// NewDecoder expects that r contain only an even number of hexadecimal characters.
+func NewDecoder(r io.Reader) io.Reader {
+       return &decoder{r: r}
+}
+
+func (d *decoder) Read(p []byte) (n int, err error) {
+       // Fill internal buffer with sufficient bytes to decode
+       if len(d.in) < 2 && d.err == nil {
+               var numCopy, numRead int
+               numCopy = copy(d.arr[:], d.in) // Copies either 0 or 1 bytes
+               numRead, d.err = d.r.Read(d.arr[numCopy:])
+               d.in = d.arr[:numCopy+numRead]
+               if d.err == io.EOF && len(d.in)%2 != 0 {
+                       d.err = io.ErrUnexpectedEOF
+               }
+       }
+
+       // Decode internal buffer into output buffer
+       if numAvail := len(d.in) / 2; len(p) > numAvail {
+               p = p[:numAvail]
+       }
+       numDec, err := Decode(p, d.in[:len(p)*2])
+       d.in = d.in[2*numDec:]
+       if err != nil {
+               d.in, d.err = nil, err // Decode error; discard input remainder
+       }
+
+       if len(d.in) < 2 {
+               return numDec, d.err // Only expose errors when buffer fully consumed
+       }
+       return numDec, nil
+}
+
 // Dumper returns a WriteCloser that writes a hex dump of all written data to
 // w. The format of the dump matches the output of `hexdump -C` on the command
 // line.
index e6dc765c9578fd40bd67d73a954f94b7a69651f9..d874b39e9597a9deb028d298e4290e45b5e21bc8 100644 (file)
@@ -7,6 +7,9 @@ package hex
 import (
        "bytes"
        "fmt"
+       "io"
+       "io/ioutil"
+       "strings"
        "testing"
 )
 
@@ -111,6 +114,63 @@ func TestInvalidStringErr(t *testing.T) {
        }
 }
 
+func TestEncoderDecoder(t *testing.T) {
+       for _, multiplier := range []int{1, 128, 192} {
+               for _, test := range encDecTests {
+                       input := bytes.Repeat(test.dec, multiplier)
+                       output := strings.Repeat(test.enc, multiplier)
+
+                       var buf bytes.Buffer
+                       enc := NewEncoder(&buf)
+                       r := struct{ io.Reader }{bytes.NewReader(input)} // io.Reader only; not io.WriterTo
+                       if n, err := io.CopyBuffer(enc, r, make([]byte, 7)); n != int64(len(input)) || err != nil {
+                               t.Errorf("encoder.Write(%q*%d) = (%d, %v), want (%d, nil)", test.dec, multiplier, n, err, len(input))
+                               continue
+                       }
+
+                       if encDst := buf.String(); encDst != output {
+                               t.Errorf("buf(%q*%d) = %v, want %v", test.dec, multiplier, encDst, output)
+                               continue
+                       }
+
+                       dec := NewDecoder(&buf)
+                       var decBuf bytes.Buffer
+                       w := struct{ io.Writer }{&decBuf} // io.Writer only; not io.ReaderFrom
+                       if _, err := io.CopyBuffer(w, dec, make([]byte, 7)); err != nil || decBuf.Len() != len(input) {
+                               t.Errorf("decoder.Read(%q*%d) = (%d, %v), want (%d, nil)", test.enc, multiplier, decBuf.Len(), err, len(input))
+                       }
+
+                       if !bytes.Equal(decBuf.Bytes(), input) {
+                               t.Errorf("decBuf(%q*%d) = %v, want %v", test.dec, multiplier, decBuf.Bytes(), input)
+                               continue
+                       }
+               }
+       }
+}
+
+func TestDecodeErr(t *testing.T) {
+       tests := []struct {
+               in      string
+               wantOut string
+               wantErr error
+       }{
+               {"", "", nil},
+               {"0", "", io.ErrUnexpectedEOF},
+               {"0g", "", InvalidByteError('g')},
+               {"00gg", "\x00", InvalidByteError('g')},
+               {"0\x01", "", InvalidByteError('\x01')},
+               {"ffeed", "\xff\xee", io.ErrUnexpectedEOF},
+       }
+
+       for _, tt := range tests {
+               dec := NewDecoder(strings.NewReader(tt.in))
+               got, err := ioutil.ReadAll(dec)
+               if string(got) != tt.wantOut || err != tt.wantErr {
+                       t.Errorf("NewDecoder(%q) = (%q, %v), want (%q, %v)", tt.in, got, err, tt.wantOut, tt.wantErr)
+               }
+       }
+}
+
 func TestDumper(t *testing.T) {
        var in [40]byte
        for i := range in {