]> Cypherpunks repositories - gostls13.git/commitdiff
encoding/asn1: reduce allocations in Marshal
authorHiroshi Ioka <hirochachacha@gmail.com>
Tue, 2 Aug 2016 05:41:53 +0000 (14:41 +0900)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 13 Sep 2016 21:05:27 +0000 (21:05 +0000)
Current code uses trees of bytes.Buffer as data representation.
Each bytes.Buffer takes 4k bytes at least, so it's waste of memory.
The change introduces trees of lazy-encoder as
alternative one which reduce allocations.

name       old time/op    new time/op    delta
Marshal-4    64.7µs ± 2%    42.0µs ± 1%  -35.07%   (p=0.000 n=9+10)

name       old alloc/op   new alloc/op   delta
Marshal-4    35.1kB ± 0%     7.6kB ± 0%  -78.27%  (p=0.000 n=10+10)

name       old allocs/op  new allocs/op  delta
Marshal-4       503 ± 0%       293 ± 0%  -41.75%  (p=0.000 n=10+10)

Change-Id: I32b96c20b8df00414b282d69743d71a598a11336
Reviewed-on: https://go-review.googlesource.com/27030
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/encoding/asn1/asn1_test.go
src/encoding/asn1/marshal.go
src/encoding/asn1/marshal_test.go

index f8623fa9a216df6fe3c00a7b36174e12a8ab782e..81f4dba8c290cd2e641166a83d13c0717efea15d 100644 (file)
@@ -132,9 +132,9 @@ func TestParseBigInt(t *testing.T) {
                        if ret.String() != test.base10 {
                                t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
                        }
-                       fw := newForkableWriter()
-                       marshalBigInt(fw, ret)
-                       result := fw.Bytes()
+                       e := makeBigInt(ret)
+                       result := make([]byte, e.Len())
+                       e.Encode(result)
                        if !bytes.Equal(result, test.in) {
                                t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
                        }
index 30797ef0996e547415977d29110a259d7b16c263..f0664d3d46b1eb9a9af06bb1e0636d6e225d9f67 100644 (file)
 package asn1
 
 import (
-       "bytes"
        "errors"
        "fmt"
-       "io"
        "math/big"
        "reflect"
        "time"
        "unicode/utf8"
 )
 
-// A forkableWriter is an in-memory buffer that can be
-// 'forked' to create new forkableWriters that bracket the
-// original. After
-//    pre, post := w.fork()
-// the overall sequence of bytes represented is logically w+pre+post.
-type forkableWriter struct {
-       *bytes.Buffer
-       pre, post *forkableWriter
+var (
+       byte00Encoder encoder = byteEncoder(0x00)
+       byteFFEncoder encoder = byteEncoder(0xff)
+)
+
+// encoder represents a ASN.1 element that is waiting to be marshaled.
+type encoder interface {
+       // Len returns the number of bytes needed to marshal this element.
+       Len() int
+       // Encode encodes this element by writing Len() bytes to dst.
+       Encode(dst []byte)
+}
+
+type byteEncoder byte
+
+func (c byteEncoder) Len() int {
+       return 1
 }
 
-func newForkableWriter() *forkableWriter {
-       return &forkableWriter{new(bytes.Buffer), nil, nil}
+func (c byteEncoder) Encode(dst []byte) {
+       dst[0] = byte(c)
 }
 
-func (f *forkableWriter) fork() (pre, post *forkableWriter) {
-       if f.pre != nil || f.post != nil {
-               panic("have already forked")
+type bytesEncoder []byte
+
+func (b bytesEncoder) Len() int {
+       return len(b)
+}
+
+func (b bytesEncoder) Encode(dst []byte) {
+       if copy(dst, b) != len(b) {
+               panic("internal error")
        }
-       f.pre = newForkableWriter()
-       f.post = newForkableWriter()
-       return f.pre, f.post
 }
 
-func (f *forkableWriter) Len() (l int) {
-       l += f.Buffer.Len()
-       if f.pre != nil {
-               l += f.pre.Len()
+type stringEncoder string
+
+func (s stringEncoder) Len() int {
+       return len(s)
+}
+
+func (s stringEncoder) Encode(dst []byte) {
+       if copy(dst, s) != len(s) {
+               panic("internal error")
        }
-       if f.post != nil {
-               l += f.post.Len()
+}
+
+type multiEncoder []encoder
+
+func (m multiEncoder) Len() int {
+       var size int
+       for _, e := range m {
+               size += e.Len()
        }
-       return
+       return size
 }
 
-func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
-       n, err = out.Write(f.Bytes())
-       if err != nil {
-               return
+func (m multiEncoder) Encode(dst []byte) {
+       var off int
+       for _, e := range m {
+               e.Encode(dst[off:])
+               off += e.Len()
        }
+}
 
-       var nn int
+type taggedEncoder struct {
+       // scratch contains temporary space for encoding the tag and length of
+       // an element in order to avoid extra allocations.
+       scratch [8]byte
+       tag     encoder
+       body    encoder
+}
 
-       if f.pre != nil {
-               nn, err = f.pre.writeTo(out)
-               n += nn
-               if err != nil {
-                       return
-               }
+func (t *taggedEncoder) Len() int {
+       return t.tag.Len() + t.body.Len()
+}
+
+func (t *taggedEncoder) Encode(dst []byte) {
+       t.tag.Encode(dst)
+       t.body.Encode(dst[t.tag.Len():])
+}
+
+type int64Encoder int64
+
+func (i int64Encoder) Len() int {
+       n := 1
+
+       for i > 127 {
+               n++
+               i >>= 8
        }
 
-       if f.post != nil {
-               nn, err = f.post.writeTo(out)
-               n += nn
+       for i < -128 {
+               n++
+               i >>= 8
        }
-       return
+
+       return n
 }
 
-func marshalBase128Int(out *forkableWriter, n int64) (err error) {
+func (i int64Encoder) Encode(dst []byte) {
+       n := i.Len()
+
+       for j := 0; j < n; j++ {
+               dst[j] = byte(i >> uint((n-1-j)*8))
+       }
+}
+
+func base128IntLength(n int64) int {
        if n == 0 {
-               err = out.WriteByte(0)
-               return
+               return 1
        }
 
        l := 0
@@ -83,54 +131,29 @@ func marshalBase128Int(out *forkableWriter, n int64) (err error) {
                l++
        }
 
+       return l
+}
+
+func appendBase128Int(dst []byte, n int64) []byte {
+       l := base128IntLength(n)
+
        for i := l - 1; i >= 0; i-- {
                o := byte(n >> uint(i*7))
                o &= 0x7f
                if i != 0 {
                        o |= 0x80
                }
-               err = out.WriteByte(o)
-               if err != nil {
-                       return
-               }
-       }
-
-       return nil
-}
 
-func marshalInt64(out *forkableWriter, i int64) (err error) {
-       n := int64Length(i)
-
-       for ; n > 0; n-- {
-               err = out.WriteByte(byte(i >> uint((n-1)*8)))
-               if err != nil {
-                       return
-               }
+               dst = append(dst, o)
        }
 
-       return nil
+       return dst
 }
 
-func int64Length(i int64) (numBytes int) {
-       numBytes = 1
-
-       for i > 127 {
-               numBytes++
-               i >>= 8
-       }
-
-       for i < -128 {
-               numBytes++
-               i >>= 8
-       }
-
-       return
-}
-
-func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
+func makeBigInt(n *big.Int) encoder {
        if n.Sign() < 0 {
                // A negative number has to be converted to two's-complement
-               // form. So we'll subtract 1 and invert. If the
+               // form. So we'll invert and subtract 1. If the
                // most-significant-bit isn't set then we'll need to pad the
                // beginning with 0xff in order to keep the number negative.
                nMinus1 := new(big.Int).Neg(n)
@@ -140,41 +163,31 @@ func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
                        bytes[i] ^= 0xff
                }
                if len(bytes) == 0 || bytes[0]&0x80 == 0 {
-                       err = out.WriteByte(0xff)
-                       if err != nil {
-                               return
-                       }
+                       return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)})
                }
-               _, err = out.Write(bytes)
+               return bytesEncoder(bytes)
        } else if n.Sign() == 0 {
                // Zero is written as a single 0 zero rather than no bytes.
-               err = out.WriteByte(0x00)
+               return byte00Encoder
        } else {
                bytes := n.Bytes()
                if len(bytes) > 0 && bytes[0]&0x80 != 0 {
                        // We'll have to pad this with 0x00 in order to stop it
                        // looking like a negative number.
-                       err = out.WriteByte(0)
-                       if err != nil {
-                               return
-                       }
+                       return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)})
                }
-               _, err = out.Write(bytes)
+               return bytesEncoder(bytes)
        }
-       return
 }
 
-func marshalLength(out *forkableWriter, i int) (err error) {
+func appendLength(dst []byte, i int) []byte {
        n := lengthLength(i)
 
        for ; n > 0; n-- {
-               err = out.WriteByte(byte(i >> uint((n-1)*8)))
-               if err != nil {
-                       return
-               }
+               dst = append(dst, byte(i>>uint((n-1)*8)))
        }
 
-       return nil
+       return dst
 }
 
 func lengthLength(i int) (numBytes int) {
@@ -186,123 +199,104 @@ func lengthLength(i int) (numBytes int) {
        return
 }
 
-func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
+func appendTagAndLength(dst []byte, t tagAndLength) []byte {
        b := uint8(t.class) << 6
        if t.isCompound {
                b |= 0x20
        }
        if t.tag >= 31 {
                b |= 0x1f
-               err = out.WriteByte(b)
-               if err != nil {
-                       return
-               }
-               err = marshalBase128Int(out, int64(t.tag))
-               if err != nil {
-                       return
-               }
+               dst = append(dst, b)
+               dst = appendBase128Int(dst, int64(t.tag))
        } else {
                b |= uint8(t.tag)
-               err = out.WriteByte(b)
-               if err != nil {
-                       return
-               }
+               dst = append(dst, b)
        }
 
        if t.length >= 128 {
                l := lengthLength(t.length)
-               err = out.WriteByte(0x80 | byte(l))
-               if err != nil {
-                       return
-               }
-               err = marshalLength(out, t.length)
-               if err != nil {
-                       return
-               }
+               dst = append(dst, 0x80|byte(l))
+               dst = appendLength(dst, t.length)
        } else {
-               err = out.WriteByte(byte(t.length))
-               if err != nil {
-                       return
-               }
+               dst = append(dst, byte(t.length))
        }
 
-       return nil
+       return dst
 }
 
-func marshalBitString(out *forkableWriter, b BitString) (err error) {
-       paddingBits := byte((8 - b.BitLength%8) % 8)
-       err = out.WriteByte(paddingBits)
-       if err != nil {
-               return
-       }
-       _, err = out.Write(b.Bytes)
-       return
+type bitStringEncoder BitString
+
+func (b bitStringEncoder) Len() int {
+       return len(b.Bytes) + 1
 }
 
-func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
-       if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
-               return StructuralError{"invalid object identifier"}
+func (b bitStringEncoder) Encode(dst []byte) {
+       dst[0] = byte((8 - b.BitLength%8) % 8)
+       if copy(dst[1:], b.Bytes) != len(b.Bytes) {
+               panic("internal error")
        }
+}
 
-       err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
-       if err != nil {
-               return
+type oidEncoder []int
+
+func (oid oidEncoder) Len() int {
+       l := base128IntLength(int64(oid[0]*40 + oid[1]))
+       for i := 2; i < len(oid); i++ {
+               l += base128IntLength(int64(oid[i]))
        }
+       return l
+}
+
+func (oid oidEncoder) Encode(dst []byte) {
+       dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
        for i := 2; i < len(oid); i++ {
-               err = marshalBase128Int(out, int64(oid[i]))
-               if err != nil {
-                       return
-               }
+               dst = appendBase128Int(dst, int64(oid[i]))
+       }
+}
+
+func makeObjectIdentifier(oid []int) (e encoder, err error) {
+       if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
+               return nil, StructuralError{"invalid object identifier"}
        }
 
-       return
+       return oidEncoder(oid), nil
 }
 
-func marshalPrintableString(out *forkableWriter, s string) (err error) {
-       b := []byte(s)
-       for _, c := range b {
-               if !isPrintable(c) {
-                       return StructuralError{"PrintableString contains invalid character"}
+func makePrintableString(s string) (e encoder, err error) {
+       for i := 0; i < len(s); i++ {
+               if !isPrintable(s[i]) {
+                       return nil, StructuralError{"PrintableString contains invalid character"}
                }
        }
 
-       _, err = out.Write(b)
-       return
+       return stringEncoder(s), nil
 }
 
-func marshalIA5String(out *forkableWriter, s string) (err error) {
-       b := []byte(s)
-       for _, c := range b {
-               if c > 127 {
-                       return StructuralError{"IA5String contains invalid character"}
+func makeIA5String(s string) (e encoder, err error) {
+       for i := 0; i < len(s); i++ {
+               if s[i] > 127 {
+                       return nil, StructuralError{"IA5String contains invalid character"}
                }
        }
 
-       _, err = out.Write(b)
-       return
+       return stringEncoder(s), nil
 }
 
-func marshalUTF8String(out *forkableWriter, s string) (err error) {
-       _, err = out.Write([]byte(s))
-       return
+func makeUTF8String(s string) encoder {
+       return stringEncoder(s)
 }
 
-func marshalTwoDigits(out *forkableWriter, v int) (err error) {
-       err = out.WriteByte(byte('0' + (v/10)%10))
-       if err != nil {
-               return
-       }
-       return out.WriteByte(byte('0' + v%10))
+func appendTwoDigits(dst []byte, v int) []byte {
+       return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
 }
 
-func marshalFourDigits(out *forkableWriter, v int) (err error) {
+func appendFourDigits(dst []byte, v int) []byte {
        var bytes [4]byte
        for i := range bytes {
                bytes[3-i] = '0' + byte(v%10)
                v /= 10
        }
-       _, err = out.Write(bytes[:])
-       return
+       return append(dst, bytes[:]...)
 }
 
 func outsideUTCRange(t time.Time) bool {
@@ -310,80 +304,75 @@ func outsideUTCRange(t time.Time) bool {
        return year < 1950 || year >= 2050
 }
 
-func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
+func makeUTCTime(t time.Time) (e encoder, err error) {
+       dst := make([]byte, 0, 18)
+
+       dst, err = appendUTCTime(dst, t)
+       if err != nil {
+               return nil, err
+       }
+
+       return bytesEncoder(dst), nil
+}
+
+func makeGeneralizedTime(t time.Time) (e encoder, err error) {
+       dst := make([]byte, 0, 20)
+
+       dst, err = appendGeneralizedTime(dst, t)
+       if err != nil {
+               return nil, err
+       }
+
+       return bytesEncoder(dst), nil
+}
+
+func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
        year := t.Year()
 
        switch {
        case 1950 <= year && year < 2000:
-               err = marshalTwoDigits(out, year-1900)
+               dst = appendTwoDigits(dst, year-1900)
        case 2000 <= year && year < 2050:
-               err = marshalTwoDigits(out, year-2000)
+               dst = appendTwoDigits(dst, year-2000)
        default:
-               return StructuralError{"cannot represent time as UTCTime"}
-       }
-       if err != nil {
-               return
+               return nil, StructuralError{"cannot represent time as UTCTime"}
        }
 
-       return marshalTimeCommon(out, t)
+       return appendTimeCommon(dst, t), nil
 }
 
-func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
+func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
        year := t.Year()
        if year < 0 || year > 9999 {
-               return StructuralError{"cannot represent time as GeneralizedTime"}
-       }
-       if err = marshalFourDigits(out, year); err != nil {
-               return
+               return nil, StructuralError{"cannot represent time as GeneralizedTime"}
        }
 
-       return marshalTimeCommon(out, t)
+       dst = appendFourDigits(dst, year)
+
+       return appendTimeCommon(dst, t), nil
 }
 
-func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
+func appendTimeCommon(dst []byte, t time.Time) []byte {
        _, month, day := t.Date()
 
-       err = marshalTwoDigits(out, int(month))
-       if err != nil {
-               return
-       }
-
-       err = marshalTwoDigits(out, day)
-       if err != nil {
-               return
-       }
+       dst = appendTwoDigits(dst, int(month))
+       dst = appendTwoDigits(dst, day)
 
        hour, min, sec := t.Clock()
 
-       err = marshalTwoDigits(out, hour)
-       if err != nil {
-               return
-       }
-
-       err = marshalTwoDigits(out, min)
-       if err != nil {
-               return
-       }
-
-       err = marshalTwoDigits(out, sec)
-       if err != nil {
-               return
-       }
+       dst = appendTwoDigits(dst, hour)
+       dst = appendTwoDigits(dst, min)
+       dst = appendTwoDigits(dst, sec)
 
        _, offset := t.Zone()
 
        switch {
        case offset/60 == 0:
-               err = out.WriteByte('Z')
-               return
+               return append(dst, 'Z')
        case offset > 0:
-               err = out.WriteByte('+')
+               dst = append(dst, '+')
        case offset < 0:
-               err = out.WriteByte('-')
-       }
-
-       if err != nil {
-               return
+               dst = append(dst, '-')
        }
 
        offsetMinutes := offset / 60
@@ -391,13 +380,10 @@ func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
                offsetMinutes = -offsetMinutes
        }
 
-       err = marshalTwoDigits(out, offsetMinutes/60)
-       if err != nil {
-               return
-       }
+       dst = appendTwoDigits(dst, offsetMinutes/60)
+       dst = appendTwoDigits(dst, offsetMinutes%60)
 
-       err = marshalTwoDigits(out, offsetMinutes%60)
-       return
+       return dst
 }
 
 func stripTagAndLength(in []byte) []byte {
@@ -408,114 +394,124 @@ func stripTagAndLength(in []byte) []byte {
        return in[offset:]
 }
 
-func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
+func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
        switch value.Type() {
        case flagType:
-               return nil
+               return bytesEncoder(nil), nil
        case timeType:
                t := value.Interface().(time.Time)
                if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
-                       return marshalGeneralizedTime(out, t)
-               } else {
-                       return marshalUTCTime(out, t)
+                       return makeGeneralizedTime(t)
                }
+               return makeUTCTime(t)
        case bitStringType:
-               return marshalBitString(out, value.Interface().(BitString))
+               return bitStringEncoder(value.Interface().(BitString)), nil
        case objectIdentifierType:
-               return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
+               return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
        case bigIntType:
-               return marshalBigInt(out, value.Interface().(*big.Int))
+               return makeBigInt(value.Interface().(*big.Int)), nil
        }
 
        switch v := value; v.Kind() {
        case reflect.Bool:
                if v.Bool() {
-                       return out.WriteByte(255)
-               } else {
-                       return out.WriteByte(0)
+                       return byteFFEncoder, nil
                }
+               return byte00Encoder, nil
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               return marshalInt64(out, v.Int())
+               return int64Encoder(v.Int()), nil
        case reflect.Struct:
                t := v.Type()
 
                startingField := 0
 
+               n := t.NumField()
+               if n == 0 {
+                       return bytesEncoder(nil), nil
+               }
+
                // If the first element of the structure is a non-empty
                // RawContents, then we don't bother serializing the rest.
-               if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
+               if t.Field(0).Type == rawContentsType {
                        s := v.Field(0)
                        if s.Len() > 0 {
-                               bytes := make([]byte, s.Len())
-                               for i := 0; i < s.Len(); i++ {
-                                       bytes[i] = uint8(s.Index(i).Uint())
-                               }
+                               bytes := s.Bytes()
                                /* The RawContents will contain the tag and
                                 * length fields but we'll also be writing
                                 * those ourselves, so we strip them out of
                                 * bytes */
-                               _, err = out.Write(stripTagAndLength(bytes))
-                               return
-                       } else {
-                               startingField = 1
+                               return bytesEncoder(stripTagAndLength(bytes)), nil
                        }
+
+                       startingField = 1
                }
 
-               for i := startingField; i < t.NumField(); i++ {
-                       var pre *forkableWriter
-                       pre, out = out.fork()
-                       err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
-                       if err != nil {
-                               return
+               switch n1 := n - startingField; n1 {
+               case 0:
+                       return bytesEncoder(nil), nil
+               case 1:
+                       return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
+               default:
+                       m := make([]encoder, n1)
+                       for i := 0; i < n1; i++ {
+                               m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
+                               if err != nil {
+                                       return nil, err
+                               }
                        }
+
+                       return multiEncoder(m), nil
                }
-               return
        case reflect.Slice:
                sliceType := v.Type()
                if sliceType.Elem().Kind() == reflect.Uint8 {
-                       bytes := make([]byte, v.Len())
-                       for i := 0; i < v.Len(); i++ {
-                               bytes[i] = uint8(v.Index(i).Uint())
-                       }
-                       _, err = out.Write(bytes)
-                       return
+                       return bytesEncoder(v.Bytes()), nil
                }
 
                var fp fieldParameters
-               for i := 0; i < v.Len(); i++ {
-                       var pre *forkableWriter
-                       pre, out = out.fork()
-                       err = marshalField(pre, v.Index(i), fp)
-                       if err != nil {
-                               return
+
+               switch l := v.Len(); l {
+               case 0:
+                       return bytesEncoder(nil), nil
+               case 1:
+                       return makeField(v.Index(0), fp)
+               default:
+                       m := make([]encoder, l)
+
+                       for i := 0; i < l; i++ {
+                               m[i], err = makeField(v.Index(i), fp)
+                               if err != nil {
+                                       return nil, err
+                               }
                        }
+
+                       return multiEncoder(m), nil
                }
-               return
        case reflect.String:
                switch params.stringType {
                case TagIA5String:
-                       return marshalIA5String(out, v.String())
+                       return makeIA5String(v.String())
                case TagPrintableString:
-                       return marshalPrintableString(out, v.String())
+                       return makePrintableString(v.String())
                default:
-                       return marshalUTF8String(out, v.String())
+                       return makeUTF8String(v.String()), nil
                }
        }
 
-       return StructuralError{"unknown Go type"}
+       return nil, StructuralError{"unknown Go type"}
 }
 
-func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
+func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
        if !v.IsValid() {
-               return fmt.Errorf("asn1: cannot marshal nil value")
+               return nil, fmt.Errorf("asn1: cannot marshal nil value")
        }
        // If the field is an interface{} then recurse into it.
        if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
-               return marshalField(out, v.Elem(), params)
+               return makeField(v.Elem(), params)
        }
 
        if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
-               return
+               return bytesEncoder(nil), nil
        }
 
        if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
@@ -523,7 +519,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
                defaultValue.SetInt(*params.defaultValue)
 
                if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
-                       return
+                       return bytesEncoder(nil), nil
                }
        }
 
@@ -532,37 +528,36 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
        // behaviour, but it's what Go has traditionally done.
        if params.optional && params.defaultValue == nil {
                if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
-                       return
+                       return bytesEncoder(nil), nil
                }
        }
 
        if v.Type() == rawValueType {
                rv := v.Interface().(RawValue)
                if len(rv.FullBytes) != 0 {
-                       _, err = out.Write(rv.FullBytes)
-               } else {
-                       err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
-                       if err != nil {
-                               return
-                       }
-                       _, err = out.Write(rv.Bytes)
+                       return bytesEncoder(rv.FullBytes), nil
                }
-               return
+
+               t := new(taggedEncoder)
+
+               t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
+               t.body = bytesEncoder(rv.Bytes)
+
+               return t, nil
        }
 
        tag, isCompound, ok := getUniversalType(v.Type())
        if !ok {
-               err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
-               return
+               return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
        }
        class := ClassUniversal
 
        if params.timeType != 0 && tag != TagUTCTime {
-               return StructuralError{"explicit time type given to non-time member"}
+               return nil, StructuralError{"explicit time type given to non-time member"}
        }
 
        if params.stringType != 0 && tag != TagPrintableString {
-               return StructuralError{"explicit string type given to non-string member"}
+               return nil, StructuralError{"explicit string type given to non-string member"}
        }
 
        switch tag {
@@ -574,7 +569,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
                        for _, r := range v.String() {
                                if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
                                        if !utf8.ValidString(v.String()) {
-                                               return errors.New("asn1: string not valid UTF-8")
+                                               return nil, errors.New("asn1: string not valid UTF-8")
                                        }
                                        tag = TagUTF8String
                                        break
@@ -591,46 +586,46 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
 
        if params.set {
                if tag != TagSequence {
-                       return StructuralError{"non sequence tagged as set"}
+                       return nil, StructuralError{"non sequence tagged as set"}
                }
                tag = TagSet
        }
 
-       tags, body := out.fork()
+       t := new(taggedEncoder)
 
-       err = marshalBody(body, v, params)
+       t.body, err = makeBody(v, params)
        if err != nil {
-               return
+               return nil, err
        }
 
-       bodyLen := body.Len()
+       bodyLen := t.body.Len()
 
-       var explicitTag *forkableWriter
        if params.explicit {
-               explicitTag, tags = tags.fork()
-       }
+               t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
 
-       if !params.explicit && params.tag != nil {
-               // implicit tag.
-               tag = *params.tag
-               class = ClassContextSpecific
-       }
+               tt := new(taggedEncoder)
 
-       err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
-       if err != nil {
-               return
-       }
+               tt.body = t
 
-       if params.explicit {
-               err = marshalTagAndLength(explicitTag, tagAndLength{
+               tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
                        class:      ClassContextSpecific,
                        tag:        *params.tag,
-                       length:     bodyLen + tags.Len(),
+                       length:     bodyLen + t.tag.Len(),
                        isCompound: true,
-               })
+               }))
+
+               return tt, nil
+       }
+
+       if params.tag != nil {
+               // implicit tag.
+               tag = *params.tag
+               class = ClassContextSpecific
        }
 
-       return err
+       t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
+
+       return t, nil
 }
 
 // Marshal returns the ASN.1 encoding of val.
@@ -643,13 +638,11 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
 //     printable:      causes strings to be marshaled as ASN.1, PrintableString strings.
 //     utf8:           causes strings to be marshaled as ASN.1, UTF8 strings
 func Marshal(val interface{}) ([]byte, error) {
-       var out bytes.Buffer
-       v := reflect.ValueOf(val)
-       f := newForkableWriter()
-       err := marshalField(f, v, fieldParameters{})
+       e, err := makeField(reflect.ValueOf(val), fieldParameters{})
        if err != nil {
                return nil, err
        }
-       _, err = f.writeTo(&out)
-       return out.Bytes(), err
+       b := make([]byte, e.Len())
+       e.Encode(b)
+       return b, nil
 }
index cdca8aa33638d8eba823a0a730ced38a239c85f3..6af770fcc3424e67c72c1e6f76026dd2a7083d7f 100644 (file)
@@ -173,3 +173,13 @@ func TestInvalidUTF8(t *testing.T) {
                t.Errorf("invalid UTF8 string was accepted")
        }
 }
+
+func BenchmarkMarshal(b *testing.B) {
+       b.ReportAllocs()
+
+       for i := 0; i < b.N; i++ {
+               for _, test := range marshalTests {
+                       Marshal(test.in)
+               }
+       }
+}