package tls
import (
+ "container/list"
"crypto"
"crypto/rand"
"crypto/x509"
RequireAndVerifyClientCert
)
+// ClientSessionState contains the state needed by clients to resume TLS
+// sessions.
+type ClientSessionState struct {
+ sessionTicket []uint8 // Encrypted ticket used for session resumption with server
+ vers uint16 // SSL/TLS version negotiated for the session
+ cipherSuite uint16 // Ciphersuite negotiated for the session
+ masterSecret []byte // MasterSecret generated by client on a full handshake
+ serverCertificates []*x509.Certificate // Certificate chain presented by the server
+}
+
+// ClientSessionCache is a cache of ClientSessionState objects that can be used
+// by a client to resume a TLS session with a given server. ClientSessionCache
+// implementations should expect to be called concurrently from different
+// goroutines.
+type ClientSessionCache interface {
+ // Get searches for a ClientSessionState associated with the given key.
+ // On return, ok is true if one was found.
+ Get(sessionKey string) (session *ClientSessionState, ok bool)
+
+ // Put adds the ClientSessionState to the cache with the given key.
+ Put(sessionKey string, cs *ClientSessionState)
+}
+
// A Config structure is used to configure a TLS client or server. After one
// has been passed to a TLS function it must not be modified.
type Config struct {
// connections using that key are compromised.
SessionTicketKey [32]byte
+ // SessionCache is a cache of ClientSessionState entries for TLS session
+ // resumption.
+ ClientSessionCache ClientSessionCache
+
// MinVersion contains the minimum SSL/TLS version that is acceptable.
// If zero, then SSLv3 is taken as the minimum.
MinVersion uint16
unmarshal([]byte) bool
}
+// lruSessionCache is a ClientSessionCache implementation that uses an LRU
+// caching strategy.
+type lruSessionCache struct {
+ sync.Mutex
+
+ m map[string]*list.Element
+ q *list.List
+ capacity int
+}
+
+type lruSessionCacheEntry struct {
+ sessionKey string
+ state *ClientSessionState
+}
+
+// NewLRUClientSessionCache returns a ClientSessionCache with the given
+// capacity that uses an LRU strategy. If capacity is < 1, a default capacity
+// is used instead.
+func NewLRUClientSessionCache(capacity int) ClientSessionCache {
+ const defaultSessionCacheCapacity = 64
+
+ if capacity < 1 {
+ capacity = defaultSessionCacheCapacity
+ }
+ return &lruSessionCache{
+ m: make(map[string]*list.Element),
+ q: list.New(),
+ capacity: capacity,
+ }
+}
+
+// Put adds the provided (sessionKey, cs) pair to the cache.
+func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) {
+ c.Lock()
+ defer c.Unlock()
+
+ if elem, ok := c.m[sessionKey]; ok {
+ entry := elem.Value.(*lruSessionCacheEntry)
+ entry.state = cs
+ c.q.MoveToFront(elem)
+ return
+ }
+
+ if c.q.Len() < c.capacity {
+ entry := &lruSessionCacheEntry{sessionKey, cs}
+ c.m[sessionKey] = c.q.PushFront(entry)
+ return
+ }
+
+ elem := c.q.Back()
+ entry := elem.Value.(*lruSessionCacheEntry)
+ delete(c.m, entry.sessionKey)
+ entry.sessionKey = sessionKey
+ entry.state = cs
+ c.q.MoveToFront(elem)
+ c.m[sessionKey] = elem
+}
+
+// Get returns the ClientSessionState value associated with a given key. It
+// returns (nil, false) if no value is found.
+func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
+ c.Lock()
+ defer c.Unlock()
+
+ if elem, ok := c.m[sessionKey]; ok {
+ c.q.MoveToFront(elem)
+ return elem.Value.(*lruSessionCacheEntry).state, true
+ }
+ return nil, false
+}
+
// TODO(jsing): Make these available to both crypto/x509 and crypto/tls.
type dsaSignature struct {
R, S *big.Int
"encoding/asn1"
"errors"
"io"
+ "net"
"strconv"
)
+type clientHandshakeState struct {
+ c *Conn
+ serverHello *serverHelloMsg
+ hello *clientHelloMsg
+ suite *cipherSuite
+ finishedHash finishedHash
+ masterSecret []byte
+ session *ClientSessionState
+}
+
func (c *Conn) clientHandshake() error {
if c.config == nil {
c.config = defaultConfig()
_, err := io.ReadFull(c.config.rand(), hello.random[4:])
if err != nil {
c.sendAlert(alertInternalError)
- return errors.New("short read from Rand")
+ return errors.New("tls: short read from Rand: " + err.Error())
}
if hello.vers >= VersionTLS12 {
hello.signatureAndHashes = supportedSKXSignatureAlgorithms
}
+ var session *ClientSessionState
+ var cacheKey string
+ sessionCache := c.config.ClientSessionCache
+ if c.config.SessionTicketsDisabled {
+ sessionCache = nil
+ }
+
+ if sessionCache != nil {
+ hello.ticketSupported = true
+
+ // Try to resume a previously negotiated TLS session, if
+ // available.
+ cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
+ candidateSession, ok := sessionCache.Get(cacheKey)
+ if ok {
+ // Check that the ciphersuite/version used for the
+ // previous session are still valid.
+ cipherSuiteOk := false
+ for _, id := range hello.cipherSuites {
+ if id == candidateSession.cipherSuite {
+ cipherSuiteOk = true
+ break
+ }
+ }
+
+ versOk := candidateSession.vers >= c.config.minVersion() &&
+ candidateSession.vers <= c.config.maxVersion()
+ if versOk && cipherSuiteOk {
+ session = candidateSession
+ }
+ }
+ }
+
+ if session != nil {
+ hello.sessionTicket = session.sessionTicket
+ // A random session ID is used to detect when the
+ // server accepted the ticket and is resuming a session
+ // (see RFC 5077).
+ hello.sessionId = make([]byte, 16)
+ if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil {
+ c.sendAlert(alertInternalError)
+ return errors.New("tls: short read from Rand: " + err.Error())
+ }
+ }
+
c.writeRecord(recordTypeHandshake, hello.marshal())
msg, err := c.readHandshake()
c.vers = vers
c.haveVers = true
- finishedHash := newFinishedHash(c.vers)
- finishedHash.Write(hello.marshal())
- finishedHash.Write(serverHello.marshal())
+ suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
+ if suite == nil {
+ return c.sendAlert(alertHandshakeFailure)
+ }
- if serverHello.compressionMethod != compressionNone {
- return c.sendAlert(alertUnexpectedMessage)
+ hs := &clientHandshakeState{
+ c: c,
+ serverHello: serverHello,
+ hello: hello,
+ suite: suite,
+ finishedHash: newFinishedHash(c.vers),
+ session: session,
}
- if !hello.nextProtoNeg && serverHello.nextProtoNeg {
- c.sendAlert(alertHandshakeFailure)
- return errors.New("server advertised unrequested NPN")
+ hs.finishedHash.Write(hs.hello.marshal())
+ hs.finishedHash.Write(hs.serverHello.marshal())
+
+ isResume, err := hs.processServerHello()
+ if err != nil {
+ return err
}
- suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
- if suite == nil {
- return c.sendAlert(alertHandshakeFailure)
+ if isResume {
+ if err := hs.establishKeys(); err != nil {
+ return err
+ }
+ if err := hs.readSessionTicket(); err != nil {
+ return err
+ }
+ if err := hs.readFinished(); err != nil {
+ return err
+ }
+ if err := hs.sendFinished(); err != nil {
+ return err
+ }
+ } else {
+ if err := hs.doFullHandshake(); err != nil {
+ return err
+ }
+ if err := hs.establishKeys(); err != nil {
+ return err
+ }
+ if err := hs.sendFinished(); err != nil {
+ return err
+ }
+ if err := hs.readSessionTicket(); err != nil {
+ return err
+ }
+ if err := hs.readFinished(); err != nil {
+ return err
+ }
}
- msg, err = c.readHandshake()
+ if sessionCache != nil && hs.session != nil && session != hs.session {
+ sessionCache.Put(cacheKey, hs.session)
+ }
+
+ c.didResume = isResume
+ c.handshakeComplete = true
+ c.cipherSuite = suite.id
+ return nil
+}
+
+func (hs *clientHandshakeState) doFullHandshake() error {
+ c := hs.c
+
+ msg, err := c.readHandshake()
if err != nil {
return err
}
if !ok || len(certMsg.certificates) == 0 {
return c.sendAlert(alertUnexpectedMessage)
}
- finishedHash.Write(certMsg.marshal())
+ hs.finishedHash.Write(certMsg.marshal())
certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates {
c.peerCertificates = certs
- if serverHello.ocspStapling {
+ if hs.serverHello.ocspStapling {
msg, err = c.readHandshake()
if err != nil {
return err
if !ok {
return c.sendAlert(alertUnexpectedMessage)
}
- finishedHash.Write(cs.marshal())
+ hs.finishedHash.Write(cs.marshal())
if cs.statusType == statusTypeOCSP {
c.ocspResponse = cs.response
return err
}
- keyAgreement := suite.ka(c.vers)
+ keyAgreement := hs.suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
- finishedHash.Write(skx.marshal())
- err = keyAgreement.processServerKeyExchange(c.config, hello, serverHello, certs[0], skx)
+ hs.finishedHash.Write(skx.marshal())
+ err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx)
if err != nil {
c.sendAlert(alertUnexpectedMessage)
return err
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
- finishedHash.Write(certReq.marshal())
+ hs.finishedHash.Write(certReq.marshal())
var rsaAvail, ecdsaAvail bool
for _, certType := range certReq.certificateTypes {
if !ok {
return c.sendAlert(alertUnexpectedMessage)
}
- finishedHash.Write(shd.marshal())
+ hs.finishedHash.Write(shd.marshal())
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
if chainToSend != nil {
certMsg.certificates = chainToSend.Certificate
}
- finishedHash.Write(certMsg.marshal())
+ hs.finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal())
}
- preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hello, certs[0])
+ preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if ckx != nil {
- finishedHash.Write(ckx.marshal())
+ hs.finishedHash.Write(ckx.marshal())
c.writeRecord(recordTypeHandshake, ckx.marshal())
}
switch key := c.config.Certificates[0].PrivateKey.(type) {
case *ecdsa.PrivateKey:
- digest, _, hashId := finishedHash.hashForClientCertificate(signatureECDSA)
+ digest, _, hashId := hs.finishedHash.hashForClientCertificate(signatureECDSA)
r, s, err := ecdsa.Sign(c.config.rand(), key, digest)
if err == nil {
signed, err = asn1.Marshal(ecdsaSignature{r, s})
certVerify.signatureAndHash.signature = signatureECDSA
certVerify.signatureAndHash.hash = hashId
case *rsa.PrivateKey:
- digest, hashFunc, hashId := finishedHash.hashForClientCertificate(signatureRSA)
+ digest, hashFunc, hashId := hs.finishedHash.hashForClientCertificate(signatureRSA)
signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest)
certVerify.signatureAndHash.signature = signatureRSA
certVerify.signatureAndHash.hash = hashId
}
certVerify.signature = signed
- finishedHash.Write(certVerify.marshal())
+ hs.finishedHash.Write(certVerify.marshal())
c.writeRecord(recordTypeHandshake, certVerify.marshal())
}
- masterSecret := masterFromPreMasterSecret(c.vers, preMasterSecret, hello.random, serverHello.random)
- clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
- keysFromMasterSecret(c.vers, masterSecret, hello.random, serverHello.random, suite.macLen, suite.keyLen, suite.ivLen)
+ hs.masterSecret = masterFromPreMasterSecret(c.vers, preMasterSecret, hs.hello.random, hs.serverHello.random)
+ return nil
+}
+
+func (hs *clientHandshakeState) establishKeys() error {
+ c := hs.c
- var clientCipher interface{}
- var clientHash macFunction
- if suite.cipher != nil {
- clientCipher = suite.cipher(clientKey, clientIV, false /* not for reading */)
- clientHash = suite.mac(c.vers, clientMAC)
+ clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
+ keysFromMasterSecret(c.vers, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
+ var clientCipher, serverCipher interface{}
+ var clientHash, serverHash macFunction
+ if hs.suite.cipher != nil {
+ clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
+ clientHash = hs.suite.mac(c.vers, clientMAC)
+ serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
+ serverHash = hs.suite.mac(c.vers, serverMAC)
} else {
- clientCipher = suite.aead(clientKey, clientIV)
+ clientCipher = hs.suite.aead(clientKey, clientIV)
+ serverCipher = hs.suite.aead(serverKey, serverIV)
}
+
+ c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
- c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+ return nil
+}
- if serverHello.nextProtoNeg {
- nextProto := new(nextProtoMsg)
- proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos)
- nextProto.proto = proto
- c.clientProtocol = proto
- c.clientProtocolFallback = fallback
+func (hs *clientHandshakeState) serverResumedSession() bool {
+ // If the server responded with the same sessionId then it means the
+ // sessionTicket is being used to resume a TLS session.
+ return hs.session != nil && hs.hello.sessionId != nil &&
+ bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
+}
- finishedHash.Write(nextProto.marshal())
- c.writeRecord(recordTypeHandshake, nextProto.marshal())
+func (hs *clientHandshakeState) processServerHello() (bool, error) {
+ c := hs.c
+
+ if hs.serverHello.compressionMethod != compressionNone {
+ return false, c.sendAlert(alertUnexpectedMessage)
}
- finished := new(finishedMsg)
- finished.verifyData = finishedHash.clientSum(masterSecret)
- finishedHash.Write(finished.marshal())
- c.writeRecord(recordTypeHandshake, finished.marshal())
+ if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
+ c.sendAlert(alertHandshakeFailure)
+ return false, errors.New("server advertised unrequested NPN")
+ }
- var serverCipher interface{}
- var serverHash macFunction
- if suite.cipher != nil {
- serverCipher = suite.cipher(serverKey, serverIV, true /* for reading */)
- serverHash = suite.mac(c.vers, serverMAC)
- } else {
- serverCipher = suite.aead(serverKey, serverIV)
+ if hs.serverResumedSession() {
+ // Restore masterSecret and peerCerts from previous state
+ hs.masterSecret = hs.session.masterSecret
+ c.peerCertificates = hs.session.serverCertificates
+ return true, nil
}
- c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
+ return false, nil
+}
+
+func (hs *clientHandshakeState) readFinished() error {
+ c := hs.c
+
c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
return err
}
- msg, err = c.readHandshake()
+ msg, err := c.readHandshake()
if err != nil {
return err
}
return c.sendAlert(alertUnexpectedMessage)
}
- verify := finishedHash.serverSum(masterSecret)
+ verify := hs.finishedHash.serverSum(hs.masterSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
return c.sendAlert(alertHandshakeFailure)
}
+ hs.finishedHash.Write(serverFinished.marshal())
+ return nil
+}
- c.handshakeComplete = true
- c.cipherSuite = suite.id
+func (hs *clientHandshakeState) readSessionTicket() error {
+ if !hs.serverHello.ticketSupported {
+ return nil
+ }
+
+ c := hs.c
+ msg, err := c.readHandshake()
+ if err != nil {
+ return err
+ }
+ sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
+ if !ok {
+ return c.sendAlert(alertUnexpectedMessage)
+ }
+ hs.finishedHash.Write(sessionTicketMsg.marshal())
+
+ hs.session = &ClientSessionState{
+ sessionTicket: sessionTicketMsg.ticket,
+ vers: c.vers,
+ cipherSuite: hs.suite.id,
+ masterSecret: hs.masterSecret,
+ serverCertificates: c.peerCertificates,
+ }
+
+ return nil
+}
+
+func (hs *clientHandshakeState) sendFinished() error {
+ c := hs.c
+
+ c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+ if hs.serverHello.nextProtoNeg {
+ nextProto := new(nextProtoMsg)
+ proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
+ nextProto.proto = proto
+ c.clientProtocol = proto
+ c.clientProtocolFallback = fallback
+
+ hs.finishedHash.Write(nextProto.marshal())
+ c.writeRecord(recordTypeHandshake, nextProto.marshal())
+ }
+
+ finished := new(finishedMsg)
+ finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
+ hs.finishedHash.Write(finished.marshal())
+ c.writeRecord(recordTypeHandshake, finished.marshal())
return nil
}
+// clientSessionCacheKey returns a key used to cache sessionTickets that could
+// be used to resume previously negotiated TLS sessions with a server.
+func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
+ if len(config.ServerName) > 0 {
+ return config.ServerName
+ }
+ return serverAddr.String()
+}
+
// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the
// set of client and server supported protocols. The set of client supported
// protocols must not be empty. It returns the resulting protocol and flag
runClientTestTLS10(t, test)
runClientTestTLS12(t, test)
}
+
+func TestClientResumption(t *testing.T) {
+ serverConfig := &Config{
+ CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
+ Certificates: testConfig.Certificates,
+ }
+ clientConfig := &Config{
+ CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+ InsecureSkipVerify: true,
+ ClientSessionCache: NewLRUClientSessionCache(32),
+ }
+
+ testResumeState := func(test string, didResume bool) {
+ hs, err := testHandshake(clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("%s: handshake failed: %s", test, err)
+ }
+ if hs.DidResume != didResume {
+ t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
+ }
+ }
+
+ testResumeState("Handshake", false)
+ testResumeState("Resume", true)
+
+ if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil {
+ t.Fatalf("Failed to invalidate SessionTicketKey")
+ }
+ testResumeState("InvalidSessionTicketKey", false)
+ testResumeState("ResumeAfterInvalidSessionTicketKey", true)
+
+ clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
+ testResumeState("DifferentCipherSuite", false)
+ testResumeState("DifferentCipherSuiteRecovers", true)
+
+ clientConfig.ClientSessionCache = nil
+ testResumeState("WithoutSessionCache", false)
+}
+
+func TestLRUClientSessionCache(t *testing.T) {
+ // Initialize cache of capacity 4.
+ cache := NewLRUClientSessionCache(4)
+ cs := make([]ClientSessionState, 6)
+ keys := []string{"0", "1", "2", "3", "4", "5", "6"}
+
+ // Add 4 entries to the cache and look them up.
+ for i := 0; i < 4; i++ {
+ cache.Put(keys[i], &cs[i])
+ }
+ for i := 0; i < 4; i++ {
+ if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
+ t.Fatalf("session cache failed lookup for added key: %s", keys[i])
+ }
+ }
+
+ // Add 2 more entries to the cache. First 2 should be evicted.
+ for i := 4; i < 6; i++ {
+ cache.Put(keys[i], &cs[i])
+ }
+ for i := 0; i < 2; i++ {
+ if s, ok := cache.Get(keys[i]); ok || s != nil {
+ t.Fatalf("session cache should have evicted key: %s", keys[i])
+ }
+ }
+
+ // Touch entry 2. LRU should evict 3 next.
+ cache.Get(keys[2])
+ cache.Put(keys[0], &cs[0])
+ if s, ok := cache.Get(keys[3]); ok || s != nil {
+ t.Fatalf("session cache should have evicted key 3")
+ }
+
+ // Update entry 0 in place.
+ cache.Put(keys[0], &cs[3])
+ if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
+ t.Fatalf("session cache failed update for key 0")
+ }
+
+ // Adding a nil entry is valid.
+ cache.Put(keys[0], nil)
+ if s, ok := cache.Get(keys[0]); !ok || s != nil {
+ t.Fatalf("failed to add nil entry to cache")
+ }
+}