]> Cypherpunks repositories - gostls13.git/commitdiff
asn1:
authorAdam Langley <agl@golang.org>
Wed, 18 Nov 2009 02:09:41 +0000 (18:09 -0800)
committerAdam Langley <agl@golang.org>
Wed, 18 Nov 2009 02:09:41 +0000 (18:09 -0800)
  * add Marshal
  * add BitString.RightAlign
  * change to using a *time.Time (from time.Time) since that's what
    the time package uses.
  * return the unparsed data from Unmarshal.

R=rsc
CC=golang-dev
https://golang.org/cl/156047

src/pkg/asn1/Makefile
src/pkg/asn1/asn1.go
src/pkg/asn1/asn1_test.go
src/pkg/asn1/common.go [new file with mode: 0644]
src/pkg/asn1/marshal.go [new file with mode: 0644]
src/pkg/asn1/marshal_test.go [new file with mode: 0644]
src/pkg/crypto/x509/x509.go

index 8ad3fb78da6cc92f931fd2c7843300c0b1fdddad..32848743adc6ba00300005fe77a5f51db73c7ad9 100644 (file)
@@ -7,5 +7,7 @@ include $(GOROOT)/src/Make.$(GOARCH)
 TARG=asn1
 GOFILES=\
        asn1.go\
+       common.go\
+       marshal.go\
 
 include $(GOROOT)/src/Make.pkg
index 3afd6fbb1e90fbf42bdc4662cad89bf652b20761..5e264dc5cae6b1e52bfd661a3bdd6311f4005414 100644 (file)
@@ -23,8 +23,6 @@ import (
        "fmt";
        "os";
        "reflect";
-       "strconv";
-       "strings";
        "time";
 )
 
@@ -111,6 +109,24 @@ func (b BitString) At(i int) int {
        return int(b.Bytes[x]>>y) & 1;
 }
 
+// RightAlign returns a slice where the padding bits are at the beginning. The
+// slice may share memory with the BitString.
+func (b BitString) RightAlign() []byte {
+       shift := uint(8 - (b.BitLength % 8));
+       if shift == 8 || len(b.Bytes) == 0 {
+               return b.Bytes
+       }
+
+       a := make([]byte, len(b.Bytes));
+       a[0] = b.Bytes[0] >> shift;
+       for i := 1; i < len(b.Bytes); i++ {
+               a[i] = b.Bytes[i-1] << (8 - shift);
+               a[i] |= b.Bytes[i] >> shift;
+       }
+
+       return a;
+}
+
 // parseBitString parses an ASN.1 bit string from the given byte array and returns it.
 func parseBitString(bytes []byte) (ret BitString, err os.Error) {
        if len(bytes) == 0 {
@@ -204,7 +220,7 @@ func twoDigits(bytes []byte, max int) (int, bool) {
 
 // parseUTCTime parses the UTCTime from the given byte array and returns the
 // resulting time.
-func parseUTCTime(bytes []byte) (ret time.Time, err os.Error) {
+func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
        // A UTCTime can take the following formats:
        //
        //             1111111
@@ -220,11 +236,13 @@ func parseUTCTime(bytes []byte) (ret time.Time, err os.Error) {
                err = SyntaxError{"UTCTime too short"};
                return;
        }
+       ret = new(time.Time);
+
        var ok1, ok2, ok3, ok4, ok5 bool;
        year, ok1 := twoDigits(bytes[0:2], 99);
        // RFC 5280, section 5.1.2.4 says that years 2050 or later use another date
        // scheme.
-       if year > 50 {
+       if year >= 50 {
                ret.Year = 1900 + int64(year)
        } else {
                ret.Year = 2000 + int64(year)
@@ -333,39 +351,6 @@ type RawValue struct {
 
 // Tagging
 
-// ASN.1 objects have metadata preceeding them:
-//   the tag: the type of the object
-//   a flag denoting if this object is compound or not
-//   the class type: the namespace of the tag
-//   the length of the object, in bytes
-
-// Here are some standard tags and classes
-
-const (
-       tagBoolean              = 1;
-       tagInteger              = 2;
-       tagBitString            = 3;
-       tagOctetString          = 4;
-       tagOID                  = 6;
-       tagSequence             = 16;
-       tagSet                  = 17;
-       tagPrintableString      = 19;
-       tagIA5String            = 22;
-       tagUTCTime              = 23;
-)
-
-const (
-       classUniversal          = 0;
-       classApplication        = 1;
-       classContextSpecific    = 2;
-       classPrivate            = 3;
-)
-
-type tagAndLength struct {
-       class, tag, length      int;
-       isCompound              bool;
-}
-
 // parseTagAndLength parses an ASN.1 tag and length pair from the given offset
 // into a byte array. It returns the parsed data and the new offset. SET and
 // SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
@@ -428,97 +413,6 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
        return;
 }
 
-// ASN.1 has IMPLICIT and EXPLICIT tags, which can be translated as "instead
-// of" and "in addition to". When not specified, every primitive type has a
-// default tag in the UNIVERSAL class.
-//
-// For example: a BIT STRING is tagged [UNIVERSAL 3] by default (although ASN.1
-// doesn't actually have a UNIVERSAL keyword). However, by saying [IMPLICIT
-// CONTEXT-SPECIFIC 42], that means that the tag is replaced by another.
-//
-// On the other hand, if it said [EXPLICIT CONTEXT-SPECIFIC 10], then an
-// /additional/ tag would wrap the default tag. This explicit tag will have the
-// compound flag set.
-//
-// (This is used in order to remove ambiguity with optional elements.)
-//
-// You can layer EXPLICIT and IMPLICIT tags to an arbitrary depth, however we
-// don't support that here. We support a single layer of EXPLICIT or IMPLICIT
-// tagging with tag strings on the fields of a structure.
-
-// fieldParameters is the parsed representation of tag string from a structure field.
-type fieldParameters struct {
-       optional        bool;   // true iff the field is OPTIONAL
-       explicit        bool;   // true iff and EXPLICIT tag is in use.
-       defaultValue    *int64; // a default value for INTEGER typed fields (maybe nil).
-       tag             *int;   // the EXPLICIT or IMPLICIT tag (maybe nil).
-
-       // Invariants:
-       //   if explicit is set, tag is non-nil.
-}
-
-// Given a tag string with the format specified in the package comment,
-// parseFieldParameters will parse it into a fieldParameters structure,
-// ignoring unknown parts of the string.
-func parseFieldParameters(str string) (ret fieldParameters) {
-       for _, part := range strings.Split(str, ",", 0) {
-               switch {
-               case part == "optional":
-                       ret.optional = true
-               case part == "explicit":
-                       ret.explicit = true;
-                       if ret.tag == nil {
-                               ret.tag = new(int);
-                               *ret.tag = 0;
-                       }
-               case strings.HasPrefix(part, "default:"):
-                       i, err := strconv.Atoi64(part[8:len(part)]);
-                       if err == nil {
-                               ret.defaultValue = new(int64);
-                               *ret.defaultValue = i;
-                       }
-               case strings.HasPrefix(part, "tag:"):
-                       i, err := strconv.Atoi(part[4:len(part)]);
-                       if err == nil {
-                               ret.tag = new(int);
-                               *ret.tag = i;
-                       }
-               }
-       }
-       return;
-}
-
-// Given a reflected Go type, getUniversalType returns the default tag number
-// and expected compound flag.
-func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
-       switch t {
-       case objectIdentifierType:
-               return tagOID, false, true
-       case bitStringType:
-               return tagBitString, false, true
-       case timeType:
-               return tagUTCTime, false, true
-       }
-       switch i := t.(type) {
-       case *reflect.BoolType:
-               return tagBoolean, false, true
-       case *reflect.IntType:
-               return tagInteger, false, true
-       case *reflect.Int64Type:
-               return tagInteger, false, true
-       case *reflect.StructType:
-               return tagSequence, true, true
-       case *reflect.SliceType:
-               if _, ok := t.(*reflect.SliceType).Elem().(*reflect.Uint8Type); ok {
-                       return tagOctetString, false, true
-               }
-               return tagSequence, true, true;
-       case *reflect.StringType:
-               return tagPrintableString, false, true
-       }
-       return 0, false, false;
-}
-
 // parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
 // a number of ASN.1 values from the given byte array and returns them as a
 // slice of Go values of the given type.
@@ -564,7 +458,7 @@ func parseSequenceOf(bytes []byte, sliceType *reflect.SliceType, elemType reflec
 var (
        bitStringType           = reflect.Typeof(BitString{});
        objectIdentifierType    = reflect.Typeof(ObjectIdentifier{});
-       timeType                = reflect.Typeof(time.Time{});
+       timeType                = reflect.Typeof(&time.Time{});
        rawValueType            = reflect.Typeof(RawValue{});
 )
 
@@ -732,11 +626,11 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
                err = err1;
                return;
        case timeType:
-               structValue := v.(*reflect.StructValue);
+               ptrValue := v.(*reflect.PtrValue);
                time, err1 := parseUTCTime(innerBytes);
                offset += t.length;
                if err1 == nil {
-                       structValue.Set(reflect.NewValue(time).(*reflect.StructValue))
+                       ptrValue.Set(reflect.NewValue(time).(*reflect.PtrValue))
                }
                err = err1;
                return;
@@ -871,8 +765,11 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
 //
 // Other ASN.1 types are not supported; if it encounters them,
 // Unmarshal returns a parse error.
-func Unmarshal(val interface{}, b []byte) os.Error {
+func Unmarshal(val interface{}, b []byte) (rest []byte, err os.Error) {
        v := reflect.NewValue(val).(*reflect.PtrValue).Elem();
-       _, err := parseField(v, b, 0, fieldParameters{});
-       return err;
+       offset, err := parseField(v, b, 0, fieldParameters{});
+       if err != nil {
+               return nil, err
+       }
+       return b[offset:len(b)], nil;
 }
index 6d537fb5e38cc77f0ecaaf351d5727838f9fa98f..0e818dc300a629f31c94b7a6a31ccc9a4ba92030 100644 (file)
@@ -89,6 +89,31 @@ func TestBitStringAt(t *testing.T) {
        }
 }
 
+type bitStringRightAlignTest struct {
+       in      []byte;
+       inlen   int;
+       out     []byte;
+}
+
+var bitStringRightAlignTests = []bitStringRightAlignTest{
+       bitStringRightAlignTest{[]byte{0x80}, 1, []byte{0x01}},
+       bitStringRightAlignTest{[]byte{0x80, 0x80}, 9, []byte{0x01, 0x01}},
+       bitStringRightAlignTest{[]byte{}, 0, []byte{}},
+       bitStringRightAlignTest{[]byte{0xce}, 8, []byte{0xce}},
+       bitStringRightAlignTest{[]byte{0xce, 0x47}, 16, []byte{0xce, 0x47}},
+       bitStringRightAlignTest{[]byte{0x34, 0x50}, 12, []byte{0x03, 0x45}},
+}
+
+func TestBitStringRightAlign(t *testing.T) {
+       for i, test := range bitStringRightAlignTests {
+               bs := BitString{test.in, test.inlen};
+               out := bs.RightAlign();
+               if bytes.Compare(out, test.out) != 0 {
+                       t.Errorf("#%d got: %x want: %x", i, out, test.out)
+               }
+       }
+}
+
 type objectIdentifierTest struct {
        in      []byte;
        ok      bool;
@@ -120,22 +145,22 @@ func TestObjectIdentifier(t *testing.T) {
 type timeTest struct {
        in      string;
        ok      bool;
-       out     time.Time;
+       out     *time.Time;
 }
 
 var timeTestData = []timeTest{
-       timeTest{"910506164540-0700", true, time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}},
-       timeTest{"910506164540+0730", true, time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}},
-       timeTest{"910506234540Z", true, time.Time{1991, 05, 06, 23, 45, 40, 0, 0, ""}},
-       timeTest{"9105062345Z", true, time.Time{1991, 05, 06, 23, 45, 0, 0, 0, ""}},
-       timeTest{"a10506234540Z", false, time.Time{}},
-       timeTest{"91a506234540Z", false, time.Time{}},
-       timeTest{"9105a6234540Z", false, time.Time{}},
-       timeTest{"910506a34540Z", false, time.Time{}},
-       timeTest{"910506334a40Z", false, time.Time{}},
-       timeTest{"91050633444aZ", false, time.Time{}},
-       timeTest{"910506334461Z", false, time.Time{}},
-       timeTest{"910506334400Za", false, time.Time{}},
+       timeTest{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}},
+       timeTest{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}},
+       timeTest{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, ""}},
+       timeTest{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, ""}},
+       timeTest{"a10506234540Z", false, nil},
+       timeTest{"91a506234540Z", false, nil},
+       timeTest{"9105a6234540Z", false, nil},
+       timeTest{"910506a34540Z", false, nil},
+       timeTest{"910506334a40Z", false, nil},
+       timeTest{"91050633444aZ", false, nil},
+       timeTest{"910506334461Z", false, nil},
+       timeTest{"910506334400Za", false, nil},
 }
 
 func TestTime(t *testing.T) {
@@ -199,14 +224,16 @@ func newString(s string) *string  { return &s }
 func newBool(b bool) *bool     { return &b }
 
 var parseFieldParametersTestData []parseFieldParametersTest = []parseFieldParametersTest{
-       parseFieldParametersTest{"", fieldParameters{false, false, nil, nil}},
-       parseFieldParametersTest{"optional", fieldParameters{true, false, nil, nil}},
-       parseFieldParametersTest{"explicit", fieldParameters{false, true, nil, new(int)}},
-       parseFieldParametersTest{"optional,explicit", fieldParameters{true, true, nil, new(int)}},
-       parseFieldParametersTest{"default:42", fieldParameters{false, false, newInt64(42), nil}},
-       parseFieldParametersTest{"tag:17", fieldParameters{false, false, nil, newInt(17)}},
-       parseFieldParametersTest{"optional,explicit,default:42,tag:17", fieldParameters{true, true, newInt64(42), newInt(17)}},
-       parseFieldParametersTest{"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, newInt64(42), newInt(17)}},
+       parseFieldParametersTest{"", fieldParameters{false, false, nil, nil, 0}},
+       parseFieldParametersTest{"ia5", fieldParameters{false, false, nil, nil, tagIA5String}},
+       parseFieldParametersTest{"printable", fieldParameters{false, false, nil, nil, tagPrintableString}},
+       parseFieldParametersTest{"optional", fieldParameters{true, false, nil, nil, 0}},
+       parseFieldParametersTest{"explicit", fieldParameters{false, true, nil, new(int), 0}},
+       parseFieldParametersTest{"optional,explicit", fieldParameters{true, true, nil, new(int), 0}},
+       parseFieldParametersTest{"default:42", fieldParameters{false, false, newInt64(42), nil, 0}},
+       parseFieldParametersTest{"tag:17", fieldParameters{false, false, nil, newInt(17), 0}},
+       parseFieldParametersTest{"optional,explicit,default:42,tag:17", fieldParameters{true, true, newInt64(42), newInt(17), 0}},
+       parseFieldParametersTest{"optional,explicit,default:42,tag:17,rubbish1", fieldParameters{true, true, newInt64(42), newInt(17), 0}},
 }
 
 func TestParseFieldParameters(t *testing.T) {
@@ -258,7 +285,7 @@ func TestUnmarshal(t *testing.T) {
                zv := reflect.MakeZero(pv.Type().(*reflect.PtrType).Elem());
                pv.(*reflect.PtrValue).PointTo(zv);
                val := pv.Interface();
-               err := Unmarshal(val, test.in);
+               _, err := Unmarshal(val, test.in);
                if err != nil {
                        t.Errorf("Unmarshal failed at index %d %v", i, err)
                }
@@ -298,7 +325,7 @@ type AttributeTypeAndValue struct {
 }
 
 type Validity struct {
-       NotBefore, NotAfter time.Time;
+       NotBefore, NotAfter *time.Time;
 }
 
 type PublicKeyInfo struct {
@@ -309,7 +336,7 @@ type PublicKeyInfo struct {
 func TestCertificate(t *testing.T) {
        // This is a minimal, self-signed certificate that should parse correctly.
        var cert Certificate;
-       if err := Unmarshal(&cert, derEncodedSelfSignedCertBytes); err != nil {
+       if _, err := Unmarshal(&cert, derEncodedSelfSignedCertBytes); err != nil {
                t.Errorf("Unmarshal failed: %v", err)
        }
        if !reflect.DeepEqual(cert, derEncodedSelfSignedCert) {
@@ -322,7 +349,7 @@ func TestCertificateWithNUL(t *testing.T) {
        // NUL isn't a permitted character in a PrintableString.
 
        var cert Certificate;
-       if err := Unmarshal(&cert, derEncodedPaypalNULCertBytes); err == nil {
+       if _, err := Unmarshal(&cert, derEncodedPaypalNULCertBytes); err == nil {
                t.Error("Unmarshal succeeded, should not have")
        }
 }
@@ -340,7 +367,7 @@ var derEncodedSelfSignedCert = Certificate{
                        RelativeDistinguishedName{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 3}, Value: "false.example.com"}},
                        RelativeDistinguishedName{AttributeTypeAndValue{Type: ObjectIdentifier{1, 2, 840, 113549, 1, 9, 1}, Value: "false@example.com"}},
                },
-               Validity: Validity{NotBefore: time.Time{Year: 2009, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, Weekday: 0, ZoneOffset: 0, Zone: ""}, NotAfter: time.Time{Year: 2010, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, Weekday: 0, ZoneOffset: 0, Zone: ""}},
+               Validity: Validity{NotBefore: &time.Time{Year: 2009, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, Weekday: 0, ZoneOffset: 0, Zone: ""}, NotAfter: &time.Time{Year: 2010, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, Weekday: 0, ZoneOffset: 0, Zone: ""}},
                Subject: RDNSequence{
                        RelativeDistinguishedName{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 6}, Value: "XX"}},
                        RelativeDistinguishedName{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 8}, Value: "Some-State"}},
diff --git a/src/pkg/asn1/common.go b/src/pkg/asn1/common.go
new file mode 100644 (file)
index 0000000..3021493
--- /dev/null
@@ -0,0 +1,140 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package asn1
+
+import (
+       "reflect";
+       "strconv";
+       "strings";
+)
+
+// ASN.1 objects have metadata preceeding them:
+//   the tag: the type of the object
+//   a flag denoting if this object is compound or not
+//   the class type: the namespace of the tag
+//   the length of the object, in bytes
+
+// Here are some standard tags and classes
+
+const (
+       tagBoolean              = 1;
+       tagInteger              = 2;
+       tagBitString            = 3;
+       tagOctetString          = 4;
+       tagOID                  = 6;
+       tagSequence             = 16;
+       tagSet                  = 17;
+       tagPrintableString      = 19;
+       tagIA5String            = 22;
+       tagUTCTime              = 23;
+)
+
+const (
+       classUniversal          = 0;
+       classApplication        = 1;
+       classContextSpecific    = 2;
+       classPrivate            = 3;
+)
+
+type tagAndLength struct {
+       class, tag, length      int;
+       isCompound              bool;
+}
+
+// ASN.1 has IMPLICIT and EXPLICIT tags, which can be translated as "instead
+// of" and "in addition to". When not specified, every primitive type has a
+// default tag in the UNIVERSAL class.
+//
+// For example: a BIT STRING is tagged [UNIVERSAL 3] by default (although ASN.1
+// doesn't actually have a UNIVERSAL keyword). However, by saying [IMPLICIT
+// CONTEXT-SPECIFIC 42], that means that the tag is replaced by another.
+//
+// On the other hand, if it said [EXPLICIT CONTEXT-SPECIFIC 10], then an
+// /additional/ tag would wrap the default tag. This explicit tag will have the
+// compound flag set.
+//
+// (This is used in order to remove ambiguity with optional elements.)
+//
+// You can layer EXPLICIT and IMPLICIT tags to an arbitrary depth, however we
+// don't support that here. We support a single layer of EXPLICIT or IMPLICIT
+// tagging with tag strings on the fields of a structure.
+
+// fieldParameters is the parsed representation of tag string from a structure field.
+type fieldParameters struct {
+       optional        bool;   // true iff the field is OPTIONAL
+       explicit        bool;   // true iff and EXPLICIT tag is in use.
+       defaultValue    *int64; // a default value for INTEGER typed fields (maybe nil).
+       tag             *int;   // the EXPLICIT or IMPLICIT tag (maybe nil).
+       stringType      int;    // the string tag to use when marshaling.
+
+       // Invariants:
+       //   if explicit is set, tag is non-nil.
+}
+
+// Given a tag string with the format specified in the package comment,
+// parseFieldParameters will parse it into a fieldParameters structure,
+// ignoring unknown parts of the string.
+func parseFieldParameters(str string) (ret fieldParameters) {
+       for _, part := range strings.Split(str, ",", 0) {
+               switch {
+               case part == "optional":
+                       ret.optional = true
+               case part == "explicit":
+                       ret.explicit = true;
+                       if ret.tag == nil {
+                               ret.tag = new(int);
+                               *ret.tag = 0;
+                       }
+               case part == "ia5":
+                       ret.stringType = tagIA5String
+               case part == "printable":
+                       ret.stringType = tagPrintableString
+               case strings.HasPrefix(part, "default:"):
+                       i, err := strconv.Atoi64(part[8:len(part)]);
+                       if err == nil {
+                               ret.defaultValue = new(int64);
+                               *ret.defaultValue = i;
+                       }
+               case strings.HasPrefix(part, "tag:"):
+                       i, err := strconv.Atoi(part[4:len(part)]);
+                       if err == nil {
+                               ret.tag = new(int);
+                               *ret.tag = i;
+                       }
+               }
+       }
+       return;
+}
+
+// Given a reflected Go type, getUniversalType returns the default tag number
+// and expected compound flag.
+func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
+       switch t {
+       case objectIdentifierType:
+               return tagOID, false, true
+       case bitStringType:
+               return tagBitString, false, true
+       case timeType:
+               return tagUTCTime, false, true
+       }
+       switch i := t.(type) {
+       case *reflect.BoolType:
+               return tagBoolean, false, true
+       case *reflect.IntType:
+               return tagInteger, false, true
+       case *reflect.Int64Type:
+               return tagInteger, false, true
+       case *reflect.StructType:
+               return tagSequence, true, true
+       case *reflect.SliceType:
+               if _, ok := t.(*reflect.SliceType).Elem().(*reflect.Uint8Type); ok {
+                       return tagOctetString, false, true
+               }
+               return tagSequence, true, true;
+       case *reflect.StringType:
+               return tagPrintableString, false, true
+       }
+       return 0, false, false;
+}
diff --git a/src/pkg/asn1/marshal.go b/src/pkg/asn1/marshal.go
new file mode 100644 (file)
index 0000000..2e9aa14
--- /dev/null
@@ -0,0 +1,400 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package asn1
+
+import (
+       "bytes";
+       "fmt";
+       "io";
+       "os";
+       "reflect";
+       "strings";
+       "time";
+)
+
+// A forkableWriter is an in-memory buffer that can be
+// 'forked' to create new forkableWriters that bracket the
+// original.  After
+//    pre, post := w.fork();
+// the overall sequence of bytes represented is logically w+pre+post.
+type forkableWriter struct {
+       *bytes.Buffer;
+       pre, post       *forkableWriter;
+}
+
+func newForkableWriter() *forkableWriter {
+       return &forkableWriter{bytes.NewBuffer(nil), nil, nil}
+}
+
+func (f *forkableWriter) fork() (pre, post *forkableWriter) {
+       f.pre = newForkableWriter();
+       f.post = newForkableWriter();
+       return f.pre, f.post;
+}
+
+func (f *forkableWriter) Len() (l int) {
+       l += f.Buffer.Len();
+       if f.pre != nil {
+               l += f.pre.Len()
+       }
+       if f.post != nil {
+               l += f.post.Len()
+       }
+       return;
+}
+
+func (f *forkableWriter) writeTo(out io.Writer) (n int, err os.Error) {
+       n, err = out.Write(f.Bytes());
+       if err != nil {
+               return
+       }
+
+       var nn int;
+
+       if f.pre != nil {
+               nn, err = f.pre.writeTo(out);
+               n += nn;
+               if err != nil {
+                       return
+               }
+       }
+
+       if f.pre != nil {
+               nn, err = f.post.writeTo(out);
+               n += nn;
+       }
+       return;
+}
+
+func marshalBase128Int(out *forkableWriter, i int64) (err os.Error) {
+       if i == 0 {
+               err = out.WriteByte(0);
+               return;
+       }
+
+       for i > 0 {
+               next := i >> 7;
+               o := byte(i & 0x7f);
+               if next > 0 {
+                       o |= 0x80
+               }
+               err = out.WriteByte(o);
+               if err != nil {
+                       return
+               }
+               i = next;
+       }
+
+       return nil;
+}
+
+func base128Length(i int) (numBytes int) {
+       if i == 0 {
+               return 1
+       }
+
+       for i > 0 {
+               numBytes++;
+               i >>= 7;
+       }
+
+       return;
+}
+
+func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) {
+       b := uint8(t.class) << 6;
+       if t.isCompound {
+               b |= 0x20
+       }
+       if t.tag >= 31 {
+               b |= 0x1f;
+               err = out.WriteByte(b);
+               if err != nil {
+                       return
+               }
+               err = marshalBase128Int(out, int64(t.tag));
+               if err != nil {
+                       return
+               }
+       } else {
+               b |= uint8(t.tag);
+               err = out.WriteByte(b);
+               if err != nil {
+                       return
+               }
+       }
+
+       if t.length >= 128 {
+               err = out.WriteByte(byte(base128Length(t.length)));
+               if err != nil {
+                       return
+               }
+               err = marshalBase128Int(out, int64(t.length));
+               if err != nil {
+                       return
+               }
+       } else {
+               err = out.WriteByte(byte(t.length));
+               if err != nil {
+                       return
+               }
+       }
+
+       return nil;
+}
+
+func marshalBitString(out *forkableWriter, b BitString) (err os.Error) {
+       paddingBits := byte((8 - b.BitLength%8) % 8);
+       err = out.WriteByte(paddingBits);
+       if err != nil {
+               return
+       }
+       _, err = out.Write(b.Bytes);
+       return;
+}
+
+func marshalObjectIdentifier(out *forkableWriter, oid []int) (err os.Error) {
+       if len(oid) < 2 || oid[0] > 6 || oid[1] >= 40 {
+               return StructuralError{"invalid object identifier"}
+       }
+
+       err = out.WriteByte(byte(oid[0]*40 + oid[1]));
+       if err != nil {
+               return
+       }
+       for i := 2; i < len(oid); i++ {
+               err = marshalBase128Int(out, int64(oid[i]));
+               if err != nil {
+                       return
+               }
+       }
+
+       return;
+}
+
+func marshalPrintableString(out *forkableWriter, s string) (err os.Error) {
+       b := strings.Bytes(s);
+       for _, c := range b {
+               if !isPrintable(c) {
+                       return StructuralError{"PrintableString contains invalid character"}
+               }
+       }
+
+       _, err = out.Write(b);
+       return;
+}
+
+func marshalIA5String(out *forkableWriter, s string) (err os.Error) {
+       b := strings.Bytes(s);
+       for _, c := range b {
+               if c > 127 {
+                       return StructuralError{"IA5String contains invalid character"}
+               }
+       }
+
+       _, err = out.Write(b);
+       return;
+}
+
+func marshalTwoDigits(out *forkableWriter, v int) (err os.Error) {
+       err = out.WriteByte(byte('0' + (v/10)%10));
+       if err != nil {
+               return
+       }
+       return out.WriteByte(byte('0' + v%10));
+}
+
+func marshalUTCTime(out *forkableWriter, t *time.Time) (err os.Error) {
+       switch {
+       case 1950 <= t.Year && t.Year < 2000:
+               err = marshalTwoDigits(out, int(t.Year-1900))
+       case 2000 <= t.Year && t.Year < 2050:
+               err = marshalTwoDigits(out, int(t.Year-2000))
+       default:
+               return StructuralError{"Cannot represent time as UTCTime"}
+       }
+
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, t.Month);
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, t.Day);
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, t.Hour);
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, t.Minute);
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, t.Second);
+       if err != nil {
+               return
+       }
+
+       switch {
+       case t.ZoneOffset/60 == 0:
+               err = out.WriteByte('Z');
+               return;
+       case t.ZoneOffset > 0:
+               err = out.WriteByte('+')
+       case t.ZoneOffset < 0:
+               err = out.WriteByte('-')
+       }
+
+       if err != nil {
+               return
+       }
+
+       offsetMinutes := t.ZoneOffset / 60;
+       if offsetMinutes < 0 {
+               offsetMinutes = -offsetMinutes
+       }
+
+       err = marshalTwoDigits(out, offsetMinutes/60);
+       if err != nil {
+               return
+       }
+
+       err = marshalTwoDigits(out, offsetMinutes%60);
+       return;
+}
+
+func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err os.Error) {
+       switch value.Type() {
+       case timeType:
+               return marshalUTCTime(out, value.Interface().(*time.Time))
+       case bitStringType:
+               return marshalBitString(out, value.Interface().(BitString))
+       case objectIdentifierType:
+               return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
+       }
+
+       switch v := value.(type) {
+       case *reflect.BoolValue:
+               if v.Get() {
+                       return out.WriteByte(1)
+               } else {
+                       return out.WriteByte(0)
+               }
+       case *reflect.IntValue:
+               return marshalBase128Int(out, int64(v.Get()))
+       case *reflect.Int64Value:
+               return marshalBase128Int(out, v.Get())
+       case *reflect.StructValue:
+               t := v.Type().(*reflect.StructType);
+               for i := 0; i < t.NumField(); i++ {
+                       err = marshalField(out, v.Field(i), parseFieldParameters(t.Field(i).Tag));
+                       if err != nil {
+                               return
+                       }
+               }
+               return;
+       case *reflect.SliceValue:
+               sliceType := v.Type().(*reflect.SliceType);
+               if _, ok := sliceType.Elem().(*reflect.Uint8Type); ok {
+                       bytes := make([]byte, v.Len());
+                       for i := 0; i < v.Len(); i++ {
+                               bytes[i] = v.Elem(i).(*reflect.Uint8Value).Get()
+                       }
+                       _, err = out.Write(bytes);
+                       return;
+               }
+
+               var params fieldParameters;
+               for i := 0; i < v.Len(); i++ {
+                       err = marshalField(out, v.Elem(i), params);
+                       if err != nil {
+                               return
+                       }
+               }
+               return;
+       case *reflect.StringValue:
+               if params.stringType == tagIA5String {
+                       return marshalIA5String(out, v.Get())
+               } else {
+                       return marshalPrintableString(out, v.Get())
+               }
+               return;
+       }
+
+       return StructuralError{"unknown Go type"};
+}
+
+func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err os.Error) {
+       tag, isCompound, ok := getUniversalType(v.Type());
+       if !ok {
+               err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())};
+               return;
+       }
+       class := classUniversal;
+
+       if params.stringType != 0 {
+               if tag != tagPrintableString {
+                       return StructuralError{"Explicit string type given to non-string member"}
+               }
+               tag = params.stringType;
+       }
+
+       tags, body := out.fork();
+
+       err = marshalBody(body, v, params);
+       if err != nil {
+               return
+       }
+
+       bodyLen := body.Len();
+
+       var explicitTag *forkableWriter;
+       if params.explicit {
+               explicitTag, tags = tags.fork()
+       }
+
+       if !params.explicit && params.tag != nil {
+               // implicit tag.
+               tag = *params.tag;
+               class = classContextSpecific;
+       }
+
+       err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound});
+       if err != nil {
+               return
+       }
+
+       if params.explicit {
+               err = marshalTagAndLength(explicitTag, tagAndLength{
+                       class: classContextSpecific,
+                       tag: *params.tag,
+                       length: bodyLen + tags.Len(),
+                       isCompound: true,
+               })
+       }
+
+       return nil;
+}
+
+// Marshal serialises val as an ASN.1 structure and writes the result to out.
+// In the case of an error, no output is produced.
+func Marshal(out io.Writer, val interface{}) os.Error {
+       v := reflect.NewValue(val);
+       f := newForkableWriter();
+       err := marshalField(f, v, fieldParameters{});
+       if err != nil {
+               return err
+       }
+       _, err = f.writeTo(out);
+       return err;
+}
diff --git a/src/pkg/asn1/marshal_test.go b/src/pkg/asn1/marshal_test.go
new file mode 100644 (file)
index 0000000..c2ce1e4
--- /dev/null
@@ -0,0 +1,78 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package asn1
+
+import (
+       "bytes";
+       "encoding/hex";
+       "testing";
+       "time";
+)
+
+type intStruct struct {
+       A int;
+}
+
+type nestedStruct struct {
+       A intStruct;
+}
+
+type marshalTest struct {
+       in      interface{};
+       out     string; // hex encoded
+}
+
+type implicitTagTest struct {
+       A int "implicit,tag:5";
+}
+
+type explicitTagTest struct {
+       A int "explicit,tag:5";
+}
+
+type ia5StringTest struct {
+       A string "ia5";
+}
+
+type printableStringTest struct {
+       A string "printable";
+}
+
+func setPST(t *time.Time) *time.Time {
+       t.ZoneOffset = -28800;
+       return t;
+}
+
+var marshalTests = []marshalTest{
+       marshalTest{10, "02010a"},
+       marshalTest{intStruct{64}, "3003020140"},
+       marshalTest{nestedStruct{intStruct{127}}, "3005300302017f"},
+       marshalTest{[]byte{1, 2, 3}, "0403010203"},
+       marshalTest{implicitTagTest{64}, "3003850140"},
+       marshalTest{explicitTagTest{64}, "3005a503020140"},
+       marshalTest{time.SecondsToUTC(0), "170d3730303130313030303030305a"},
+       marshalTest{time.SecondsToUTC(1258325776), "170d3039313131353232353631365a"},
+       marshalTest{setPST(time.SecondsToUTC(1258325776)), "17113039313131353232353631362d30383030"},
+       marshalTest{BitString{[]byte{0x80}, 1}, "03020780"},
+       marshalTest{BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"},
+       marshalTest{ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"},
+       marshalTest{"test", "130474657374"},
+       marshalTest{ia5StringTest{"test"}, "3006160474657374"},
+       marshalTest{printableStringTest{"test"}, "3006130474657374"},
+}
+
+func TestMarshal(t *testing.T) {
+       for i, test := range marshalTests {
+               buf := bytes.NewBuffer(nil);
+               err := Marshal(buf, test.in);
+               if err != nil {
+                       t.Errorf("#%d failed: %s", i, err)
+               }
+               out, _ := hex.DecodeString(test.out);
+               if bytes.Compare(out, buf.Bytes()) != 0 {
+                       t.Errorf("#%d got: %x want %x", i, buf.Bytes(), out)
+               }
+       }
+}
index 572b4e63948c03441e8023b718858b534a18e7ad..c1488e41e09c649c944c700c97ec907a20627767 100644 (file)
@@ -32,7 +32,7 @@ func rawValueIsInteger(raw *asn1.RawValue) bool {
 // ParsePKCS1PrivateKey returns an RSA private key from its ASN.1 PKCS#1 DER encoded form.
 func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err os.Error) {
        var priv pkcs1PrivateKey;
-       err = asn1.Unmarshal(&priv, der);
+       _, err = asn1.Unmarshal(&priv, der);
        if err != nil {
                return
        }