]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: split clientHandshake into multiple methods
authorSergey Frolov <sfrolov@google.com>
Tue, 30 May 2017 16:53:11 +0000 (12:53 -0400)
committerAdam Langley <agl@golang.org>
Wed, 9 Aug 2017 22:24:19 +0000 (22:24 +0000)
Change-Id: I23bfaa7e03a21aad4e85baa3bf52bb00c09b75d0
Reviewed-on: https://go-review.googlesource.com/44354
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/handshake_client.go

index a4ca5d34fb878cb03ee612e61ee936375ba2fa11..f8db66279f79b731929acce50051144af89bcebd 100644 (file)
@@ -29,51 +29,38 @@ type clientHandshakeState struct {
        session      *ClientSessionState
 }
 
-// c.out.Mutex <= L; c.handshakeMutex <= L.
-func (c *Conn) clientHandshake() error {
-       if c.config == nil {
-               c.config = defaultConfig()
-       }
-
-       // This may be a renegotiation handshake, in which case some fields
-       // need to be reset.
-       c.didResume = false
-
-       if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
-               return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
+func makeClientHello(config *Config) (*clientHelloMsg, error) {
+       if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
+               return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
        }
 
        nextProtosLength := 0
-       for _, proto := range c.config.NextProtos {
+       for _, proto := range config.NextProtos {
                if l := len(proto); l == 0 || l > 255 {
-                       return errors.New("tls: invalid NextProtos value")
+                       return nil, errors.New("tls: invalid NextProtos value")
                } else {
                        nextProtosLength += 1 + l
                }
        }
+
        if nextProtosLength > 0xffff {
-               return errors.New("tls: NextProtos values too large")
+               return nil, errors.New("tls: NextProtos values too large")
        }
 
        hello := &clientHelloMsg{
-               vers:                         c.config.maxVersion(),
+               vers:                         config.maxVersion(),
                compressionMethods:           []uint8{compressionNone},
                random:                       make([]byte, 32),
                ocspStapling:                 true,
                scts:                         true,
-               serverName:                   hostnameInSNI(c.config.ServerName),
-               supportedCurves:              c.config.curvePreferences(),
+               serverName:                   hostnameInSNI(config.ServerName),
+               supportedCurves:              config.curvePreferences(),
                supportedPoints:              []uint8{pointFormatUncompressed},
-               nextProtoNeg:                 len(c.config.NextProtos) > 0,
+               nextProtoNeg:                 len(config.NextProtos) > 0,
                secureRenegotiationSupported: true,
-               alpnProtocols:                c.config.NextProtos,
+               alpnProtocols:                config.NextProtos,
        }
-
-       if c.handshakes > 0 {
-               hello.secureRenegotiation = c.clientFinished[:]
-       }
-
-       possibleCipherSuites := c.config.cipherSuites()
+       possibleCipherSuites := config.cipherSuites()
        hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
 
 NextCipherSuite:
@@ -92,16 +79,37 @@ NextCipherSuite:
                }
        }
 
-       _, err := io.ReadFull(c.config.rand(), hello.random)
+       _, err := io.ReadFull(config.rand(), hello.random)
        if err != nil {
-               c.sendAlert(alertInternalError)
-               return errors.New("tls: short read from Rand: " + err.Error())
+               return nil, errors.New("tls: short read from Rand: " + err.Error())
        }
 
        if hello.vers >= VersionTLS12 {
                hello.signatureAndHashes = supportedSignatureAlgorithms
        }
 
+       return hello, nil
+}
+
+// c.out.Mutex <= L; c.handshakeMutex <= L.
+func (c *Conn) clientHandshake() error {
+       if c.config == nil {
+               c.config = defaultConfig()
+       }
+
+       // This may be a renegotiation handshake, in which case some fields
+       // need to be reset.
+       c.didResume = false
+
+       hello, err := makeClientHello(c.config)
+       if err != nil {
+               return err
+       }
+
+       if c.handshakes > 0 {
+               hello.secureRenegotiation = c.clientFinished[:]
+       }
+
        var session *ClientSessionState
        var cacheKey string
        sessionCache := c.config.ClientSessionCache
@@ -147,12 +155,36 @@ NextCipherSuite:
                // (see RFC 5077).
                hello.sessionId = make([]byte, 16)
                if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil {
-                       c.sendAlert(alertInternalError)
                        return errors.New("tls: short read from Rand: " + err.Error())
                }
        }
 
-       if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+       hs := &clientHandshakeState{
+               c:       c,
+               hello:   hello,
+               session: session,
+       }
+
+       if err = hs.handshake(); err != nil {
+               return err
+       }
+
+       // If we had a successful handshake and hs.session is different from
+       // the one already cached - cache a new one
+       if sessionCache != nil && hs.session != nil && session != hs.session {
+               sessionCache.Put(cacheKey, hs.session)
+       }
+
+       return nil
+}
+
+// Does the handshake, either a full one or resumes old session.
+// Requires hs.c, hs.hello, and, optionally, hs.session to be set.
+func (hs *clientHandshakeState) handshake() error {
+       c := hs.c
+
+       // send ClientHello
+       if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
                return err
        }
 
@@ -160,34 +192,19 @@ NextCipherSuite:
        if err != nil {
                return err
        }
-       serverHello, ok := msg.(*serverHelloMsg)
-       if !ok {
-               c.sendAlert(alertUnexpectedMessage)
-               return unexpectedMessageError(serverHello, msg)
-       }
 
-       vers, ok := c.config.mutualVersion(serverHello.vers)
-       if !ok || vers < VersionTLS10 {
-               // TLS 1.0 is the minimum version supported as a client.
-               c.sendAlert(alertProtocolVersion)
-               return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
+       var ok bool
+       if hs.serverHello, ok = msg.(*serverHelloMsg); !ok {
+               c.sendAlert(alertUnexpectedMessage)
+               return unexpectedMessageError(hs.serverHello, msg)
        }
-       c.vers = vers
-       c.haveVers = true
 
-       suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite)
-       if suite == nil {
-               c.sendAlert(alertHandshakeFailure)
-               return errors.New("tls: server chose an unconfigured cipher suite")
+       if err = hs.pickTLSVersion(); err != nil {
+               return err
        }
 
-       hs := &clientHandshakeState{
-               c:            c,
-               serverHello:  serverHello,
-               hello:        hello,
-               suite:        suite,
-               finishedHash: newFinishedHash(c.vers, suite),
-               session:      session,
+       if err = hs.pickCipherSuite(); err != nil {
+               return err
        }
 
        isResume, err := hs.processServerHello()
@@ -195,6 +212,8 @@ NextCipherSuite:
                return err
        }
 
+       hs.finishedHash = newFinishedHash(c.vers, hs.suite)
+
        // No signatures of the handshake are needed in a resumption.
        // Otherwise, in a full handshake, if we don't have any certificates
        // configured then we will never send a CertificateVerify message and
@@ -246,13 +265,33 @@ NextCipherSuite:
                }
        }
 
-       if sessionCache != nil && hs.session != nil && session != hs.session {
-               sessionCache.Put(cacheKey, hs.session)
-       }
-
        c.didResume = isResume
        c.handshakeComplete = true
-       c.cipherSuite = suite.id
+
+       return nil
+}
+
+func (hs *clientHandshakeState) pickTLSVersion() error {
+       vers, ok := hs.c.config.mutualVersion(hs.serverHello.vers)
+       if !ok || vers < VersionTLS10 {
+               // TLS 1.0 is the minimum version supported as a client.
+               hs.c.sendAlert(alertProtocolVersion)
+               return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers)
+       }
+
+       hs.c.vers = vers
+       hs.c.haveVers = true
+
+       return nil
+}
+
+func (hs *clientHandshakeState) pickCipherSuite() error {
+       if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil {
+               hs.c.sendAlert(alertHandshakeFailure)
+               return errors.New("tls: server chose an unconfigured cipher suite")
+       }
+
+       hs.c.cipherSuite = hs.suite.id
        return nil
 }