From e5b13401c6b19f58a8439f1019a80fe540c0c687 Mon Sep 17 00:00:00 2001 From: Minaev Mike Date: Fri, 26 Jan 2018 09:17:46 +0000 Subject: [PATCH] crypto/tls: fix deadlock when Read and Close called concurrently The existing implementation of TLS connection has a deadlock. It occurs when client connects to TLS server and doesn't send data for handshake, so server calls Close on this connection. This is because server reads data under locked mutex, while Close method tries to lock the same mutex. Fixes #23518 Change-Id: I4fb0a2a770f3d911036bfd9a7da7cc41c1b27e19 Reviewed-on: https://go-review.googlesource.com/90155 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Filippo Valsorda --- src/crypto/tls/conn.go | 37 +++++++++++++------------ src/crypto/tls/handshake_client.go | 3 +- src/crypto/tls/handshake_client_test.go | 19 +++++++++++++ src/crypto/tls/handshake_server.go | 3 +- src/crypto/tls/handshake_server_test.go | 18 ++++++++++++ 5 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index cdaa7aba97..2adb967537 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -27,15 +27,16 @@ type Conn struct { conn net.Conn isClient bool + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex handshakeErr error // error resulting from handshake vers uint16 // TLS version haveVers bool // version has been negotiated config *Config // configuration passed to constructor - // handshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - handshakeComplete bool // handshakes counts the number of handshakes performed on the // connection so far. If renegotiation is disabled then this is either // zero or one. @@ -571,12 +572,12 @@ func (c *Conn) readRecord(want recordType) error { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) case recordTypeHandshake, recordTypeChangeCipherSpec: - if c.handshakeComplete { + if c.handshakeComplete() { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) } case recordTypeApplicationData: - if !c.handshakeComplete { + if !c.handshakeComplete() { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) } @@ -1048,7 +1049,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if !c.handshakeComplete { + if !c.handshakeComplete() { return 0, alertInternalError } @@ -1114,7 +1115,7 @@ func (c *Conn) handleRenegotiation() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - c.handshakeComplete = false + atomic.StoreUint32(&c.handshakeStatus, 0) if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { c.handshakes++ } @@ -1215,11 +1216,9 @@ func (c *Conn) Close() error { var alertErr error - c.handshakeMutex.Lock() - if c.handshakeComplete { + if c.handshakeComplete() { alertErr = c.closeNotify() } - c.handshakeMutex.Unlock() if err := c.conn.Close(); err != nil { return err @@ -1233,9 +1232,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com // called once the handshake has completed and does not call CloseWrite on the // underlying connection. Most callers should just use Close. func (c *Conn) CloseWrite() error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.handshakeComplete { + if !c.handshakeComplete() { return errEarlyCloseWrite } @@ -1264,7 +1261,7 @@ func (c *Conn) Handshake() error { if err := c.handshakeErr; err != nil { return err } - if c.handshakeComplete { + if c.handshakeComplete() { return nil } @@ -1284,7 +1281,7 @@ func (c *Conn) Handshake() error { c.flush() } - if c.handshakeErr == nil && !c.handshakeComplete { + if c.handshakeErr == nil && !c.handshakeComplete() { panic("handshake should have had a result.") } @@ -1297,10 +1294,10 @@ func (c *Conn) ConnectionState() ConnectionState { defer c.handshakeMutex.Unlock() var state ConnectionState - state.HandshakeComplete = c.handshakeComplete + state.HandshakeComplete = c.handshakeComplete() state.ServerName = c.serverName - if c.handshakeComplete { + if state.HandshakeComplete { state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume @@ -1345,7 +1342,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if !c.handshakeComplete { + if !c.handshakeComplete() { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { @@ -1353,3 +1350,7 @@ func (c *Conn) VerifyHostname(host string) error { } return c.peerCertificates[0].VerifyHostname(host) } + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index d7fb368228..32fdc6d6eb 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -17,6 +17,7 @@ import ( "net" "strconv" "strings" + "sync/atomic" ) type clientHandshakeState struct { @@ -266,7 +267,7 @@ func (hs *clientHandshakeState) handshake() error { c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) c.didResume = isResume - c.handshakeComplete = true + atomic.StoreUint32(&c.handshakeStatus, 1) return nil } diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 2ab4e474ec..79fb3421a8 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -1617,3 +1617,22 @@ RwBA9Xk1KBNF t.Error("A RSA-PSS certificate was parsed like a PKCS1 one, and it will be mistakenly used with rsa_pss_rsae_xxx signature algorithms") } } + +func TestCloseClientConnectionOnIdleServer(t *testing.T) { + clientConn, serverConn := net.Pipe() + client := Client(clientConn, testConfig.Clone()) + go func() { + var b [1]byte + serverConn.Read(b[:]) + client.Close() + }() + client.SetWriteDeadline(time.Now().Add(time.Second)) + err := client.Handshake() + if err != nil { + if !strings.Contains(err.Error(), "read/write on closed pipe") { + t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) + } + } else { + t.Errorf("Error expected, but no error returned") + } +} diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 0d685927b3..ac491bad39 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "sync/atomic" ) // serverHandshakeState contains details of a server handshake in progress. @@ -103,7 +104,7 @@ func (c *Conn) serverHandshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - c.handshakeComplete = true + atomic.StoreUint32(&c.handshakeStatus, 1) return nil } diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 69e6cc9bd6..01d7b5ceec 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -1403,3 +1403,21 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{ } var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) + +func TestCloseServerConnectionOnIdleClient(t *testing.T) { + clientConn, serverConn := net.Pipe() + server := Server(serverConn, testConfig.Clone()) + go func() { + clientConn.Write([]byte{'0'}) + server.Close() + }() + server.SetReadDeadline(time.Now().Add(time.Second)) + err := server.Handshake() + if err != nil { + if !strings.Contains(err.Error(), "read/write on closed pipe") { + t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) + } + } else { + t.Errorf("Error expected, but no error returned") + } +} -- 2.48.1