]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: implement TLS 1.3 version-specific messages
authorFilippo Valsorda <filippo@golang.org>
Sun, 28 Oct 2018 22:04:54 +0000 (18:04 -0400)
committerFilippo Valsorda <filippo@golang.org>
Fri, 2 Nov 2018 22:04:51 +0000 (22:04 +0000)
Note that there is significant code duplication due to extensions with
the same format appearing in different messages in TLS 1.3. This will be
cleaned up in a future refactor once CL 145317 is merged.

Enforcing the presence/absence of each extension in each message is left
to the upper layer, based on both protocol version and extensions
advertised in CH and CR. Duplicated extensions and unknown extensions in
SH, EE, HRR, and CT will be tightened up in a future CL.

The TLS 1.2 CertificateStatus message was restricted to accepting only
type OCSP as any other type (none of which are specified so far) would
have to be negotiated.

Updates #9671

Change-Id: I7c42394c5cc0af01faa84b9b9f25fdc6e7cfbb9e
Reviewed-on: https://go-review.googlesource.com/c/145477
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/handshake_server.go

index 4808c01f9ca02c81e26fe5df80d154e79fd35f4d..d08b096b20900695172476af4b3af748f6d1815d 100644 (file)
@@ -56,19 +56,23 @@ const (
 
 // TLS handshake message types.
 const (
-       typeHelloRequest       uint8 = 0
-       typeClientHello        uint8 = 1
-       typeServerHello        uint8 = 2
-       typeNewSessionTicket   uint8 = 4
-       typeCertificate        uint8 = 11
-       typeServerKeyExchange  uint8 = 12
-       typeCertificateRequest uint8 = 13
-       typeServerHelloDone    uint8 = 14
-       typeCertificateVerify  uint8 = 15
-       typeClientKeyExchange  uint8 = 16
-       typeFinished           uint8 = 20
-       typeCertificateStatus  uint8 = 22
-       typeNextProtocol       uint8 = 67 // Not IANA assigned
+       typeHelloRequest        uint8 = 0
+       typeClientHello         uint8 = 1
+       typeServerHello         uint8 = 2
+       typeNewSessionTicket    uint8 = 4
+       typeEndOfEarlyData      uint8 = 5
+       typeEncryptedExtensions uint8 = 8
+       typeCertificate         uint8 = 11
+       typeServerKeyExchange   uint8 = 12
+       typeCertificateRequest  uint8 = 13
+       typeServerHelloDone     uint8 = 14
+       typeCertificateVerify   uint8 = 15
+       typeClientKeyExchange   uint8 = 16
+       typeFinished            uint8 = 20
+       typeCertificateStatus   uint8 = 22
+       typeKeyUpdate           uint8 = 24
+       typeNextProtocol        uint8 = 67  // Not IANA assigned
+       typeMessageHash         uint8 = 254 // synthetic message
 )
 
 // TLS compression types.
@@ -87,6 +91,7 @@ const (
        extensionSCT                     uint16 = 18
        extensionSessionTicket           uint16 = 35
        extensionPreSharedKey            uint16 = 41
+       extensionEarlyData               uint16 = 42
        extensionSupportedVersions       uint16 = 43
        extensionCookie                  uint16 = 44
        extensionPSKModes                uint16 = 45
index 5af1413935363c10f96abab97877256ed531cb31..36199640956b87483b54da6526e73b5e5b322289 100644 (file)
@@ -990,12 +990,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
        case typeServerHello:
                m = new(serverHelloMsg)
        case typeNewSessionTicket:
-               m = new(newSessionTicketMsg)
+               if c.vers == VersionTLS13 {
+                       m = new(newSessionTicketMsgTLS13)
+               } else {
+                       m = new(newSessionTicketMsg)
+               }
        case typeCertificate:
-               m = new(certificateMsg)
+               if c.vers == VersionTLS13 {
+                       m = new(certificateMsgTLS13)
+               } else {
+                       m = new(certificateMsg)
+               }
        case typeCertificateRequest:
-               m = &certificateRequestMsg{
-                       hasSignatureAlgorithm: c.vers >= VersionTLS12,
+               if c.vers == VersionTLS13 {
+                       m = new(certificateRequestMsgTLS13)
+               } else {
+                       m = &certificateRequestMsg{
+                               hasSignatureAlgorithm: c.vers >= VersionTLS12,
+                       }
                }
        case typeCertificateStatus:
                m = new(certificateStatusMsg)
@@ -1013,6 +1025,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
                m = new(nextProtoMsg)
        case typeFinished:
                m = new(finishedMsg)
+       case typeEncryptedExtensions:
+               m = new(encryptedExtensionsMsg)
+       case typeEndOfEarlyData:
+               m = new(endOfEarlyDataMsg)
+       case typeKeyUpdate:
+               m = new(keyUpdateMsg)
        default:
                return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
        }
index fb74f79bd8830d4ea98f168b6d4f10eb41808c37..322839caacecb2061ac73c39b0b2ad4fce51b10d 100644 (file)
@@ -393,9 +393,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                }
                hs.finishedHash.Write(cs.marshal())
 
-               if cs.statusType == statusTypeOCSP {
-                       c.ocspResponse = cs.response
-               }
+               c.ocspResponse = cs.response
 
                msg, err = c.readHandshake()
                if err != nil {
index d04efc98f6d886eafdfa178f8caf30b1e72b7fd1..82b91cc87e449ed1f9a2b26973c66ce7344fce7e 100644 (file)
@@ -71,6 +71,7 @@ type clientHelloMsg struct {
        supportedVersions                []uint16
        cookie                           []byte
        keyShares                        []keyShare
+       earlyData                        bool
        pskModes                         []uint8
        pskIdentities                    []pskIdentity
        pskBinders                       [][]byte
@@ -239,6 +240,11 @@ func (m *clientHelloMsg) marshal() []byte {
                                        })
                                })
                        }
+                       if m.earlyData {
+                               // RFC 8446, Section 4.2.10
+                               b.AddUint16(extensionEarlyData)
+                               b.AddUint16(0) // empty extension_data
+                       }
                        if len(m.pskModes) > 0 {
                                // RFC 8446, Section 4.2.9
                                b.AddUint16(extensionPSKModes)
@@ -478,6 +484,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
                                }
                                m.keyShares = append(m.keyShares, ks)
                        }
+               case extensionEarlyData:
+                       // RFC 8446, Section 4.2.10
+                       m.earlyData = true
                case extensionPSKModes:
                        // RFC 8446, Section 4.2.9
                        if !readUint8LengthPrefixed(&extData, &m.pskModes) {
@@ -782,6 +791,342 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
        return true
 }
 
+type encryptedExtensionsMsg struct {
+       raw          []byte
+       alpnProtocol string
+}
+
+func (m *encryptedExtensionsMsg) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var b cryptobyte.Builder
+       b.AddUint8(typeEncryptedExtensions)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       if len(m.alpnProtocol) > 0 {
+                               b.AddUint16(extensionALPN)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddBytes([]byte(m.alpnProtocol))
+                                               })
+                                       })
+                               })
+                       }
+               })
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
+}
+
+func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
+       *m = encryptedExtensionsMsg{raw: data}
+       s := cryptobyte.String(data)
+
+       var extensions cryptobyte.String
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
+               return false
+       }
+
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
+                       return false
+               }
+
+               switch extension {
+               case extensionALPN:
+                       var protoList cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
+                               return false
+                       }
+                       var proto cryptobyte.String
+                       if !protoList.ReadUint8LengthPrefixed(&proto) ||
+                               proto.Empty() || !protoList.Empty() {
+                               return false
+                       }
+                       m.alpnProtocol = string(proto)
+               default:
+                       // Ignore unknown extensions.
+                       continue
+               }
+
+               if !extData.Empty() {
+                       return false
+               }
+       }
+
+       return true
+}
+
+type endOfEarlyDataMsg struct{}
+
+func (m *endOfEarlyDataMsg) marshal() []byte {
+       x := make([]byte, 4)
+       x[0] = typeEndOfEarlyData
+       return x
+}
+
+func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
+       return len(data) == 4
+}
+
+type keyUpdateMsg struct {
+       raw             []byte
+       updateRequested bool
+}
+
+func (m *keyUpdateMsg) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var b cryptobyte.Builder
+       b.AddUint8(typeKeyUpdate)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               if m.updateRequested {
+                       b.AddUint8(1)
+               } else {
+                       b.AddUint8(0)
+               }
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
+}
+
+func (m *keyUpdateMsg) unmarshal(data []byte) bool {
+       m.raw = data
+       s := cryptobyte.String(data)
+
+       var updateRequested uint8
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint8(&updateRequested) || !s.Empty() {
+               return false
+       }
+       switch updateRequested {
+       case 0:
+               m.updateRequested = false
+       case 1:
+               m.updateRequested = true
+       default:
+               return false
+       }
+       return true
+}
+
+type newSessionTicketMsgTLS13 struct {
+       raw          []byte
+       lifetime     uint32
+       ageAdd       uint32
+       nonce        []byte
+       label        []byte
+       maxEarlyData uint32
+}
+
+func (m *newSessionTicketMsgTLS13) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var b cryptobyte.Builder
+       b.AddUint8(typeNewSessionTicket)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint32(m.lifetime)
+               b.AddUint32(m.ageAdd)
+               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.nonce)
+               })
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.label)
+               })
+
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       if m.maxEarlyData > 0 {
+                               b.AddUint16(extensionEarlyData)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint32(m.maxEarlyData)
+                               })
+                       }
+               })
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
+}
+
+func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
+       *m = newSessionTicketMsgTLS13{raw: data}
+       s := cryptobyte.String(data)
+
+       var extensions cryptobyte.String
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint32(&m.lifetime) ||
+               !s.ReadUint32(&m.ageAdd) ||
+               !readUint8LengthPrefixed(&s, &m.nonce) ||
+               !readUint16LengthPrefixed(&s, &m.label) ||
+               !s.ReadUint16LengthPrefixed(&extensions) ||
+               !s.Empty() {
+               return false
+       }
+
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
+                       return false
+               }
+
+               switch extension {
+               case extensionEarlyData:
+                       if !extData.ReadUint32(&m.maxEarlyData) {
+                               return false
+                       }
+               default:
+                       // Ignore unknown extensions.
+                       continue
+               }
+
+               if !extData.Empty() {
+                       return false
+               }
+       }
+
+       return true
+}
+
+type certificateRequestMsgTLS13 struct {
+       raw                              []byte
+       ocspStapling                     bool
+       scts                             bool
+       supportedSignatureAlgorithms     []SignatureScheme
+       supportedSignatureAlgorithmsCert []SignatureScheme
+}
+
+func (m *certificateRequestMsgTLS13) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var b cryptobyte.Builder
+       b.AddUint8(typeCertificateRequest)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               // certificate_request_context (SHALL be zero length unless used for
+               // post-handshake authentication)
+               b.AddUint8(0)
+
+               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                       if m.ocspStapling {
+                               b.AddUint16(extensionStatusRequest)
+                               b.AddUint16(0) // empty extension_data
+                       }
+                       if m.scts {
+                               // RFC 8446, Section 4.4.2.1 makes no mention of
+                               // signed_certificate_timestamp in CertificateRequest, but
+                               // "Extensions in the Certificate message from the client MUST
+                               // correspond to extensions in the CertificateRequest message
+                               // from the server." and it appears in the table in Section 4.2.
+                               b.AddUint16(extensionSCT)
+                               b.AddUint16(0) // empty extension_data
+                       }
+                       if len(m.supportedSignatureAlgorithms) > 0 {
+                               b.AddUint16(extensionSignatureAlgorithms)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, sigAlgo := range m.supportedSignatureAlgorithms {
+                                                       b.AddUint16(uint16(sigAlgo))
+                                               }
+                                       })
+                               })
+                       }
+                       if len(m.supportedSignatureAlgorithmsCert) > 0 {
+                               b.AddUint16(extensionSignatureAlgorithmsCert)
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                               for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
+                                                       b.AddUint16(uint16(sigAlgo))
+                                               }
+                                       })
+                               })
+                       }
+               })
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
+}
+
+func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
+       *m = certificateRequestMsgTLS13{raw: data}
+       s := cryptobyte.String(data)
+
+       var context, extensions cryptobyte.String
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
+               !s.ReadUint16LengthPrefixed(&extensions) ||
+               !s.Empty() {
+               return false
+       }
+
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
+                       return false
+               }
+
+               switch extension {
+               case extensionStatusRequest:
+                       m.ocspStapling = true
+               case extensionSCT:
+                       m.scts = true
+               case extensionSignatureAlgorithms:
+                       var sigAndAlgs cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+                               return false
+                       }
+                       for !sigAndAlgs.Empty() {
+                               var sigAndAlg uint16
+                               if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+                                       return false
+                               }
+                               m.supportedSignatureAlgorithms = append(
+                                       m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+                       }
+               case extensionSignatureAlgorithmsCert:
+                       var sigAndAlgs cryptobyte.String
+                       if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+                               return false
+                       }
+                       for !sigAndAlgs.Empty() {
+                               var sigAndAlg uint16
+                               if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+                                       return false
+                               }
+                               m.supportedSignatureAlgorithmsCert = append(
+                                       m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
+                       }
+               default:
+                       // Ignore unknown extensions.
+                       continue
+               }
+
+               if !extData.Empty() {
+                       return false
+               }
+       }
+
+       return true
+}
+
 type certificateMsg struct {
        raw          []byte
        certificates [][]byte
@@ -859,6 +1204,131 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
        return true
 }
 
+type certificateMsgTLS13 struct {
+       raw          []byte
+       certificate  Certificate
+       ocspStapling bool
+       scts         bool
+}
+
+func (m *certificateMsgTLS13) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var b cryptobyte.Builder
+       b.AddUint8(typeCertificate)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint8(0) // certificate_request_context
+               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                       for i, cert := range m.certificate.Certificate {
+                               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       b.AddBytes(cert)
+                               })
+                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       if i > 0 {
+                                               // This library only supports OCSP and SCT for leaf certificates.
+                                               return
+                                       }
+                                       if m.ocspStapling {
+                                               b.AddUint16(extensionStatusRequest)
+                                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddUint8(statusTypeOCSP)
+                                                       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                               b.AddBytes(m.certificate.OCSPStaple)
+                                                       })
+                                               })
+                                       }
+                                       if m.scts {
+                                               b.AddUint16(extensionSCT)
+                                               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                               for _, sct := range m.certificate.SignedCertificateTimestamps {
+                                                                       b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                                               b.AddBytes(sct)
+                                                                       })
+                                                               }
+                                                       })
+                                               })
+                                       }
+                               })
+                       }
+               })
+       })
+
+       m.raw = b.BytesOrPanic()
+       return m.raw
+}
+
+func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
+       *m = certificateMsgTLS13{raw: data}
+       s := cryptobyte.String(data)
+
+       var context, certList cryptobyte.String
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
+               !s.ReadUint24LengthPrefixed(&certList) ||
+               !s.Empty() {
+               return false
+       }
+
+       for !certList.Empty() {
+               var cert []byte
+               var extensions cryptobyte.String
+               if !readUint24LengthPrefixed(&certList, &cert) ||
+                       !certList.ReadUint16LengthPrefixed(&extensions) {
+                       return false
+               }
+               m.certificate.Certificate = append(m.certificate.Certificate, cert)
+               for !extensions.Empty() {
+                       var extension uint16
+                       var extData cryptobyte.String
+                       if !extensions.ReadUint16(&extension) ||
+                               !extensions.ReadUint16LengthPrefixed(&extData) {
+                               return false
+                       }
+                       if len(m.certificate.Certificate) > 1 {
+                               // This library only supports OCSP and SCT for leaf certificates.
+                               continue
+                       }
+
+                       switch extension {
+                       case extensionStatusRequest:
+                               m.ocspStapling = true
+                               var statusType uint8
+                               if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
+                                       !readUint24LengthPrefixed(&extData, &m.certificate.OCSPStaple) ||
+                                       len(m.certificate.OCSPStaple) == 0 {
+                                       return false
+                               }
+                       case extensionSCT:
+                               m.scts = true
+                               var sctList cryptobyte.String
+                               if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
+                                       return false
+                               }
+                               for !sctList.Empty() {
+                                       var sct []byte
+                                       if !readUint16LengthPrefixed(&sctList, &sct) ||
+                                               len(sct) == 0 {
+                                               return false
+                                       }
+                                       m.certificate.SignedCertificateTimestamps = append(
+                                               m.certificate.SignedCertificateTimestamps, sct)
+                               }
+                       default:
+                               // Ignore unknown extensions.
+                               continue
+                       }
+
+                       if !extData.Empty() {
+                               return false
+                       }
+               }
+       }
+       return true
+}
+
 type serverKeyExchangeMsg struct {
        raw []byte
        key []byte
@@ -890,9 +1360,8 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
 }
 
 type certificateStatusMsg struct {
-       raw        []byte
-       statusType uint8
-       response   []byte
+       raw      []byte
+       response []byte
 }
 
 func (m *certificateStatusMsg) marshal() []byte {
@@ -900,46 +1369,29 @@ func (m *certificateStatusMsg) marshal() []byte {
                return m.raw
        }
 
-       var x []byte
-       if m.statusType == statusTypeOCSP {
-               x = make([]byte, 4+4+len(m.response))
-               x[0] = typeCertificateStatus
-               l := len(m.response) + 4
-               x[1] = byte(l >> 16)
-               x[2] = byte(l >> 8)
-               x[3] = byte(l)
-               x[4] = statusTypeOCSP
-
-               l -= 4
-               x[5] = byte(l >> 16)
-               x[6] = byte(l >> 8)
-               x[7] = byte(l)
-               copy(x[8:], m.response)
-       } else {
-               x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
-       }
+       var b cryptobyte.Builder
+       b.AddUint8(typeCertificateStatus)
+       b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+               b.AddUint8(statusTypeOCSP)
+               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                       b.AddBytes(m.response)
+               })
+       })
 
-       m.raw = x
-       return x
+       m.raw = b.BytesOrPanic()
+       return m.raw
 }
 
 func (m *certificateStatusMsg) unmarshal(data []byte) bool {
        m.raw = data
-       if len(data) < 5 {
-               return false
-       }
-       m.statusType = data[4]
+       s := cryptobyte.String(data)
 
-       m.response = nil
-       if m.statusType == statusTypeOCSP {
-               if len(data) < 8 {
-                       return false
-               }
-               respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
-               if uint32(len(data)) != 4+4+respLen {
-                       return false
-               }
-               m.response = data[8:]
+       var statusType uint8
+       if !s.Skip(4) || // message type and uint24 length field
+               !s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
+               !readUint24LengthPrefixed(&s, &m.response) ||
+               len(m.response) == 0 || !s.Empty() {
+               return false
        }
        return true
 }
index d32f33f378871f6189021a5f328534acd25d47ee..ab9e1f50fd4ed7c12638368e492f89d3cdfe6899 100644 (file)
@@ -29,6 +29,12 @@ var tests = []interface{}{
        &nextProtoMsg{},
        &newSessionTicketMsg{},
        &sessionState{},
+       &encryptedExtensionsMsg{},
+       &endOfEarlyDataMsg{},
+       &keyUpdateMsg{},
+       &newSessionTicketMsgTLS13{},
+       &certificateRequestMsgTLS13{},
+       &certificateMsgTLS13{},
 }
 
 func TestMarshalUnmarshal(t *testing.T) {
@@ -184,6 +190,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
                m.pskIdentities = append(m.pskIdentities, psk)
                m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
        }
+       if rand.Intn(10) > 5 {
+               m.earlyData = true
+       }
 
        return reflect.ValueOf(m)
 }
@@ -209,7 +218,9 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        if rand.Intn(10) > 5 {
                m.ticketSupported = true
        }
-       m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
+       if rand.Intn(10) > 5 {
+               m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
+       }
 
        for i := 0; i < rand.Intn(4); i++ {
                m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
@@ -241,6 +252,16 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.ValueOf(m)
 }
 
+func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &encryptedExtensionsMsg{}
+
+       if rand.Intn(10) > 5 {
+               m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
+       }
+
+       return reflect.ValueOf(m)
+}
+
 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &certificateMsg{}
        numCerts := rand.Intn(20)
@@ -270,12 +291,7 @@ func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
 
 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &certificateStatusMsg{}
-       if rand.Intn(10) > 5 {
-               m.statusType = statusTypeOCSP
-               m.response = randomBytes(rand.Intn(10)+1, rand)
-       } else {
-               m.statusType = 42
-       }
+       m.response = randomBytes(rand.Intn(10)+1, rand)
        return reflect.ValueOf(m)
 }
 
@@ -316,6 +332,66 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.ValueOf(s)
 }
 
+func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &endOfEarlyDataMsg{}
+       return reflect.ValueOf(m)
+}
+
+func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &keyUpdateMsg{}
+       m.updateRequested = rand.Intn(10) > 5
+       return reflect.ValueOf(m)
+}
+
+func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &newSessionTicketMsgTLS13{}
+       m.lifetime = uint32(rand.Intn(500000))
+       m.ageAdd = uint32(rand.Intn(500000))
+       m.nonce = randomBytes(rand.Intn(100), rand)
+       m.label = randomBytes(rand.Intn(1000), rand)
+       if rand.Intn(10) > 5 {
+               m.maxEarlyData = uint32(rand.Intn(500000))
+       }
+       return reflect.ValueOf(m)
+}
+
+func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &certificateRequestMsgTLS13{}
+       if rand.Intn(10) > 5 {
+               m.ocspStapling = true
+       }
+       if rand.Intn(10) > 5 {
+               m.scts = true
+       }
+       if rand.Intn(10) > 5 {
+               m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
+       }
+       if rand.Intn(10) > 5 {
+               m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
+       }
+       return reflect.ValueOf(m)
+}
+
+func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &certificateMsgTLS13{}
+       for i := 0; i < rand.Intn(2)+1; i++ {
+               m.certificate.Certificate = append(
+                       m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
+       }
+       if rand.Intn(10) > 5 {
+               m.ocspStapling = true
+               m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
+       }
+       if rand.Intn(10) > 5 {
+               m.scts = true
+               for i := 0; i < rand.Intn(2)+1; i++ {
+                       m.certificate.SignedCertificateTimestamps = append(
+                               m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
+               }
+       }
+       return reflect.ValueOf(m)
+}
+
 func TestRejectEmptySCTList(t *testing.T) {
        // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
 
index bec128f4154643abce29b7eb1f6d0d1df2e1feb0..2c916e853e62265fab6c350381407ebea0db98b2 100644 (file)
@@ -389,7 +389,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
 
        if hs.hello.ocspStapling {
                certStatus := new(certificateStatusMsg)
-               certStatus.statusType = statusTypeOCSP
                certStatus.response = hs.cert.OCSPStaple
                hs.finishedHash.Write(certStatus.marshal())
                if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {