]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: buffer handshake messages.
authorAdam Langley <agl@golang.org>
Wed, 1 Jun 2016 21:41:09 +0000 (14:41 -0700)
committerAndrew Gerrand <adg@golang.org>
Wed, 1 Jun 2016 23:26:04 +0000 (23:26 +0000)
This change causes TLS handshake messages to be buffered and written in
a single Write to the underlying net.Conn.

There are two reasons to want to do this:

Firstly, it's slightly preferable to do this in order to save sending
several, small packets over the network where a single one will do.

Secondly, since 37c28759ca46cf381a466e32168a793165d9c9e9 errors from
Write have been returned from a handshake. This means that, if a peer
closes the connection during a handshake, a “broken pipe” error may
result from tls.Conn.Handshake(). This can mask any, more detailed,
fatal alerts that the peer may have sent because a read will never
happen.

Buffering handshake messages means that the peer will not receive, and
possibly reject, any of a flow while it's still being written.

Fixes #15709

Change-Id: I38dcff1abecc06e52b2de647ea98713ce0fb9a21
Reviewed-on: https://go-review.googlesource.com/23609
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Andrew Gerrand <adg@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go

index 40c17440d652dc694d057a114d787447aeec3319..87bef23d91ff874625cefac2da3a15f6a1332881 100644 (file)
@@ -71,10 +71,12 @@ type Conn struct {
        clientProtocolFallback bool
 
        // input/output
-       in, out  halfConn     // in.Mutex < out.Mutex
-       rawInput *block       // raw input, right off the wire
-       input    *block       // application data waiting to be read
-       hand     bytes.Buffer // handshake data waiting to be read
+       in, out   halfConn     // in.Mutex < out.Mutex
+       rawInput  *block       // raw input, right off the wire
+       input     *block       // application data waiting to be read
+       hand      bytes.Buffer // handshake data waiting to be read
+       buffering bool         // whether records are buffered in sendBuf
+       sendBuf   []byte       // a buffer of records waiting to be sent
 
        // bytesSent counts the bytes of application data sent.
        // packetsSent counts packets.
@@ -803,6 +805,30 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
        return n
 }
 
+// c.out.Mutex <= L.
+func (c *Conn) write(data []byte) (int, error) {
+       if c.buffering {
+               c.sendBuf = append(c.sendBuf, data...)
+               return len(data), nil
+       }
+
+       n, err := c.conn.Write(data)
+       c.bytesSent += int64(n)
+       return n, err
+}
+
+func (c *Conn) flush() (int, error) {
+       if len(c.sendBuf) == 0 {
+               return 0, nil
+       }
+
+       n, err := c.conn.Write(c.sendBuf)
+       c.bytesSent += int64(n)
+       c.sendBuf = nil
+       c.buffering = false
+       return n, err
+}
+
 // writeRecordLocked writes a TLS record with the given type and payload to the
 // connection and updates the record layer state.
 // c.out.Mutex <= L.
@@ -862,10 +888,9 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
                }
                copy(b.data[recordHeaderLen+explicitIVLen:], data)
                c.out.encrypt(b, explicitIVLen)
-               if _, err := c.conn.Write(b.data); err != nil {
+               if _, err := c.write(b.data); err != nil {
                        return n, err
                }
-               c.bytesSent += int64(m)
                n += m
                data = data[m:]
        }
index 475737b9893246b2d855f6dc0f6d2d40969f2776..f789e6f888f9adc219abe97d6571935675a02f97 100644 (file)
@@ -206,6 +206,7 @@ NextCipherSuite:
        hs.finishedHash.Write(hs.hello.marshal())
        hs.finishedHash.Write(hs.serverHello.marshal())
 
+       c.buffering = true
        if isResume {
                if err := hs.establishKeys(); err != nil {
                        return err
@@ -220,6 +221,9 @@ NextCipherSuite:
                if err := hs.sendFinished(c.clientFinished[:]); err != nil {
                        return err
                }
+               if _, err := c.flush(); err != nil {
+                       return err
+               }
        } else {
                if err := hs.doFullHandshake(); err != nil {
                        return err
@@ -230,6 +234,9 @@ NextCipherSuite:
                if err := hs.sendFinished(c.clientFinished[:]); err != nil {
                        return err
                }
+               if _, err := c.flush(); err != nil {
+                       return err
+               }
                c.clientFinishedIsFirst = true
                if err := hs.readSessionTicket(); err != nil {
                        return err
index 40b0770e12c6377eaa750f2ccacc983840280cd4..c5000e590738a468c60c46b514e5f968cea3de12 100644 (file)
@@ -983,7 +983,7 @@ func (b *brokenConn) Write(data []byte) (int, error) {
 
 func TestFailedWrite(t *testing.T) {
        // Test that a write error during the handshake is returned.
-       for _, breakAfter := range []int{0, 1, 2, 3} {
+       for _, breakAfter := range []int{0, 1} {
                c, s := net.Pipe()
                done := make(chan bool)
 
@@ -1003,3 +1003,45 @@ func TestFailedWrite(t *testing.T) {
                <-done
        }
 }
+
+// writeCountingConn wraps a net.Conn and counts the number of Write calls.
+type writeCountingConn struct {
+       net.Conn
+
+       // numWrites is the number of writes that have been done.
+       numWrites int
+}
+
+func (wcc *writeCountingConn) Write(data []byte) (int, error) {
+       wcc.numWrites++
+       return wcc.Conn.Write(data)
+}
+
+func TestBuffering(t *testing.T) {
+       c, s := net.Pipe()
+       done := make(chan bool)
+
+       clientWCC := &writeCountingConn{Conn: c}
+       serverWCC := &writeCountingConn{Conn: s}
+
+       go func() {
+               Server(serverWCC, testConfig).Handshake()
+               serverWCC.Close()
+               done <- true
+       }()
+
+       err := Client(clientWCC, testConfig).Handshake()
+       if err != nil {
+               t.Fatal(err)
+       }
+       clientWCC.Close()
+       <-done
+
+       if n := clientWCC.numWrites; n != 2 {
+               t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
+       }
+
+       if n := serverWCC.numWrites; n != 2 {
+               t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
+       }
+}
index cf617df19f5c0023d173fc6120718a76898de151..1aac72956142d689c64e7fe661d4fbf42e9d67d0 100644 (file)
@@ -52,6 +52,7 @@ func (c *Conn) serverHandshake() error {
        }
 
        // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
+       c.buffering = true
        if isResume {
                // The client has included a session ticket and so we do an abbreviated handshake.
                if err := hs.doResumeHandshake(); err != nil {
@@ -71,6 +72,9 @@ func (c *Conn) serverHandshake() error {
                if err := hs.sendFinished(c.serverFinished[:]); err != nil {
                        return err
                }
+               if _, err := c.flush(); err != nil {
+                       return err
+               }
                c.clientFinishedIsFirst = false
                if err := hs.readFinished(nil); err != nil {
                        return err
@@ -89,12 +93,16 @@ func (c *Conn) serverHandshake() error {
                        return err
                }
                c.clientFinishedIsFirst = true
+               c.buffering = true
                if err := hs.sendSessionTicket(); err != nil {
                        return err
                }
                if err := hs.sendFinished(nil); err != nil {
                        return err
                }
+               if _, err := c.flush(); err != nil {
+                       return err
+               }
        }
        c.handshakeComplete = true
 
@@ -430,6 +438,10 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                return err
        }
 
+       if _, err := c.flush(); err != nil {
+               return err
+       }
+
        var pub crypto.PublicKey // public key for client auth, if any
 
        msg, err := c.readHandshake()