// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
- handshakeErr error // error resulting from handshake
- vers uint16 // TLS version
- haveVers bool // version has been negotiated
- config *Config // configuration passed to constructor
+ // handshakeCond, if not nil, indicates that a goroutine is committed
+ // to running the handshake for this Conn. Other goroutines that need
+ // to wait for the handshake can wait on this, under handshakeMutex.
+ handshakeCond *sync.Cond
+ handshakeErr error // error resulting from handshake
+ vers uint16 // TLS version
+ haveVers bool // version has been negotiated
+ config *Config // configuration passed to constructor
// handshakeComplete is true if the connection is currently transfering
// application data (i.e. is not currently processing a handshake).
handshakeComplete bool
// need to check whether a handshake is pending (such as Write) to
// block.
//
- // Thus we take c.handshakeMutex first and, if we find that a handshake
- // is needed, then we unlock, acquire c.in and c.handshakeMutex in the
- // correct order, and check again.
+ // Thus we first take c.handshakeMutex to check whether a handshake is
+ // needed.
+ //
+ // If so then, previously, this code would unlock handshakeMutex and
+ // then lock c.in and handshakeMutex in the correct order to run the
+ // handshake. The problem was that it was possible for a Read to
+ // complete the handshake once handshakeMutex was unlocked and then
+ // keep c.in while waiting for network data. Thus a concurrent
+ // operation could be blocked on c.in.
+ //
+ // Thus handshakeCond is used to signal that a goroutine is committed
+ // to running the handshake and other goroutines can wait on it if they
+ // need. handshakeCond is protected by handshakeMutex.
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
- for i := 0; i < 2; i++ {
- if i == 1 {
- c.handshakeMutex.Unlock()
- c.in.Lock()
- defer c.in.Unlock()
- c.handshakeMutex.Lock()
- }
-
+ for {
if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}
+ if c.handshakeCond == nil {
+ break
+ }
+
+ c.handshakeCond.Wait()
+ }
+
+ // Set handshakeCond to indicate that this goroutine is committing to
+ // running the handshake.
+ c.handshakeCond = sync.NewCond(&c.handshakeMutex)
+ c.handshakeMutex.Unlock()
+
+ c.in.Lock()
+ defer c.in.Unlock()
+
+ c.handshakeMutex.Lock()
+
+ // The handshake cannot have completed when handshakeMutex was unlocked
+ // because this goroutine set handshakeCond.
+ if c.handshakeErr != nil || c.handshakeComplete {
+ panic("handshake should not have been able to complete after handshakeCond was set")
}
if c.isClient {
// alert that might be left in the buffer.
c.flush()
}
+
+ if c.handshakeErr == nil && !c.handshakeComplete {
+ panic("handshake should have had a result.")
+ }
+
+ // Wake any other goroutines that are waiting for this handshake to
+ // complete.
+ c.handshakeCond.Broadcast()
+ c.handshakeCond = nil
+
return c.handshakeErr
}
t.Errorf("expected server handshake to complete with one write, but saw %d", n)
}
}
+
+func TestHandshakeRace(t *testing.T) {
+ // This test races a Read and Write to try and complete a handshake in
+ // order to provide some evidence that there are no races or deadlocks
+ // in the handshake locking.
+ for i := 0; i < 32; i++ {
+ c, s := net.Pipe()
+
+ go func() {
+ server := Server(s, testConfig)
+ if err := server.Handshake(); err != nil {
+ panic(err)
+ }
+
+ var request [1]byte
+ if n, err := server.Read(request[:]); err != nil || n != 1 {
+ panic(err)
+ }
+
+ server.Write(request[:])
+ server.Close()
+ }()
+
+ startWrite := make(chan struct{})
+ startRead := make(chan struct{})
+ readDone := make(chan struct{})
+
+ client := Client(c, testConfig)
+ go func() {
+ <-startWrite
+ var request [1]byte
+ client.Write(request[:])
+ }()
+
+ go func() {
+ <-startRead
+ var reply [1]byte
+ if n, err := client.Read(reply[:]); err != nil || n != 1 {
+ panic(err)
+ }
+ c.Close()
+ readDone <- struct{}{}
+ }()
+
+ if i&1 == 1 {
+ startWrite <- struct{}{}
+ startRead <- struct{}{}
+ } else {
+ startRead <- struct{}{}
+ startWrite <- struct{}{}
+ }
+ <-readDone
+ }
+}