]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: don't cache marshal'd bytes
authorRoland Shoemaker <roland@golang.org>
Thu, 18 Apr 2024 17:51:25 +0000 (10:51 -0700)
committerGopher Robot <gobot@golang.org>
Fri, 19 Apr 2024 16:55:49 +0000 (16:55 +0000)
Only cache the wire representation for clientHelloMsg and serverHelloMsg
during unmarshal, which are the only places we actually need to hold
onto them. For everything else, remove the raw field.

This appears to have zero performance impact:

name                                               old time/op   new time/op   delta
CertCache/0-10                                       177µs ± 2%    189µs ±11%   ~     (p=0.700 n=3+3)
CertCache/1-10                                       184µs ± 3%    182µs ± 6%   ~     (p=1.000 n=3+3)
CertCache/2-10                                       187µs ±12%    187µs ± 2%   ~     (p=1.000 n=3+3)
CertCache/3-10                                       204µs ±21%    187µs ± 1%   ~     (p=0.700 n=3+3)
HandshakeServer/RSA-10                               410µs ± 2%    410µs ± 3%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv13-10             473µs ± 3%    460µs ± 2%   ~     (p=0.200 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv12-10             498µs ± 3%    489µs ± 2%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv13-10      140µs ± 5%    138µs ± 5%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv12-10      132µs ± 1%    133µs ± 2%   ~     (p=0.400 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv13-10    168µs ± 1%    171µs ± 4%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv12-10    166µs ± 3%    163µs ± 0%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv13-10     1.87ms ± 2%   1.81ms ± 0%   ~     (p=0.100 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv12-10     1.86ms ± 0%   1.86ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv12-10                  6.79ms ± 3%   6.73ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                  6.73ms ± 1%   6.75ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                  12.8ms ± 2%   12.7ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                  13.1ms ± 3%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                  24.9ms ± 2%   24.7ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                  26.0ms ± 4%   24.9ms ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                  50.0ms ± 3%   48.9ms ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                  49.8ms ± 2%   49.3ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                 97.3ms ± 1%   97.4ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                 97.9ms ± 0%   97.9ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                  195ms ± 0%    194ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                  196ms ± 0%    196ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                  405ms ± 3%    385ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                  391ms ± 1%    388ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10              6.75ms ± 0%   6.75ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10              6.84ms ± 1%   6.77ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10              12.8ms ± 1%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10              12.8ms ± 1%   13.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10              24.8ms ± 1%   24.8ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10              25.1ms ± 2%   25.1ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10              49.2ms ± 2%   48.9ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10              49.3ms ± 1%   49.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10             97.1ms ± 0%   98.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10             98.8ms ± 1%   98.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10              192ms ± 0%    198ms ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10              194ms ± 0%    196ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10              385ms ± 1%    384ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10              387ms ± 0%    388ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/200kbps/TLSv12-10                  694ms ± 0%    694ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/200kbps/TLSv13-10                  699ms ± 0%    699ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/500kbps/TLSv12-10                  278ms ± 0%    278ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/500kbps/TLSv13-10                  280ms ± 0%    280ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/1000kbps/TLSv12-10                 140ms ± 1%    140ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/1000kbps/TLSv13-10                 141ms ± 0%    141ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/2000kbps/TLSv12-10                70.5ms ± 0%   70.4ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/2000kbps/TLSv13-10                70.7ms ± 0%   70.7ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv12-10                28.8ms ± 0%   28.8ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv13-10                28.9ms ± 0%   28.9ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv12-10              134ms ± 0%    134ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv13-10              138ms ± 0%    138ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv12-10             54.1ms ± 0%   54.1ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv13-10             55.7ms ± 0%   55.7ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv12-10            27.6ms ± 0%   27.6ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv13-10            28.4ms ± 0%   28.4ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv12-10            14.4ms ± 0%   14.4ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv13-10            14.6ms ± 0%   14.6ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv12-10            6.44ms ± 0%   6.45ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv13-10            6.49ms ± 0%   6.49ms ± 0%   ~     (p=0.700 n=3+3)

name                                               old speed     new speed     delta
Throughput/MaxPacket/1MB/TLSv12-10                 155MB/s ± 3%  156MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                 156MB/s ± 1%  155MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                 163MB/s ± 2%  165MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                 160MB/s ± 3%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                 162MB/s ± 4%  168MB/s ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                 168MB/s ± 3%  172MB/s ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                172MB/s ± 1%  172MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                171MB/s ± 0%  171MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                172MB/s ± 0%  173MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                171MB/s ± 0%  172MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                166MB/s ± 3%  174MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                171MB/s ± 1%  173MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10             155MB/s ± 0%  155MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10             153MB/s ± 1%  155MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10             164MB/s ± 1%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10             163MB/s ± 1%  162MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10             169MB/s ± 1%  169MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10             167MB/s ± 1%  167MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10             170MB/s ± 2%  171MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10             170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10            170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10            175MB/s ± 0%  170MB/s ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.300 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10            174MB/s ± 1%  175MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10            174MB/s ± 0%  173MB/s ± 0%   ~     (p=0.400 n=3+3)

Change-Id: Ifa79cce002011850ed8b2835edd34f60e014eea8
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-arm64-longtest
Reviewed-on: https://go-review.googlesource.com/c/go/+/580215
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>

src/crypto/tls/common.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go

index 849e8b0a209d330f6f902509387d1d0fd5a2cfbe..58dc0c231cc7ac1bd12504fe039670506cf6a7d1 100644 (file)
@@ -1445,6 +1445,15 @@ type handshakeMessage interface {
        unmarshal([]byte) bool
 }
 
+type handshakeMessageWithOriginalBytes interface {
+       handshakeMessage
+
+       // originalBytes should return the original bytes that were passed to
+       // unmarshal to create the message. If the message was not produced by
+       // unmarshal, it should return nil.
+       originalBytes() []byte
+}
+
 // lruSessionCache is a ClientSessionCache implementation that uses an LRU
 // caching strategy.
 type lruSessionCache struct {
index a0fc413f8ff4f1546494676d8b846e510f86dbda..bc8670a6f2b5a59156c549da0a60857e7e407e5f 100644 (file)
@@ -249,7 +249,6 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
                hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
        }
 
-       hs.hello.raw = nil
        if len(hs.hello.pskIdentities) > 0 {
                pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
                if pskSuite == nil {
index a86055a0601344e6f966c947e2071508b82ac031..b1920db6c213e05f0914578f43031dd34ad35d41 100644 (file)
@@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
 }
 
 type clientHelloMsg struct {
-       raw                              []byte
+       original                         []byte
        vers                             uint16
        random                           []byte
        sessionId                        []byte
@@ -98,10 +98,6 @@ type clientHelloMsg struct {
 }
 
 func (m *clientHelloMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var exts cryptobyte.Builder
        if len(m.serverName) > 0 {
                // RFC 6066, Section 3
@@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
                }
        })
 
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 // marshalWithoutBinders returns the ClientHello through the
@@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
                bindersLen += len(binder)
        }
 
-       fullMessage, err := m.marshal()
-       if err != nil {
-               return nil, err
+       var fullMessage []byte
+       if m.original != nil {
+               fullMessage = m.original
+       } else {
+               var err error
+               fullMessage, err = m.marshal()
+               if err != nil {
+                       return nil, err
+               }
        }
        return fullMessage[:len(fullMessage)-bindersLen], nil
 }
 
-// updateBinders updates the m.pskBinders field, if necessary updating the
-// cached marshaled representation. The supplied binders must have the same
-// length as the current m.pskBinders.
+// updateBinders updates the m.pskBinders field. The supplied binders must have
+// the same length as the current m.pskBinders.
 func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
        if len(pskBinders) != len(m.pskBinders) {
                return errors.New("tls: internal error: pskBinders length mismatch")
@@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
                }
        }
        m.pskBinders = pskBinders
-       if m.raw != nil {
-               helloBytes, err := m.marshalWithoutBinders()
-               if err != nil {
-                       return err
-               }
-               lenWithoutBinders := len(helloBytes)
-               b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
-               b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
-                       for _, binder := range m.pskBinders {
-                               b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
-                                       b.AddBytes(binder)
-                               })
-                       }
-               })
-               if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
-                       return errors.New("tls: internal error: failed to update binders")
-               }
-       }
 
        return nil
 }
 
 func (m *clientHelloMsg) unmarshal(data []byte) bool {
-       *m = clientHelloMsg{raw: data}
+       *m = clientHelloMsg{original: data}
        s := cryptobyte.String(data)
 
        if !s.Skip(4) || // message type and uint24 length field
@@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
        return true
 }
 
+func (m *clientHelloMsg) originalBytes() []byte {
+       return m.original
+}
+
 type serverHelloMsg struct {
-       raw                          []byte
+       original                     []byte
        vers                         uint16
        random                       []byte
        sessionId                    []byte
@@ -651,10 +637,6 @@ type serverHelloMsg struct {
 }
 
 func (m *serverHelloMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var exts cryptobyte.Builder
        if m.ocspStapling {
                exts.AddUint16(extensionStatusRequest)
@@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) {
                }
        })
 
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *serverHelloMsg) unmarshal(data []byte) bool {
-       *m = serverHelloMsg{raw: data}
+       *m = serverHelloMsg{original: data}
        s := cryptobyte.String(data)
 
        if !s.Skip(4) || // message type and uint24 length field
@@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
        return true
 }
 
+func (m *serverHelloMsg) originalBytes() []byte {
+       return m.original
+}
+
 type encryptedExtensionsMsg struct {
-       raw                     []byte
        alpnProtocol            string
        quicTransportParameters []byte
        earlyData               bool
 }
 
 func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeEncryptedExtensions)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
                })
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
-       *m = encryptedExtensionsMsg{raw: data}
+       *m = encryptedExtensionsMsg{}
        s := cryptobyte.String(data)
 
        var extensions cryptobyte.String
@@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
 }
 
 type keyUpdateMsg struct {
-       raw             []byte
        updateRequested bool
 }
 
 func (m *keyUpdateMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeKeyUpdate)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) {
                }
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *keyUpdateMsg) unmarshal(data []byte) bool {
-       m.raw = data
        s := cryptobyte.String(data)
 
        var updateRequested uint8
@@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool {
 }
 
 type newSessionTicketMsgTLS13 struct {
-       raw          []byte
        lifetime     uint32
        ageAdd       uint32
        nonce        []byte
@@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct {
 }
 
 func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeNewSessionTicket)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
                })
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
-       *m = newSessionTicketMsgTLS13{raw: data}
+       *m = newSessionTicketMsgTLS13{}
        s := cryptobyte.String(data)
 
        var extensions cryptobyte.String
@@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
 }
 
 type certificateRequestMsgTLS13 struct {
-       raw                              []byte
        ocspStapling                     bool
        scts                             bool
        supportedSignatureAlgorithms     []SignatureScheme
@@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct {
 }
 
 func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeCertificateRequest)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
                })
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
-       *m = certificateRequestMsgTLS13{raw: data}
+       *m = certificateRequestMsgTLS13{}
        s := cryptobyte.String(data)
 
        var context, extensions cryptobyte.String
@@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
 }
 
 type certificateMsg struct {
-       raw          []byte
        certificates [][]byte
 }
 
 func (m *certificateMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var i int
        for _, slice := range m.certificates {
                i += len(slice)
@@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) {
                y = y[3+len(slice):]
        }
 
-       m.raw = x
-       return m.raw, nil
+       return x, nil
 }
 
 func (m *certificateMsg) unmarshal(data []byte) bool {
@@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
                return false
        }
 
-       m.raw = data
        certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
        if uint32(len(data)) != certsLen+7 {
                return false
@@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
 }
 
 type certificateMsgTLS13 struct {
-       raw          []byte
        certificate  Certificate
        ocspStapling bool
        scts         bool
 }
 
 func (m *certificateMsgTLS13) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeCertificate)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) {
                marshalCertificate(b, certificate)
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
@@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
 }
 
 func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
-       *m = certificateMsgTLS13{raw: data}
+       *m = certificateMsgTLS13{}
        s := cryptobyte.String(data)
 
        var context cryptobyte.String
@@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
 }
 
 type serverKeyExchangeMsg struct {
-       raw []byte
        key []byte
 }
 
 func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
        length := len(m.key)
        x := make([]byte, length+4)
        x[0] = typeServerKeyExchange
@@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
        x[3] = uint8(length)
        copy(x[4:], m.key)
 
-       m.raw = x
        return x, nil
 }
 
 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
-       m.raw = data
        if len(data) < 4 {
                return false
        }
@@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
 }
 
 type certificateStatusMsg struct {
-       raw      []byte
        response []byte
 }
 
 func (m *certificateStatusMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeCertificateStatus)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) {
                })
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *certificateStatusMsg) unmarshal(data []byte) bool {
-       m.raw = data
        s := cryptobyte.String(data)
 
        var statusType uint8
@@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
 }
 
 type clientKeyExchangeMsg struct {
-       raw        []byte
        ciphertext []byte
 }
 
 func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
        length := len(m.ciphertext)
        x := make([]byte, length+4)
        x[0] = typeClientKeyExchange
@@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
        x[3] = uint8(length)
        copy(x[4:], m.ciphertext)
 
-       m.raw = x
        return x, nil
 }
 
 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
-       m.raw = data
        if len(data) < 4 {
                return false
        }
@@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
 }
 
 type finishedMsg struct {
-       raw        []byte
        verifyData []byte
 }
 
 func (m *finishedMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeFinished)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
                b.AddBytes(m.verifyData)
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *finishedMsg) unmarshal(data []byte) bool {
-       m.raw = data
        s := cryptobyte.String(data)
        return s.Skip(1) &&
                readUint24LengthPrefixed(&s, &m.verifyData) &&
@@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
 }
 
 type certificateRequestMsg struct {
-       raw []byte
        // hasSignatureAlgorithm indicates whether this message includes a list of
        // supported signature algorithms. This change was introduced with TLS 1.2.
        hasSignatureAlgorithm bool
@@ -1654,10 +1567,6 @@ type certificateRequestMsg struct {
 }
 
 func (m *certificateRequestMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        // See RFC 4346, Section 7.4.4.
        length := 1 + len(m.certificateTypes) + 2
        casLength := 0
@@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) {
                y = y[len(ca):]
        }
 
-       m.raw = x
-       return m.raw, nil
+       return x, nil
 }
 
 func (m *certificateRequestMsg) unmarshal(data []byte) bool {
-       m.raw = data
-
        if len(data) < 5 {
                return false
        }
@@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
 }
 
 type certificateVerifyMsg struct {
-       raw                   []byte
        hasSignatureAlgorithm bool // format change introduced in TLS 1.2
        signatureAlgorithm    SignatureScheme
        signature             []byte
 }
 
 func (m *certificateVerifyMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        var b cryptobyte.Builder
        b.AddUint8(typeCertificateVerify)
        b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) {
                })
        })
 
-       var err error
-       m.raw, err = b.Bytes()
-       return m.raw, err
+       return b.Bytes()
 }
 
 func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
-       m.raw = data
        s := cryptobyte.String(data)
 
        if !s.Skip(4) { // message type and uint24 length field
@@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
 }
 
 type newSessionTicketMsg struct {
-       raw    []byte
        ticket []byte
 }
 
 func (m *newSessionTicketMsg) marshal() ([]byte, error) {
-       if m.raw != nil {
-               return m.raw, nil
-       }
-
        // See RFC 5077, Section 3.3.
        ticketLen := len(m.ticket)
        length := 2 + 4 + ticketLen
@@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) {
        x[9] = uint8(ticketLen)
        copy(x[10:], m.ticket)
 
-       m.raw = x
-
-       return m.raw, nil
+       return x, nil
 }
 
 func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
-       m.raw = data
-
        if len(data) < 10 {
                return false
        }
@@ -1891,9 +1780,25 @@ type transcriptHash interface {
        Write([]byte) (int, error)
 }
 
-// transcriptMsg is a helper used to marshal and hash messages which typically
-// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
+// transcriptMsg is a helper used to hash messages which are not hashed when
+// they are read from, or written to, the wire. This is typically the case for
+// messages which are either not sent, or need to be hashed out of order from
+// when they are read/written.
+//
+// For most messages, the message is marshalled using their marshal method,
+// since their wire representation is idempotent. For clientHelloMsg and
+// serverHelloMsg, we store the original wire representation of the message and
+// use that for hashing, since unmarshal/marshal are not idempotent due to
+// extension ordering and other malleable fields, which may cause differences
+// between what was received and what we marshal.
 func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
+       if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
+               if orig := msgWithOrig.originalBytes(); orig != nil {
+                       h.Write(msgWithOrig.originalBytes())
+                       return nil
+               }
+       }
+
        data, err := msg.marshal()
        if err != nil {
                return err
index 72e8bd8c25690463b3083b7d450300000a31ffa8..6c083f104378dbcf5aeed2ab2b53f58954c95e87 100644 (file)
@@ -53,49 +53,61 @@ func TestMarshalUnmarshal(t *testing.T) {
 
        for i, m := range tests {
                ty := reflect.ValueOf(m).Type()
-
-               n := 100
-               if testing.Short() {
-                       n = 5
-               }
-               for j := 0; j < n; j++ {
-                       v, ok := quick.Value(ty, rand)
-                       if !ok {
-                               t.Errorf("#%d: failed to create value", i)
-                               break
+               t.Run(ty.String(), func(t *testing.T) {
+                       n := 100
+                       if testing.Short() {
+                               n = 5
                        }
+                       for j := 0; j < n; j++ {
+                               v, ok := quick.Value(ty, rand)
+                               if !ok {
+                                       t.Errorf("#%d: failed to create value", i)
+                                       break
+                               }
 
-                       m1 := v.Interface().(handshakeMessage)
-                       marshaled := mustMarshal(t, m1)
-                       if !m.unmarshal(marshaled) {
-                               t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
-                               break
-                       }
-                       m.marshal() // to fill any marshal cache in the message
+                               m1 := v.Interface().(handshakeMessage)
+                               marshaled := mustMarshal(t, m1)
+                               if !m.unmarshal(marshaled) {
+                                       t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
+                                       break
+                               }
 
-                       if m, ok := m.(*SessionState); ok {
-                               m.activeCertHandles = nil
-                       }
+                               if m, ok := m.(*SessionState); ok {
+                                       m.activeCertHandles = nil
+                               }
 
-                       if !reflect.DeepEqual(m1, m) {
-                               t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
-                               break
-                       }
+                               // clientHelloMsg and serverHelloMsg, when unmarshalled, store
+                               // their original representation, for later use in the handshake
+                               // transcript. In order to prevent DeepEqual from failing since
+                               // we didn't create the original message via unmarshalling, nil
+                               // the field.
+                               switch t := m.(type) {
+                               case *clientHelloMsg:
+                                       t.original = nil
+                               case *serverHelloMsg:
+                                       t.original = nil
+                               }
 
-                       if i >= 3 {
-                               // The first three message types (ClientHello,
-                               // ServerHello and Finished) are allowed to
-                               // have parsable prefixes because the extension
-                               // data is optional and the length of the
-                               // Finished varies across versions.
-                               for j := 0; j < len(marshaled); j++ {
-                                       if m.unmarshal(marshaled[0:j]) {
-                                               t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
-                                               break
+                               if !reflect.DeepEqual(m1, m) {
+                                       t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
+                                       break
+                               }
+
+                               if i >= 3 {
+                                       // The first three message types (ClientHello,
+                                       // ServerHello and Finished) are allowed to
+                                       // have parsable prefixes because the extension
+                                       // data is optional and the length of the
+                                       // Finished varies across versions.
+                                       for j := 0; j < len(marshaled); j++ {
+                                               if m.unmarshal(marshaled[0:j]) {
+                                                       t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
+                                                       break
+                                               }
                                        }
                                }
                        }
-               }
+               })
        }
 }