]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: change Conn.handshakeStatus type to atomic.Bool
authorLudi Rehak <ludi317@gmail.com>
Tue, 9 Aug 2022 16:36:17 +0000 (09:36 -0700)
committerGopher Robot <gobot@golang.org>
Thu, 11 Aug 2022 13:58:45 +0000 (13:58 +0000)
Change the type of Conn.handshakeStatus from an atomically
accessed uint32 to an atomic.Bool. Change its name to
Conn.isHandshakeComplete to indicate it is a boolean value.
Eliminate the handshakeComplete() helper function, which checks
for equality with 1, in favor of the simpler
c.isHandshakeComplete.Load().

Change-Id: I084c83956fff266e2145847e8645372bef6ae9df
Reviewed-on: https://go-review.googlesource.com/c/go/+/422296
Auto-Submit: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>

src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_tls13.go

index 1861a312f1c47afb413092d66abcb866306ac64f..b1a7dcc42f72c79489dfbcfcfb5ece8e043fff9c 100644 (file)
@@ -30,11 +30,10 @@ type Conn struct {
        isClient    bool
        handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
 
-       // handshakeStatus is 1 if the connection is currently transferring
+       // isHandshakeComplete is true if the connection is currently transferring
        // application data (i.e. is not currently processing a handshake).
-       // handshakeStatus == 1 implies handshakeErr == nil.
-       // This field is only to be accessed with sync/atomic.
-       handshakeStatus uint32
+       // isHandshakeComplete is true implies handshakeErr == nil.
+       isHandshakeComplete atomic.Bool
        // constant after handshake; protected by handshakeMutex
        handshakeMutex sync.Mutex
        handshakeErr   error   // error resulting from handshake
@@ -604,7 +603,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
        if c.in.err != nil {
                return c.in.err
        }
-       handshakeComplete := c.handshakeComplete()
+       handshakeComplete := c.isHandshakeComplete.Load()
 
        // This function modifies c.rawInput, which owns the c.input memory.
        if c.input.Len() != 0 {
@@ -1130,7 +1129,7 @@ func (c *Conn) Write(b []byte) (int, error) {
                return 0, err
        }
 
-       if !c.handshakeComplete() {
+       if !c.isHandshakeComplete.Load() {
                return 0, alertInternalError
        }
 
@@ -1200,7 +1199,7 @@ func (c *Conn) handleRenegotiation() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
 
-       atomic.StoreUint32(&c.handshakeStatus, 0)
+       c.isHandshakeComplete.Store(false)
        if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
                c.handshakes++
        }
@@ -1337,7 +1336,7 @@ func (c *Conn) Close() error {
        }
 
        var alertErr error
-       if c.handshakeComplete() {
+       if c.isHandshakeComplete.Load() {
                if err := c.closeNotify(); err != nil {
                        alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
                }
@@ -1355,7 +1354,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
 // called once the handshake has completed and does not call CloseWrite on the
 // underlying connection. Most callers should just use Close.
 func (c *Conn) CloseWrite() error {
-       if !c.handshakeComplete() {
+       if !c.isHandshakeComplete.Load() {
                return errEarlyCloseWrite
        }
 
@@ -1409,7 +1408,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
        // Fast sync/atomic-based exit if there is no handshake in flight and the
        // last one succeeded without an error. Avoids the expensive context setup
        // and mutex for most Read and Write calls.
-       if c.handshakeComplete() {
+       if c.isHandshakeComplete.Load() {
                return nil
        }
 
@@ -1452,7 +1451,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
        if err := c.handshakeErr; err != nil {
                return err
        }
-       if c.handshakeComplete() {
+       if c.isHandshakeComplete.Load() {
                return nil
        }
 
@@ -1468,10 +1467,10 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
                c.flush()
        }
 
-       if c.handshakeErr == nil && !c.handshakeComplete() {
+       if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
                c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
        }
-       if c.handshakeErr != nil && c.handshakeComplete() {
+       if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
                panic("tls: internal error: handshake returned an error but is marked successful")
        }
 
@@ -1487,7 +1486,7 @@ func (c *Conn) ConnectionState() ConnectionState {
 
 func (c *Conn) connectionStateLocked() ConnectionState {
        var state ConnectionState
-       state.HandshakeComplete = c.handshakeComplete()
+       state.HandshakeComplete = c.isHandshakeComplete.Load()
        state.Version = c.vers
        state.NegotiatedProtocol = c.clientProtocol
        state.DidResume = c.didResume
@@ -1531,7 +1530,7 @@ func (c *Conn) VerifyHostname(host string) error {
        if !c.isClient {
                return errors.New("tls: VerifyHostname called on TLS server connection")
        }
-       if !c.handshakeComplete() {
+       if !c.isHandshakeComplete.Load() {
                return errors.New("tls: handshake has not yet been performed")
        }
        if len(c.verifiedChains) == 0 {
@@ -1539,7 +1538,3 @@ func (c *Conn) VerifyHostname(host string) error {
        }
        return c.peerCertificates[0].VerifyHostname(host)
 }
-
-func (c *Conn) handshakeComplete() bool {
-       return atomic.LoadUint32(&c.handshakeStatus) == 1
-}
index e61e3eb5409690947452c8579f642d71e6ab3e0e..e07cf79629ad2e5f54ede2faa492a0c9ed479d56 100644 (file)
@@ -19,7 +19,6 @@ import (
        "io"
        "net"
        "strings"
-       "sync/atomic"
        "time"
 )
 
@@ -455,7 +454,7 @@ func (hs *clientHandshakeState) handshake() error {
        }
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
-       atomic.StoreUint32(&c.handshakeStatus, 1)
+       c.isHandshakeComplete.Store(true)
 
        return nil
 }
index c7989867f5637b99f41268dec35188ff1208934a..ac783afdfcaf19887b243c3b952f0bd4e01ce58f 100644 (file)
@@ -12,7 +12,6 @@ import (
        "crypto/rsa"
        "errors"
        "hash"
-       "sync/atomic"
        "time"
 )
 
@@ -104,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
                return err
        }
 
-       atomic.StoreUint32(&c.handshakeStatus, 1)
+       c.isHandshakeComplete.Store(true)
 
        return nil
 }
index 7606305c1dfbcc1573cee069158ff660b3ae468c..844e887af3d677dc35a2734124d97ad041daed09 100644 (file)
@@ -16,7 +16,6 @@ import (
        "fmt"
        "hash"
        "io"
-       "sync/atomic"
        "time"
 )
 
@@ -122,7 +121,7 @@ func (hs *serverHandshakeState) handshake() error {
        }
 
        c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
-       atomic.StoreUint32(&c.handshakeStatus, 1)
+       c.isHandshakeComplete.Store(true)
 
        return nil
 }
index 03a477f7bed54b658f069fec8f7eb15bf1b3de9f..712f3589b3fb7b2b54545c14475a7d6a3515baeb 100644 (file)
@@ -14,7 +14,6 @@ import (
        "errors"
        "hash"
        "io"
-       "sync/atomic"
        "time"
 )
 
@@ -82,7 +81,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
                return err
        }
 
-       atomic.StoreUint32(&c.handshakeStatus, 1)
+       c.isHandshakeComplete.Store(true)
 
        return nil
 }