]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: use SessionState on the client side
authorFilippo Valsorda <filippo@golang.org>
Sun, 21 May 2023 19:17:56 +0000 (21:17 +0200)
committerFilippo Valsorda <filippo@golang.org>
Wed, 24 May 2023 23:56:41 +0000 (23:56 +0000)
Another internal change, that allows exposing the new APIs easily in
following CLs.

For #60105

Change-Id: I9c61b9f6e9d29af633f952444f514bcbbe82fe4e
Reviewed-on: https://go-review.googlesource.com/c/go/+/496819
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Filippo Valsorda <filippo@golang.org>

src/crypto/tls/cache.go
src/crypto/tls/common.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/ticket.go

index 09f58250a8387357587f56dcbb2a9779d97f10e9..a7677611fdc7bf927585cb3cf84628f126a36eea 100644 (file)
@@ -39,7 +39,7 @@ type certCache struct {
        sync.Map
 }
 
-var clientCertCache = new(certCache)
+var globalCertCache = new(certCache)
 
 // activeCert is a handle to a certificate held in the cache. Once there are
 // no alive activeCerts for a given certificate, the certificate is removed
index 58e97306c0e5112e7eedced162bb8739aacb616b..ccaf7d352fcdef77b6ed8f4065af0b5a4a25799c 100644 (file)
@@ -330,25 +330,6 @@ func requiresClientCert(c ClientAuthType) bool {
        }
 }
 
-// ClientSessionState contains the state needed by clients to resume TLS
-// sessions.
-type ClientSessionState struct {
-       sessionTicket      []uint8               // Encrypted ticket used for session resumption with server
-       vers               uint16                // TLS version negotiated for the session
-       cipherSuite        uint16                // Ciphersuite negotiated for the session
-       masterSecret       []byte                // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret
-       serverCertificates []*x509.Certificate   // Certificate chain presented by the server
-       verifiedChains     [][]*x509.Certificate // Certificate chains we built for verification
-       receivedAt         time.Time             // When the session ticket was received from the server
-       ocspResponse       []byte                // Stapled OCSP response presented by the server
-       scts               [][]byte              // SCTs presented by the server
-
-       // TLS 1.3 fields.
-       nonce  []byte    // Ticket nonce sent by the server, to derive PSK
-       useBy  time.Time // Expiration of the ticket lifetime as set by the server
-       ageAdd uint32    // Random obfuscation factor for sending the ticket age
-}
-
 // ClientSessionCache is a cache of ClientSessionState objects that can be used
 // by a client to resume a TLS session with a given server. ClientSessionCache
 // implementations should expect to be called concurrently from different
index 9f74cc4ef9723eb69d95013e7fc7a78a20a5d482..2156e9183b1aad2f322308afe69e0b5497c738ed 100644 (file)
@@ -31,7 +31,8 @@ type clientHandshakeState struct {
        suite        *cipherSuite
        finishedHash finishedHash
        masterSecret []byte
-       session      *ClientSessionState
+       session      *SessionState // the session being resumed
+       ticket       []byte        // a fresh ticket received during this handshake
 }
 
 var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme
@@ -177,11 +178,11 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
        }
        c.serverName = hello.serverName
 
-       cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
+       session, earlySecret, binderKey, err := c.loadSession(hello)
        if err != nil {
                return err
        }
-       if cacheKey != "" && session != nil {
+       if session != nil {
                defer func() {
                        // If we got a handshake failure when resuming a session, throw away
                        // the session ticket. See RFC 5077, Section 3.2.
@@ -190,7 +191,9 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
                        // does require servers to abort on invalid binders, so we need to
                        // delete tickets to recover from a corrupted PSK.
                        if err != nil {
-                               c.config.ClientSessionCache.Put(cacheKey, nil)
+                               if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
+                                       c.config.ClientSessionCache.Put(cacheKey, nil)
+                               }
                        }
                }()
        }
@@ -255,19 +258,13 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
                return err
        }
 
-       // If we had a successful handshake and hs.session is different from
-       // the one already cached - cache a new one.
-       if cacheKey != "" && hs.session != nil && session != hs.session {
-               c.config.ClientSessionCache.Put(cacheKey, hs.session)
-       }
-
        return nil
 }
 
-func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
-       session *ClientSessionState, earlySecret, binderKey []byte, err error) {
+func (c *Conn) loadSession(hello *clientHelloMsg) (
+       session *SessionState, earlySecret, binderKey []byte, err error) {
        if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
-               return "", nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
 
        hello.ticketSupported = true
@@ -282,29 +279,30 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
        // renegotiation is primarily used to allow a client to send a client
        // certificate, which would be skipped if session resumption occurred.
        if c.handshakes != 0 {
-               return "", nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
 
        // Try to resume a previously negotiated TLS session, if available.
-       cacheKey = c.clientSessionCacheKey()
+       cacheKey := c.clientSessionCacheKey()
        if cacheKey == "" {
-               return "", nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
-       session, ok := c.config.ClientSessionCache.Get(cacheKey)
-       if !ok || session == nil {
-               return cacheKey, nil, nil, nil, nil
+       cs, ok := c.config.ClientSessionCache.Get(cacheKey)
+       if !ok || cs == nil {
+               return nil, nil, nil, nil
        }
+       session = cs.session
 
        // Check that version used for the previous session is still valid.
        versOk := false
        for _, v := range hello.supportedVersions {
-               if v == session.vers {
+               if v == session.version {
                        versOk = true
                        break
                }
        }
        if !versOk {
-               return cacheKey, nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
 
        // Check that the cached server certificate is not expired, and that it's
@@ -313,41 +311,41 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
        if !c.config.InsecureSkipVerify {
                if len(session.verifiedChains) == 0 {
                        // The original connection had InsecureSkipVerify, while this doesn't.
-                       return cacheKey, nil, nil, nil, nil
+                       return nil, nil, nil, nil
                }
-               serverCert := session.serverCertificates[0]
+               serverCert := session.peerCertificates[0]
                if c.config.time().After(serverCert.NotAfter) {
                        // Expired certificate, delete the entry.
                        c.config.ClientSessionCache.Put(cacheKey, nil)
-                       return cacheKey, nil, nil, nil, nil
+                       return nil, nil, nil, nil
                }
                if err := serverCert.VerifyHostname(c.config.ServerName); err != nil {
-                       return cacheKey, nil, nil, nil, nil
+                       return nil, nil, nil, nil
                }
        }
 
-       if session.vers != VersionTLS13 {
+       if session.version != VersionTLS13 {
                // In TLS 1.2 the cipher suite must match the resumed session. Ensure we
                // are still offering it.
                if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil {
-                       return cacheKey, nil, nil, nil, nil
+                       return nil, nil, nil, nil
                }
 
-               hello.sessionTicket = session.sessionTicket
+               hello.sessionTicket = cs.ticket
                return
        }
 
        // Check that the session ticket is not expired.
-       if c.config.time().After(session.useBy) {
+       if c.config.time().After(time.Unix(int64(session.useBy), 0)) {
                c.config.ClientSessionCache.Put(cacheKey, nil)
-               return cacheKey, nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
 
        // In TLS 1.3 the KDF hash must match the resumed session. Ensure we
        // offer at least one cipher suite with that hash.
        cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite)
        if cipherSuite == nil {
-               return cacheKey, nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
        cipherSuiteOk := false
        for _, offeredID := range hello.cipherSuites {
@@ -358,32 +356,30 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
                }
        }
        if !cipherSuiteOk {
-               return cacheKey, nil, nil, nil, nil
+               return nil, nil, nil, nil
        }
 
        // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
-       ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond)
+       ticketAge := c.config.time().Sub(time.Unix(int64(session.createdAt), 0))
        identity := pskIdentity{
-               label:               session.sessionTicket,
-               obfuscatedTicketAge: ticketAge + session.ageAdd,
+               label:               cs.ticket,
+               obfuscatedTicketAge: uint32(ticketAge/time.Millisecond) + session.ageAdd,
        }
        hello.pskIdentities = []pskIdentity{identity}
        hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())}
 
        // Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
-       psk := cipherSuite.expandLabel(session.masterSecret, "resumption",
-               session.nonce, cipherSuite.hash.Size())
-       earlySecret = cipherSuite.extract(psk, nil)
+       earlySecret = cipherSuite.extract(session.secret, nil)
        binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
        transcript := cipherSuite.hash.New()
        helloBytes, err := hello.marshalWithoutBinders()
        if err != nil {
-               return "", nil, nil, nil, err
+               return nil, nil, nil, err
        }
        transcript.Write(helloBytes)
        pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)}
        if err := hello.updateBinders(pskBinders); err != nil {
-               return "", nil, nil, nil, err
+               return nil, nil, nil, err
        }
 
        return
@@ -485,6 +481,9 @@ func (hs *clientHandshakeState) handshake() error {
                        return err
                }
        }
+       if err := hs.saveSessionTicket(); err != nil {
+               return err
+       }
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
        c.isHandshakeComplete.Store(true)
@@ -752,7 +751,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
                return false, nil
        }
 
-       if hs.session.vers != c.vers {
+       if hs.session.version != c.vers {
                c.sendAlert(alertHandshakeFailure)
                return false, errors.New("tls: server resumed a session with a different version")
        }
@@ -762,9 +761,10 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
                return false, errors.New("tls: server resumed a session with a different cipher suite")
        }
 
-       // Restore masterSecret, peerCerts, and ocspResponse from previous state
-       hs.masterSecret = hs.session.masterSecret
-       c.peerCertificates = hs.session.serverCertificates
+       // Restore master secret and certificates from previous state
+       hs.masterSecret = hs.session.secret
+       c.peerCertificates = hs.session.peerCertificates
+       c.activeCertHandles = hs.c.activeCertHandles
        c.verifiedChains = hs.session.verifiedChains
        c.ocspResponse = hs.session.ocspResponse
        // Let the ServerHello SCTs override the session SCTs from the original
@@ -836,8 +836,13 @@ func (hs *clientHandshakeState) readSessionTicket() error {
        if !hs.serverHello.ticketSupported {
                return nil
        }
-
        c := hs.c
+
+       if !hs.hello.ticketSupported {
+               c.sendAlert(alertIllegalParameter)
+               return errors.New("tls: server sent unrequested session ticket")
+       }
+
        msg, err := c.readHandshake(&hs.finishedHash)
        if err != nil {
                return err
@@ -848,18 +853,29 @@ func (hs *clientHandshakeState) readSessionTicket() error {
                return unexpectedMessageError(sessionTicketMsg, msg)
        }
 
-       hs.session = &ClientSessionState{
-               sessionTicket:      sessionTicketMsg.ticket,
-               vers:               c.vers,
-               cipherSuite:        hs.suite.id,
-               masterSecret:       hs.masterSecret,
-               serverCertificates: c.peerCertificates,
-               verifiedChains:     c.verifiedChains,
-               receivedAt:         c.config.time(),
-               ocspResponse:       c.ocspResponse,
-               scts:               c.scts,
+       hs.ticket = sessionTicketMsg.ticket
+       return nil
+}
+
+func (hs *clientHandshakeState) saveSessionTicket() error {
+       if hs.ticket == nil {
+               return nil
+       }
+       c := hs.c
+
+       cacheKey := c.clientSessionCacheKey()
+       if cacheKey == "" {
+               return nil
+       }
+
+       session, err := c.sessionState()
+       if err != nil {
+               return err
        }
+       session.secret = hs.masterSecret
 
+       cs := &ClientSessionState{ticket: hs.ticket, session: session}
+       c.config.ClientSessionCache.Put(cacheKey, cs)
        return nil
 }
 
@@ -885,7 +901,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
        activeHandles := make([]*activeCert, len(certificates))
        certs := make([]*x509.Certificate, len(certificates))
        for i, asn1Data := range certificates {
-               cert, err := clientCertCache.newCert(asn1Data)
+               cert, err := globalCertCache.newCert(asn1Data)
                if err != nil {
                        c.sendAlert(alertBadCertificate)
                        return errors.New("tls: failed to parse certificate from server: " + err.Error())
index fef5038810ca4114129ca21482fef554b6c369cf..cf7c09b08faed9d63f2ed049ee78c787b2e5e6c8 100644 (file)
@@ -916,14 +916,14 @@ func testResumption(t *testing.T, version uint16) {
        }
 
        getTicket := func() []byte {
-               return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
+               return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.ticket
        }
        deleteTicket := func() {
                ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
                clientConfig.ClientSessionCache.Put(ticketKey, nil)
        }
        corruptTicket := func() {
-               clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
+               clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.secret[0] ^= 0xff
        }
        randomKey := func() [32]byte {
                var k [32]byte
index 15e0a748485717839bcbb066172d0fa9c1299416..b26992b19ebe6e1edd0fb8dbf7dde32066c8853f 100644 (file)
@@ -23,7 +23,7 @@ type clientHandshakeStateTLS13 struct {
        hello       *clientHelloMsg
        ecdheKey    *ecdh.PrivateKey
 
-       session     *ClientSessionState
+       session     *SessionState
        earlySecret []byte
        binderKey   []byte
 
@@ -256,8 +256,8 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
                }
                if pskSuite.hash == hs.suite.hash {
                        // Update binders and obfuscated_ticket_age.
-                       ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
-                       hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
+                       ticketAge := c.config.time().Sub(time.Unix(int64(hs.session.createdAt), 0))
+                       hs.hello.pskIdentities[0].obfuscatedTicketAge = uint32(ticketAge/time.Millisecond) + hs.session.ageAdd
 
                        transcript := hs.suite.hash.New()
                        transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
@@ -355,7 +355,8 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
 
        hs.usingPSK = true
        c.didResume = true
-       c.peerCertificates = hs.session.serverCertificates
+       c.peerCertificates = hs.session.peerCertificates
+       c.activeCertHandles = hs.session.activeCertHandles
        c.verifiedChains = hs.session.verifiedChains
        c.ocspResponse = hs.session.ocspResponse
        c.scts = hs.session.scts
@@ -719,28 +720,21 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
                return c.sendAlert(alertInternalError)
        }
 
-       // Save the resumption_master_secret and nonce instead of deriving the PSK
-       // to do the least amount of work on NewSessionTicket messages before we
-       // know if the ticket will be used. Forward secrecy of resumed connections
-       // is guaranteed by the requirement for pskModeDHE.
-       session := &ClientSessionState{
-               sessionTicket:      msg.label,
-               vers:               c.vers,
-               cipherSuite:        c.cipherSuite,
-               masterSecret:       c.resumptionSecret,
-               serverCertificates: c.peerCertificates,
-               verifiedChains:     c.verifiedChains,
-               receivedAt:         c.config.time(),
-               nonce:              msg.nonce,
-               useBy:              c.config.time().Add(lifetime),
-               ageAdd:             msg.ageAdd,
-               ocspResponse:       c.ocspResponse,
-               scts:               c.scts,
-       }
-
-       cacheKey := c.clientSessionCacheKey()
-       if cacheKey != "" {
-               c.config.ClientSessionCache.Put(cacheKey, session)
+       psk := cipherSuite.expandLabel(c.resumptionSecret, "resumption",
+               msg.nonce, cipherSuite.hash.Size())
+
+       session, err := c.sessionState()
+       if err != nil {
+               c.sendAlert(alertInternalError)
+               return err
+       }
+       session.secret = psk
+       session.useBy = uint64(c.config.time().Add(lifetime).Unix())
+       session.ageAdd = msg.ageAdd
+       cs := &ClientSessionState{ticket: msg.label, session: session}
+
+       if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
+               c.config.ClientSessionCache.Put(cacheKey, cs)
        }
 
        return nil
index b280f0967468a84c91a55b341c77bb8a303bbfbc..85efacf4cb1bc800a58005eac26756904d2e8650 100644 (file)
@@ -6,7 +6,9 @@ package tls
 
 import (
        "bytes"
+       "crypto/x509"
        "encoding/hex"
+       "math"
        "math/rand"
        "reflect"
        "strings"
@@ -71,6 +73,10 @@ func TestMarshalUnmarshal(t *testing.T) {
                        }
                        m.marshal() // to fill any marshal cache in the message
 
+                       if m, ok := m.(*SessionState); ok {
+                               m.activeCertHandles = nil
+                       }
+
                        if !reflect.DeepEqual(m1, m) {
                                t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
                                break
@@ -97,7 +103,7 @@ func TestFuzz(t *testing.T) {
        rand := rand.New(rand.NewSource(0))
        for _, m := range tests {
                for j := 0; j < 1000; j++ {
-                       len := rand.Intn(100)
+                       len := rand.Intn(1000)
                        bytes := randomBytes(len, rand)
                        // This just looks for crashes due to bounds errors etc.
                        m.unmarshal(bytes)
@@ -313,23 +319,59 @@ func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.ValueOf(m)
 }
 
+var sessionTestCerts []*x509.Certificate
+
+func init() {
+       cert, err := x509.ParseCertificate(testRSACertificate)
+       if err != nil {
+               panic(err)
+       }
+       sessionTestCerts = append(sessionTestCerts, cert)
+       cert, err = x509.ParseCertificate(testRSACertificateIssuer)
+       if err != nil {
+               panic(err)
+       }
+       sessionTestCerts = append(sessionTestCerts, cert)
+}
+
 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
        s := &SessionState{}
-       s.version = uint16(rand.Intn(10000))
-       s.cipherSuite = uint16(rand.Intn(10000))
-       s.secret = randomBytes(rand.Intn(100)+1, rand)
+       isTLS13 := rand.Intn(10) > 5
+       if isTLS13 {
+               s.version = VersionTLS13
+       } else {
+               s.version = uint16(rand.Intn(VersionTLS13))
+       }
+       s.isClient = rand.Intn(10) > 5
+       s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
        s.createdAt = uint64(rand.Int63())
-       for i := 0; i < rand.Intn(2)+1; i++ {
-               s.certificate.Certificate = append(
-                       s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
+       s.secret = randomBytes(rand.Intn(100)+1, rand)
+       if s.isClient || rand.Intn(10) > 5 {
+               if rand.Intn(10) > 5 {
+                       s.peerCertificates = sessionTestCerts
+               } else {
+                       s.peerCertificates = sessionTestCerts[:1]
+               }
        }
-       if rand.Intn(10) > 5 {
-               s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
+       if rand.Intn(10) > 5 && s.peerCertificates != nil {
+               s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
        }
-       if rand.Intn(10) > 5 {
+       if rand.Intn(10) > 5 && s.peerCertificates != nil {
                for i := 0; i < rand.Intn(2)+1; i++ {
-                       s.certificate.SignedCertificateTimestamps = append(
-                               s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
+                       s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
+               }
+       }
+       if s.isClient {
+               for i := 0; i < rand.Intn(3); i++ {
+                       if rand.Intn(10) > 5 {
+                               s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
+                       } else {
+                               s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
+                       }
+               }
+               if isTLS13 {
+                       s.useBy = uint64(rand.Int63())
+                       s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
                }
        }
        return reflect.ValueOf(s)
index 5e5badca95710b48328a8fe4dc4d5a3aa7e3b2be..7dda65676ab15455cd5deec9cf09b338ec07f6f3 100644 (file)
@@ -448,7 +448,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
                return false
        }
 
-       sessionHasClientCerts := len(hs.sessionState.certificate.Certificate) != 0
+       sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0
        needClientCerts := requiresClientCert(c.config.ClientAuth)
        if needClientCerts && !sessionHasClientCerts {
                return false
@@ -481,7 +481,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
                return err
        }
 
-       if err := c.processCertsFromClient(hs.sessionState.certificate); err != nil {
+       if err := c.processCertsFromClient(hs.sessionState.certificate()); err != nil {
                return err
        }
 
@@ -759,27 +759,15 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
        c := hs.c
        m := new(newSessionTicketMsg)
 
-       createdAt := uint64(c.config.time().Unix())
+       state, err := c.sessionState()
+       if err != nil {
+               return err
+       }
+       state.secret = hs.masterSecret
        if hs.sessionState != nil {
                // If this is re-wrapping an old key, then keep
                // the original time it was created.
-               createdAt = hs.sessionState.createdAt
-       }
-
-       var certsFromClient [][]byte
-       for _, cert := range c.peerCertificates {
-               certsFromClient = append(certsFromClient, cert.Raw)
-       }
-       state := SessionState{
-               version:     c.vers,
-               cipherSuite: hs.suite.id,
-               createdAt:   createdAt,
-               secret:      hs.masterSecret,
-               certificate: Certificate{
-                       Certificate:                 certsFromClient,
-                       OCSPStaple:                  c.ocspResponse,
-                       SignedCertificateTimestamps: c.scts,
-               },
+               state.createdAt = hs.sessionState.createdAt
        }
        stateBytes, err := state.Bytes()
        if err != nil {
index f770a21663ecc73d98a914b9fb0acf8655ca4274..6753ad4aee05928ba073ed26c05b18ca10495e9d 100644 (file)
@@ -301,7 +301,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
                // PSK connections don't re-establish client certificates, but carry
                // them over in the session ticket. Ensure the presence of client certs
                // in the ticket is consistent with the configured requirements.
-               sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
+               sessionHasClientCerts := len(sessionState.peerCertificates) != 0
                needClientCerts := requiresClientCert(c.config.ClientAuth)
                if needClientCerts && !sessionHasClientCerts {
                        continue
@@ -331,7 +331,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
                }
 
                c.didResume = true
-               if err := c.processCertsFromClient(sessionState.certificate); err != nil {
+               if err := c.processCertsFromClient(sessionState.certificate()); err != nil {
                        return err
                }
 
@@ -776,21 +776,11 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
 
        m := new(newSessionTicketMsgTLS13)
 
-       var certsFromClient [][]byte
-       for _, cert := range c.peerCertificates {
-               certsFromClient = append(certsFromClient, cert.Raw)
-       }
-       state := &SessionState{
-               version:     c.vers,
-               cipherSuite: hs.suite.id,
-               createdAt:   uint64(c.config.time().Unix()),
-               secret:      psk,
-               certificate: Certificate{
-                       Certificate:                 certsFromClient,
-                       OCSPStaple:                  c.ocspResponse,
-                       SignedCertificateTimestamps: c.scts,
-               },
+       state, err := c.sessionState()
+       if err != nil {
+               return err
        }
+       state.secret = psk
        stateBytes, err := state.Bytes()
        if err != nil {
                c.sendAlert(alertInternalError)
index dfa0d430c0042e1ecd4ede4143f87b12fab4f2ce..44bedd66de5a7d8f5c8577dc0470b5e32b1e1eec 100644 (file)
@@ -10,6 +10,7 @@ import (
        "crypto/hmac"
        "crypto/sha256"
        "crypto/subtle"
+       "crypto/x509"
        "errors"
        "io"
 
@@ -18,12 +19,63 @@ import (
 
 // A SessionState is a resumable session.
 type SessionState struct {
-       version uint16 // uint16 version;
-       // uint8 revision = 1;
+       // Encoded as a SessionState (in the language of RFC 8446, Section 3).
+       //
+       //   enum { server(1), client(2) } SessionStateType;
+       //
+       //   opaque Certificate<1..2^24-1>;
+       //
+       //   Certificate CertificateChain<0..2^24-1>;
+       //
+       //   struct {
+       //       uint16 version;
+       //       SessionStateType type;
+       //       uint16 cipher_suite;
+       //       uint64 created_at;
+       //       opaque secret<1..2^8-1>;
+       //       CertificateEntry certificate_list<0..2^24-1>;
+       //       select (SessionState.type) {
+       //           case server: /* empty */;
+       //           case client: struct {
+       //               CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */
+       //               select (SessionState.version) {
+       //                   case VersionTLS10..VersionTLS12: /* empty */;
+       //                   case VersionTLS13: struct {
+       //                       uint64 use_by;
+       //                       uint32 age_add;
+       //                   };
+       //               };
+       //           };
+       //       };
+       //   } SessionState;
+       //
+
+       version     uint16
+       isClient    bool
        cipherSuite uint16
-       createdAt   uint64
-       secret      []byte      // opaque master_secret<1..2^8-1>;
-       certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
+       // createdAt is the generation time of the secret on the sever (which for
+       // TLS 1.0–1.2 might be earlier than the current session) and the time at
+       // which the ticket was received on the client.
+       createdAt         uint64 // seconds since UNIX epoch
+       secret            []byte // master secret for TLS 1.2, or the PSK for TLS 1.3
+       peerCertificates  []*x509.Certificate
+       activeCertHandles []*activeCert
+       ocspResponse      []byte
+       scts              [][]byte
+
+       // Client-side fields.
+       verifiedChains [][]*x509.Certificate
+
+       // Client-side TLS 1.3-only fields.
+       useBy  uint64 // seconds since UNIX epoch
+       ageAdd uint32
+}
+
+// ClientSessionState contains the state needed by clients to resume TLS
+// sessions.
+type ClientSessionState struct {
+       ticket  []byte
+       session *SessionState
 }
 
 // Bytes encodes the session, including any private fields, so that it can be
@@ -31,38 +83,157 @@ type SessionState struct {
 //
 // The specific encoding should be considered opaque and may change incompatibly
 // between Go versions.
-func (m *SessionState) Bytes() ([]byte, error) {
+func (s *SessionState) Bytes() ([]byte, error) {
        var b cryptobyte.Builder
-       b.AddUint16(m.version)
-       b.AddUint8(1) // revision
-       b.AddUint16(m.cipherSuite)
-       addUint64(&b, m.createdAt)
+       b.AddUint16(s.version)
+       if s.isClient {
+               b.AddUint8(2) // client
+       } else {
+               b.AddUint8(1) // server
+       }
+       b.AddUint16(s.cipherSuite)
+       addUint64(&b, s.createdAt)
        b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
-               b.AddBytes(m.secret)
+               b.AddBytes(s.secret)
        })
-       marshalCertificate(&b, m.certificate)
+       marshalCertificate(&b, s.certificate())
+       if s.isClient {
+               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                       for _, chain := range s.verifiedChains {
+                               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                                       // We elide the first certificate because it's always the leaf.
+                                       if len(chain) == 0 {
+                                               b.SetError(errors.New("tls: internal error: empty verified chain"))
+                                               return
+                                       }
+                                       for _, cert := range chain[1:] {
+                                               b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+                                                       b.AddBytes(cert.Raw)
+                                               })
+                                       }
+                               })
+                       }
+               })
+               if s.version >= VersionTLS13 {
+                       addUint64(&b, s.useBy)
+                       b.AddUint32(s.ageAdd)
+               }
+       }
        return b.Bytes()
 }
 
+func (s *SessionState) certificate() Certificate {
+       return Certificate{
+               Certificate:                 certificatesToBytesSlice(s.peerCertificates),
+               OCSPStaple:                  s.ocspResponse,
+               SignedCertificateTimestamps: s.scts,
+       }
+}
+
+func certificatesToBytesSlice(certs []*x509.Certificate) [][]byte {
+       s := make([][]byte, 0, len(certs))
+       for _, c := range certs {
+               s = append(s, c.Raw)
+       }
+       return s
+}
+
 // ParseSessionState parses a [SessionState] encoded by [SessionState.Bytes].
 func ParseSessionState(data []byte) (*SessionState, error) {
        ss := &SessionState{}
        s := cryptobyte.String(data)
-       var revision uint8
+       var typ uint8
+       var cert Certificate
        if !s.ReadUint16(&ss.version) ||
-               !s.ReadUint8(&revision) ||
-               revision != 1 ||
+               !s.ReadUint8(&typ) ||
+               (typ != 1 && typ != 2) ||
                !s.ReadUint16(&ss.cipherSuite) ||
                !readUint64(&s, &ss.createdAt) ||
                !readUint8LengthPrefixed(&s, &ss.secret) ||
                len(ss.secret) == 0 ||
-               !unmarshalCertificate(&s, &ss.certificate) ||
-               !s.Empty() {
+               !unmarshalCertificate(&s, &cert) {
+               return nil, errors.New("tls: invalid session encoding")
+       }
+       for _, cert := range cert.Certificate {
+               c, err := globalCertCache.newCert(cert)
+               if err != nil {
+                       return nil, err
+               }
+               ss.activeCertHandles = append(ss.activeCertHandles, c)
+               ss.peerCertificates = append(ss.peerCertificates, c.cert)
+       }
+       ss.ocspResponse = cert.OCSPStaple
+       ss.scts = cert.SignedCertificateTimestamps
+       if isClient := typ == 2; !isClient {
+               if !s.Empty() {
+                       return nil, errors.New("tls: invalid session encoding")
+               }
+               return ss, nil
+       }
+       ss.isClient = true
+       if len(ss.peerCertificates) == 0 {
+               return nil, errors.New("tls: no server certificates in client session")
+       }
+       var chainList cryptobyte.String
+       if !s.ReadUint24LengthPrefixed(&chainList) {
+               return nil, errors.New("tls: invalid session encoding")
+       }
+       for !chainList.Empty() {
+               var certList cryptobyte.String
+               if !chainList.ReadUint24LengthPrefixed(&certList) {
+                       return nil, errors.New("tls: invalid session encoding")
+               }
+               var chain []*x509.Certificate
+               chain = append(chain, ss.peerCertificates[0])
+               for !certList.Empty() {
+                       var cert []byte
+                       if !readUint24LengthPrefixed(&certList, &cert) {
+                               return nil, errors.New("tls: invalid session encoding")
+                       }
+                       c, err := globalCertCache.newCert(cert)
+                       if err != nil {
+                               return nil, err
+                       }
+                       ss.activeCertHandles = append(ss.activeCertHandles, c)
+                       chain = append(chain, c.cert)
+               }
+               ss.verifiedChains = append(ss.verifiedChains, chain)
+       }
+       if ss.version < VersionTLS13 {
+               if !s.Empty() {
+                       return nil, errors.New("tls: invalid session encoding")
+               }
+               return ss, nil
+       }
+       if !s.ReadUint64(&ss.useBy) || !s.ReadUint32(&ss.ageAdd) || !s.Empty() {
                return nil, errors.New("tls: invalid session encoding")
        }
        return ss, nil
 }
 
+// sessionState returns a partially filled-out [SessionState] with information
+// from the current connection.
+func (c *Conn) sessionState() (*SessionState, error) {
+       var verifiedChains [][]*x509.Certificate
+       if c.isClient {
+               verifiedChains = c.verifiedChains
+               if len(c.peerCertificates) == 0 {
+                       return nil, errors.New("tls: internal error: empty peer certificates")
+               }
+       }
+       return &SessionState{
+               version:           c.vers,
+               cipherSuite:       c.cipherSuite,
+               createdAt:         uint64(c.config.time().Unix()),
+               peerCertificates:  c.peerCertificates,
+               activeCertHandles: c.activeCertHandles,
+               ocspResponse:      c.ocspResponse,
+               scts:              c.scts,
+               isClient:          c.isClient,
+               verifiedChains:    verifiedChains,
+       }, nil
+}
+
 func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
        if len(c.ticketKeys) == 0 {
                return nil, errors.New("tls: internal error: session ticket keys unavailable")