]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: check errors from (*Conn).writeRecord
authorTamir Duberstein <tamird@gmail.com>
Fri, 26 Feb 2016 19:17:29 +0000 (14:17 -0500)
committerAdam Langley <agl@golang.org>
Wed, 2 Mar 2016 18:20:46 +0000 (18:20 +0000)
This promotes a connection hang during TLS handshake to a proper error.
This doesn't fully address #14539 because the error reported in that
case is a write-on-socket-not-connected error, which implies that an
earlier error during connection setup is not being checked, but it is
an improvement over the current behaviour.

Updates #14539.

Change-Id: I0571a752d32d5303db48149ab448226868b19495
Reviewed-on: https://go-review.googlesource.com/19990
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/conn.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go

index 65b1d4b2e3aa5c215ca3b22a6eced522d16c2a75..89e4c2f74ae5908247df0e9202a22fb9017cf2e7 100644 (file)
@@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error {
                c.tmp[0] = alertLevelError
        }
        c.tmp[1] = byte(err)
-       c.writeRecord(recordTypeAlert, c.tmp[0:2])
-       // closeNotify is a special case in that it isn't an error:
-       if err != alertCloseNotify {
-               return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+
+       _, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2])
+       if err == alertCloseNotify {
+               // closeNotify is a special case in that it isn't an error.
+               return writeErr
        }
-       return nil
+
+       return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
 }
 
 // sendAlert sends a TLS alert message.
@@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error {
 // writeRecord writes a TLS record with the given type and payload
 // to the connection and updates the record layer state.
 // c.out.Mutex <= L.
-func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
        b := c.out.newBlock()
+       defer c.out.freeBlock(b)
+
+       var n int
        for len(data) > 0 {
                m := len(data)
                if m > maxPlaintext {
@@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
                        if explicitIVIsSeq {
                                copy(explicitIV, c.out.seq[:])
                        } else {
-                               if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
-                                       break
+                               if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
+                                       return n, err
                                }
                        }
                }
                copy(b.data[recordHeaderLen+explicitIVLen:], data)
                c.out.encrypt(b, explicitIVLen)
-               _, err = c.conn.Write(b.data)
-               if err != nil {
-                       break
+               if _, err := c.conn.Write(b.data); err != nil {
+                       return n, err
                }
                n += m
                data = data[m:]
        }
-       c.out.freeBlock(b)
 
        if typ == recordTypeChangeCipherSpec {
-               err = c.out.changeCipherSpec()
-               if err != nil {
-                       // Cannot call sendAlert directly,
-                       // because we already hold c.out.Mutex.
-                       c.tmp[0] = alertLevelError
-                       c.tmp[1] = byte(err.(alert))
-                       c.writeRecord(recordTypeAlert, c.tmp[0:2])
-                       return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+               if err := c.out.changeCipherSpec(); err != nil {
+                       return n, c.sendAlertLocked(err.(alert))
                }
        }
-       return
+
+       return n, nil
 }
 
 // readHandshake reads the next handshake message from
index b1299229265ef2b6c45fa7ac3b5d281108a17bd6..d38b061edd3f0d22d2db84df7fc7b3b4a5b82e22 100644 (file)
@@ -138,7 +138,9 @@ NextCipherSuite:
                }
        }
 
-       c.writeRecord(recordTypeHandshake, hello.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+               return err
+       }
 
        msg, err := c.readHandshake()
        if err != nil {
@@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                        certMsg.certificates = chainToSend.Certificate
                }
                hs.finishedHash.Write(certMsg.marshal())
-               c.writeRecord(recordTypeHandshake, certMsg.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+                       return err
+               }
        }
 
        preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
@@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        }
        if ckx != nil {
                hs.finishedHash.Write(ckx.marshal())
-               c.writeRecord(recordTypeHandshake, ckx.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
+                       return err
+               }
        }
 
        if chainToSend != nil {
@@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                }
 
                hs.finishedHash.Write(certVerify.marshal())
-               c.writeRecord(recordTypeHandshake, certVerify.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
+                       return err
+               }
        }
 
        hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
@@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error {
 func (hs *clientHandshakeState) sendFinished(out []byte) error {
        c := hs.c
 
-       c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+       if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+               return err
+       }
        if hs.serverHello.nextProtoNeg {
                nextProto := new(nextProtoMsg)
                proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
@@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
                c.clientProtocolFallback = fallback
 
                hs.finishedHash.Write(nextProto.marshal())
-               c.writeRecord(recordTypeHandshake, nextProto.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
+                       return err
+               }
        }
 
        finished := new(finishedMsg)
        finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
        hs.finishedHash.Write(finished.marshal())
-       c.writeRecord(recordTypeHandshake, finished.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+               return err
+       }
        copy(out, finished.verifyData)
        return nil
 }
index 9b6c4328a52a6c03a3c0962c7eb1fcfd62de598e..322c64e4614f9bd62a0230b0da540b7a7d314f00 100644 (file)
@@ -12,6 +12,7 @@ import (
        "encoding/base64"
        "encoding/binary"
        "encoding/pem"
+       "errors"
        "fmt"
        "io"
        "net"
@@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
                t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
        }
 }
+
+// brokenConn wraps a net.Conn and causes all Writes after a certain number to
+// fail with brokenConnErr.
+type brokenConn struct {
+       net.Conn
+
+       // breakAfter is the number of successful writes that will be allowed
+       // before all subsequent writes fail.
+       breakAfter int
+
+       // numWrites is the number of writes that have been done.
+       numWrites int
+}
+
+// brokenConnErr is the error that brokenConn returns once exhausted.
+var brokenConnErr = errors.New("too many writes to brokenConn")
+
+func (b *brokenConn) Write(data []byte) (int, error) {
+       if b.numWrites >= b.breakAfter {
+               return 0, brokenConnErr
+       }
+
+       b.numWrites++
+       return b.Conn.Write(data)
+}
+
+func TestFailedWrite(t *testing.T) {
+       // Test that a write error during the handshake is returned.
+       for _, breakAfter := range []int{0, 1, 2, 3} {
+               c, s := net.Pipe()
+               done := make(chan bool)
+
+               go func() {
+                       Server(s, testConfig).Handshake()
+                       s.Close()
+                       done <- true
+               }()
+
+               brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
+               err := Client(brokenC, testConfig).Handshake()
+               if err != brokenConnErr {
+                       t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
+               }
+               brokenC.Close()
+
+               <-done
+       }
+}
index dbab60b6bdbfb32588fde699f85053ec9351c12d..facc17d94e2795eb29079497d383779659505cca 100644 (file)
@@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
        hs.finishedHash.discardHandshakeBuffer()
        hs.finishedHash.Write(hs.clientHello.marshal())
        hs.finishedHash.Write(hs.hello.marshal())
-       c.writeRecord(recordTypeHandshake, hs.hello.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+               return err
+       }
 
        if len(hs.sessionState.certificates) > 0 {
                if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
@@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        }
        hs.finishedHash.Write(hs.clientHello.marshal())
        hs.finishedHash.Write(hs.hello.marshal())
-       c.writeRecord(recordTypeHandshake, hs.hello.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+               return err
+       }
 
        certMsg := new(certificateMsg)
        certMsg.certificates = hs.cert.Certificate
        hs.finishedHash.Write(certMsg.marshal())
-       c.writeRecord(recordTypeHandshake, certMsg.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+               return err
+       }
 
        if hs.hello.ocspStapling {
                certStatus := new(certificateStatusMsg)
                certStatus.statusType = statusTypeOCSP
                certStatus.response = hs.cert.OCSPStaple
                hs.finishedHash.Write(certStatus.marshal())
-               c.writeRecord(recordTypeHandshake, certStatus.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
+                       return err
+               }
        }
 
        keyAgreement := hs.suite.ka(c.vers)
@@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        }
        if skx != nil {
                hs.finishedHash.Write(skx.marshal())
-               c.writeRecord(recordTypeHandshake, skx.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
+                       return err
+               }
        }
 
        if config.ClientAuth >= RequestClientCert {
@@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                        certReq.certificateAuthorities = config.ClientCAs.Subjects()
                }
                hs.finishedHash.Write(certReq.marshal())
-               c.writeRecord(recordTypeHandshake, certReq.marshal())
+               if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+                       return err
+               }
        }
 
        helloDone := new(serverHelloDoneMsg)
        hs.finishedHash.Write(helloDone.marshal())
-       c.writeRecord(recordTypeHandshake, helloDone.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
+               return err
+       }
 
        var pub crypto.PublicKey // public key for client auth, if any
 
@@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
        }
 
        hs.finishedHash.Write(m.marshal())
-       c.writeRecord(recordTypeHandshake, m.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+               return err
+       }
 
        return nil
 }
@@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
 func (hs *serverHandshakeState) sendFinished(out []byte) error {
        c := hs.c
 
-       c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+       if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+               return err
+       }
 
        finished := new(finishedMsg)
        finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
        hs.finishedHash.Write(finished.marshal())
-       c.writeRecord(recordTypeHandshake, finished.marshal())
+       if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+               return err
+       }
 
        c.cipherSuite = hs.suite.id
        copy(out, finished.verifyData)
index f8de4e4551e34be441cd4bd467e87842ef6a918b..afadd62b367e5b2684d537b75bb635c9ea50bc0b 100644 (file)
@@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
                cli.writeRecord(recordTypeHandshake, m.marshal())
                c.Close()
        }()
-       err := Server(s, serverConfig).Handshake()
+       hs := serverHandshakeState{
+               c: Server(s, serverConfig),
+       }
+       _, err := hs.readClientHello()
        s.Close()
        if len(expectedSubStr) == 0 {
                if err != nil && err != io.EOF {