"io"
"net"
"sync"
+ "sync/atomic"
"time"
)
input *block // application data waiting to be read
hand bytes.Buffer // handshake data waiting to be read
+ // activeCall is an atomic int32; the low bit is whether Close has
+ // been called. the rest of the bits are the number of goroutines
+ // in Conn.Write.
+ activeCall int32
+
tmp [16]byte
}
return m, nil
}
+var errClosed = errors.New("crypto/tls: use of closed connection")
+
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
+ // interlock with Close below
+ for {
+ x := atomic.LoadInt32(&c.activeCall)
+ if x&1 != 0 {
+ return 0, errClosed
+ }
+ if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
+ defer atomic.AddInt32(&c.activeCall, -2)
+ break
+ }
+ }
+
if err := c.Handshake(); err != nil {
return 0, err
}
// Close closes the connection.
func (c *Conn) Close() error {
+ // Interlock with Conn.Write above.
+ var x int32
+ for {
+ x = atomic.LoadInt32(&c.activeCall)
+ if x&1 != 0 {
+ return errClosed
+ }
+ if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
+ break
+ }
+ }
+ if x != 0 {
+ // io.Writer and io.Closer should not be used concurrently.
+ // If Close is called while a Write is currently in-flight,
+ // interpret that as a sign that this Close is really just
+ // being used to break the Write and/or clean up resources and
+ // avoid sending the alertCloseNotify, which may block
+ // waiting on handshakeMutex or the c.out mutex.
+ return c.conn.Close()
+ }
+
var alertErr error
c.handshakeMutex.Lock()
import (
"bytes"
+ "errors"
"fmt"
"internal/testenv"
"io"
c.Close()
}
}
+
+func TestConnCloseBreakingWrite(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ srvCh := make(chan *Conn, 1)
+ var serr error
+ var sconn net.Conn
+ go func() {
+ var err error
+ sconn, err = ln.Accept()
+ if err != nil {
+ serr = err
+ srvCh <- nil
+ return
+ }
+ serverConfig := *testConfig
+ srv := Server(sconn, &serverConfig)
+ if err := srv.Handshake(); err != nil {
+ serr = fmt.Errorf("handshake: %v", err)
+ srvCh <- nil
+ return
+ }
+ srvCh <- srv
+ }()
+
+ cconn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cconn.Close()
+
+ conn := &changeImplConn{
+ Conn: cconn,
+ }
+
+ clientConfig := *testConfig
+ tconn := Client(conn, &clientConfig)
+ if err := tconn.Handshake(); err != nil {
+ t.Fatal(err)
+ }
+
+ srv := <-srvCh
+ if srv == nil {
+ t.Fatal(serr)
+ }
+ defer sconn.Close()
+
+ connClosed := make(chan struct{})
+ conn.closeFunc = func() error {
+ close(connClosed)
+ return nil
+ }
+
+ inWrite := make(chan bool, 1)
+ var errConnClosed = errors.New("conn closed for test")
+ conn.writeFunc = func(p []byte) (n int, err error) {
+ inWrite <- true
+ <-connClosed
+ return 0, errConnClosed
+ }
+
+ closeReturned := make(chan bool, 1)
+ go func() {
+ <-inWrite
+ tconn.Close() // test that this doesn't block forever.
+ closeReturned <- true
+ }()
+
+ _, err = tconn.Write([]byte("foo"))
+ if err != errConnClosed {
+ t.Errorf("Write error = %v; want errConnClosed", err)
+ }
+
+ <-closeReturned
+ if err := tconn.Close(); err != errClosed {
+ t.Errorf("Close error = %v; want errClosed", err)
+ }
+}
+
+// changeImplConn is a net.Conn which can change its Write and Close
+// methods.
+type changeImplConn struct {
+ net.Conn
+ writeFunc func([]byte) (int, error)
+ closeFunc func() error
+}
+
+func (w *changeImplConn) Write(p []byte) (n int, err error) {
+ if w.writeFunc != nil {
+ return w.writeFunc(p)
+ }
+ return w.Conn.Write(p)
+}
+
+func (w *changeImplConn) Close() error {
+ if w.closeFunc != nil {
+ return w.closeFunc()
+ }
+ return w.Conn.Close()
+}