]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: implement TLS 1.3 record layer and cipher suites
authorFilippo Valsorda <filippo@golang.org>
Sat, 27 Oct 2018 16:50:25 +0000 (12:50 -0400)
committerFilippo Valsorda <filippo@golang.org>
Fri, 2 Nov 2018 21:54:38 +0000 (21:54 +0000)
Updates #9671

Change-Id: I1ea7b724975c0841d01f4536eebb23956b30d5ea
Reviewed-on: https://go-review.googlesource.com/c/145297
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/cipher_suites.go
src/crypto/tls/common.go
src/crypto/tls/conn.go
src/crypto/tls/tls_test.go

index e937235876e7ece3a6e9df6c89b0454793d59d83..d948fac8cdd304efbd36645a1d25cab0c3589d5b 100644 (file)
@@ -5,6 +5,7 @@
 package tls
 
 import (
+       "crypto"
        "crypto/aes"
        "crypto/cipher"
        "crypto/des"
@@ -58,8 +59,7 @@ const (
        suiteDefaultOff
 )
 
-// A cipherSuite is a specific combination of key agreement, cipher and MAC
-// function. All cipher suites currently assume RSA key agreement.
+// A cipherSuite is a specific combination of key agreement, cipher and MAC function.
 type cipherSuite struct {
        id uint16
        // the lengths, in bytes, of the key material needed for each component.
@@ -71,7 +71,7 @@ type cipherSuite struct {
        flags  int
        cipher func(key, iv []byte, isRead bool) interface{}
        mac    func(version uint16, macKey []byte) macFunction
-       aead   func(key, fixedNonce []byte) cipher.AEAD
+       aead   func(key, fixedNonce []byte) aead
 }
 
 var cipherSuites = []*cipherSuite{
@@ -103,6 +103,21 @@ var cipherSuites = []*cipherSuite{
        {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteDefaultOff, cipherRC4, macSHA1, nil},
 }
 
+// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
+// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
+type cipherSuiteTLS13 struct {
+       id     uint16
+       keyLen int
+       aead   func(key, fixedNonce []byte) aead
+       hash   crypto.Hash
+}
+
+var cipherSuitesTLS13 = []*cipherSuiteTLS13{
+       {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
+       {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
+       {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
+}
+
 func cipherRC4(key, iv []byte, isRead bool) interface{} {
        cipher, _ := rc4.NewCipher(key)
        return cipher
@@ -161,36 +176,41 @@ type aead interface {
        explicitNonceLen() int
 }
 
-// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
+const (
+       aeadNonceLength   = 12
+       noncePrefixLength = 4
+)
+
+// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
 // each call.
-type fixedNonceAEAD struct {
+type prefixNonceAEAD struct {
        // nonce contains the fixed part of the nonce in the first four bytes.
-       nonce [12]byte
+       nonce [aeadNonceLength]byte
        aead  cipher.AEAD
 }
 
-func (f *fixedNonceAEAD) NonceSize() int        { return 8 }
-func (f *fixedNonceAEAD) Overhead() int         { return f.aead.Overhead() }
-func (f *fixedNonceAEAD) explicitNonceLen() int { return 8 }
+func (f *prefixNonceAEAD) NonceSize() int        { return aeadNonceLength - noncePrefixLength }
+func (f *prefixNonceAEAD) Overhead() int         { return f.aead.Overhead() }
+func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
 
-func (f *fixedNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
+func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
        copy(f.nonce[4:], nonce)
        return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
 }
 
-func (f *fixedNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) {
+func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
        copy(f.nonce[4:], nonce)
-       return f.aead.Open(out, f.nonce[:], plaintext, additionalData)
+       return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
 }
 
 // xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
 // before each call.
 type xorNonceAEAD struct {
-       nonceMask [12]byte
+       nonceMask [aeadNonceLength]byte
        aead      cipher.AEAD
 }
 
-func (f *xorNonceAEAD) NonceSize() int        { return 8 }
+func (f *xorNonceAEAD) NonceSize() int        { return 8 } // 64-bit sequence number
 func (f *xorNonceAEAD) Overhead() int         { return f.aead.Overhead() }
 func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
 
@@ -206,11 +226,11 @@ func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte
        return result
 }
 
-func (f *xorNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) {
+func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
        for i, b := range nonce {
                f.nonceMask[4+i] ^= b
        }
-       result, err := f.aead.Open(out, f.nonceMask[:], plaintext, additionalData)
+       result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
        for i, b := range nonce {
                f.nonceMask[4+i] ^= b
        }
@@ -218,7 +238,10 @@ func (f *xorNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byt
        return result, err
 }
 
-func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD {
+func aeadAESGCM(key, noncePrefix []byte) aead {
+       if len(noncePrefix) != noncePrefixLength {
+               panic("tls: internal error: wrong nonce length")
+       }
        aes, err := aes.NewCipher(key)
        if err != nil {
                panic(err)
@@ -228,19 +251,40 @@ func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD {
                panic(err)
        }
 
-       ret := &fixedNonceAEAD{aead: aead}
-       copy(ret.nonce[:], fixedNonce)
+       ret := &prefixNonceAEAD{aead: aead}
+       copy(ret.nonce[:], noncePrefix)
        return ret
 }
 
-func aeadChaCha20Poly1305(key, fixedNonce []byte) cipher.AEAD {
+func aeadAESGCMTLS13(key, nonceMask []byte) aead {
+       if len(nonceMask) != aeadNonceLength {
+               panic("tls: internal error: wrong nonce length")
+       }
+       aes, err := aes.NewCipher(key)
+       if err != nil {
+               panic(err)
+       }
+       aead, err := cipher.NewGCM(aes)
+       if err != nil {
+               panic(err)
+       }
+
+       ret := &xorNonceAEAD{aead: aead}
+       copy(ret.nonceMask[:], nonceMask)
+       return ret
+}
+
+func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
+       if len(nonceMask) != aeadNonceLength {
+               panic("tls: internal error: wrong nonce length")
+       }
        aead, err := chacha20poly1305.New(key)
        if err != nil {
                panic(err)
        }
 
        ret := &xorNonceAEAD{aead: aead}
-       copy(ret.nonceMask[:], fixedNonce)
+       copy(ret.nonceMask[:], nonceMask)
        return ret
 }
 
@@ -371,6 +415,7 @@ func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
 //
 // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
 const (
+       // TLS 1.0 - 1.2 cipher suites.
        TLS_RSA_WITH_RC4_128_SHA                uint16 = 0x0005
        TLS_RSA_WITH_3DES_EDE_CBC_SHA           uint16 = 0x000a
        TLS_RSA_WITH_AES_128_CBC_SHA            uint16 = 0x002f
@@ -394,6 +439,11 @@ const (
        TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305    uint16 = 0xcca8
        TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305  uint16 = 0xcca9
 
+       // TLS 1.3 cipher suites.
+       TLS_AES_128_GCM_SHA256       uint16 = 0x1301
+       TLS_AES_256_GCM_SHA384       uint16 = 0x1302
+       TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
+
        // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
        // that the client is doing version fallback. See RFC 7507.
        TLS_FALLBACK_SCSV uint16 = 0x5600
index 717c5f0b0e8b8739cbd080762655d90b2ec4642c..4808c01f9ca02c81e26fe5df80d154e79fd35f4d 100644 (file)
@@ -26,14 +26,19 @@ const (
        VersionTLS10 = 0x0301
        VersionTLS11 = 0x0302
        VersionTLS12 = 0x0303
+
+       // VersionTLS13 is under development in this library and can't be selected
+       // nor negotiated yet on either side.
+       VersionTLS13 = 0x0304
 )
 
 const (
-       maxPlaintext      = 16384        // maximum plaintext payload length
-       maxCiphertext     = 16384 + 2048 // maximum ciphertext payload length
-       recordHeaderLen   = 5            // record header length
-       maxHandshake      = 65536        // maximum handshake we support (protocol max is 16 MB)
-       maxWarnAlertCount = 5            // maximum number of consecutive warning alerts
+       maxPlaintext       = 16384        // maximum plaintext payload length
+       maxCiphertext      = 16384 + 2048 // maximum ciphertext payload length
+       maxCiphertextTLS13 = 16384 + 256  // maximum ciphertext length in TLS 1.3
+       recordHeaderLen    = 5            // record header length
+       maxHandshake       = 65536        // maximum handshake we support (protocol max is 16 MB)
+       maxUselessRecords  = 5            // maximum number of consecutive non-advancing records
 
        minVersion = VersionTLS10
        maxVersion = VersionTLS12
@@ -942,8 +947,9 @@ func defaultConfig() *Config {
 }
 
 var (
-       once                   sync.Once
-       varDefaultCipherSuites []uint16
+       once                        sync.Once
+       varDefaultCipherSuites      []uint16
+       varDefaultCipherSuitesTLS13 []uint16
 )
 
 func defaultCipherSuites() []uint16 {
@@ -951,19 +957,24 @@ func defaultCipherSuites() []uint16 {
        return varDefaultCipherSuites
 }
 
+func defaultCipherSuitesTLS13() []uint16 {
+       once.Do(initDefaultCipherSuites)
+       return varDefaultCipherSuitesTLS13
+}
+
 func initDefaultCipherSuites() {
        var topCipherSuites []uint16
 
        // Check the cpu flags for each platform that has optimized GCM implementations.
-       // Worst case, these variables will just all be false
-       hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
+       // Worst case, these variables will just all be false.
+       var (
+               hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
+               hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
+               // Keep in sync with crypto/aes/cipher_s390x.go.
+               hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
 
-       hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
-
-       // Keep in sync with crypto/aes/cipher_s390x.go.
-       hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
-
-       hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X
+               hasGCMAsm = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X
+       )
 
        if hasGCMAsm {
                // If AES-GCM hardware is provided then prioritise AES-GCM
@@ -976,6 +987,11 @@ func initDefaultCipherSuites() {
                        TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
                        TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
                }
+               varDefaultCipherSuitesTLS13 = []uint16{
+                       TLS_AES_128_GCM_SHA256,
+                       TLS_CHACHA20_POLY1305_SHA256,
+                       TLS_AES_256_GCM_SHA384,
+               }
        } else {
                // Without AES-GCM hardware, we put the ChaCha20-Poly1305
                // cipher suites first.
@@ -987,6 +1003,11 @@ func initDefaultCipherSuites() {
                        TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
                        TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
                }
+               varDefaultCipherSuitesTLS13 = []uint16{
+                       TLS_CHACHA20_POLY1305_SHA256,
+                       TLS_AES_128_GCM_SHA256,
+                       TLS_AES_256_GCM_SHA384,
+               }
        }
 
        varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites))
index dae5fd103a1abe1065f3c18db772cc1aea137e10..5af1413935363c10f96abab97877256ed531cb31 100644 (file)
@@ -94,9 +94,9 @@ type Conn struct {
        bytesSent   int64
        packetsSent int64
 
-       // warnCount counts the number of consecutive warning alerts received
+       // retryCount counts the number of consecutive warning alerts received
        // by Conn.readRecord. Protected by in.Mutex.
-       warnCount int
+       retryCount int
 
        // activeCall is an atomic int32; the low bit is whether Close has
        // been called. the rest of the bits are the number of goroutines
@@ -291,9 +291,17 @@ type cbcMode interface {
 
 // decrypt authenticates and decrypts the record if protection is active at
 // this stage. The returned plaintext might overlap with the input.
-func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
+func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
+       var plaintext []byte
+       typ := recordType(record[0])
        payload := record[recordHeaderLen:]
 
+       // In TLS 1.3, change_cipher_spec messages are to be ignored without being
+       // decrypted. See RFC 8446, Appendix D.4.
+       if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
+               return payload, typ, nil
+       }
+
        paddingGood := byte(255)
        paddingLen := 0
 
@@ -305,7 +313,7 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
                        c.XORKeyStream(payload, payload)
                case aead:
                        if len(payload) < explicitNonceLen {
-                               return nil, alertBadRecordMAC
+                               return nil, 0, alertBadRecordMAC
                        }
                        nonce := payload[:explicitNonceLen]
                        if len(nonce) == 0 {
@@ -313,22 +321,27 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
                        }
                        payload = payload[explicitNonceLen:]
 
-                       copy(hc.additionalData[:], hc.seq[:])
-                       copy(hc.additionalData[8:], record[:3])
-                       n := len(payload) - c.Overhead()
-                       hc.additionalData[11] = byte(n >> 8)
-                       hc.additionalData[12] = byte(n)
+                       additionalData := hc.additionalData[:]
+                       if hc.version == VersionTLS13 {
+                               additionalData = record[:recordHeaderLen]
+                       } else {
+                               copy(additionalData, hc.seq[:])
+                               copy(additionalData[8:], record[:3])
+                               n := len(payload) - c.Overhead()
+                               additionalData[11] = byte(n >> 8)
+                               additionalData[12] = byte(n)
+                       }
 
                        var err error
-                       plaintext, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
+                       plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
                        if err != nil {
-                               return nil, alertBadRecordMAC
+                               return nil, 0, alertBadRecordMAC
                        }
                case cbcMode:
                        blockSize := c.BlockSize()
                        minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) // TODO: vuln?
                        if len(payload)%blockSize != 0 || len(payload) < minPayload {
-                               return nil, alertBadRecordMAC
+                               return nil, 0, alertBadRecordMAC
                        }
 
                        if explicitNonceLen > 0 {
@@ -351,6 +364,26 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
                default:
                        panic("unknown cipher type")
                }
+
+               if hc.version == VersionTLS13 {
+                       if typ != recordTypeApplicationData {
+                               return nil, 0, alertUnexpectedMessage
+                       }
+                       if len(plaintext) > maxPlaintext+1 {
+                               return nil, 0, alertRecordOverflow
+                       }
+                       // Remove padding and find the ContentType scanning from the end.
+                       for i := len(plaintext) - 1; i >= 0; i-- {
+                               if plaintext[i] != 0 {
+                                       typ = recordType(plaintext[i])
+                                       plaintext = plaintext[:i]
+                                       break
+                               }
+                               if i == 0 {
+                                       return nil, 0, alertUnexpectedMessage
+                               }
+                       }
+               }
        } else {
                plaintext = payload
        }
@@ -358,7 +391,7 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
        if hc.mac != nil {
                macSize := hc.mac.Size()
                if len(payload) < macSize {
-                       return nil, alertBadRecordMAC
+                       return nil, 0, alertBadRecordMAC
                }
 
                n := len(payload) - macSize - paddingLen
@@ -369,14 +402,14 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
                localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
 
                if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
-                       return nil, alertBadRecordMAC
+                       return nil, 0, alertBadRecordMAC
                }
 
                plaintext = payload[:n]
        }
 
        hc.incSeq()
-       return plaintext, nil
+       return plaintext, typ, nil
 }
 
 // sliceForAppend extends the input slice by n bytes. head is the full extended
@@ -438,12 +471,24 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
                        nonce = hc.seq[:]
                }
 
-               copy(hc.additionalData[:], hc.seq[:])
-               copy(hc.additionalData[8:], record[:3])
-               hc.additionalData[11] = byte(len(payload) >> 8)
-               hc.additionalData[12] = byte(len(payload))
+               if hc.version == VersionTLS13 {
+                       record = append(record, payload...)
+
+                       // Encrypt the actual ContentType and replace the plaintext one.
+                       record = append(record, record[0])
+                       record[0] = byte(recordTypeApplicationData)
 
-               record = c.Seal(record, nonce, payload, hc.additionalData[:])
+                       n := len(payload) + 1 + c.Overhead()
+                       record[3] = byte(n >> 8)
+                       record[4] = byte(n)
+
+                       record = c.Seal(record[:recordHeaderLen],
+                               nonce, record[recordHeaderLen:], record[:recordHeaderLen])
+               } else {
+                       copy(hc.additionalData[:], hc.seq[:])
+                       copy(hc.additionalData[8:], record)
+                       record = c.Seal(record, nonce, payload, hc.additionalData[:])
+               }
        case cbcMode:
                blockSize := c.BlockSize()
                plaintextLen := len(payload) + len(mac)
@@ -546,7 +591,7 @@ func (c *Conn) readRecord(want recordType) error {
 
        vers := uint16(hdr[1])<<8 | uint16(hdr[2])
        n := int(hdr[3])<<8 | int(hdr[4])
-       if c.haveVers && vers != c.vers {
+       if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
                c.sendAlert(alertProtocolVersion)
                msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
                return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
@@ -560,7 +605,7 @@ func (c *Conn) readRecord(want recordType) error {
                        return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
                }
        }
-       if n > maxCiphertext {
+       if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
                c.sendAlert(alertRecordOverflow)
                msg := fmt.Sprintf("oversized record received with length %d", n)
                return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
@@ -574,7 +619,7 @@ func (c *Conn) readRecord(want recordType) error {
 
        // Process message.
        record := c.rawInput.Next(recordHeaderLen + n)
-       data, err := c.in.decrypt(record)
+       data, typ, err := c.in.decrypt(record)
        if err != nil {
                return c.in.setErrorLocked(c.sendAlert(err.(alert)))
        }
@@ -582,9 +627,31 @@ func (c *Conn) readRecord(want recordType) error {
                return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
        }
 
-       if typ != recordTypeAlert && len(data) > 0 {
-               // this is a valid non-alert message: reset the count of alerts
-               c.warnCount = 0
+       // Application Data messages are always protected.
+       if c.in.cipher == nil && typ == recordTypeApplicationData {
+               return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+       }
+
+       if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
+               // This is a state-advancing message: reset the retry count.
+               c.retryCount = 0
+       }
+
+       // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
+       if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
+               return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+       }
+
+       // In TLS 1.3, change_cipher_spec records are ignored until the Finished.
+       // See RFC 8446, Appendix D.4. Note that according to Section 5, a server
+       // can send a ChangeCipherSpec before its ServerHello, when c.vers is still
+       // unset. That's not useful though and suspicious if the server then selects
+       // a lower protocol version, so don't allow that.
+       if c.vers == VersionTLS13 && typ == recordTypeChangeCipherSpec {
+               if len(data) != 1 || data[0] != 1 || c.handshakeComplete() {
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+               }
+               return c.retryReadRecord(want)
        }
 
        switch typ {
@@ -598,14 +665,12 @@ func (c *Conn) readRecord(want recordType) error {
                if alert(data[1]) == alertCloseNotify {
                        return c.in.setErrorLocked(io.EOF)
                }
+               if c.vers == VersionTLS13 {
+                       return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
+               }
                switch data[0] {
                case alertLevelWarning:
-                       c.warnCount++
-                       if c.warnCount > maxWarnAlertCount {
-                               c.sendAlert(alertUnexpectedMessage)
-                               return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
-                       }
-                       return c.readRecord(want) // Drop the record on the floor and retry.
+                       return c.retryReadRecord(want) // Drop the record on the floor and retry.
                case alertLevelError:
                        return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
                default:
@@ -629,6 +694,9 @@ func (c *Conn) readRecord(want recordType) error {
                if typ != want {
                        return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
+               if len(data) == 0 {
+                       return c.retryReadRecord(want)
+               }
                // Note that data is owned by c.rawInput, following the Next call above,
                // to avoid copying the plaintext. This is safe because c.rawInput is
                // not read from or written to until c.input is drained.
@@ -636,14 +704,35 @@ func (c *Conn) readRecord(want recordType) error {
                return nil
 
        case recordTypeHandshake:
-               if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
+               if typ != want && !c.isRenegotiationAcceptable() {
                        return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
                }
+               if len(data) == 0 {
+                       return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+               }
                c.hand.Write(data)
                return nil
        }
 }
 
+// retryReadRecord recurses into readRecord to drop a non-advancing record, like
+// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
+func (c *Conn) retryReadRecord(want recordType) error {
+       c.retryCount++
+       if c.retryCount > maxUselessRecords {
+               c.sendAlert(alertUnexpectedMessage)
+               return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
+       }
+       return c.readRecord(want)
+}
+
+func (c *Conn) isRenegotiationAcceptable() bool {
+       return c.isClient &&
+               c.vers != VersionTLS13 &&
+               c.handshakeComplete() &&
+               c.config.Renegotiation != RenegotiateNever
+}
+
 // atLeastReader reads from R, stopping with EOF once at least N bytes have been
 // read. It is different from an io.LimitedReader in that it doesn't cut short
 // the last Read call, and in that it considers an early EOF an error.
@@ -767,6 +856,9 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
                        panic("unknown cipher type")
                }
        }
+       if c.vers == VersionTLS13 {
+               payloadBytes-- // encrypted ContentType
+       }
 
        // Allow packet growth in arithmetic progression up to max.
        pkt := c.packetsSent
@@ -822,6 +914,10 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
                        // Some TLS servers fail if the record version is
                        // greater than TLS 1.0 for the initial ClientHello.
                        vers = VersionTLS10
+               } else if vers == VersionTLS13 {
+                       // TLS 1.3 froze the record layer version to 1.2.
+                       // See RFC 8446, Section 5.1.
+                       vers = VersionTLS12
                }
                c.outBuf[1] = byte(vers >> 8)
                c.outBuf[2] = byte(vers)
index 7542699bdc9ad607f47d68c0fe99c27943a42d66..e9abe01280fd21ab0b48d6b2eecfbdc822accf39 100644 (file)
@@ -589,7 +589,7 @@ func TestWarningAlertFlood(t *testing.T) {
                if err == nil {
                        return errors.New("unexpected lack of error from server")
                }
-               const expected = "too many warn"
+               const expected = "too many ignored"
                if str := err.Error(); !strings.Contains(str, expected) {
                        return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
                }
@@ -610,7 +610,7 @@ func TestWarningAlertFlood(t *testing.T) {
                t.Fatal(err)
        }
 
-       for i := 0; i < maxWarnAlertCount+1; i++ {
+       for i := 0; i < maxUselessRecords+1; i++ {
                conn.sendAlert(alertNoRenegotiation)
        }