From 6b29f7bfbe9ca985ce2419285f0b56b5428e1ffe Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Wed, 12 Feb 2014 11:20:01 -0500 Subject: [PATCH] crypto/tls: better error messages. LGTM=bradfitz R=golang-codereviews, bradfitz CC=golang-codereviews https://golang.org/cl/60580046 --- src/pkg/crypto/tls/common.go | 5 +++ src/pkg/crypto/tls/conn.go | 23 ++++++++---- src/pkg/crypto/tls/handshake_client.go | 41 ++++++++++++++------- src/pkg/crypto/tls/handshake_server.go | 40 +++++++++++++------- src/pkg/crypto/tls/handshake_server_test.go | 15 ++++---- src/pkg/crypto/tls/key_agreement.go | 15 ++++---- 6 files changed, 90 insertions(+), 49 deletions(-) diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index 66a36de1e0..3382853ee6 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -9,6 +9,7 @@ import ( "crypto" "crypto/rand" "crypto/x509" + "fmt" "io" "math/big" "strings" @@ -540,3 +541,7 @@ func initDefaultCipherSuites() { varDefaultCipherSuites[i] = suite.id } } + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index 17421ee8e2..c33549c9ef 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -12,6 +12,7 @@ import ( "crypto/subtle" "crypto/x509" "errors" + "fmt" "io" "net" "sync" @@ -518,14 +519,17 @@ func (c *Conn) readRecord(want recordType) error { // else application data. (We don't support renegotiation.) switch want { default: - return c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return errors.New("tls: unknown record type requested") case recordTypeHandshake, recordTypeChangeCipherSpec: if c.handshakeComplete { - return c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete") } case recordTypeApplicationData: if !c.handshakeComplete { - return c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return errors.New("tls: application data record requested before handshake complete") } } @@ -562,10 +566,12 @@ Again: vers := uint16(b.data[1])<<8 | uint16(b.data[2]) n := int(b.data[3])<<8 | int(b.data[4]) if c.haveVers && vers != c.vers { - return c.sendAlert(alertProtocolVersion) + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers) } if n > maxCiphertext { - return c.sendAlert(alertRecordOverflow) + c.sendAlert(alertRecordOverflow) + return fmt.Errorf("tls: oversized record received with length %d", n) } if !c.haveVers { // First message, be extra suspicious: @@ -577,7 +583,8 @@ Again: // well under a kilobyte. If the length is >= 12 kB, // it's probably not real. if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: first record does not look like a TLS handshake") } } if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { @@ -990,10 +997,10 @@ func (c *Conn) VerifyHostname(host string) error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if !c.isClient { - return errors.New("VerifyHostname called on TLS server connection") + return errors.New("tls: VerifyHostname called on TLS server connection") } if !c.handshakeComplete { - return errors.New("TLS handshake has not yet been performed") + return errors.New("tls: handshake has not yet been performed") } return c.peerCertificates[0].VerifyHostname(host) } diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index dbbccfee46..fd1303eebb 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -12,6 +12,7 @@ import ( "crypto/x509" "encoding/asn1" "errors" + "fmt" "io" "net" "strconv" @@ -126,20 +127,23 @@ NextCipherSuite: } serverHello, ok := msg.(*serverHelloMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) } vers, ok := c.config.mutualVersion(serverHello.vers) if !ok || vers < VersionTLS10 { // TLS 1.0 is the minimum version supported as a client. - return c.sendAlert(alertProtocolVersion) + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) } c.vers = vers c.haveVers = true suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) if suite == nil { - return c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return fmt.Errorf("tls: server selected an unsupported cipher suite") } hs := &clientHandshakeState{ @@ -209,7 +213,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { } certMsg, ok := msg.(*certificateMsg) if !ok || len(certMsg.certificates) == 0 { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) } hs.finishedHash.Write(certMsg.marshal()) @@ -218,7 +223,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { cert, err := x509.ParseCertificate(asn1Data) if err != nil { c.sendAlert(alertBadCertificate) - return errors.New("failed to parse certificate from server: " + err.Error()) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) } certs[i] = cert } @@ -248,7 +253,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { case *rsa.PublicKey, *ecdsa.PublicKey: break default: - return c.sendAlert(alertUnsupportedCertificate) + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) } c.peerCertificates = certs @@ -260,7 +266,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { } cs, ok := msg.(*certificateStatusMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(cs, msg) } hs.finishedHash.Write(cs.marshal()) @@ -371,7 +378,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { shd, ok := msg.(*serverHelloDoneMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) } hs.finishedHash.Write(shd.marshal()) @@ -421,7 +429,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { err = errors.New("unknown private key type") } if err != nil { - return c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake with client certificate: " + err.Error()) } certVerify.signature = signed @@ -466,12 +475,13 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { c := hs.c if hs.serverHello.compressionMethod != compressionNone { - return false, c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") } if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg { c.sendAlert(alertHandshakeFailure) - return false, errors.New("server advertised unrequested NPN") + return false, errors.New("server advertised unrequested NPN extension") } if hs.serverResumedSession() { @@ -497,13 +507,15 @@ func (hs *clientHandshakeState) readFinished() error { } serverFinished, ok := msg.(*finishedMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) } verify := hs.finishedHash.serverSum(hs.masterSecret) if len(verify) != len(serverFinished.verifyData) || subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - return c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") } hs.finishedHash.Write(serverFinished.marshal()) return nil @@ -521,7 +533,8 @@ func (hs *clientHandshakeState) readSessionTicket() error { } sessionTicketMsg, ok := msg.(*newSessionTicketMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) } hs.finishedHash.Write(sessionTicketMsg.marshal()) diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index e441ccbcce..12e5ff1e58 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -12,6 +12,7 @@ import ( "crypto/x509" "encoding/asn1" "errors" + "fmt" "io" ) @@ -100,11 +101,13 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { var ok bool hs.clientHello, ok = msg.(*clientHelloMsg) if !ok { - return false, c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return false, unexpectedMessageError(hs.clientHello, msg) } c.vers, ok = config.mutualVersion(hs.clientHello.vers) if !ok { - return false, c.sendAlert(alertProtocolVersion) + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) } c.haveVers = true @@ -142,14 +145,16 @@ Curves: } if !foundCompression { - return false, c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: client does not support uncompressed connections") } hs.hello.vers = c.vers hs.hello.random = make([]byte, 32) _, err = io.ReadFull(config.rand(), hs.hello.random) if err != nil { - return false, c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return false, err } hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation hs.hello.compressionMethod = compressionNone @@ -166,7 +171,8 @@ Curves: } if len(config.Certificates) == 0 { - return false, c.sendAlert(alertInternalError) + c.sendAlert(alertInternalError) + return false, errors.New("tls: no certificates configured") } hs.cert = &config.Certificates[0] if len(hs.clientHello.serverName) > 0 { @@ -195,7 +201,8 @@ Curves: } if hs.suite == nil { - return false, c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: no cipher suite supported by both client and server") } return false, nil @@ -345,7 +352,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { // certificate message, even if it's empty. if config.ClientAuth >= RequestClientCert { if certMsg, ok = msg.(*certificateMsg); !ok { - return c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) } hs.finishedHash.Write(certMsg.marshal()) @@ -372,7 +380,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { // Get client key exchange ckx, ok := msg.(*clientKeyExchangeMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) } hs.finishedHash.Write(ckx.marshal()) @@ -389,7 +398,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { } certVerify, ok := msg.(*certificateVerifyMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) } switch key := pub.(type) { @@ -469,7 +479,8 @@ func (hs *serverHandshakeState) readFinished() error { } nextProto, ok := msg.(*nextProtoMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(nextProto, msg) } hs.finishedHash.Write(nextProto.marshal()) c.clientProtocol = nextProto.proto @@ -481,13 +492,15 @@ func (hs *serverHandshakeState) readFinished() error { } clientFinished, ok := msg.(*finishedMsg) if !ok { - return c.sendAlert(alertUnexpectedMessage) + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) } verify := hs.finishedHash.clientSum(hs.masterSecret) if len(verify) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - return c.sendAlert(alertHandshakeFailure) + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") } hs.finishedHash.Write(clientFinished.marshal()) @@ -590,7 +603,8 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c case *ecdsa.PublicKey, *rsa.PublicKey: pub = key default: - return nil, c.sendAlert(alertUnsupportedCertificate) + c.sendAlert(alertUnsupportedCertificate) + return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey) } c.peerCertificates = certs return pub, nil diff --git a/src/pkg/crypto/tls/handshake_server_test.go b/src/pkg/crypto/tls/handshake_server_test.go index a8cf462c70..4f41ab9b78 100644 --- a/src/pkg/crypto/tls/handshake_server_test.go +++ b/src/pkg/crypto/tls/handshake_server_test.go @@ -20,6 +20,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" "time" ) @@ -53,7 +54,7 @@ func init() { testConfig.BuildNameToCertificate() } -func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) { +func testClientHelloFailure(t *testing.T, m handshakeMessage, expectedSubStr string) { // Create in-memory network connection, // send message to server. Should return // expected error. @@ -68,20 +69,20 @@ func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) { }() err := Server(s, testConfig).Handshake() s.Close() - if e, ok := err.(*net.OpError); !ok || e.Err != expected { - t.Errorf("Got error: %s; expected: %s", err, expected) + if err == nil || !strings.Contains(err.Error(), expectedSubStr) { + t.Errorf("Got error: %s; expected to match substring '%s'", err, expectedSubStr) } } func TestSimpleError(t *testing.T) { - testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage) + testClientHelloFailure(t, &serverHelloDoneMsg{}, "unexpected handshake message") } var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205} func TestRejectBadProtocolVersion(t *testing.T) { for _, v := range badProtocolVersions { - testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion) + testClientHelloFailure(t, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version") } } @@ -91,7 +92,7 @@ func TestNoSuiteOverlap(t *testing.T) { cipherSuites: []uint16{0xff00}, compressionMethods: []uint8{0}, } - testClientHelloFailure(t, clientHello, alertHandshakeFailure) + testClientHelloFailure(t, clientHello, "no cipher suite supported by both client and server") } func TestNoCompressionOverlap(t *testing.T) { @@ -100,7 +101,7 @@ func TestNoCompressionOverlap(t *testing.T) { cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{0xff}, } - testClientHelloFailure(t, clientHello, alertHandshakeFailure) + testClientHelloFailure(t, clientHello, "client does not support uncompressed connections") } func TestTLS12OnlyCipherSuites(t *testing.T) { diff --git a/src/pkg/crypto/tls/key_agreement.go b/src/pkg/crypto/tls/key_agreement.go index 7e820c1e7e..861faf0e85 100644 --- a/src/pkg/crypto/tls/key_agreement.go +++ b/src/pkg/crypto/tls/key_agreement.go @@ -19,6 +19,9 @@ import ( "math/big" ) +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + // rsaKeyAgreement implements the standard TLS key agreement where the client // encrypts the pre-master secret to the server's public key. type rsaKeyAgreement struct{} @@ -35,14 +38,14 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi } if len(ckx.ciphertext) < 2 { - return nil, errors.New("bad ClientKeyExchange") + return nil, errClientKeyExchange } ciphertext := ckx.ciphertext if version != VersionSSL30 { ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) if ciphertextLen != len(ckx.ciphertext)-2 { - return nil, errors.New("bad ClientKeyExchange") + return nil, errClientKeyExchange } ciphertext = ckx.ciphertext[2:] } @@ -61,7 +64,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi } func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - return errors.New("unexpected ServerKeyExchange") + return errors.New("tls: unexpected ServerKeyExchange") } func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { @@ -271,11 +274,11 @@ Curve: func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { - return nil, errors.New("bad ClientKeyExchange") + return nil, errClientKeyExchange } x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:]) if x == nil { - return nil, errors.New("bad ClientKeyExchange") + return nil, errClientKeyExchange } x, _ = ka.curve.ScalarMult(x, y, ka.privateKey) preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3) @@ -285,8 +288,6 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert return preMasterSecret, nil } -var errServerKeyExchange = errors.New("invalid ServerKeyExchange") - func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { if len(skx.key) < 4 { return errServerKeyExchange -- 2.48.1