]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: better error messages.
authorAdam Langley <agl@golang.org>
Wed, 12 Feb 2014 16:20:01 +0000 (11:20 -0500)
committerAdam Langley <agl@golang.org>
Wed, 12 Feb 2014 16:20:01 +0000 (11:20 -0500)
LGTM=bradfitz
R=golang-codereviews, bradfitz
CC=golang-codereviews
https://golang.org/cl/60580046

src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/conn.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/handshake_server_test.go
src/pkg/crypto/tls/key_agreement.go

index 66a36de1e0eff234646fe12c1c504b03b750e734..3382853ee660a37fd43db1dbf273d7205497a6a2 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto"
        "crypto/rand"
        "crypto/x509"
+       "fmt"
        "io"
        "math/big"
        "strings"
@@ -540,3 +541,7 @@ func initDefaultCipherSuites() {
                varDefaultCipherSuites[i] = suite.id
        }
 }
+
+func unexpectedMessageError(wanted, got interface{}) error {
+       return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
+}
index 17421ee8e29f928bceca7c881cfb236c0f451860..c33549c9efffa3ee83d17b6f6d21e0eaef213483 100644 (file)
@@ -12,6 +12,7 @@ import (
        "crypto/subtle"
        "crypto/x509"
        "errors"
+       "fmt"
        "io"
        "net"
        "sync"
@@ -518,14 +519,17 @@ func (c *Conn) readRecord(want recordType) error {
        // else application data.  (We don't support renegotiation.)
        switch want {
        default:
-               return c.sendAlert(alertInternalError)
+               c.sendAlert(alertInternalError)
+               return errors.New("tls: unknown record type requested")
        case recordTypeHandshake, recordTypeChangeCipherSpec:
                if c.handshakeComplete {
-                       return c.sendAlert(alertInternalError)
+                       c.sendAlert(alertInternalError)
+                       return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
                }
        case recordTypeApplicationData:
                if !c.handshakeComplete {
-                       return c.sendAlert(alertInternalError)
+                       c.sendAlert(alertInternalError)
+                       return errors.New("tls: application data record requested before handshake complete")
                }
        }
 
@@ -562,10 +566,12 @@ Again:
        vers := uint16(b.data[1])<<8 | uint16(b.data[2])
        n := int(b.data[3])<<8 | int(b.data[4])
        if c.haveVers && vers != c.vers {
-               return c.sendAlert(alertProtocolVersion)
+               c.sendAlert(alertProtocolVersion)
+               return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
        }
        if n > maxCiphertext {
-               return c.sendAlert(alertRecordOverflow)
+               c.sendAlert(alertRecordOverflow)
+               return fmt.Errorf("tls: oversized record received with length %d", n)
        }
        if !c.haveVers {
                // First message, be extra suspicious:
@@ -577,7 +583,8 @@ Again:
                // well under a kilobyte.  If the length is >= 12 kB,
                // it's probably not real.
                if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
-                       return c.sendAlert(alertUnexpectedMessage)
+                       c.sendAlert(alertUnexpectedMessage)
+                       return fmt.Errorf("tls: first record does not look like a TLS handshake")
                }
        }
        if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -990,10 +997,10 @@ func (c *Conn) VerifyHostname(host string) error {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
        if !c.isClient {
-               return errors.New("VerifyHostname called on TLS server connection")
+               return errors.New("tls: VerifyHostname called on TLS server connection")
        }
        if !c.handshakeComplete {
-               return errors.New("TLS handshake has not yet been performed")
+               return errors.New("tls: handshake has not yet been performed")
        }
        return c.peerCertificates[0].VerifyHostname(host)
 }
index dbbccfee46de4f691cf927a6b05253830a2821f1..fd1303eebb947ff7d3a35ad46530c44f6460a05c 100644 (file)
@@ -12,6 +12,7 @@ import (
        "crypto/x509"
        "encoding/asn1"
        "errors"
+       "fmt"
        "io"
        "net"
        "strconv"
@@ -126,20 +127,23 @@ NextCipherSuite:
        }
        serverHello, ok := msg.(*serverHelloMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(serverHello, msg)
        }
 
        vers, ok := c.config.mutualVersion(serverHello.vers)
        if !ok || vers < VersionTLS10 {
                // TLS 1.0 is the minimum version supported as a client.
-               return c.sendAlert(alertProtocolVersion)
+               c.sendAlert(alertProtocolVersion)
+               return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
        }
        c.vers = vers
        c.haveVers = true
 
        suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
        if suite == nil {
-               return c.sendAlert(alertHandshakeFailure)
+               c.sendAlert(alertHandshakeFailure)
+               return fmt.Errorf("tls: server selected an unsupported cipher suite")
        }
 
        hs := &clientHandshakeState{
@@ -209,7 +213,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        }
        certMsg, ok := msg.(*certificateMsg)
        if !ok || len(certMsg.certificates) == 0 {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(certMsg, msg)
        }
        hs.finishedHash.Write(certMsg.marshal())
 
@@ -218,7 +223,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                cert, err := x509.ParseCertificate(asn1Data)
                if err != nil {
                        c.sendAlert(alertBadCertificate)
-                       return errors.New("failed to parse certificate from server: " + err.Error())
+                       return errors.New("tls: failed to parse certificate from server: " + err.Error())
                }
                certs[i] = cert
        }
@@ -248,7 +253,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
        case *rsa.PublicKey, *ecdsa.PublicKey:
                break
        default:
-               return c.sendAlert(alertUnsupportedCertificate)
+               c.sendAlert(alertUnsupportedCertificate)
+               return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
        }
 
        c.peerCertificates = certs
@@ -260,7 +266,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                }
                cs, ok := msg.(*certificateStatusMsg)
                if !ok {
-                       return c.sendAlert(alertUnexpectedMessage)
+                       c.sendAlert(alertUnexpectedMessage)
+                       return unexpectedMessageError(cs, msg)
                }
                hs.finishedHash.Write(cs.marshal())
 
@@ -371,7 +378,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
 
        shd, ok := msg.(*serverHelloDoneMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(shd, msg)
        }
        hs.finishedHash.Write(shd.marshal())
 
@@ -421,7 +429,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
                        err = errors.New("unknown private key type")
                }
                if err != nil {
-                       return c.sendAlert(alertInternalError)
+                       c.sendAlert(alertInternalError)
+                       return errors.New("tls: failed to sign handshake with client certificate: " + err.Error())
                }
                certVerify.signature = signed
 
@@ -466,12 +475,13 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
        c := hs.c
 
        if hs.serverHello.compressionMethod != compressionNone {
-               return false, c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return false, errors.New("tls: server selected unsupported compression format")
        }
 
        if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
                c.sendAlert(alertHandshakeFailure)
-               return false, errors.New("server advertised unrequested NPN")
+               return false, errors.New("server advertised unrequested NPN extension")
        }
 
        if hs.serverResumedSession() {
@@ -497,13 +507,15 @@ func (hs *clientHandshakeState) readFinished() error {
        }
        serverFinished, ok := msg.(*finishedMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(serverFinished, msg)
        }
 
        verify := hs.finishedHash.serverSum(hs.masterSecret)
        if len(verify) != len(serverFinished.verifyData) ||
                subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
-               return c.sendAlert(alertHandshakeFailure)
+               c.sendAlert(alertHandshakeFailure)
+               return errors.New("tls: server's Finished message was incorrect")
        }
        hs.finishedHash.Write(serverFinished.marshal())
        return nil
@@ -521,7 +533,8 @@ func (hs *clientHandshakeState) readSessionTicket() error {
        }
        sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(sessionTicketMsg, msg)
        }
        hs.finishedHash.Write(sessionTicketMsg.marshal())
 
index e441ccbcce57746b2d060aa7ee6775169e5a659e..12e5ff1e5895b9052b2751e043a51c8db7e39a78 100644 (file)
@@ -12,6 +12,7 @@ import (
        "crypto/x509"
        "encoding/asn1"
        "errors"
+       "fmt"
        "io"
 )
 
@@ -100,11 +101,13 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
        var ok bool
        hs.clientHello, ok = msg.(*clientHelloMsg)
        if !ok {
-               return false, c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return false, unexpectedMessageError(hs.clientHello, msg)
        }
        c.vers, ok = config.mutualVersion(hs.clientHello.vers)
        if !ok {
-               return false, c.sendAlert(alertProtocolVersion)
+               c.sendAlert(alertProtocolVersion)
+               return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
        }
        c.haveVers = true
 
@@ -142,14 +145,16 @@ Curves:
        }
 
        if !foundCompression {
-               return false, c.sendAlert(alertHandshakeFailure)
+               c.sendAlert(alertHandshakeFailure)
+               return false, errors.New("tls: client does not support uncompressed connections")
        }
 
        hs.hello.vers = c.vers
        hs.hello.random = make([]byte, 32)
        _, err = io.ReadFull(config.rand(), hs.hello.random)
        if err != nil {
-               return false, c.sendAlert(alertInternalError)
+               c.sendAlert(alertInternalError)
+               return false, err
        }
        hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation
        hs.hello.compressionMethod = compressionNone
@@ -166,7 +171,8 @@ Curves:
        }
 
        if len(config.Certificates) == 0 {
-               return false, c.sendAlert(alertInternalError)
+               c.sendAlert(alertInternalError)
+               return false, errors.New("tls: no certificates configured")
        }
        hs.cert = &config.Certificates[0]
        if len(hs.clientHello.serverName) > 0 {
@@ -195,7 +201,8 @@ Curves:
        }
 
        if hs.suite == nil {
-               return false, c.sendAlert(alertHandshakeFailure)
+               c.sendAlert(alertHandshakeFailure)
+               return false, errors.New("tls: no cipher suite supported by both client and server")
        }
 
        return false, nil
@@ -345,7 +352,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        // certificate message, even if it's empty.
        if config.ClientAuth >= RequestClientCert {
                if certMsg, ok = msg.(*certificateMsg); !ok {
-                       return c.sendAlert(alertHandshakeFailure)
+                       c.sendAlert(alertUnexpectedMessage)
+                       return unexpectedMessageError(certMsg, msg)
                }
                hs.finishedHash.Write(certMsg.marshal())
 
@@ -372,7 +380,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
        // Get client key exchange
        ckx, ok := msg.(*clientKeyExchangeMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(ckx, msg)
        }
        hs.finishedHash.Write(ckx.marshal())
 
@@ -389,7 +398,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
                }
                certVerify, ok := msg.(*certificateVerifyMsg)
                if !ok {
-                       return c.sendAlert(alertUnexpectedMessage)
+                       c.sendAlert(alertUnexpectedMessage)
+                       return unexpectedMessageError(certVerify, msg)
                }
 
                switch key := pub.(type) {
@@ -469,7 +479,8 @@ func (hs *serverHandshakeState) readFinished() error {
                }
                nextProto, ok := msg.(*nextProtoMsg)
                if !ok {
-                       return c.sendAlert(alertUnexpectedMessage)
+                       c.sendAlert(alertUnexpectedMessage)
+                       return unexpectedMessageError(nextProto, msg)
                }
                hs.finishedHash.Write(nextProto.marshal())
                c.clientProtocol = nextProto.proto
@@ -481,13 +492,15 @@ func (hs *serverHandshakeState) readFinished() error {
        }
        clientFinished, ok := msg.(*finishedMsg)
        if !ok {
-               return c.sendAlert(alertUnexpectedMessage)
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(clientFinished, msg)
        }
 
        verify := hs.finishedHash.clientSum(hs.masterSecret)
        if len(verify) != len(clientFinished.verifyData) ||
                subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
-               return c.sendAlert(alertHandshakeFailure)
+               c.sendAlert(alertHandshakeFailure)
+               return errors.New("tls: client's Finished message is incorrect")
        }
 
        hs.finishedHash.Write(clientFinished.marshal())
@@ -590,7 +603,8 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c
                case *ecdsa.PublicKey, *rsa.PublicKey:
                        pub = key
                default:
-                       return nil, c.sendAlert(alertUnsupportedCertificate)
+                       c.sendAlert(alertUnsupportedCertificate)
+                       return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey)
                }
                c.peerCertificates = certs
                return pub, nil
index a8cf462c70ee3d3312d8506cb6c0c3e6602ea989..4f41ab9b78c2866aa4c0495413a28e793fbffc2a 100644 (file)
@@ -20,6 +20,7 @@ import (
        "os"
        "os/exec"
        "path/filepath"
+       "strings"
        "testing"
        "time"
 )
@@ -53,7 +54,7 @@ func init() {
        testConfig.BuildNameToCertificate()
 }
 
-func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) {
+func testClientHelloFailure(t *testing.T, m handshakeMessage, expectedSubStr string) {
        // Create in-memory network connection,
        // send message to server.  Should return
        // expected error.
@@ -68,20 +69,20 @@ func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) {
        }()
        err := Server(s, testConfig).Handshake()
        s.Close()
-       if e, ok := err.(*net.OpError); !ok || e.Err != expected {
-               t.Errorf("Got error: %s; expected: %s", err, expected)
+       if err == nil || !strings.Contains(err.Error(), expectedSubStr) {
+               t.Errorf("Got error: %s; expected to match substring '%s'", err, expectedSubStr)
        }
 }
 
 func TestSimpleError(t *testing.T) {
-       testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage)
+       testClientHelloFailure(t, &serverHelloDoneMsg{}, "unexpected handshake message")
 }
 
 var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205}
 
 func TestRejectBadProtocolVersion(t *testing.T) {
        for _, v := range badProtocolVersions {
-               testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion)
+               testClientHelloFailure(t, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version")
        }
 }
 
@@ -91,7 +92,7 @@ func TestNoSuiteOverlap(t *testing.T) {
                cipherSuites:       []uint16{0xff00},
                compressionMethods: []uint8{0},
        }
-       testClientHelloFailure(t, clientHello, alertHandshakeFailure)
+       testClientHelloFailure(t, clientHello, "no cipher suite supported by both client and server")
 }
 
 func TestNoCompressionOverlap(t *testing.T) {
@@ -100,7 +101,7 @@ func TestNoCompressionOverlap(t *testing.T) {
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{0xff},
        }
-       testClientHelloFailure(t, clientHello, alertHandshakeFailure)
+       testClientHelloFailure(t, clientHello, "client does not support uncompressed connections")
 }
 
 func TestTLS12OnlyCipherSuites(t *testing.T) {
index 7e820c1e7e91c978e41f489fa04ddfae80899c24..861faf0e85a0495c058de1f543986ad4dd04121f 100644 (file)
@@ -19,6 +19,9 @@ import (
        "math/big"
 )
 
+var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
+var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
+
 // rsaKeyAgreement implements the standard TLS key agreement where the client
 // encrypts the pre-master secret to the server's public key.
 type rsaKeyAgreement struct{}
@@ -35,14 +38,14 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
        }
 
        if len(ckx.ciphertext) < 2 {
-               return nil, errors.New("bad ClientKeyExchange")
+               return nil, errClientKeyExchange
        }
 
        ciphertext := ckx.ciphertext
        if version != VersionSSL30 {
                ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
                if ciphertextLen != len(ckx.ciphertext)-2 {
-                       return nil, errors.New("bad ClientKeyExchange")
+                       return nil, errClientKeyExchange
                }
                ciphertext = ckx.ciphertext[2:]
        }
@@ -61,7 +64,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
 }
 
 func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
-       return errors.New("unexpected ServerKeyExchange")
+       return errors.New("tls: unexpected ServerKeyExchange")
 }
 
 func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
@@ -271,11 +274,11 @@ Curve:
 
 func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
        if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
-               return nil, errors.New("bad ClientKeyExchange")
+               return nil, errClientKeyExchange
        }
        x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:])
        if x == nil {
-               return nil, errors.New("bad ClientKeyExchange")
+               return nil, errClientKeyExchange
        }
        x, _ = ka.curve.ScalarMult(x, y, ka.privateKey)
        preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3)
@@ -285,8 +288,6 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert
        return preMasterSecret, nil
 }
 
-var errServerKeyExchange = errors.New("invalid ServerKeyExchange")
-
 func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
        if len(skx.key) < 4 {
                return errServerKeyExchange