This promotes a connection hang during TLS handshake to a proper error.
This doesn't fully address #14539 because the error reported in that
case is a write-on-socket-not-connected error, which implies that an
earlier error during connection setup is not being checked, but it is
an improvement over the current behaviour.
Updates #14539.
Change-Id: I0571a752d32d5303db48149ab448226868b19495
Reviewed-on: https://go-review.googlesource.com/19990
Reviewed-by: Adam Langley <agl@golang.org>
c.tmp[0] = alertLevelError
}
c.tmp[1] = byte(err)
- c.writeRecord(recordTypeAlert, c.tmp[0:2])
- // closeNotify is a special case in that it isn't an error:
- if err != alertCloseNotify {
- return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+
+ _, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2])
+ if err == alertCloseNotify {
+ // closeNotify is a special case in that it isn't an error.
+ return writeErr
}
- return nil
+
+ return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
// sendAlert sends a TLS alert message.
// writeRecord writes a TLS record with the given type and payload
// to the connection and updates the record layer state.
// c.out.Mutex <= L.
-func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
b := c.out.newBlock()
+ defer c.out.freeBlock(b)
+
+ var n int
for len(data) > 0 {
m := len(data)
if m > maxPlaintext {
if explicitIVIsSeq {
copy(explicitIV, c.out.seq[:])
} else {
- if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
- break
+ if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
+ return n, err
}
}
}
copy(b.data[recordHeaderLen+explicitIVLen:], data)
c.out.encrypt(b, explicitIVLen)
- _, err = c.conn.Write(b.data)
- if err != nil {
- break
+ if _, err := c.conn.Write(b.data); err != nil {
+ return n, err
}
n += m
data = data[m:]
}
- c.out.freeBlock(b)
if typ == recordTypeChangeCipherSpec {
- err = c.out.changeCipherSpec()
- if err != nil {
- // Cannot call sendAlert directly,
- // because we already hold c.out.Mutex.
- c.tmp[0] = alertLevelError
- c.tmp[1] = byte(err.(alert))
- c.writeRecord(recordTypeAlert, c.tmp[0:2])
- return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+ if err := c.out.changeCipherSpec(); err != nil {
+ return n, c.sendAlertLocked(err.(alert))
}
}
- return
+
+ return n, nil
}
// readHandshake reads the next handshake message from
}
}
- c.writeRecord(recordTypeHandshake, hello.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+ return err
+ }
msg, err := c.readHandshake()
if err != nil {
certMsg.certificates = chainToSend.Certificate
}
hs.finishedHash.Write(certMsg.marshal())
- c.writeRecord(recordTypeHandshake, certMsg.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ return err
+ }
}
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
}
if ckx != nil {
hs.finishedHash.Write(ckx.marshal())
- c.writeRecord(recordTypeHandshake, ckx.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
+ return err
+ }
}
if chainToSend != nil {
}
hs.finishedHash.Write(certVerify.marshal())
- c.writeRecord(recordTypeHandshake, certVerify.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
+ return err
+ }
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c
- c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+ if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ return err
+ }
if hs.serverHello.nextProtoNeg {
nextProto := new(nextProtoMsg)
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
c.clientProtocolFallback = fallback
hs.finishedHash.Write(nextProto.marshal())
- c.writeRecord(recordTypeHandshake, nextProto.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
+ return err
+ }
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
hs.finishedHash.Write(finished.marshal())
- c.writeRecord(recordTypeHandshake, finished.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ return err
+ }
copy(out, finished.verifyData)
return nil
}
"encoding/base64"
"encoding/binary"
"encoding/pem"
+ "errors"
"fmt"
"io"
"net"
t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
}
}
+
+// brokenConn wraps a net.Conn and causes all Writes after a certain number to
+// fail with brokenConnErr.
+type brokenConn struct {
+ net.Conn
+
+ // breakAfter is the number of successful writes that will be allowed
+ // before all subsequent writes fail.
+ breakAfter int
+
+ // numWrites is the number of writes that have been done.
+ numWrites int
+}
+
+// brokenConnErr is the error that brokenConn returns once exhausted.
+var brokenConnErr = errors.New("too many writes to brokenConn")
+
+func (b *brokenConn) Write(data []byte) (int, error) {
+ if b.numWrites >= b.breakAfter {
+ return 0, brokenConnErr
+ }
+
+ b.numWrites++
+ return b.Conn.Write(data)
+}
+
+func TestFailedWrite(t *testing.T) {
+ // Test that a write error during the handshake is returned.
+ for _, breakAfter := range []int{0, 1, 2, 3} {
+ c, s := net.Pipe()
+ done := make(chan bool)
+
+ go func() {
+ Server(s, testConfig).Handshake()
+ s.Close()
+ done <- true
+ }()
+
+ brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
+ err := Client(brokenC, testConfig).Handshake()
+ if err != brokenConnErr {
+ t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
+ }
+ brokenC.Close()
+
+ <-done
+ }
+}
hs.finishedHash.discardHandshakeBuffer()
hs.finishedHash.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.hello.marshal())
- c.writeRecord(recordTypeHandshake, hs.hello.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ return err
+ }
if len(hs.sessionState.certificates) > 0 {
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
}
hs.finishedHash.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.hello.marshal())
- c.writeRecord(recordTypeHandshake, hs.hello.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ return err
+ }
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
hs.finishedHash.Write(certMsg.marshal())
- c.writeRecord(recordTypeHandshake, certMsg.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ return err
+ }
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP
certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal())
- c.writeRecord(recordTypeHandshake, certStatus.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
+ return err
+ }
}
keyAgreement := hs.suite.ka(c.vers)
}
if skx != nil {
hs.finishedHash.Write(skx.marshal())
- c.writeRecord(recordTypeHandshake, skx.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
+ return err
+ }
}
if config.ClientAuth >= RequestClientCert {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
hs.finishedHash.Write(certReq.marshal())
- c.writeRecord(recordTypeHandshake, certReq.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+ return err
+ }
}
helloDone := new(serverHelloDoneMsg)
hs.finishedHash.Write(helloDone.marshal())
- c.writeRecord(recordTypeHandshake, helloDone.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
+ return err
+ }
var pub crypto.PublicKey // public key for client auth, if any
}
hs.finishedHash.Write(m.marshal())
- c.writeRecord(recordTypeHandshake, m.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+ return err
+ }
return nil
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
- c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+ if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ return err
+ }
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
hs.finishedHash.Write(finished.marshal())
- c.writeRecord(recordTypeHandshake, finished.marshal())
+ if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ return err
+ }
c.cipherSuite = hs.suite.id
copy(out, finished.verifyData)
cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close()
}()
- err := Server(s, serverConfig).Handshake()
+ hs := serverHandshakeState{
+ c: Server(s, serverConfig),
+ }
+ _, err := hs.readClientHello()
s.Close()
if len(expectedSubStr) == 0 {
if err != nil && err != io.EOF {