From ae4aac00bba5d1d616408a1c07bd4ef5691e3a00 Mon Sep 17 00:00:00 2001 From: Hiroshi Ioka Date: Tue, 2 Aug 2016 14:41:53 +0900 Subject: [PATCH] encoding/asn1: reduce allocations in Marshal MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Current code uses trees of bytes.Buffer as data representation. Each bytes.Buffer takes 4k bytes at least, so it's waste of memory. The change introduces trees of lazy-encoder as alternative one which reduce allocations. name old time/op new time/op delta Marshal-4 64.7µs ± 2% 42.0µs ± 1% -35.07% (p=0.000 n=9+10) name old alloc/op new alloc/op delta Marshal-4 35.1kB ± 0% 7.6kB ± 0% -78.27% (p=0.000 n=10+10) name old allocs/op new allocs/op delta Marshal-4 503 ± 0% 293 ± 0% -41.75% (p=0.000 n=10+10) Change-Id: I32b96c20b8df00414b282d69743d71a598a11336 Reviewed-on: https://go-review.googlesource.com/27030 Reviewed-by: Adam Langley Reviewed-by: Brad Fitzpatrick Run-TryBot: Adam Langley TryBot-Result: Gobot Gobot --- src/encoding/asn1/asn1_test.go | 6 +- src/encoding/asn1/marshal.go | 615 +++++++++++++++--------------- src/encoding/asn1/marshal_test.go | 10 + 3 files changed, 317 insertions(+), 314 deletions(-) diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go index f8623fa9a2..81f4dba8c2 100644 --- a/src/encoding/asn1/asn1_test.go +++ b/src/encoding/asn1/asn1_test.go @@ -132,9 +132,9 @@ func TestParseBigInt(t *testing.T) { if ret.String() != test.base10 { t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10) } - fw := newForkableWriter() - marshalBigInt(fw, ret) - result := fw.Bytes() + e := makeBigInt(ret) + result := make([]byte, e.Len()) + e.Encode(result) if !bytes.Equal(result, test.in) { t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in) } diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go index 30797ef099..f0664d3d46 100644 --- a/src/encoding/asn1/marshal.go +++ b/src/encoding/asn1/marshal.go @@ -5,77 +5,125 @@ package asn1 import ( - "bytes" "errors" "fmt" - "io" "math/big" "reflect" "time" "unicode/utf8" ) -// A forkableWriter is an in-memory buffer that can be -// 'forked' to create new forkableWriters that bracket the -// original. After -// pre, post := w.fork() -// the overall sequence of bytes represented is logically w+pre+post. -type forkableWriter struct { - *bytes.Buffer - pre, post *forkableWriter +var ( + byte00Encoder encoder = byteEncoder(0x00) + byteFFEncoder encoder = byteEncoder(0xff) +) + +// encoder represents a ASN.1 element that is waiting to be marshaled. +type encoder interface { + // Len returns the number of bytes needed to marshal this element. + Len() int + // Encode encodes this element by writing Len() bytes to dst. + Encode(dst []byte) +} + +type byteEncoder byte + +func (c byteEncoder) Len() int { + return 1 } -func newForkableWriter() *forkableWriter { - return &forkableWriter{new(bytes.Buffer), nil, nil} +func (c byteEncoder) Encode(dst []byte) { + dst[0] = byte(c) } -func (f *forkableWriter) fork() (pre, post *forkableWriter) { - if f.pre != nil || f.post != nil { - panic("have already forked") +type bytesEncoder []byte + +func (b bytesEncoder) Len() int { + return len(b) +} + +func (b bytesEncoder) Encode(dst []byte) { + if copy(dst, b) != len(b) { + panic("internal error") } - f.pre = newForkableWriter() - f.post = newForkableWriter() - return f.pre, f.post } -func (f *forkableWriter) Len() (l int) { - l += f.Buffer.Len() - if f.pre != nil { - l += f.pre.Len() +type stringEncoder string + +func (s stringEncoder) Len() int { + return len(s) +} + +func (s stringEncoder) Encode(dst []byte) { + if copy(dst, s) != len(s) { + panic("internal error") } - if f.post != nil { - l += f.post.Len() +} + +type multiEncoder []encoder + +func (m multiEncoder) Len() int { + var size int + for _, e := range m { + size += e.Len() } - return + return size } -func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) { - n, err = out.Write(f.Bytes()) - if err != nil { - return +func (m multiEncoder) Encode(dst []byte) { + var off int + for _, e := range m { + e.Encode(dst[off:]) + off += e.Len() } +} - var nn int +type taggedEncoder struct { + // scratch contains temporary space for encoding the tag and length of + // an element in order to avoid extra allocations. + scratch [8]byte + tag encoder + body encoder +} - if f.pre != nil { - nn, err = f.pre.writeTo(out) - n += nn - if err != nil { - return - } +func (t *taggedEncoder) Len() int { + return t.tag.Len() + t.body.Len() +} + +func (t *taggedEncoder) Encode(dst []byte) { + t.tag.Encode(dst) + t.body.Encode(dst[t.tag.Len():]) +} + +type int64Encoder int64 + +func (i int64Encoder) Len() int { + n := 1 + + for i > 127 { + n++ + i >>= 8 } - if f.post != nil { - nn, err = f.post.writeTo(out) - n += nn + for i < -128 { + n++ + i >>= 8 } - return + + return n } -func marshalBase128Int(out *forkableWriter, n int64) (err error) { +func (i int64Encoder) Encode(dst []byte) { + n := i.Len() + + for j := 0; j < n; j++ { + dst[j] = byte(i >> uint((n-1-j)*8)) + } +} + +func base128IntLength(n int64) int { if n == 0 { - err = out.WriteByte(0) - return + return 1 } l := 0 @@ -83,54 +131,29 @@ func marshalBase128Int(out *forkableWriter, n int64) (err error) { l++ } + return l +} + +func appendBase128Int(dst []byte, n int64) []byte { + l := base128IntLength(n) + for i := l - 1; i >= 0; i-- { o := byte(n >> uint(i*7)) o &= 0x7f if i != 0 { o |= 0x80 } - err = out.WriteByte(o) - if err != nil { - return - } - } - - return nil -} -func marshalInt64(out *forkableWriter, i int64) (err error) { - n := int64Length(i) - - for ; n > 0; n-- { - err = out.WriteByte(byte(i >> uint((n-1)*8))) - if err != nil { - return - } + dst = append(dst, o) } - return nil + return dst } -func int64Length(i int64) (numBytes int) { - numBytes = 1 - - for i > 127 { - numBytes++ - i >>= 8 - } - - for i < -128 { - numBytes++ - i >>= 8 - } - - return -} - -func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { +func makeBigInt(n *big.Int) encoder { if n.Sign() < 0 { // A negative number has to be converted to two's-complement - // form. So we'll subtract 1 and invert. If the + // form. So we'll invert and subtract 1. If the // most-significant-bit isn't set then we'll need to pad the // beginning with 0xff in order to keep the number negative. nMinus1 := new(big.Int).Neg(n) @@ -140,41 +163,31 @@ func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { bytes[i] ^= 0xff } if len(bytes) == 0 || bytes[0]&0x80 == 0 { - err = out.WriteByte(0xff) - if err != nil { - return - } + return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}) } - _, err = out.Write(bytes) + return bytesEncoder(bytes) } else if n.Sign() == 0 { // Zero is written as a single 0 zero rather than no bytes. - err = out.WriteByte(0x00) + return byte00Encoder } else { bytes := n.Bytes() if len(bytes) > 0 && bytes[0]&0x80 != 0 { // We'll have to pad this with 0x00 in order to stop it // looking like a negative number. - err = out.WriteByte(0) - if err != nil { - return - } + return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}) } - _, err = out.Write(bytes) + return bytesEncoder(bytes) } - return } -func marshalLength(out *forkableWriter, i int) (err error) { +func appendLength(dst []byte, i int) []byte { n := lengthLength(i) for ; n > 0; n-- { - err = out.WriteByte(byte(i >> uint((n-1)*8))) - if err != nil { - return - } + dst = append(dst, byte(i>>uint((n-1)*8))) } - return nil + return dst } func lengthLength(i int) (numBytes int) { @@ -186,123 +199,104 @@ func lengthLength(i int) (numBytes int) { return } -func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) { +func appendTagAndLength(dst []byte, t tagAndLength) []byte { b := uint8(t.class) << 6 if t.isCompound { b |= 0x20 } if t.tag >= 31 { b |= 0x1f - err = out.WriteByte(b) - if err != nil { - return - } - err = marshalBase128Int(out, int64(t.tag)) - if err != nil { - return - } + dst = append(dst, b) + dst = appendBase128Int(dst, int64(t.tag)) } else { b |= uint8(t.tag) - err = out.WriteByte(b) - if err != nil { - return - } + dst = append(dst, b) } if t.length >= 128 { l := lengthLength(t.length) - err = out.WriteByte(0x80 | byte(l)) - if err != nil { - return - } - err = marshalLength(out, t.length) - if err != nil { - return - } + dst = append(dst, 0x80|byte(l)) + dst = appendLength(dst, t.length) } else { - err = out.WriteByte(byte(t.length)) - if err != nil { - return - } + dst = append(dst, byte(t.length)) } - return nil + return dst } -func marshalBitString(out *forkableWriter, b BitString) (err error) { - paddingBits := byte((8 - b.BitLength%8) % 8) - err = out.WriteByte(paddingBits) - if err != nil { - return - } - _, err = out.Write(b.Bytes) - return +type bitStringEncoder BitString + +func (b bitStringEncoder) Len() int { + return len(b.Bytes) + 1 } -func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) { - if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { - return StructuralError{"invalid object identifier"} +func (b bitStringEncoder) Encode(dst []byte) { + dst[0] = byte((8 - b.BitLength%8) % 8) + if copy(dst[1:], b.Bytes) != len(b.Bytes) { + panic("internal error") } +} - err = marshalBase128Int(out, int64(oid[0]*40+oid[1])) - if err != nil { - return +type oidEncoder []int + +func (oid oidEncoder) Len() int { + l := base128IntLength(int64(oid[0]*40 + oid[1])) + for i := 2; i < len(oid); i++ { + l += base128IntLength(int64(oid[i])) } + return l +} + +func (oid oidEncoder) Encode(dst []byte) { + dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1])) for i := 2; i < len(oid); i++ { - err = marshalBase128Int(out, int64(oid[i])) - if err != nil { - return - } + dst = appendBase128Int(dst, int64(oid[i])) + } +} + +func makeObjectIdentifier(oid []int) (e encoder, err error) { + if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { + return nil, StructuralError{"invalid object identifier"} } - return + return oidEncoder(oid), nil } -func marshalPrintableString(out *forkableWriter, s string) (err error) { - b := []byte(s) - for _, c := range b { - if !isPrintable(c) { - return StructuralError{"PrintableString contains invalid character"} +func makePrintableString(s string) (e encoder, err error) { + for i := 0; i < len(s); i++ { + if !isPrintable(s[i]) { + return nil, StructuralError{"PrintableString contains invalid character"} } } - _, err = out.Write(b) - return + return stringEncoder(s), nil } -func marshalIA5String(out *forkableWriter, s string) (err error) { - b := []byte(s) - for _, c := range b { - if c > 127 { - return StructuralError{"IA5String contains invalid character"} +func makeIA5String(s string) (e encoder, err error) { + for i := 0; i < len(s); i++ { + if s[i] > 127 { + return nil, StructuralError{"IA5String contains invalid character"} } } - _, err = out.Write(b) - return + return stringEncoder(s), nil } -func marshalUTF8String(out *forkableWriter, s string) (err error) { - _, err = out.Write([]byte(s)) - return +func makeUTF8String(s string) encoder { + return stringEncoder(s) } -func marshalTwoDigits(out *forkableWriter, v int) (err error) { - err = out.WriteByte(byte('0' + (v/10)%10)) - if err != nil { - return - } - return out.WriteByte(byte('0' + v%10)) +func appendTwoDigits(dst []byte, v int) []byte { + return append(dst, byte('0'+(v/10)%10), byte('0'+v%10)) } -func marshalFourDigits(out *forkableWriter, v int) (err error) { +func appendFourDigits(dst []byte, v int) []byte { var bytes [4]byte for i := range bytes { bytes[3-i] = '0' + byte(v%10) v /= 10 } - _, err = out.Write(bytes[:]) - return + return append(dst, bytes[:]...) } func outsideUTCRange(t time.Time) bool { @@ -310,80 +304,75 @@ func outsideUTCRange(t time.Time) bool { return year < 1950 || year >= 2050 } -func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { +func makeUTCTime(t time.Time) (e encoder, err error) { + dst := make([]byte, 0, 18) + + dst, err = appendUTCTime(dst, t) + if err != nil { + return nil, err + } + + return bytesEncoder(dst), nil +} + +func makeGeneralizedTime(t time.Time) (e encoder, err error) { + dst := make([]byte, 0, 20) + + dst, err = appendGeneralizedTime(dst, t) + if err != nil { + return nil, err + } + + return bytesEncoder(dst), nil +} + +func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) { year := t.Year() switch { case 1950 <= year && year < 2000: - err = marshalTwoDigits(out, year-1900) + dst = appendTwoDigits(dst, year-1900) case 2000 <= year && year < 2050: - err = marshalTwoDigits(out, year-2000) + dst = appendTwoDigits(dst, year-2000) default: - return StructuralError{"cannot represent time as UTCTime"} - } - if err != nil { - return + return nil, StructuralError{"cannot represent time as UTCTime"} } - return marshalTimeCommon(out, t) + return appendTimeCommon(dst, t), nil } -func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) { +func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) { year := t.Year() if year < 0 || year > 9999 { - return StructuralError{"cannot represent time as GeneralizedTime"} - } - if err = marshalFourDigits(out, year); err != nil { - return + return nil, StructuralError{"cannot represent time as GeneralizedTime"} } - return marshalTimeCommon(out, t) + dst = appendFourDigits(dst, year) + + return appendTimeCommon(dst, t), nil } -func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { +func appendTimeCommon(dst []byte, t time.Time) []byte { _, month, day := t.Date() - err = marshalTwoDigits(out, int(month)) - if err != nil { - return - } - - err = marshalTwoDigits(out, day) - if err != nil { - return - } + dst = appendTwoDigits(dst, int(month)) + dst = appendTwoDigits(dst, day) hour, min, sec := t.Clock() - err = marshalTwoDigits(out, hour) - if err != nil { - return - } - - err = marshalTwoDigits(out, min) - if err != nil { - return - } - - err = marshalTwoDigits(out, sec) - if err != nil { - return - } + dst = appendTwoDigits(dst, hour) + dst = appendTwoDigits(dst, min) + dst = appendTwoDigits(dst, sec) _, offset := t.Zone() switch { case offset/60 == 0: - err = out.WriteByte('Z') - return + return append(dst, 'Z') case offset > 0: - err = out.WriteByte('+') + dst = append(dst, '+') case offset < 0: - err = out.WriteByte('-') - } - - if err != nil { - return + dst = append(dst, '-') } offsetMinutes := offset / 60 @@ -391,13 +380,10 @@ func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { offsetMinutes = -offsetMinutes } - err = marshalTwoDigits(out, offsetMinutes/60) - if err != nil { - return - } + dst = appendTwoDigits(dst, offsetMinutes/60) + dst = appendTwoDigits(dst, offsetMinutes%60) - err = marshalTwoDigits(out, offsetMinutes%60) - return + return dst } func stripTagAndLength(in []byte) []byte { @@ -408,114 +394,124 @@ func stripTagAndLength(in []byte) []byte { return in[offset:] } -func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) { +func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) { switch value.Type() { case flagType: - return nil + return bytesEncoder(nil), nil case timeType: t := value.Interface().(time.Time) if params.timeType == TagGeneralizedTime || outsideUTCRange(t) { - return marshalGeneralizedTime(out, t) - } else { - return marshalUTCTime(out, t) + return makeGeneralizedTime(t) } + return makeUTCTime(t) case bitStringType: - return marshalBitString(out, value.Interface().(BitString)) + return bitStringEncoder(value.Interface().(BitString)), nil case objectIdentifierType: - return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) + return makeObjectIdentifier(value.Interface().(ObjectIdentifier)) case bigIntType: - return marshalBigInt(out, value.Interface().(*big.Int)) + return makeBigInt(value.Interface().(*big.Int)), nil } switch v := value; v.Kind() { case reflect.Bool: if v.Bool() { - return out.WriteByte(255) - } else { - return out.WriteByte(0) + return byteFFEncoder, nil } + return byte00Encoder, nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return marshalInt64(out, v.Int()) + return int64Encoder(v.Int()), nil case reflect.Struct: t := v.Type() startingField := 0 + n := t.NumField() + if n == 0 { + return bytesEncoder(nil), nil + } + // If the first element of the structure is a non-empty // RawContents, then we don't bother serializing the rest. - if t.NumField() > 0 && t.Field(0).Type == rawContentsType { + if t.Field(0).Type == rawContentsType { s := v.Field(0) if s.Len() > 0 { - bytes := make([]byte, s.Len()) - for i := 0; i < s.Len(); i++ { - bytes[i] = uint8(s.Index(i).Uint()) - } + bytes := s.Bytes() /* The RawContents will contain the tag and * length fields but we'll also be writing * those ourselves, so we strip them out of * bytes */ - _, err = out.Write(stripTagAndLength(bytes)) - return - } else { - startingField = 1 + return bytesEncoder(stripTagAndLength(bytes)), nil } + + startingField = 1 } - for i := startingField; i < t.NumField(); i++ { - var pre *forkableWriter - pre, out = out.fork() - err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1"))) - if err != nil { - return + switch n1 := n - startingField; n1 { + case 0: + return bytesEncoder(nil), nil + case 1: + return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1"))) + default: + m := make([]encoder, n1) + for i := 0; i < n1; i++ { + m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1"))) + if err != nil { + return nil, err + } } + + return multiEncoder(m), nil } - return case reflect.Slice: sliceType := v.Type() if sliceType.Elem().Kind() == reflect.Uint8 { - bytes := make([]byte, v.Len()) - for i := 0; i < v.Len(); i++ { - bytes[i] = uint8(v.Index(i).Uint()) - } - _, err = out.Write(bytes) - return + return bytesEncoder(v.Bytes()), nil } var fp fieldParameters - for i := 0; i < v.Len(); i++ { - var pre *forkableWriter - pre, out = out.fork() - err = marshalField(pre, v.Index(i), fp) - if err != nil { - return + + switch l := v.Len(); l { + case 0: + return bytesEncoder(nil), nil + case 1: + return makeField(v.Index(0), fp) + default: + m := make([]encoder, l) + + for i := 0; i < l; i++ { + m[i], err = makeField(v.Index(i), fp) + if err != nil { + return nil, err + } } + + return multiEncoder(m), nil } - return case reflect.String: switch params.stringType { case TagIA5String: - return marshalIA5String(out, v.String()) + return makeIA5String(v.String()) case TagPrintableString: - return marshalPrintableString(out, v.String()) + return makePrintableString(v.String()) default: - return marshalUTF8String(out, v.String()) + return makeUTF8String(v.String()), nil } } - return StructuralError{"unknown Go type"} + return nil, StructuralError{"unknown Go type"} } -func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) { +func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { if !v.IsValid() { - return fmt.Errorf("asn1: cannot marshal nil value") + return nil, fmt.Errorf("asn1: cannot marshal nil value") } // If the field is an interface{} then recurse into it. if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { - return marshalField(out, v.Elem(), params) + return makeField(v.Elem(), params) } if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { - return + return bytesEncoder(nil), nil } if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { @@ -523,7 +519,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) defaultValue.SetInt(*params.defaultValue) if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) { - return + return bytesEncoder(nil), nil } } @@ -532,37 +528,36 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) // behaviour, but it's what Go has traditionally done. if params.optional && params.defaultValue == nil { if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { - return + return bytesEncoder(nil), nil } } if v.Type() == rawValueType { rv := v.Interface().(RawValue) if len(rv.FullBytes) != 0 { - _, err = out.Write(rv.FullBytes) - } else { - err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}) - if err != nil { - return - } - _, err = out.Write(rv.Bytes) + return bytesEncoder(rv.FullBytes), nil } - return + + t := new(taggedEncoder) + + t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})) + t.body = bytesEncoder(rv.Bytes) + + return t, nil } tag, isCompound, ok := getUniversalType(v.Type()) if !ok { - err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} - return + return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} } class := ClassUniversal if params.timeType != 0 && tag != TagUTCTime { - return StructuralError{"explicit time type given to non-time member"} + return nil, StructuralError{"explicit time type given to non-time member"} } if params.stringType != 0 && tag != TagPrintableString { - return StructuralError{"explicit string type given to non-string member"} + return nil, StructuralError{"explicit string type given to non-string member"} } switch tag { @@ -574,7 +569,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) for _, r := range v.String() { if r >= utf8.RuneSelf || !isPrintable(byte(r)) { if !utf8.ValidString(v.String()) { - return errors.New("asn1: string not valid UTF-8") + return nil, errors.New("asn1: string not valid UTF-8") } tag = TagUTF8String break @@ -591,46 +586,46 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) if params.set { if tag != TagSequence { - return StructuralError{"non sequence tagged as set"} + return nil, StructuralError{"non sequence tagged as set"} } tag = TagSet } - tags, body := out.fork() + t := new(taggedEncoder) - err = marshalBody(body, v, params) + t.body, err = makeBody(v, params) if err != nil { - return + return nil, err } - bodyLen := body.Len() + bodyLen := t.body.Len() - var explicitTag *forkableWriter if params.explicit { - explicitTag, tags = tags.fork() - } + t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound})) - if !params.explicit && params.tag != nil { - // implicit tag. - tag = *params.tag - class = ClassContextSpecific - } + tt := new(taggedEncoder) - err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound}) - if err != nil { - return - } + tt.body = t - if params.explicit { - err = marshalTagAndLength(explicitTag, tagAndLength{ + tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{ class: ClassContextSpecific, tag: *params.tag, - length: bodyLen + tags.Len(), + length: bodyLen + t.tag.Len(), isCompound: true, - }) + })) + + return tt, nil + } + + if params.tag != nil { + // implicit tag. + tag = *params.tag + class = ClassContextSpecific } - return err + t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound})) + + return t, nil } // Marshal returns the ASN.1 encoding of val. @@ -643,13 +638,11 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) // printable: causes strings to be marshaled as ASN.1, PrintableString strings. // utf8: causes strings to be marshaled as ASN.1, UTF8 strings func Marshal(val interface{}) ([]byte, error) { - var out bytes.Buffer - v := reflect.ValueOf(val) - f := newForkableWriter() - err := marshalField(f, v, fieldParameters{}) + e, err := makeField(reflect.ValueOf(val), fieldParameters{}) if err != nil { return nil, err } - _, err = f.writeTo(&out) - return out.Bytes(), err + b := make([]byte, e.Len()) + e.Encode(b) + return b, nil } diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go index cdca8aa336..6af770fcc3 100644 --- a/src/encoding/asn1/marshal_test.go +++ b/src/encoding/asn1/marshal_test.go @@ -173,3 +173,13 @@ func TestInvalidUTF8(t *testing.T) { t.Errorf("invalid UTF8 string was accepted") } } + +func BenchmarkMarshal(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, test := range marshalTests { + Marshal(test.in) + } + } +} -- 2.48.1