]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: fix deadlock when Read and Close called concurrently
authorMinaev Mike <minaev.mike@gmail.com>
Fri, 26 Jan 2018 09:17:46 +0000 (09:17 +0000)
committerFilippo Valsorda <filippo@golang.org>
Wed, 25 Jul 2018 23:53:54 +0000 (23:53 +0000)
The existing implementation of TLS connection has a deadlock. It occurs
when client connects to TLS server and doesn't send data for
handshake, so server calls Close on this connection. This is because
server reads data under locked mutex, while Close method tries to
lock the same mutex.

Fixes #23518

Change-Id: I4fb0a2a770f3d911036bfd9a7da7cc41c1b27e19
Reviewed-on: https://go-review.googlesource.com/90155
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go

index cdaa7aba97c8830b19da01e02db43d8652f01002..2adb967537ccc5801f65e8ad210c9f2a524441d9 100644 (file)
@@ -27,15 +27,16 @@ type Conn struct {
        conn     net.Conn
        isClient bool
 
+       // handshakeStatus is 1 if the connection is currently transferring
+       // application data (i.e. is not currently processing a handshake).
+       // This field is only to be accessed with sync/atomic.
+       handshakeStatus uint32
        // constant after handshake; protected by handshakeMutex
        handshakeMutex sync.Mutex
        handshakeErr   error   // error resulting from handshake
        vers           uint16  // TLS version
        haveVers       bool    // version has been negotiated
        config         *Config // configuration passed to constructor
-       // handshakeComplete is true if the connection is currently transferring
-       // application data (i.e. is not currently processing a handshake).
-       handshakeComplete bool
        // handshakes counts the number of handshakes performed on the
        // connection so far. If renegotiation is disabled then this is either
        // zero or one.
@@ -571,12 +572,12 @@ func (c *Conn) readRecord(want recordType) error {
                c.sendAlert(alertInternalError)
                return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
        case recordTypeHandshake, recordTypeChangeCipherSpec:
-               if c.handshakeComplete {
+               if c.handshakeComplete() {
                        c.sendAlert(alertInternalError)
                        return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
                }
        case recordTypeApplicationData:
-               if !c.handshakeComplete {
+               if !c.handshakeComplete() {
                        c.sendAlert(alertInternalError)
                        return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
                }
@@ -1048,7 +1049,7 @@ func (c *Conn) Write(b []byte) (int, error) {
                return 0, err
        }
 
-       if !c.handshakeComplete {
+       if !c.handshakeComplete() {
                return 0, alertInternalError
        }
 
@@ -1114,7 +1115,7 @@ func (c *Conn) handleRenegotiation() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
 
-       c.handshakeComplete = false
+       atomic.StoreUint32(&c.handshakeStatus, 0)
        if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
                c.handshakes++
        }
@@ -1215,11 +1216,9 @@ func (c *Conn) Close() error {
 
        var alertErr error
 
-       c.handshakeMutex.Lock()
-       if c.handshakeComplete {
+       if c.handshakeComplete() {
                alertErr = c.closeNotify()
        }
-       c.handshakeMutex.Unlock()
 
        if err := c.conn.Close(); err != nil {
                return err
@@ -1233,9 +1232,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
 // called once the handshake has completed and does not call CloseWrite on the
 // underlying connection. Most callers should just use Close.
 func (c *Conn) CloseWrite() error {
-       c.handshakeMutex.Lock()
-       defer c.handshakeMutex.Unlock()
-       if !c.handshakeComplete {
+       if !c.handshakeComplete() {
                return errEarlyCloseWrite
        }
 
@@ -1264,7 +1261,7 @@ func (c *Conn) Handshake() error {
        if err := c.handshakeErr; err != nil {
                return err
        }
-       if c.handshakeComplete {
+       if c.handshakeComplete() {
                return nil
        }
 
@@ -1284,7 +1281,7 @@ func (c *Conn) Handshake() error {
                c.flush()
        }
 
-       if c.handshakeErr == nil && !c.handshakeComplete {
+       if c.handshakeErr == nil && !c.handshakeComplete() {
                panic("handshake should have had a result.")
        }
 
@@ -1297,10 +1294,10 @@ func (c *Conn) ConnectionState() ConnectionState {
        defer c.handshakeMutex.Unlock()
 
        var state ConnectionState
-       state.HandshakeComplete = c.handshakeComplete
+       state.HandshakeComplete = c.handshakeComplete()
        state.ServerName = c.serverName
 
-       if c.handshakeComplete {
+       if state.HandshakeComplete {
                state.Version = c.vers
                state.NegotiatedProtocol = c.clientProtocol
                state.DidResume = c.didResume
@@ -1345,7 +1342,7 @@ func (c *Conn) VerifyHostname(host string) error {
        if !c.isClient {
                return errors.New("tls: VerifyHostname called on TLS server connection")
        }
-       if !c.handshakeComplete {
+       if !c.handshakeComplete() {
                return errors.New("tls: handshake has not yet been performed")
        }
        if len(c.verifiedChains) == 0 {
@@ -1353,3 +1350,7 @@ func (c *Conn) VerifyHostname(host string) error {
        }
        return c.peerCertificates[0].VerifyHostname(host)
 }
+
+func (c *Conn) handshakeComplete() bool {
+       return atomic.LoadUint32(&c.handshakeStatus) == 1
+}
index d7fb3682289607acb5c9abad9bd31fa769a16a62..32fdc6d6eb115ae3a2219d4a79f6f61678766eed 100644 (file)
@@ -17,6 +17,7 @@ import (
        "net"
        "strconv"
        "strings"
+       "sync/atomic"
 )
 
 type clientHandshakeState struct {
@@ -266,7 +267,7 @@ func (hs *clientHandshakeState) handshake() error {
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
        c.didResume = isResume
-       c.handshakeComplete = true
+       atomic.StoreUint32(&c.handshakeStatus, 1)
 
        return nil
 }
index 2ab4e474ec3bac3cdac7fc225d7cca2eb3141b80..79fb3421a8c7c629180b0f9acd8adec14ea41067 100644 (file)
@@ -1617,3 +1617,22 @@ RwBA9Xk1KBNF
                t.Error("A RSA-PSS certificate was parsed like a PKCS1 one, and it will be mistakenly used with rsa_pss_rsae_xxx signature algorithms")
        }
 }
+
+func TestCloseClientConnectionOnIdleServer(t *testing.T) {
+       clientConn, serverConn := net.Pipe()
+       client := Client(clientConn, testConfig.Clone())
+       go func() {
+               var b [1]byte
+               serverConn.Read(b[:])
+               client.Close()
+       }()
+       client.SetWriteDeadline(time.Now().Add(time.Second))
+       err := client.Handshake()
+       if err != nil {
+               if !strings.Contains(err.Error(), "read/write on closed pipe") {
+                       t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error())
+               }
+       } else {
+               t.Errorf("Error expected, but no error returned")
+       }
+}
index 0d685927b37643bcfff59258b0251b9640909b2b..ac491bad390c632610b2ac1b5346a389961bcf80 100644 (file)
@@ -13,6 +13,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "sync/atomic"
 )
 
 // serverHandshakeState contains details of a server handshake in progress.
@@ -103,7 +104,7 @@ func (c *Conn) serverHandshake() error {
        }
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
-       c.handshakeComplete = true
+       atomic.StoreUint32(&c.handshakeStatus, 1)
 
        return nil
 }
index 69e6cc9bd67fb236ace789f5766dd2e7d31c4c78..01d7b5ceecd8ef72d42498fe025562a02f8579a7 100644 (file)
@@ -1403,3 +1403,21 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{
 }
 
 var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75"))
+
+func TestCloseServerConnectionOnIdleClient(t *testing.T) {
+       clientConn, serverConn := net.Pipe()
+       server := Server(serverConn, testConfig.Clone())
+       go func() {
+               clientConn.Write([]byte{'0'})
+               server.Close()
+       }()
+       server.SetReadDeadline(time.Now().Add(time.Second))
+       err := server.Handshake()
+       if err != nil {
+               if !strings.Contains(err.Error(), "read/write on closed pipe") {
+                       t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error())
+               }
+       } else {
+               t.Errorf("Error expected, but no error returned")
+       }
+}