]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: rewrite some messages with golang.org/x/crypto/cryptobyte
authorFilippo Valsorda <filippo@golang.org>
Thu, 25 Oct 2018 01:22:00 +0000 (21:22 -0400)
committerFilippo Valsorda <filippo@golang.org>
Mon, 29 Oct 2018 17:05:55 +0000 (17:05 +0000)
As a first round, rewrite those handshake message types which can be
reused in TLS 1.3 with golang.org/x/crypto/cryptobyte. All other types
changed significantly in TLS 1.3 and will require separate
implementations. They will be ported to cryptobyte in a later CL.

The only semantic changes should be enforcing the random length on the
marshaling side, enforcing a couple more "must not be empty" on the
unmarshaling side, and checking the rest of the SNI list even if we only
take the first.

Change-Id: Idd2ced60c558fafcf02ee489195b6f3b4735fe22
Reviewed-on: https://go-review.googlesource.com/c/144115
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go
src/go/build/deps_test.go

index 8e236434409d2a9a5898775f6bef5dc18d7c2ebe..dae5fd103a1abe1065f3c18db772cc1aea137e10 100644 (file)
@@ -899,7 +899,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
                m = new(certificateMsg)
        case typeCertificateRequest:
                m = &certificateRequestMsg{
-                       hasSignatureAndHash: c.vers >= VersionTLS12,
+                       hasSignatureAlgorithm: c.vers >= VersionTLS12,
                }
        case typeCertificateStatus:
                m = new(certificateStatusMsg)
@@ -911,7 +911,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
                m = new(clientKeyExchangeMsg)
        case typeCertificateVerify:
                m = &certificateVerifyMsg{
-                       hasSignatureAndHash: c.vers >= VersionTLS12,
+                       hasSignatureAlgorithm: c.vers >= VersionTLS12,
                }
        case typeNextProtocol:
                m = new(nextProtoMsg)
index af290e33a7852b140d3fa89d28077198616aac72..fb74f79bd8830d4ea98f168b6d4f10eb41808c37 100644 (file)
@@ -471,7 +471,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
 
        if chainToSend != nil && len(chainToSend.Certificate) > 0 {
                certVerify := &certificateVerifyMsg{
-                       hasSignatureAndHash: c.vers >= VersionTLS12,
+                       hasSignatureAlgorithm: c.vers >= VersionTLS12,
                }
 
                key, ok := chainToSend.PrivateKey.(crypto.Signer)
@@ -486,7 +486,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                        return err
                }
                // SignatureAndHashAlgorithm was introduced in TLS 1.2.
-               if certVerify.hasSignatureAndHash {
+               if certVerify.hasSignatureAlgorithm {
                        certVerify.signatureAlgorithm = signatureAlgorithm
                }
                digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret)
@@ -739,7 +739,7 @@ func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (
        if c.config.GetClientCertificate != nil {
                var signatureSchemes []SignatureScheme
 
-               if !certReq.hasSignatureAndHash {
+               if !certReq.hasSignatureAlgorithm {
                        // Prior to TLS 1.2, the signature schemes were not
                        // included in the certificate request message. In this
                        // case we use a plausible list based on the acceptable
index c5d995060731da52f51ede9958d0cc87a0558586..d6785550a29949bdf0ca460439c8cb489bccbb1d 100644 (file)
@@ -5,9 +5,49 @@
 package tls
 
 import (
+       "fmt"
+       "golang_org/x/crypto/cryptobyte"
        "strings"
 )
 
+// The marshalingFunction type is an adapter to allow the use of ordinary
+// functions as cryptobyte.MarshalingValue.
+type marshalingFunction func(b *cryptobyte.Builder) error
+
+func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
+       return f(b)
+}
+
+// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
+// the length of the sequence is not the value specified, it produces an error.
+func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
+       b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
+               if len(v) != n {
+                       return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
+               }
+               b.AddBytes(v)
+               return nil
+       }))
+}
+
+// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+       return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
+}
+
+// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+       return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
+}
+
+// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+       return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
+}
+
 type clientHelloMsg struct {
        raw                          []byte
        vers                         uint16
@@ -34,442 +74,289 @@ func (m *clientHelloMsg) marshal() []byte {
                return m.raw
        }
 
-       length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
-       numExtensions := 0
-       extensionsLength := 0
-       if m.nextProtoNeg {
-               numExtensions++
-       }
-       if m.ocspStapling {
-               extensionsLength += 1 + 2 + 2
-               numExtensions++
-       }
-       if len(m.serverName) > 0 {
-               extensionsLength += 5 + len(m.serverName)
-               numExtensions++
-       }
-       if len(m.supportedCurves) > 0 {
-               extensionsLength += 2 + 2*len(m.supportedCurves)
-               numExtensions++
-       }
-       if len(m.supportedPoints) > 0 {
-               extensionsLength += 1 + len(m.supportedPoints)
-               numExtensions++
-       }
-       if m.ticketSupported {
-               extensionsLength += len(m.sessionTicket)
-               numExtensions++
-       }
-       if len(m.supportedSignatureAlgorithms) > 0 {
-               extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
-               numExtensions++
-       }
-       if m.secureRenegotiationSupported {
-               extensionsLength += 1 + len(m.secureRenegotiation)
-               numExtensions++
-       }
-       if len(m.alpnProtocols) > 0 {
-               extensionsLength += 2
-               for _, s := range m.alpnProtocols {
-                       if l := len(s); l == 0 || l > 255 {
-                               panic("invalid ALPN protocol")
+       var b cryptobyte.Builder
+       b.AddUint8(typeClientHello)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint16(m.vers)
+               addBytesWithLength(b, m.random, 32)
+               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.sessionId)
+               })
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       for _, suite := range m.cipherSuites {
+                               b.AddUint16(suite)
+                       }
+               })
+               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.compressionMethods)
+               })
+
+               // If extensions aren't present, omit them.
+               var extensionsPresent bool
+               bWithoutExtensions := *b
+
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       if m.nextProtoNeg {
+                               // draft-agl-tls-nextprotoneg-04
+                               b.AddUint16(extensionNextProtoNeg)
+                               b.AddUint16(0) // empty extension_data
+                       }
+                       if len(m.serverName) > 0 {
+                               // RFC 6066, Section 3
+                               b.AddUint16(extensionServerName)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddUint8(0) // name_type = host_name
+                                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddBytes([]byte(m.serverName))
+                                               })
+                                       })
+                               })
+                       }
+                       if m.ocspStapling {
+                               // RFC 4366, Section 3.6
+                               b.AddUint16(extensionStatusRequest)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint8(1)  // status_type = ocsp
+                                       b.AddUint16(0) // empty responder_id_list
+                                       b.AddUint16(0) // empty request_extensions
+                               })
+                       }
+                       if len(m.supportedCurves) > 0 {
+                               // RFC 4492, Section 5.1.1
+                               b.AddUint16(extensionSupportedCurves)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, curve := range m.supportedCurves {
+                                                       b.AddUint16(uint16(curve))
+                                               }
+                                       })
+                               })
+                       }
+                       if len(m.supportedPoints) > 0 {
+                               // RFC 4492, Section 5.1.2
+                               b.AddUint16(extensionSupportedPoints)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddBytes(m.supportedPoints)
+                                       })
+                               })
+                       }
+                       if m.ticketSupported {
+                               // RFC 5077, Section 3.2
+                               b.AddUint16(extensionSessionTicket)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddBytes(m.sessionTicket)
+                               })
+                       }
+                       if len(m.supportedSignatureAlgorithms) > 0 {
+                               // RFC 5246, Section 7.4.1.4.1
+                               b.AddUint16(extensionSignatureAlgorithms)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, sigAlgo := range m.supportedSignatureAlgorithms {
+                                                       b.AddUint16(uint16(sigAlgo))
+                                               }
+                                       })
+                               })
+                       }
+                       if m.secureRenegotiationSupported {
+                               // RFC 5746, Section 3.2
+                               b.AddUint16(extensionRenegotiationInfo)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddBytes(m.secureRenegotiation)
+                                       })
+                               })
+                       }
+                       if len(m.alpnProtocols) > 0 {
+                               // RFC 7301, Section 3.1
+                               b.AddUint16(extensionALPN)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, proto := range m.alpnProtocols {
+                                                       b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                               b.AddBytes([]byte(proto))
+                                                       })
+                                               }
+                                       })
+                               })
+                       }
+                       if m.scts {
+                               // RFC 6962, Section 3.3.1
+                               b.AddUint16(extensionSCT)
+                               b.AddUint16(0) // empty extension_data
                        }
-                       extensionsLength++
-                       extensionsLength += len(s)
-               }
-               numExtensions++
-       }
-       if m.scts {
-               numExtensions++
-       }
-       if numExtensions > 0 {
-               extensionsLength += 4 * numExtensions
-               length += 2 + extensionsLength
-       }
-
-       x := make([]byte, 4+length)
-       x[0] = typeClientHello
-       x[1] = uint8(length >> 16)
-       x[2] = uint8(length >> 8)
-       x[3] = uint8(length)
-       x[4] = uint8(m.vers >> 8)
-       x[5] = uint8(m.vers)
-       copy(x[6:38], m.random)
-       x[38] = uint8(len(m.sessionId))
-       copy(x[39:39+len(m.sessionId)], m.sessionId)
-       y := x[39+len(m.sessionId):]
-       y[0] = uint8(len(m.cipherSuites) >> 7)
-       y[1] = uint8(len(m.cipherSuites) << 1)
-       for i, suite := range m.cipherSuites {
-               y[2+i*2] = uint8(suite >> 8)
-               y[3+i*2] = uint8(suite)
-       }
-       z := y[2+len(m.cipherSuites)*2:]
-       z[0] = uint8(len(m.compressionMethods))
-       copy(z[1:], m.compressionMethods)
-
-       z = z[1+len(m.compressionMethods):]
-       if numExtensions > 0 {
-               z[0] = byte(extensionsLength >> 8)
-               z[1] = byte(extensionsLength)
-               z = z[2:]
-       }
-       if m.nextProtoNeg {
-               z[0] = byte(extensionNextProtoNeg >> 8)
-               z[1] = byte(extensionNextProtoNeg & 0xff)
-               // The length is always 0
-               z = z[4:]
-       }
-       if len(m.serverName) > 0 {
-               z[0] = byte(extensionServerName >> 8)
-               z[1] = byte(extensionServerName & 0xff)
-               l := len(m.serverName) + 5
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               z = z[4:]
-
-               // RFC 3546, Section 3.1
-               //
-               // struct {
-               //     NameType name_type;
-               //     select (name_type) {
-               //         case host_name: HostName;
-               //     } name;
-               // } ServerName;
-               //
-               // enum {
-               //     host_name(0), (255)
-               // } NameType;
-               //
-               // opaque HostName<1..2^16-1>;
-               //
-               // struct {
-               //     ServerName server_name_list<1..2^16-1>
-               // } ServerNameList;
-
-               z[0] = byte((len(m.serverName) + 3) >> 8)
-               z[1] = byte(len(m.serverName) + 3)
-               z[3] = byte(len(m.serverName) >> 8)
-               z[4] = byte(len(m.serverName))
-               copy(z[5:], []byte(m.serverName))
-               z = z[l:]
-       }
-       if m.ocspStapling {
-               // RFC 4366, Section 3.6
-               z[0] = byte(extensionStatusRequest >> 8)
-               z[1] = byte(extensionStatusRequest)
-               z[2] = 0
-               z[3] = 5
-               z[4] = 1 // OCSP type
-               // Two zero valued uint16s for the two lengths.
-               z = z[9:]
-       }
-       if len(m.supportedCurves) > 0 {
-               // RFC 4492, Section 5.5.1
-               z[0] = byte(extensionSupportedCurves >> 8)
-               z[1] = byte(extensionSupportedCurves)
-               l := 2 + 2*len(m.supportedCurves)
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               l -= 2
-               z[4] = byte(l >> 8)
-               z[5] = byte(l)
-               z = z[6:]
-               for _, curve := range m.supportedCurves {
-                       z[0] = byte(curve >> 8)
-                       z[1] = byte(curve)
-                       z = z[2:]
-               }
-       }
-       if len(m.supportedPoints) > 0 {
-               // RFC 4492, Section 5.5.2
-               z[0] = byte(extensionSupportedPoints >> 8)
-               z[1] = byte(extensionSupportedPoints)
-               l := 1 + len(m.supportedPoints)
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               l--
-               z[4] = byte(l)
-               z = z[5:]
-               for _, pointFormat := range m.supportedPoints {
-                       z[0] = pointFormat
-                       z = z[1:]
-               }
-       }
-       if m.ticketSupported {
-               // RFC 5077, Section 3.2
-               z[0] = byte(extensionSessionTicket >> 8)
-               z[1] = byte(extensionSessionTicket)
-               l := len(m.sessionTicket)
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               z = z[4:]
-               copy(z, m.sessionTicket)
-               z = z[len(m.sessionTicket):]
-       }
-       if len(m.supportedSignatureAlgorithms) > 0 {
-               // RFC 5246, Section 7.4.1.4.1
-               z[0] = byte(extensionSignatureAlgorithms >> 8)
-               z[1] = byte(extensionSignatureAlgorithms)
-               l := 2 + 2*len(m.supportedSignatureAlgorithms)
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               z = z[4:]
-
-               l -= 2
-               z[0] = byte(l >> 8)
-               z[1] = byte(l)
-               z = z[2:]
-               for _, sigAlgo := range m.supportedSignatureAlgorithms {
-                       z[0] = byte(sigAlgo >> 8)
-                       z[1] = byte(sigAlgo)
-                       z = z[2:]
-               }
-       }
-       if m.secureRenegotiationSupported {
-               z[0] = byte(extensionRenegotiationInfo >> 8)
-               z[1] = byte(extensionRenegotiationInfo & 0xff)
-               z[2] = 0
-               z[3] = byte(len(m.secureRenegotiation) + 1)
-               z[4] = byte(len(m.secureRenegotiation))
-               z = z[5:]
-               copy(z, m.secureRenegotiation)
-               z = z[len(m.secureRenegotiation):]
-       }
-       if len(m.alpnProtocols) > 0 {
-               z[0] = byte(extensionALPN >> 8)
-               z[1] = byte(extensionALPN & 0xff)
-               lengths := z[2:]
-               z = z[6:]
-
-               stringsLength := 0
-               for _, s := range m.alpnProtocols {
-                       l := len(s)
-                       z[0] = byte(l)
-                       copy(z[1:], s)
-                       z = z[1+l:]
-                       stringsLength += 1 + l
-               }
 
-               lengths[2] = byte(stringsLength >> 8)
-               lengths[3] = byte(stringsLength)
-               stringsLength += 2
-               lengths[0] = byte(stringsLength >> 8)
-               lengths[1] = byte(stringsLength)
-       }
-       if m.scts {
-               // RFC 6962, Section 3.3.1
-               z[0] = byte(extensionSCT >> 8)
-               z[1] = byte(extensionSCT)
-               // zero uint16 for the zero-length extension_data
-               z = z[4:]
-       }
+                       extensionsPresent = len(b.BytesOrPanic()) > 2
+               })
 
-       m.raw = x
+               if !extensionsPresent {
+                       *b = bWithoutExtensions
+               }
+       })
 
-       return x
+       m.raw = b.BytesOrPanic()
+       return m.raw
 }
 
 func (m *clientHelloMsg) unmarshal(data []byte) bool {
-       if len(data) < 42 {
-               return false
-       }
-       m.raw = data
-       m.vers = uint16(data[4])<<8 | uint16(data[5])
-       m.random = data[6:38]
-       sessionIdLen := int(data[38])
-       if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
-               return false
-       }
-       m.sessionId = data[39 : 39+sessionIdLen]
-       data = data[39+sessionIdLen:]
-       if len(data) < 2 {
+       *m = clientHelloMsg{raw: data}
+       s := cryptobyte.String(data)
+
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
+               !readUint8LengthPrefixed(&s, &m.sessionId) {
                return false
        }
-       // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
-       // they are uint16s, the number must be even.
-       cipherSuiteLen := int(data[0])<<8 | int(data[1])
-       if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
+
+       var cipherSuites cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&cipherSuites) {
                return false
        }
-       numCipherSuites := cipherSuiteLen / 2
-       m.cipherSuites = make([]uint16, numCipherSuites)
-       for i := 0; i < numCipherSuites; i++ {
-               m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
-               if m.cipherSuites[i] == scsvRenegotiation {
+       m.cipherSuites = []uint16{}
+       m.secureRenegotiationSupported = false
+       for !cipherSuites.Empty() {
+               var suite uint16
+               if !cipherSuites.ReadUint16(&suite) {
+                       return false
+               }
+               if suite == scsvRenegotiation {
                        m.secureRenegotiationSupported = true
                }
+               m.cipherSuites = append(m.cipherSuites, suite)
        }
-       data = data[2+cipherSuiteLen:]
-       if len(data) < 1 {
-               return false
-       }
-       compressionMethodsLen := int(data[0])
-       if len(data) < 1+compressionMethodsLen {
+
+       if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
                return false
        }
-       m.compressionMethods = data[1 : 1+compressionMethodsLen]
-
-       data = data[1+compressionMethodsLen:]
-
-       m.nextProtoNeg = false
-       m.serverName = ""
-       m.ocspStapling = false
-       m.ticketSupported = false
-       m.sessionTicket = nil
-       m.supportedSignatureAlgorithms = nil
-       m.alpnProtocols = nil
-       m.scts = false
 
-       if len(data) == 0 {
+       if s.Empty() {
                // ClientHello is optionally followed by extension data
                return true
        }
-       if len(data) < 2 {
-               return false
-       }
 
-       extensionsLength := int(data[0])<<8 | int(data[1])
-       data = data[2:]
-       if extensionsLength != len(data) {
+       var extensions cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
                return false
        }
 
-       for len(data) != 0 {
-               if len(data) < 4 {
-                       return false
-               }
-               extension := uint16(data[0])<<8 | uint16(data[1])
-               length := int(data[2])<<8 | int(data[3])
-               data = data[4:]
-               if len(data) < length {
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
                        return false
                }
 
                switch extension {
                case extensionServerName:
-                       d := data[:length]
-                       if len(d) < 2 {
+                       // RFC 6066, Section 3
+                       var nameList cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
                                return false
                        }
-                       namesLen := int(d[0])<<8 | int(d[1])
-                       d = d[2:]
-                       if len(d) != namesLen {
-                               return false
-                       }
-                       for len(d) > 0 {
-                               if len(d) < 3 {
+                       for !nameList.Empty() {
+                               var nameType uint8
+                               var serverName cryptobyte.String
+                               if !nameList.ReadUint8(&nameType) ||
+                                       !nameList.ReadUint16LengthPrefixed(&serverName) ||
+                                       serverName.Empty() {
                                        return false
                                }
-                               nameType := d[0]
-                               nameLen := int(d[1])<<8 | int(d[2])
-                               d = d[3:]
-                               if len(d) < nameLen {
+                               if nameType != 0 {
+                                       continue
+                               }
+                               if len(m.serverName) != 0 {
+                                       // Multiple names of the same name_type are prohibited.
                                        return false
                                }
-                               if nameType == 0 {
-                                       m.serverName = string(d[:nameLen])
-                                       // An SNI value may not include a trailing dot.
-                                       // See RFC 6066, Section 3.
-                                       if strings.HasSuffix(m.serverName, ".") {
-                                               return false
-                                       }
-                                       break
+                               m.serverName = string(serverName)
+                               // An SNI value may not include a trailing dot.
+                               if strings.HasSuffix(m.serverName, ".") {
+                                       return false
                                }
-                               d = d[nameLen:]
                        }
                case extensionNextProtoNeg:
-                       if length > 0 {
-                               return false
-                       }
+                       // draft-agl-tls-nextprotoneg-04
                        m.nextProtoNeg = true
                case extensionStatusRequest:
-                       m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
-               case extensionSupportedCurves:
-                       // RFC 4492, Section 5.5.1
-                       if length < 2 {
+                       // RFC 4366, Section 3.6
+                       var statusType uint8
+                       var ignored cryptobyte.String
+                       if !extData.ReadUint8(&statusType) ||
+                               !extData.ReadUint16LengthPrefixed(&ignored) ||
+                               !extData.ReadUint16LengthPrefixed(&ignored) {
                                return false
                        }
-                       l := int(data[0])<<8 | int(data[1])
-                       if l%2 == 1 || length != l+2 {
+                       m.ocspStapling = statusType == statusTypeOCSP
+               case extensionSupportedCurves:
+                       // RFC 4492, Section 5.1.1
+                       var curves cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
                                return false
                        }
-                       numCurves := l / 2
-                       m.supportedCurves = make([]CurveID, numCurves)
-                       d := data[2:]
-                       for i := 0; i < numCurves; i++ {
-                               m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
-                               d = d[2:]
+                       for !curves.Empty() {
+                               var curve uint16
+                               if !curves.ReadUint16(&curve) {
+                                       return false
+                               }
+                               m.supportedCurves = append(m.supportedCurves, CurveID(curve))
                        }
                case extensionSupportedPoints:
-                       // RFC 4492, Section 5.5.2
-                       if length < 1 {
-                               return false
-                       }
-                       l := int(data[0])
-                       if length != l+1 {
+                       // RFC 4492, Section 5.1.2
+                       if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
+                               len(m.supportedPoints) == 0 {
                                return false
                        }
-                       m.supportedPoints = make([]uint8, l)
-                       copy(m.supportedPoints, data[1:])
                case extensionSessionTicket:
                        // RFC 5077, Section 3.2
                        m.ticketSupported = true
-                       m.sessionTicket = data[:length]
+                       extData.ReadBytes(&m.sessionTicket, len(extData))
                case extensionSignatureAlgorithms:
                        // RFC 5246, Section 7.4.1.4.1
-                       if length < 2 || length&1 != 0 {
+                       var sigAndAlgs cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
                                return false
                        }
-                       l := int(data[0])<<8 | int(data[1])
-                       if l != length-2 {
-                               return false
-                       }
-                       n := l / 2
-                       d := data[2:]
-                       m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
-                       for i := range m.supportedSignatureAlgorithms {
-                               m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
-                               d = d[2:]
+                       for !sigAndAlgs.Empty() {
+                               var sigAndAlg uint16
+                               if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+                                       return false
+                               }
+                               m.supportedSignatureAlgorithms = append(
+                                       m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
                        }
                case extensionRenegotiationInfo:
-                       if length == 0 {
+                       // RFC 5746, Section 3.2
+                       if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
                                return false
                        }
-                       d := data[:length]
-                       l := int(d[0])
-                       d = d[1:]
-                       if l != len(d) {
-                               return false
-                       }
-
-                       m.secureRenegotiation = d
                        m.secureRenegotiationSupported = true
                case extensionALPN:
-                       if length < 2 {
+                       // RFC 7301, Section 3.1
+                       var protoList cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
                                return false
                        }
-                       l := int(data[0])<<8 | int(data[1])
-                       if l != length-2 {
-                               return false
-                       }
-                       d := data[2:length]
-                       for len(d) != 0 {
-                               stringLen := int(d[0])
-                               d = d[1:]
-                               if stringLen == 0 || stringLen > len(d) {
+                       for !protoList.Empty() {
+                               var proto cryptobyte.String
+                               if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
                                        return false
                                }
-                               m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
-                               d = d[stringLen:]
+                               m.alpnProtocols = append(m.alpnProtocols, string(proto))
                        }
                case extensionSCT:
+                       // RFC 6962, Section 3.3.1
                        m.scts = true
-                       if length != 0 {
-                               return false
-                       }
+               default:
+                       // Ignore unknown extensions.
+                       continue
+               }
+
+               if !extData.Empty() {
+                       return false
                }
-               data = data[length:]
        }
 
        return true
@@ -497,280 +384,165 @@ func (m *serverHelloMsg) marshal() []byte {
                return m.raw
        }
 
-       length := 38 + len(m.sessionId)
-       numExtensions := 0
-       extensionsLength := 0
-
-       nextProtoLen := 0
-       if m.nextProtoNeg {
-               numExtensions++
-               for _, v := range m.nextProtos {
-                       nextProtoLen += len(v)
-               }
-               nextProtoLen += len(m.nextProtos)
-               extensionsLength += nextProtoLen
-       }
-       if m.ocspStapling {
-               numExtensions++
-       }
-       if m.ticketSupported {
-               numExtensions++
-       }
-       if m.secureRenegotiationSupported {
-               extensionsLength += 1 + len(m.secureRenegotiation)
-               numExtensions++
-       }
-       if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
-               if alpnLen >= 256 {
-                       panic("invalid ALPN protocol")
-               }
-               extensionsLength += 2 + 1 + alpnLen
-               numExtensions++
-       }
-       sctLen := 0
-       if len(m.scts) > 0 {
-               for _, sct := range m.scts {
-                       sctLen += len(sct) + 2
-               }
-               extensionsLength += 2 + sctLen
-               numExtensions++
-       }
+       var b cryptobyte.Builder
+       b.AddUint8(typeServerHello)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint16(m.vers)
+               addBytesWithLength(b, m.random, 32)
+               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.sessionId)
+               })
+               b.AddUint16(m.cipherSuite)
+               b.AddUint8(m.compressionMethod)
+
+               // If extensions aren't present, omit them.
+               var extensionsPresent bool
+               bWithoutExtensions := *b
+
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       if m.nextProtoNeg {
+                               b.AddUint16(extensionNextProtoNeg)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       for _, proto := range m.nextProtos {
+                                               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddBytes([]byte(proto))
+                                               })
+                                       }
+                               })
+                       }
+                       if m.ocspStapling {
+                               b.AddUint16(extensionStatusRequest)
+                               b.AddUint16(0) // empty extension_data
+                       }
+                       if m.ticketSupported {
+                               b.AddUint16(extensionSessionTicket)
+                               b.AddUint16(0) // empty extension_data
+                       }
+                       if m.secureRenegotiationSupported {
+                               b.AddUint16(extensionRenegotiationInfo)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddBytes(m.secureRenegotiation)
+                                       })
+                               })
+                       }
+                       if len(m.alpnProtocol) > 0 {
+                               b.AddUint16(extensionALPN)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddBytes([]byte(m.alpnProtocol))
+                                               })
+                                       })
+                               })
+                       }
+                       if len(m.scts) > 0 {
+                               b.AddUint16(extensionSCT)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, sct := range m.scts {
+                                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                               b.AddBytes(sct)
+                                                       })
+                                               }
+                                       })
+                               })
+                       }
 
-       if numExtensions > 0 {
-               extensionsLength += 4 * numExtensions
-               length += 2 + extensionsLength
-       }
+                       extensionsPresent = len(b.BytesOrPanic()) > 2
+               })
 
-       x := make([]byte, 4+length)
-       x[0] = typeServerHello
-       x[1] = uint8(length >> 16)
-       x[2] = uint8(length >> 8)
-       x[3] = uint8(length)
-       x[4] = uint8(m.vers >> 8)
-       x[5] = uint8(m.vers)
-       copy(x[6:38], m.random)
-       x[38] = uint8(len(m.sessionId))
-       copy(x[39:39+len(m.sessionId)], m.sessionId)
-       z := x[39+len(m.sessionId):]
-       z[0] = uint8(m.cipherSuite >> 8)
-       z[1] = uint8(m.cipherSuite)
-       z[2] = m.compressionMethod
-
-       z = z[3:]
-       if numExtensions > 0 {
-               z[0] = byte(extensionsLength >> 8)
-               z[1] = byte(extensionsLength)
-               z = z[2:]
-       }
-       if m.nextProtoNeg {
-               z[0] = byte(extensionNextProtoNeg >> 8)
-               z[1] = byte(extensionNextProtoNeg & 0xff)
-               z[2] = byte(nextProtoLen >> 8)
-               z[3] = byte(nextProtoLen)
-               z = z[4:]
-
-               for _, v := range m.nextProtos {
-                       l := len(v)
-                       if l > 255 {
-                               l = 255
-                       }
-                       z[0] = byte(l)
-                       copy(z[1:], []byte(v[0:l]))
-                       z = z[1+l:]
+               if !extensionsPresent {
+                       *b = bWithoutExtensions
                }
-       }
-       if m.ocspStapling {
-               z[0] = byte(extensionStatusRequest >> 8)
-               z[1] = byte(extensionStatusRequest)
-               z = z[4:]
-       }
-       if m.ticketSupported {
-               z[0] = byte(extensionSessionTicket >> 8)
-               z[1] = byte(extensionSessionTicket)
-               z = z[4:]
-       }
-       if m.secureRenegotiationSupported {
-               z[0] = byte(extensionRenegotiationInfo >> 8)
-               z[1] = byte(extensionRenegotiationInfo & 0xff)
-               z[2] = 0
-               z[3] = byte(len(m.secureRenegotiation) + 1)
-               z[4] = byte(len(m.secureRenegotiation))
-               z = z[5:]
-               copy(z, m.secureRenegotiation)
-               z = z[len(m.secureRenegotiation):]
-       }
-       if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
-               z[0] = byte(extensionALPN >> 8)
-               z[1] = byte(extensionALPN & 0xff)
-               l := 2 + 1 + alpnLen
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               l -= 2
-               z[4] = byte(l >> 8)
-               z[5] = byte(l)
-               l -= 1
-               z[6] = byte(l)
-               copy(z[7:], []byte(m.alpnProtocol))
-               z = z[7+alpnLen:]
-       }
-       if sctLen > 0 {
-               z[0] = byte(extensionSCT >> 8)
-               z[1] = byte(extensionSCT)
-               l := sctLen + 2
-               z[2] = byte(l >> 8)
-               z[3] = byte(l)
-               z[4] = byte(sctLen >> 8)
-               z[5] = byte(sctLen)
-
-               z = z[6:]
-               for _, sct := range m.scts {
-                       z[0] = byte(len(sct) >> 8)
-                       z[1] = byte(len(sct))
-                       copy(z[2:], sct)
-                       z = z[len(sct)+2:]
-               }
-       }
+       })
 
-       m.raw = x
-
-       return x
+       m.raw = b.BytesOrPanic()
+       return m.raw
 }
 
 func (m *serverHelloMsg) unmarshal(data []byte) bool {
-       if len(data) < 42 {
+       *m = serverHelloMsg{raw: data}
+       s := cryptobyte.String(data)
+
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
+               !readUint8LengthPrefixed(&s, &m.sessionId) ||
+               !s.ReadUint16(&m.cipherSuite) ||
+               !s.ReadUint8(&m.compressionMethod) {
                return false
        }
-       m.raw = data
-       m.vers = uint16(data[4])<<8 | uint16(data[5])
-       m.random = data[6:38]
-       sessionIdLen := int(data[38])
-       if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
-               return false
-       }
-       m.sessionId = data[39 : 39+sessionIdLen]
-       data = data[39+sessionIdLen:]
-       if len(data) < 3 {
-               return false
-       }
-       m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
-       m.compressionMethod = data[2]
-       data = data[3:]
-
-       m.nextProtoNeg = false
-       m.nextProtos = nil
-       m.ocspStapling = false
-       m.scts = nil
-       m.ticketSupported = false
-       m.alpnProtocol = ""
-
-       if len(data) == 0 {
+
+       if s.Empty() {
                // ServerHello is optionally followed by extension data
                return true
        }
-       if len(data) < 2 {
-               return false
-       }
 
-       extensionsLength := int(data[0])<<8 | int(data[1])
-       data = data[2:]
-       if len(data) != extensionsLength {
+       var extensions cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
                return false
        }
 
-       for len(data) != 0 {
-               if len(data) < 4 {
-                       return false
-               }
-               extension := uint16(data[0])<<8 | uint16(data[1])
-               length := int(data[2])<<8 | int(data[3])
-               data = data[4:]
-               if len(data) < length {
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
                        return false
                }
 
                switch extension {
                case extensionNextProtoNeg:
                        m.nextProtoNeg = true
-                       d := data[:length]
-                       for len(d) > 0 {
-                               l := int(d[0])
-                               d = d[1:]
-                               if l == 0 || l > len(d) {
+                       for !extData.Empty() {
+                               var proto cryptobyte.String
+                               if !extData.ReadUint8LengthPrefixed(&proto) ||
+                                       proto.Empty() {
                                        return false
                                }
-                               m.nextProtos = append(m.nextProtos, string(d[:l]))
-                               d = d[l:]
+                               m.nextProtos = append(m.nextProtos, string(proto))
                        }
                case extensionStatusRequest:
-                       if length > 0 {
-                               return false
-                       }
                        m.ocspStapling = true
                case extensionSessionTicket:
-                       if length > 0 {
-                               return false
-                       }
                        m.ticketSupported = true
                case extensionRenegotiationInfo:
-                       if length == 0 {
+                       if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
                                return false
                        }
-                       d := data[:length]
-                       l := int(d[0])
-                       d = d[1:]
-                       if l != len(d) {
-                               return false
-                       }
-
-                       m.secureRenegotiation = d
                        m.secureRenegotiationSupported = true
                case extensionALPN:
-                       d := data[:length]
-                       if len(d) < 3 {
+                       var protoList cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
                                return false
                        }
-                       l := int(d[0])<<8 | int(d[1])
-                       if l != len(d)-2 {
+                       var proto cryptobyte.String
+                       if !protoList.ReadUint8LengthPrefixed(&proto) ||
+                               proto.Empty() || !protoList.Empty() {
                                return false
                        }
-                       d = d[2:]
-                       l = int(d[0])
-                       if l != len(d)-1 {
-                               return false
-                       }
-                       d = d[1:]
-                       if len(d) == 0 {
-                               // ALPN protocols must not be empty.
-                               return false
-                       }
-                       m.alpnProtocol = string(d)
+                       m.alpnProtocol = string(proto)
                case extensionSCT:
-                       d := data[:length]
-
-                       if len(d) < 2 {
-                               return false
-                       }
-                       l := int(d[0])<<8 | int(d[1])
-                       d = d[2:]
-                       if len(d) != l || l == 0 {
+                       var sctList cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
                                return false
                        }
-
-                       m.scts = make([][]byte, 0, 3)
-                       for len(d) != 0 {
-                               if len(d) < 2 {
-                                       return false
-                               }
-                               sctLen := int(d[0])<<8 | int(d[1])
-                               d = d[2:]
-                               if sctLen == 0 || len(d) < sctLen {
+                       for !sctList.Empty() {
+                               var sct []byte
+                               if !readUint16LengthPrefixed(&sctList, &sct) ||
+                                       len(sct) == 0 {
                                        return false
                                }
-                               m.scts = append(m.scts, d[:sctLen])
-                               d = d[sctLen:]
+                               m.scts = append(m.scts, sct)
                        }
+               default:
+                       // Ignore unknown extensions.
+                       continue
+               }
+
+               if !extData.Empty() {
+                       return false
                }
-               data = data[length:]
        }
 
        return true
@@ -989,26 +761,27 @@ type finishedMsg struct {
        verifyData []byte
 }
 
-func (m *finishedMsg) marshal() (x []byte) {
+func (m *finishedMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
        }
 
-       x = make([]byte, 4+len(m.verifyData))
-       x[0] = typeFinished
-       x[3] = byte(len(m.verifyData))
-       copy(x[4:], m.verifyData)
-       m.raw = x
-       return
+       var b cryptobyte.Builder
+       b.AddUint8(typeFinished)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddBytes(m.verifyData)
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
 }
 
 func (m *finishedMsg) unmarshal(data []byte) bool {
        m.raw = data
-       if len(data) < 4 {
-               return false
-       }
-       m.verifyData = data[4:]
-       return true
+       s := cryptobyte.String(data)
+       return s.Skip(1) &&
+               readUint24LengthPrefixed(&s, &m.verifyData) &&
+               s.Empty()
 }
 
 type nextProtoMsg struct {
@@ -1073,10 +846,9 @@ func (m *nextProtoMsg) unmarshal(data []byte) bool {
 
 type certificateRequestMsg struct {
        raw []byte
-       // hasSignatureAndHash indicates whether this message includes a list
-       // of signature and hash functions. This change was introduced with TLS
-       // 1.2.
-       hasSignatureAndHash bool
+       // hasSignatureAlgorithm indicates whether this message includes a list of
+       // supported signature algorithms. This change was introduced with TLS 1.2.
+       hasSignatureAlgorithm bool
 
        certificateTypes             []byte
        supportedSignatureAlgorithms []SignatureScheme
@@ -1096,7 +868,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
        }
        length += casLength
 
-       if m.hasSignatureAndHash {
+       if m.hasSignatureAlgorithm {
                length += 2 + 2*len(m.supportedSignatureAlgorithms)
        }
 
@@ -1111,7 +883,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
        copy(x[5:], m.certificateTypes)
        y := x[5+len(m.certificateTypes):]
 
-       if m.hasSignatureAndHash {
+       if m.hasSignatureAlgorithm {
                n := len(m.supportedSignatureAlgorithms) * 2
                y[0] = uint8(n >> 8)
                y[1] = uint8(n)
@@ -1163,7 +935,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
 
        data = data[numCertTypes:]
 
-       if m.hasSignatureAndHash {
+       if m.hasSignatureAlgorithm {
                if len(data) < 2 {
                        return false
                }
@@ -1215,10 +987,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
 }
 
 type certificateVerifyMsg struct {
-       raw                 []byte
-       hasSignatureAndHash bool
-       signatureAlgorithm  SignatureScheme
-       signature           []byte
+       raw                   []byte
+       hasSignatureAlgorithm bool // format change introduced in TLS 1.2
+       signatureAlgorithm    SignatureScheme
+       signature             []byte
 }
 
 func (m *certificateVerifyMsg) marshal() (x []byte) {
@@ -1226,62 +998,34 @@ func (m *certificateVerifyMsg) marshal() (x []byte) {
                return m.raw
        }
 
-       // See RFC 4346, Section 7.4.8.
-       siglength := len(m.signature)
-       length := 2 + siglength
-       if m.hasSignatureAndHash {
-               length += 2
-       }
-       x = make([]byte, 4+length)
-       x[0] = typeCertificateVerify
-       x[1] = uint8(length >> 16)
-       x[2] = uint8(length >> 8)
-       x[3] = uint8(length)
-       y := x[4:]
-       if m.hasSignatureAndHash {
-               y[0] = uint8(m.signatureAlgorithm >> 8)
-               y[1] = uint8(m.signatureAlgorithm)
-               y = y[2:]
-       }
-       y[0] = uint8(siglength >> 8)
-       y[1] = uint8(siglength)
-       copy(y[2:], m.signature)
-
-       m.raw = x
+       var b cryptobyte.Builder
+       b.AddUint8(typeCertificateVerify)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               if m.hasSignatureAlgorithm {
+                       b.AddUint16(uint16(m.signatureAlgorithm))
+               }
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.signature)
+               })
+       })
 
-       return
+       m.raw = b.BytesOrPanic()
+       return m.raw
 }
 
 func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
        m.raw = data
+       s := cryptobyte.String(data)
 
-       if len(data) < 6 {
-               return false
-       }
-
-       length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
-       if uint32(len(data))-4 != length {
-               return false
-       }
-
-       data = data[4:]
-       if m.hasSignatureAndHash {
-               m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
-               data = data[2:]
-       }
-
-       if len(data) < 2 {
+       if !s.Skip(4) { // message type and uint24 length field
                return false
        }
-       siglength := int(data[0])<<8 + int(data[1])
-       data = data[2:]
-       if len(data) != siglength {
-               return false
+       if m.hasSignatureAlgorithm {
+               if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
+                       return false
+               }
        }
-
-       m.signature = data
-
-       return true
+       return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
 }
 
 type newSessionTicketMsg struct {
index c8cc0d6c5a0f69c39300f5a5c053bf3a33e814bc..fbc294b64ed4c2de732ed5a9b1eb0bb110216eb9 100644 (file)
@@ -20,7 +20,9 @@ var tests = []interface{}{
 
        &certificateMsg{},
        &certificateRequestMsg{},
-       &certificateVerifyMsg{},
+       &certificateVerifyMsg{
+               hasSignatureAlgorithm: true,
+       },
        &certificateStatusMsg{},
        &clientKeyExchangeMsg{},
        &nextProtoMsg{},
@@ -149,6 +151,10 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        if rand.Intn(10) > 5 {
                m.scts = true
        }
+       if rand.Intn(10) > 5 {
+               m.secureRenegotiationSupported = true
+               m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
+       }
 
        return reflect.ValueOf(m)
 }
@@ -180,6 +186,11 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
                m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
        }
 
+       if rand.Intn(10) > 5 {
+               m.secureRenegotiationSupported = true
+               m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
+       }
+
        return reflect.ValueOf(m)
 }
 
@@ -204,6 +215,8 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value
 
 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &certificateVerifyMsg{}
+       m.hasSignatureAlgorithm = true
+       m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
        m.signature = randomBytes(rand.Intn(15)+1, rand)
        return reflect.ValueOf(m)
 }
index b077c9058041dbd30416a5893cbf8523cc549bf8..bec128f4154643abce29b7eb1f6d0d1df2e1feb0 100644 (file)
@@ -418,7 +418,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                        byte(certTypeECDSASign),
                }
                if c.vers >= VersionTLS12 {
-                       certReq.hasSignatureAndHash = true
+                       certReq.hasSignatureAlgorithm = true
                        certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms
                }
 
index e14adbd7664fb4a41c586f771d89be4cb72511d5..01de92d97108b33227210b53146cc441bea6aa61 100644 (file)
@@ -101,13 +101,17 @@ var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x020
 
 func TestRejectBadProtocolVersion(t *testing.T) {
        for _, v := range badProtocolVersions {
-               testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version")
+               testClientHelloFailure(t, testConfig, &clientHelloMsg{
+                       vers:   v,
+                       random: make([]byte, 32),
+               }, "unsupported, maximum protocol version")
        }
 }
 
 func TestNoSuiteOverlap(t *testing.T) {
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{0xff00},
                compressionMethods: []uint8{compressionNone},
        }
@@ -117,6 +121,7 @@ func TestNoSuiteOverlap(t *testing.T) {
 func TestNoCompressionOverlap(t *testing.T) {
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{0xff},
        }
@@ -126,6 +131,7 @@ func TestNoCompressionOverlap(t *testing.T) {
 func TestNoRC4ByDefault(t *testing.T) {
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
        }
@@ -137,7 +143,11 @@ func TestNoRC4ByDefault(t *testing.T) {
 }
 
 func TestRejectSNIWithTrailingDot(t *testing.T) {
-       testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: VersionTLS12, serverName: "foo.com."}, "unexpected message")
+       testClientHelloFailure(t, testConfig, &clientHelloMsg{
+               vers:       VersionTLS12,
+               random:     make([]byte, 32),
+               serverName: "foo.com.",
+       }, "unexpected message")
 }
 
 func TestDontSelectECDSAWithRSAKey(t *testing.T) {
@@ -145,6 +155,7 @@ func TestDontSelectECDSAWithRSAKey(t *testing.T) {
        // won't be selected if the server's private key doesn't support it.
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
                compressionMethods: []uint8{compressionNone},
                supportedCurves:    []CurveID{CurveP256},
@@ -170,6 +181,7 @@ func TestDontSelectRSAWithECDSAKey(t *testing.T) {
        // won't be selected if the server's private key doesn't support it.
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
                compressionMethods: []uint8{compressionNone},
                supportedCurves:    []CurveID{CurveP256},
@@ -242,11 +254,9 @@ func TestRenegotiationExtension(t *testing.T) {
 func TestTLS12OnlyCipherSuites(t *testing.T) {
        // Test that a Server doesn't select a TLS 1.2-only cipher suite when
        // the client negotiates TLS 1.1.
-       var zeros [32]byte
-
        clientHello := &clientHelloMsg{
                vers:   VersionTLS11,
-               random: zeros[:],
+               random: make([]byte, 32),
                cipherSuites: []uint16{
                        // The Server, by default, will use the client's
                        // preference order. So the GCM cipher suite
@@ -878,6 +888,7 @@ func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
 
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
                serverName:         "test",
@@ -898,6 +909,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) {
 
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
        }
@@ -909,6 +921,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) {
 
        clientHello = &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
        }
@@ -1212,6 +1225,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
 
        clientHello := &clientHelloMsg{
                vers:               VersionTLS10,
+               random:             make([]byte, 32),
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
                serverName:         expectedServerName,
index 5e5b5ed655e20a19abe0b98f5d40e00b2f660f16..2bb63f4e8421679753e40be970f97e541783780b 100644 (file)
@@ -389,7 +389,7 @@ var pkgDeps = map[string][]string{
 
        // SSL/TLS.
        "crypto/tls": {
-               "L4", "CRYPTO-MATH", "OS",
+               "L4", "CRYPTO-MATH", "OS", "golang_org/x/crypto/cryptobyte",
                "container/list", "crypto/x509", "encoding/pem", "net", "syscall",
        },
        "crypto/x509": {