]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: split connErr to avoid read/write races.
authorAdam Langley <agl@golang.org>
Mon, 3 Mar 2014 14:01:44 +0000 (09:01 -0500)
committerAdam Langley <agl@golang.org>
Mon, 3 Mar 2014 14:01:44 +0000 (09:01 -0500)
Currently a write error will cause future reads to return that same error.
However, there may have been extra information from a peer pending on
the read direction that is now unavailable.

This change splits the single connErr into errors for the read, write and
handshake. (Splitting off the handshake error is needed because both read
and write paths check the handshake error.)

Fixes #7414.

LGTM=bradfitz, r
R=golang-codereviews, r, bradfitz
CC=golang-codereviews
https://golang.org/cl/69090044

src/pkg/crypto/tls/conn.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_server.go

index f1bb3f613d1115159e2ea891310d1cb08c7816ad..d25ad287aa3f7a3635a3208497369097f2e67f92 100644 (file)
@@ -28,6 +28,7 @@ type Conn struct {
 
        // constant after handshake; protected by handshakeMutex
        handshakeMutex    sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
+       handshakeErr      error      // error resulting from handshake
        vers              uint16     // TLS version
        haveVers          bool       // version has been negotiated
        config            *Config    // configuration passed to constructor
@@ -45,9 +46,6 @@ type Conn struct {
        clientProtocol         string
        clientProtocolFallback bool
 
-       // first permanent error
-       connErr
-
        // input/output
        in, out  halfConn     // in.Mutex < out.Mutex
        rawInput *block       // raw input, right off the wire
@@ -57,27 +55,6 @@ type Conn struct {
        tmp [16]byte
 }
 
-type connErr struct {
-       mu    sync.Mutex
-       value error
-}
-
-func (e *connErr) setError(err error) error {
-       e.mu.Lock()
-       defer e.mu.Unlock()
-
-       if e.value == nil {
-               e.value = err
-       }
-       return err
-}
-
-func (e *connErr) error() error {
-       e.mu.Lock()
-       defer e.mu.Unlock()
-       return e.value
-}
-
 // Access to net.Conn methods.
 // Cannot just embed net.Conn because that would
 // export the struct field too.
@@ -116,6 +93,8 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
 // connection, either sending or receiving.
 type halfConn struct {
        sync.Mutex
+
+       err     error       // first permanent error
        version uint16      // protocol version
        cipher  interface{} // cipher algorithm
        mac     macFunction
@@ -129,6 +108,18 @@ type halfConn struct {
        inDigestBuf, outDigestBuf []byte
 }
 
+func (hc *halfConn) setErrorLocked(err error) error {
+       hc.err = err
+       return err
+}
+
+func (hc *halfConn) error() error {
+       hc.Lock()
+       err := hc.err
+       hc.Unlock()
+       return err
+}
+
 // prepareCipherSpec sets the encryption and MAC states
 // that a subsequent changeCipherSpec will use.
 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
@@ -520,16 +511,16 @@ func (c *Conn) readRecord(want recordType) error {
        switch want {
        default:
                c.sendAlert(alertInternalError)
-               return errors.New("tls: unknown record type requested")
+               return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
        case recordTypeHandshake, recordTypeChangeCipherSpec:
                if c.handshakeComplete {
                        c.sendAlert(alertInternalError)
-                       return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
+                       return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
                }
        case recordTypeApplicationData:
                if !c.handshakeComplete {
                        c.sendAlert(alertInternalError)
-                       return errors.New("tls: application data record requested before handshake complete")
+                       return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
                }
        }
 
@@ -548,7 +539,7 @@ Again:
                //      err = io.ErrUnexpectedEOF
                // }
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
-                       c.setError(err)
+                       c.in.setErrorLocked(err)
                }
                return err
        }
@@ -560,18 +551,18 @@ Again:
        // an SSLv2 client.
        if want == recordTypeHandshake && typ == 0x80 {
                c.sendAlert(alertProtocolVersion)
-               return errors.New("tls: unsupported SSLv2 handshake received")
+               return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
        }
 
        vers := uint16(b.data[1])<<8 | uint16(b.data[2])
        n := int(b.data[3])<<8 | int(b.data[4])
        if c.haveVers && vers != c.vers {
                c.sendAlert(alertProtocolVersion)
-               return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
+               return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
        }
        if n > maxCiphertext {
                c.sendAlert(alertRecordOverflow)
-               return fmt.Errorf("tls: oversized record received with length %d", n)
+               return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
        }
        if !c.haveVers {
                // First message, be extra suspicious:
@@ -584,7 +575,7 @@ Again:
                // it's probably not real.
                if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
                        c.sendAlert(alertUnexpectedMessage)
-                       return fmt.Errorf("tls: first record does not look like a TLS handshake")
+                       return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
                }
        }
        if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -592,7 +583,7 @@ Again:
                        err = io.ErrUnexpectedEOF
                }
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
-                       c.setError(err)
+                       c.in.setErrorLocked(err)
                }
                return err
        }
@@ -601,27 +592,27 @@ Again:
        b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
        ok, off, err := c.in.decrypt(b)
        if !ok {
-               return c.sendAlert(err)
+               c.in.setErrorLocked(c.sendAlert(err))
        }
        b.off = off
        data := b.data[b.off:]
        if len(data) > maxPlaintext {
-               c.sendAlert(alertRecordOverflow)
+               err := c.sendAlert(alertRecordOverflow)
                c.in.freeBlock(b)
-               return c.error()
+               return c.in.setErrorLocked(err)
        }
 
        switch typ {
        default:
-               c.sendAlert(alertUnexpectedMessage)
+               c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 
        case recordTypeAlert:
                if len(data) != 2 {
-                       c.sendAlert(alertUnexpectedMessage)
+                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                        break
                }
                if alert(data[1]) == alertCloseNotify {
-                       c.setError(io.EOF)
+                       c.in.setErrorLocked(io.EOF)
                        break
                }
                switch data[0] {
@@ -630,24 +621,24 @@ Again:
                        c.in.freeBlock(b)
                        goto Again
                case alertLevelError:
-                       c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
+                       c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
                default:
-                       c.sendAlert(alertUnexpectedMessage)
+                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                }
 
        case recordTypeChangeCipherSpec:
                if typ != want || len(data) != 1 || data[0] != 1 {
-                       c.sendAlert(alertUnexpectedMessage)
+                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                        break
                }
                err := c.in.changeCipherSpec()
                if err != nil {
-                       c.sendAlert(err.(alert))
+                       c.in.setErrorLocked(c.sendAlert(err.(alert)))
                }
 
        case recordTypeApplicationData:
                if typ != want {
-                       c.sendAlert(alertUnexpectedMessage)
+                       c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
                        break
                }
                c.input = b
@@ -656,7 +647,7 @@ Again:
        case recordTypeHandshake:
                // TODO(rsc): Should at least pick off connection close.
                if typ != want {
-                       return c.sendAlert(alertNoRenegotiation)
+                       return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
                }
                c.hand.Write(data)
        }
@@ -664,7 +655,7 @@ Again:
        if b != nil {
                c.in.freeBlock(b)
        }
-       return c.error()
+       return c.in.err
 }
 
 // sendAlert sends a TLS alert message.
@@ -680,7 +671,7 @@ func (c *Conn) sendAlertLocked(err alert) error {
        c.writeRecord(recordTypeAlert, c.tmp[0:2])
        // closeNotify is a special case in that it isn't an error:
        if err != alertCloseNotify {
-               return c.setError(&net.OpError{Op: "local error", Err: err})
+               return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
        }
        return nil
 }
@@ -766,7 +757,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
                        c.tmp[0] = alertLevelError
                        c.tmp[1] = byte(err.(alert))
                        c.writeRecord(recordTypeAlert, c.tmp[0:2])
-                       return n, c.setError(&net.OpError{Op: "local error", Err: err})
+                       return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
                }
        }
        return
@@ -777,7 +768,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
 // c.in.Mutex < L; c.out.Mutex < L.
 func (c *Conn) readHandshake() (interface{}, error) {
        for c.hand.Len() < 4 {
-               if err := c.error(); err != nil {
+               if err := c.in.err; err != nil {
                        return nil, err
                }
                if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -788,11 +779,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
        data := c.hand.Bytes()
        n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
        if n > maxHandshake {
-               c.sendAlert(alertInternalError)
-               return nil, c.error()
+               return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
        }
        for c.hand.Len() < 4+n {
-               if err := c.error(); err != nil {
+               if err := c.in.err; err != nil {
                        return nil, err
                }
                if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -831,8 +821,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
        case typeFinished:
                m = new(finishedMsg)
        default:
-               c.sendAlert(alertUnexpectedMessage)
-               return nil, alertUnexpectedMessage
+               return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
        }
 
        // The handshake message unmarshallers
@@ -841,25 +830,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
        data = append([]byte(nil), data...)
 
        if !m.unmarshal(data) {
-               c.sendAlert(alertUnexpectedMessage)
-               return nil, alertUnexpectedMessage
+               return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
        }
        return m, nil
 }
 
 // Write writes data to the connection.
 func (c *Conn) Write(b []byte) (int, error) {
-       if err := c.error(); err != nil {
-               return 0, err
-       }
-
        if err := c.Handshake(); err != nil {
-               return 0, c.setError(err)
+               return 0, err
        }
 
        c.out.Lock()
        defer c.out.Unlock()
 
+       if err := c.out.err; err != nil {
+               return 0, err
+       }
+
        if !c.handshakeComplete {
                return 0, alertInternalError
        }
@@ -878,14 +866,14 @@ func (c *Conn) Write(b []byte) (int, error) {
                if _, ok := c.out.cipher.(cipher.BlockMode); ok {
                        n, err := c.writeRecord(recordTypeApplicationData, b[:1])
                        if err != nil {
-                               return n, c.setError(err)
+                               return n, c.out.setErrorLocked(err)
                        }
                        m, b = 1, b[1:]
                }
        }
 
        n, err := c.writeRecord(recordTypeApplicationData, b)
-       return n + m, c.setError(err)
+       return n + m, c.out.setErrorLocked(err)
 }
 
 // Read can be made to time out and return a net.Error with Timeout() == true
@@ -902,13 +890,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
        // CBC IV. So this loop ignores a limited number of empty records.
        const maxConsecutiveEmptyRecords = 100
        for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
-               for c.input == nil && c.error() == nil {
+               for c.input == nil && c.in.err == nil {
                        if err := c.readRecord(recordTypeApplicationData); err != nil {
                                // Soft error, like EAGAIN
                                return 0, err
                        }
                }
-               if err := c.error(); err != nil {
+               if err := c.in.err; err != nil {
                        return 0, err
                }
 
@@ -949,16 +937,19 @@ func (c *Conn) Close() error {
 func (c *Conn) Handshake() error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
-       if err := c.error(); err != nil {
+       if err := c.handshakeErr; err != nil {
                return err
        }
        if c.handshakeComplete {
                return nil
        }
+
        if c.isClient {
-               return c.clientHandshake()
+               c.handshakeErr = c.clientHandshake()
+       } else {
+               c.handshakeErr = c.serverHandshake()
        }
-       return c.serverHandshake()
+       return c.handshakeErr
 }
 
 // ConnectionState returns basic TLS details about the connection.
index ce0a306fcd24837f7441b3bf54254ef911a0cb80..a320fde1bc7cb82c2f91f76cbf089b783179bc9c 100644 (file)
@@ -501,7 +501,7 @@ func (hs *clientHandshakeState) readFinished() error {
        c := hs.c
 
        c.readRecord(recordTypeChangeCipherSpec)
-       if err := c.error(); err != nil {
+       if err := c.in.error(); err != nil {
                return err
        }
 
index 7fb5365ff68215c1fca9eeb426dec612f0d6a78d..75111eba0042f54c2cd42b59bd1f82161c957c84 100644 (file)
@@ -470,7 +470,7 @@ func (hs *serverHandshakeState) readFinished() error {
        c := hs.c
 
        c.readRecord(recordTypeChangeCipherSpec)
-       if err := c.error(); err != nil {
+       if err := c.in.error(); err != nil {
                return err
        }