]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: fix data race on conn.err
authorDave Cheney <dave@cheney.net>
Thu, 6 Sep 2012 07:50:26 +0000 (17:50 +1000)
committerDave Cheney <dave@cheney.net>
Thu, 6 Sep 2012 07:50:26 +0000 (17:50 +1000)
Fixes #3862.

There were many areas where conn.err was being accessed
outside the mutex. This proposal moves the err value to
an embedded struct to make it more obvious when the error
value is being accessed.

As there are no Benchmark tests in this package I cannot
feel confident of the impact of this additional locking,
although most will be uncontended.

R=dvyukov, agl
CC=golang-dev
https://golang.org/cl/6497070

src/pkg/crypto/tls/conn.go
src/pkg/crypto/tls/handshake_client.go

index 455910af41503575b69aa7aaca266fe763efe5e0..5dc344bed5bcc36846a0161b45567b5cd41edd73 100644 (file)
@@ -44,8 +44,7 @@ type Conn struct {
        clientProtocolFallback bool
 
        // first permanent error
-       errMutex sync.Mutex
-       err      error
+       connErr
 
        // input/output
        in, out  halfConn     // in.Mutex < out.Mutex
@@ -56,21 +55,25 @@ type Conn struct {
        tmp [16]byte
 }
 
-func (c *Conn) setError(err error) error {
-       c.errMutex.Lock()
-       defer c.errMutex.Unlock()
+type connErr struct {
+       mu    sync.Mutex
+       value error
+}
+
+func (e *connErr) setError(err error) error {
+       e.mu.Lock()
+       defer e.mu.Unlock()
 
-       if c.err == nil {
-               c.err = err
+       if e.value == nil {
+               e.value = err
        }
        return err
 }
 
-func (c *Conn) error() error {
-       c.errMutex.Lock()
-       defer c.errMutex.Unlock()
-
-       return c.err
+func (e *connErr) error() error {
+       e.mu.Lock()
+       defer e.mu.Unlock()
+       return e.value
 }
 
 // Access to net.Conn methods.
@@ -660,8 +663,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
                        c.tmp[0] = alertLevelError
                        c.tmp[1] = byte(err.(alert))
                        c.writeRecord(recordTypeAlert, c.tmp[0:2])
-                       c.err = &net.OpError{Op: "local error", Err: err}
-                       return n, c.err
+                       return n, c.setError(&net.OpError{Op: "local error", Err: err})
                }
        }
        return
@@ -672,8 +674,8 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
 // c.in.Mutex < L; c.out.Mutex < L.
 func (c *Conn) readHandshake() (interface{}, error) {
        for c.hand.Len() < 4 {
-               if c.err != nil {
-                       return nil, c.err
+               if err := c.error(); err != nil {
+                       return nil, err
                }
                if err := c.readRecord(recordTypeHandshake); err != nil {
                        return nil, err
@@ -684,11 +686,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
        n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
        if n > maxHandshake {
                c.sendAlert(alertInternalError)
-               return nil, c.err
+               return nil, c.error()
        }
        for c.hand.Len() < 4+n {
-               if c.err != nil {
-                       return nil, c.err
+               if err := c.error(); err != nil {
+                       return nil, err
                }
                if err := c.readRecord(recordTypeHandshake); err != nil {
                        return nil, err
@@ -738,12 +740,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
 
 // Write writes data to the connection.
 func (c *Conn) Write(b []byte) (int, error) {
-       if c.err != nil {
-               return 0, c.err
+       if err := c.error(); err != nil {
+               return 0, err
        }
 
-       if c.err = c.Handshake(); c.err != nil {
-               return 0, c.err
+       if err := c.Handshake(); err != nil {
+               return 0, c.setError(err)
        }
 
        c.out.Lock()
@@ -753,9 +755,8 @@ func (c *Conn) Write(b []byte) (int, error) {
                return 0, alertInternalError
        }
 
-       var n int
-       n, c.err = c.writeRecord(recordTypeApplicationData, b)
-       return n, c.err
+       n, err := c.writeRecord(recordTypeApplicationData, b)
+       return n, c.setError(err)
 }
 
 // Read can be made to time out and return a net.Error with Timeout() == true
@@ -768,14 +769,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
        c.in.Lock()
        defer c.in.Unlock()
 
-       for c.input == nil && c.err == nil {
+       for c.input == nil && c.error() == nil {
                if err := c.readRecord(recordTypeApplicationData); err != nil {
                        // Soft error, like EAGAIN
                        return 0, err
                }
        }
-       if c.err != nil {
-               return 0, c.err
+       if err := c.error(); err != nil {
+               return 0, err
        }
        n, err = c.input.Read(b)
        if c.input.off >= len(c.input.data) {
index 2877f17387dcf272d678e153488e751e12277142..c6637c593c9b03b2c3a5bdee1a8cf86dbd6d8504 100644 (file)
@@ -306,8 +306,8 @@ func (c *Conn) clientHandshake() error {
        serverHash := suite.mac(c.vers, serverMAC)
        c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
        c.readRecord(recordTypeChangeCipherSpec)
-       if c.err != nil {
-               return c.err
+       if err := c.error(); err != nil {
+               return err
        }
 
        msg, err = c.readHandshake()