]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/openpgp/packet: add basic routines
authorAdam Langley <agl@golang.org>
Thu, 3 Feb 2011 14:22:40 +0000 (09:22 -0500)
committerAdam Langley <agl@golang.org>
Thu, 3 Feb 2011 14:22:40 +0000 (09:22 -0500)
Since nobody suggested major changes to the higher level API, I'm
splitting up the lower level code for review. This is the first of the
changes for the packet reading/writing code.

It deliberately doesn't include a Makefile because the package is
incomplete.

R=bradfitzgo
CC=golang-dev
https://golang.org/cl/4080051

src/pkg/crypto/openpgp/packet/packet.go [new file with mode: 0644]
src/pkg/crypto/openpgp/packet/packet_test.go [new file with mode: 0644]

diff --git a/src/pkg/crypto/openpgp/packet/packet.go b/src/pkg/crypto/openpgp/packet/packet.go
new file mode 100644 (file)
index 0000000..80e25e2
--- /dev/null
@@ -0,0 +1,395 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package implements parsing and serialisation of OpenPGP packets, as
+// specified in RFC 4880.
+package packet
+
+import (
+       "crypto/aes"
+       "crypto/cast5"
+       "crypto/cipher"
+       "crypto/openpgp/error"
+       "io"
+       "os"
+)
+
+// readFull is the same as io.ReadFull except that reading zero bytes returns
+// ErrUnexpectedEOF rather than EOF.
+func readFull(r io.Reader, buf []byte) (n int, err os.Error) {
+       n, err = io.ReadFull(r, buf)
+       if err == os.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return
+}
+
+// readLength reads an OpenPGP length from r. See RFC 4880, section 4.2.2.
+func readLength(r io.Reader) (length int64, isPartial bool, err os.Error) {
+       var buf [4]byte
+       _, err = readFull(r, buf[:1])
+       if err != nil {
+               return
+       }
+       switch {
+       case buf[0] < 192:
+               length = int64(buf[0])
+       case buf[0] < 224:
+               length = int64(buf[0]-192) << 8
+               _, err = readFull(r, buf[0:1])
+               if err != nil {
+                       return
+               }
+               length += int64(buf[0]) + 192
+       case buf[0] < 255:
+               length = int64(1) << (buf[0] & 0x1f)
+               isPartial = true
+       default:
+               _, err = readFull(r, buf[0:4])
+               if err != nil {
+                       return
+               }
+               length = int64(buf[0])<<24 |
+                       int64(buf[1])<<16 |
+                       int64(buf[2])<<8 |
+                       int64(buf[3])
+       }
+       return
+}
+
+// partialLengthReader wraps an io.Reader and handles OpenPGP partial lengths.
+// The continuation lengths are parsed and removed from the stream and EOF is
+// returned at the end of the packet. See RFC 4880, section 4.2.2.4.
+type partialLengthReader struct {
+       r         io.Reader
+       remaining int64
+       isPartial bool
+}
+
+func (r *partialLengthReader) Read(p []byte) (n int, err os.Error) {
+       for r.remaining == 0 {
+               if !r.isPartial {
+                       return 0, os.EOF
+               }
+               r.remaining, r.isPartial, err = readLength(r.r)
+               if err != nil {
+                       return 0, err
+               }
+       }
+
+       toRead := int64(len(p))
+       if toRead > r.remaining {
+               toRead = r.remaining
+       }
+
+       n, err = r.r.Read(p[:int(toRead)])
+       r.remaining -= int64(n)
+       if n < int(toRead) && err == os.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return
+}
+
+// A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the
+// underlying Reader returns EOF before the limit has been reached.
+type spanReader struct {
+       r io.Reader
+       n int64
+}
+
+func (l *spanReader) Read(p []byte) (n int, err os.Error) {
+       if l.n <= 0 {
+               return 0, os.EOF
+       }
+       if int64(len(p)) > l.n {
+               p = p[0:l.n]
+       }
+       n, err = l.r.Read(p)
+       l.n -= int64(n)
+       if l.n > 0 && err == os.EOF {
+               err = io.ErrUnexpectedEOF
+       }
+       return
+}
+
+// readHeader parses a packet header and returns an io.Reader which will return
+// the contents of the packet. See RFC 4880, section 4.2.
+func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader, err os.Error) {
+       var buf [4]byte
+       _, err = io.ReadFull(r, buf[:1])
+       if err != nil {
+               return
+       }
+       if buf[0]&0x80 == 0 {
+               err = error.StructuralError("tag byte does not have MSB set")
+               return
+       }
+       if buf[0]&0x40 == 0 {
+               // Old format packet
+               tag = packetType((buf[0] & 0x3f) >> 2)
+               lengthType := buf[0] & 3
+               if lengthType == 3 {
+                       length = -1
+                       contents = r
+                       return
+               }
+               lengthBytes := 1 << lengthType
+               _, err = readFull(r, buf[0:lengthBytes])
+               if err != nil {
+                       return
+               }
+               for i := 0; i < lengthBytes; i++ {
+                       length <<= 8
+                       length |= int64(buf[i])
+               }
+               contents = &spanReader{r, length}
+               return
+       }
+
+       // New format packet
+       tag = packetType(buf[0] & 0x3f)
+       length, isPartial, err := readLength(r)
+       if err != nil {
+               return
+       }
+       if isPartial {
+               contents = &partialLengthReader{
+                       remaining: length,
+                       isPartial: true,
+                       r:         r,
+               }
+               length = -1
+       } else {
+               contents = &spanReader{r, length}
+       }
+       return
+}
+
+// serialiseHeader writes an OpenPGP packet header to w. See RFC 4880, section
+// 4.2.
+func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
+       var buf [5]byte
+       var n int
+
+       buf[0] = 0x80 | 0x40 | byte(ptype)
+       if length < 192 {
+               buf[1] = byte(length)
+               n = 2
+       } else if length < 8384 {
+               length -= 192
+               buf[1] = byte(length >> 8)
+               buf[2] = byte(length)
+               n = 3
+       } else {
+               buf[0] = 255
+               buf[1] = byte(length >> 24)
+               buf[2] = byte(length >> 16)
+               buf[3] = byte(length >> 8)
+               buf[4] = byte(length)
+               n = 5
+       }
+
+       _, err = w.Write(buf[:n])
+       return
+}
+
+// Packet represents an OpenPGP packet. Users are expected to try casting
+// instances of this interface to specific packet types.
+type Packet interface {
+       parse(io.Reader) os.Error
+}
+
+// consumeAll reads from the given Reader until error, returning the number of
+// bytes read.
+func consumeAll(r io.Reader) (n int64, err os.Error) {
+       var m int
+       var buf [1024]byte
+
+       for {
+               m, err = r.Read(buf[:])
+               n += int64(m)
+               if err == os.EOF {
+                       err = nil
+                       return
+               }
+               if err != nil {
+                       return
+               }
+       }
+
+       panic("unreachable")
+}
+
+// packetType represents the numeric ids of the different OpenPGP packet types. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-2
+type packetType uint8
+
+const (
+       packetTypeEncryptedKey              packetType = 1
+       packetTypeSignature                 packetType = 2
+       packetTypeSymmetricKeyEncrypted     packetType = 3
+       packetTypeOnePassSignature          packetType = 4
+       packetTypePrivateKey                packetType = 5
+       packetTypePublicKey                 packetType = 6
+       packetTypePrivateSubkey             packetType = 7
+       packetTypeCompressed                packetType = 8
+       packetTypeSymmetricallyEncrypted    packetType = 9
+       packetTypeLiteralData               packetType = 11
+       packetTypeUserId                    packetType = 13
+       packetTypePublicSubkey              packetType = 14
+       packetTypeSymmetricallyEncryptedMDC packetType = 18
+)
+
+// Read reads a single OpenPGP packet from the given io.Reader. If there is an
+// error parsing a packet, the whole packet is consumed from the input.
+func Read(r io.Reader) (p Packet, err os.Error) {
+       tag, _, contents, err := readHeader(r)
+       if err != nil {
+               return
+       }
+
+       switch tag {
+       case packetTypeEncryptedKey:
+               p = new(EncryptedKey)
+       case packetTypeSignature:
+               p = new(Signature)
+       case packetTypeSymmetricKeyEncrypted:
+               p = new(SymmetricKeyEncrypted)
+       case packetTypeOnePassSignature:
+               p = new(OnePassSignature)
+       case packetTypePrivateKey, packetTypePrivateSubkey:
+               pk := new(PrivateKey)
+               if tag == packetTypePrivateSubkey {
+                       pk.IsSubKey = true
+               }
+               p = pk
+       case packetTypePublicKey, packetTypePublicSubkey:
+               pk := new(PublicKey)
+               if tag == packetTypePublicSubkey {
+                       pk.IsSubKey = true
+               }
+               p = pk
+       case packetTypeCompressed:
+               p = new(Compressed)
+       case packetTypeSymmetricallyEncrypted:
+               p = new(SymmetricallyEncrypted)
+       case packetTypeLiteralData:
+               p = new(LiteralData)
+       case packetTypeUserId:
+               p = new(UserId)
+       case packetTypeSymmetricallyEncryptedMDC:
+               se := new(SymmetricallyEncrypted)
+               se.MDC = true
+               p = se
+       default:
+               err = error.UnknownPacketTypeError(tag)
+       }
+       if p != nil {
+               err = p.parse(contents)
+       }
+       if err != nil {
+               consumeAll(contents)
+       }
+       return
+}
+
+// SignatureType represents the different semantic meanings of an OpenPGP
+// signature. See RFC 4880, section 5.2.1.
+type SignatureType uint8
+
+const (
+       SigTypeBinary        SignatureType = 0
+       SigTypeText          SignatureType = 1
+       SigTypeGenericCert   = 0x10
+       SigTypePersonaCert   = 0x11
+       SigTypeCasualCert    = 0x12
+       SigTypePositiveCert  = 0x13
+       SigTypeSubkeyBinding = 0x18
+)
+
+// PublicKeyAlgorithm represents the different public key system specified for
+// OpenPGP. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-12
+type PublicKeyAlgorithm uint8
+
+const (
+       PubKeyAlgoRSA            PublicKeyAlgorithm = 1
+       PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
+       PubKeyAlgoRSASignOnly    PublicKeyAlgorithm = 3
+       PubKeyAlgoElgamal        PublicKeyAlgorithm = 16
+       PubKeyAlgoDSA            PublicKeyAlgorithm = 17
+)
+
+// CipherFunction represents the different block ciphers specified for OpenPGP. See
+// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13
+type CipherFunction uint8
+
+const (
+       CipherCAST5  = 3
+       CipherAES128 = 7
+       CipherAES192 = 8
+       CipherAES256 = 9
+)
+
+// keySize returns the key size, in bytes, of cipher.
+func (cipher CipherFunction) keySize() int {
+       switch cipher {
+       case CipherCAST5:
+               return cast5.KeySize
+       case CipherAES128:
+               return 16
+       case CipherAES192:
+               return 24
+       case CipherAES256:
+               return 32
+       }
+       return 0
+}
+
+// blockSize returns the block size, in bytes, of cipher.
+func (cipher CipherFunction) blockSize() int {
+       switch cipher {
+       case CipherCAST5:
+               return 8
+       case CipherAES128, CipherAES192, CipherAES256:
+               return 16
+       }
+       return 0
+}
+
+// new returns a fresh instance of the given cipher.
+func (cipher CipherFunction) new(key []byte) (block cipher.Block) {
+       switch cipher {
+       case CipherCAST5:
+               block, _ = cast5.NewCipher(key)
+       case CipherAES128, CipherAES192, CipherAES256:
+               block, _ = aes.NewCipher(key)
+       }
+       return
+}
+
+// readMPI reads a big integer from r. The bit length returned is the bit
+// length that was specified in r. This is preserved so that the integer can be
+// reserialised exactly.
+func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
+       var buf [2]byte
+       _, err = readFull(r, buf[0:])
+       if err != nil {
+               return
+       }
+       bitLength = uint16(buf[0])<<8 | uint16(buf[1])
+       numBytes := (int(bitLength) + 7) / 8
+       mpi = make([]byte, numBytes)
+       _, err = readFull(r, mpi)
+       return
+}
+
+// writeMPI serialises a big integer to r.
+func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
+       _, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
+       if err == nil {
+               _, err = w.Write(mpiBytes)
+       }
+       return
+}
diff --git a/src/pkg/crypto/openpgp/packet/packet_test.go b/src/pkg/crypto/openpgp/packet/packet_test.go
new file mode 100644 (file)
index 0000000..050b734
--- /dev/null
@@ -0,0 +1,192 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package packet
+
+import (
+       "bytes"
+       "crypto/openpgp/error"
+       "encoding/hex"
+       "fmt"
+       "io"
+       "io/ioutil"
+       "os"
+       "testing"
+)
+
+func TestReadFull(t *testing.T) {
+       var out [4]byte
+
+       b := bytes.NewBufferString("foo")
+       n, err := readFull(b, out[:3])
+       if n != 3 || err != nil {
+               t.Errorf("full read failed n:%d err:%s", n, err)
+       }
+
+       b = bytes.NewBufferString("foo")
+       n, err = readFull(b, out[:4])
+       if n != 3 || err != io.ErrUnexpectedEOF {
+               t.Errorf("partial read failed n:%d err:%s", n, err)
+       }
+
+       b = bytes.NewBuffer(nil)
+       n, err = readFull(b, out[:3])
+       if n != 0 || err != io.ErrUnexpectedEOF {
+               t.Errorf("empty read failed n:%d err:%s", n, err)
+       }
+}
+
+func readerFromHex(s string) io.Reader {
+       data, err := hex.DecodeString(s)
+       if err != nil {
+               panic("readerFromHex: bad input")
+       }
+       return bytes.NewBuffer(data)
+}
+
+var readLengthTests = []struct {
+       hexInput  string
+       length    int64
+       isPartial bool
+       err       os.Error
+}{
+       {"", 0, false, io.ErrUnexpectedEOF},
+       {"1f", 31, false, nil},
+       {"c0", 0, false, io.ErrUnexpectedEOF},
+       {"c101", 256 + 1 + 192, false, nil},
+       {"e0", 1, true, nil},
+       {"e1", 2, true, nil},
+       {"e2", 4, true, nil},
+       {"ff", 0, false, io.ErrUnexpectedEOF},
+       {"ff00", 0, false, io.ErrUnexpectedEOF},
+       {"ff0000", 0, false, io.ErrUnexpectedEOF},
+       {"ff000000", 0, false, io.ErrUnexpectedEOF},
+       {"ff00000000", 0, false, nil},
+       {"ff01020304", 16909060, false, nil},
+}
+
+func TestReadLength(t *testing.T) {
+       for i, test := range readLengthTests {
+               length, isPartial, err := readLength(readerFromHex(test.hexInput))
+               if test.err != nil {
+                       if err != test.err {
+                               t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
+                       }
+                       continue
+               }
+               if err != nil {
+                       t.Errorf("%d: unexpected error: %s", i, err)
+                       continue
+               }
+               if length != test.length || isPartial != test.isPartial {
+                       t.Errorf("%d: bad result got:(%d,%t) want:(%d,%t)", i, length, isPartial, test.length, test.isPartial)
+               }
+       }
+}
+
+var partialLengthReaderTests = []struct {
+       hexInput  string
+       err       os.Error
+       hexOutput string
+}{
+       {"e0", io.ErrUnexpectedEOF, ""},
+       {"e001", io.ErrUnexpectedEOF, ""},
+       {"e0010102", nil, "0102"},
+       {"ff00000000", nil, ""},
+       {"e10102e1030400", nil, "01020304"},
+       {"e101", io.ErrUnexpectedEOF, ""},
+}
+
+func TestPartialLengthReader(t *testing.T) {
+       for i, test := range partialLengthReaderTests {
+               r := &partialLengthReader{readerFromHex(test.hexInput), 0, true}
+               out, err := ioutil.ReadAll(r)
+               if test.err != nil {
+                       if err != test.err {
+                               t.Errorf("%d: expected different error got:%s want:%s", i, err, test.err)
+                       }
+                       continue
+               }
+               if err != nil {
+                       t.Errorf("%d: unexpected error: %s", i, err)
+                       continue
+               }
+
+               got := fmt.Sprintf("%x", out)
+               if got != test.hexOutput {
+                       t.Errorf("%d: got:%s want:%s", test.hexOutput, got)
+               }
+       }
+}
+
+var readHeaderTests = []struct {
+       hexInput        string
+       structuralError bool
+       unexpectedEOF   bool
+       tag             int
+       length          int64
+       hexOutput       string
+}{
+       {"", false, false, 0, 0, ""},
+       {"7f", true, false, 0, 0, ""},
+
+       // Old format headers
+       {"80", false, true, 0, 0, ""},
+       {"8001", false, true, 0, 1, ""},
+       {"800102", false, false, 0, 1, "02"},
+       {"81000102", false, false, 0, 1, "02"},
+       {"820000000102", false, false, 0, 1, "02"},
+       {"860000000102", false, false, 1, 1, "02"},
+       {"83010203", false, false, 0, -1, "010203"},
+
+       // New format headers
+       {"c0", false, true, 0, 0, ""},
+       {"c000", false, false, 0, 0, ""},
+       {"c00102", false, false, 0, 1, "02"},
+       {"c0020203", false, false, 0, 2, "0203"},
+       {"c00202", false, true, 0, 2, ""},
+       {"c3020203", false, false, 3, 2, "0203"},
+}
+
+func TestReadHeader(t *testing.T) {
+       for i, test := range readHeaderTests {
+               tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
+               if test.structuralError {
+                       if _, ok := err.(error.StructuralError); ok {
+                               continue
+                       }
+                       t.Errorf("%d: expected StructuralError, got:%s", i, err)
+                       continue
+               }
+               if err != nil {
+                       if len(test.hexInput) == 0 && err == os.EOF {
+                               continue
+                       }
+                       if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
+                               t.Errorf("%d: unexpected error from readHeader: %s", i, err)
+                       }
+                       continue
+               }
+               if int(tag) != test.tag || length != test.length {
+                       t.Errorf("%d: got:(%d,%d) want:(%d,%d)", i, int(tag), length, test.tag, test.length)
+                       continue
+               }
+
+               body, err := ioutil.ReadAll(contents)
+               if err != nil {
+                       if !test.unexpectedEOF || err != io.ErrUnexpectedEOF {
+                               t.Errorf("%d: unexpected error from contents: %s", i, err)
+                       }
+                       continue
+               }
+               if test.unexpectedEOF {
+                       t.Errorf("%d: expected ErrUnexpectedEOF from contents but got no error", i)
+                       continue
+               }
+               got := fmt.Sprintf("%x", body)
+               if got != test.hexOutput {
+                       t.Errorf("%d: got:%s want:%s", i, got, test.hexOutput)
+               }
+       }
+}