go Client(c, clientConfig).Handshake()
srv := Server(s, testConfig)
- msg, err := srv.readHandshake()
+ msg, err := srv.readHandshake(nil)
if err != nil {
t.Fatal(err)
}
}
type handshakeMessage interface {
- marshal() []byte
+ marshal() ([]byte, error)
unmarshal([]byte) bool
}
return n, nil
}
-// writeRecord writes a TLS record with the given type and payload to the
-// connection and updates the record layer state.
-func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
+// writeHandshakeRecord writes a handshake message to the connection and updates
+// the record layer state. If transcript is non-nil the marshalled message is
+// written to it.
+func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
- return c.writeRecordLocked(typ, data)
+ data, err := msg.marshal()
+ if err != nil {
+ return 0, err
+ }
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
+ return c.writeRecordLocked(recordTypeHandshake, data)
+}
+
+// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
+// updates the record layer state.
+func (c *Conn) writeChangeCipherRecord() error {
+ c.out.Lock()
+ defer c.out.Unlock()
+ _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
+ return err
}
// readHandshake reads the next handshake message from
-// the record layer.
-func (c *Conn) readHandshake() (any, error) {
+// the record layer. If transcript is non-nil, the message
+// is written to the passed transcriptHash.
+func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
+
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
return m, nil
}
return errors.New("tls: internal error: unexpected renegotiation")
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
return c.handleRenegotiation()
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
defer c.out.Unlock()
msg := &keyUpdateMsg{}
- _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
+ msgBytes, err := msg.marshal()
+ if err != nil {
+ return err
+ }
+ _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.out.setErrorLocked(err)
}
c.serverName = hello.serverName
- cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
+ cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
+ if err != nil {
+ return err
+ }
if cacheKey != "" && session != nil {
defer func() {
// If we got a handshake failure when resuming a session, throw away
}()
}
- if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
- msg, err := c.readHandshake()
+ // serverHelloMsg is not included in the transcript
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
}
func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
- session *ClientSessionState, earlySecret, binderKey []byte) {
+ session *ClientSessionState, earlySecret, binderKey []byte, err error) {
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
- return "", nil, nil, nil
+ return "", nil, nil, nil, nil
}
hello.ticketSupported = true
// renegotiation is primarily used to allow a client to send a client
// certificate, which would be skipped if session resumption occurred.
if c.handshakes != 0 {
- return "", nil, nil, nil
+ return "", nil, nil, nil, nil
}
// Try to resume a previously negotiated TLS session, if available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
session, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || session == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Check that version used for the previous session is still valid.
}
}
if !versOk {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Check that the cached server certificate is not expired, and that it's
if !c.config.InsecureSkipVerify {
if len(session.verifiedChains) == 0 {
// The original connection had InsecureSkipVerify, while this doesn't.
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
serverCert := session.serverCertificates[0]
if c.config.time().After(serverCert.NotAfter) {
// Expired certificate, delete the entry.
c.config.ClientSessionCache.Put(cacheKey, nil)
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
if err := serverCert.VerifyHostname(c.config.ServerName); err != nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
}
// In TLS 1.2 the cipher suite must match the resumed session. Ensure we
// are still offering it.
if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
hello.sessionTicket = session.sessionTicket
// Check that the session ticket is not expired.
if c.config.time().After(session.useBy) {
c.config.ClientSessionCache.Put(cacheKey, nil)
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// In TLS 1.3 the KDF hash must match the resumed session. Ensure we
// offer at least one cipher suite with that hash.
cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite)
if cipherSuite == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
cipherSuiteOk := false
for _, offeredID := range hello.cipherSuites {
}
}
if !cipherSuiteOk {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
earlySecret = cipherSuite.extract(psk, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
transcript := cipherSuite.hash.New()
- transcript.Write(hello.marshalWithoutBinders())
+ helloBytes, err := hello.marshalWithoutBinders()
+ if err != nil {
+ return "", nil, nil, nil, err
+ }
+ transcript.Write(helloBytes)
pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)}
- hello.updateBinders(pskBinders)
+ if err := hello.updateBinders(pskBinders); err != nil {
+ return "", nil, nil, nil, err
+ }
return
}
hs.finishedHash.discardHandshakeBuffer()
}
- hs.finishedHash.Write(hs.hello.marshal())
- hs.finishedHash.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil {
+ return err
+ }
c.buffering = true
c.didResume = isResume
func (hs *clientHandshakeState) doFullHandshake() error {
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received unexpected CertificateStatus message")
}
- hs.finishedHash.Write(cs.marshal())
c.ocspResponse = cs.response
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
- hs.finishedHash.Write(skx.marshal())
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
if err != nil {
c.sendAlert(alertUnexpectedMessage)
return err
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsg)
if ok {
certRequested = true
- hs.finishedHash.Write(certReq.marshal())
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
return err
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(shd, msg)
}
- hs.finishedHash.Write(shd.marshal())
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
if certRequested {
certMsg = new(certificateMsg)
certMsg.certificates = chainToSend.Certificate
- hs.finishedHash.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
}
return err
}
if ckx != nil {
- hs.finishedHash.Write(ckx.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil {
return err
}
}
return err
}
- hs.finishedHash.Write(certVerify.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil {
return err
}
}
return err
}
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server's Finished message was incorrect")
}
- hs.finishedHash.Write(serverFinished.marshal())
+
+ if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil {
+ return err
+ }
+
copy(out, verify)
return nil
}
}
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}
- hs.finishedHash.Write(sessionTicketMsg.marshal())
hs.session = &ClientSessionState{
sessionTicket: sessionTicketMsg.ticket,
func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c
- if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
- hs.finishedHash.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256,
alpnProtocol: "how-about-this",
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
s.Write([]byte{
byte(recordTypeHandshake),
random: make([]byte, 32),
cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
s.Write([]byte{
byte(recordTypeHandshake),
}
hs.transcript = hs.suite.hash.New()
- hs.transcript.Write(hs.hello.marshal())
+
+ if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
+ return err
+ }
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
}
}
- hs.transcript.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
c.buffering = true
if err := hs.processServerHello(); err != nil {
}
hs.sentDummyCCS = true
- _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
- return err
+ return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
- hs.transcript.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
- transcript.Write(hs.serverHello.marshal())
- transcript.Write(hs.hello.marshalWithoutBinders())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
+ helloBytes, err := hs.hello.marshalWithoutBinders()
+ if err != nil {
+ return err
+ }
+ transcript.Write(helloBytes)
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
- hs.hello.updateBinders(pskBinders)
+ if err := hs.hello.updateBinders(pskBinders); err != nil {
+ return err
+ }
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
}
}
- hs.transcript.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
- msg, err := c.readHandshake()
+ // serverHelloMsg is not included in the transcript
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
+
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
- hs.transcript.Write(encryptedExtensions.marshal())
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
return nil
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
- hs.transcript.Write(certReq.marshal())
-
hs.certReq = certReq
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(hs.transcript)
if err != nil {
return err
}
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
- hs.transcript.Write(certMsg.marshal())
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
return err
}
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
- hs.transcript.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, hs.transcript); err != nil {
+ return err
+ }
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
return errors.New("tls: invalid server finished hash")
}
- hs.transcript.Write(finished.marshal())
+ if err := transcriptMsg(finished, hs.transcript); err != nil {
+ return err
+ }
// Derive secrets that take context through the server Finished.
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
- hs.transcript.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
}
certVerifyMsg.signature = sig
- hs.transcript.Write(certVerifyMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
- hs.transcript.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
package tls
import (
+ "errors"
"fmt"
"strings"
pskBinders [][]byte
}
-func (m *clientHelloMsg) marshal() []byte {
+func (m *clientHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
+ }
+
+ var exts cryptobyte.Builder
+ if len(m.serverName) > 0 {
+ // RFC 6066, Section 3
+ exts.AddUint16(extensionServerName)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8(0) // name_type = host_name
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(m.serverName))
+ })
+ })
+ })
+ }
+ if m.ocspStapling {
+ // RFC 4366, Section 3.6
+ exts.AddUint16(extensionStatusRequest)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8(1) // status_type = ocsp
+ exts.AddUint16(0) // empty responder_id_list
+ exts.AddUint16(0) // empty request_extensions
+ })
+ }
+ if len(m.supportedCurves) > 0 {
+ // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
+ exts.AddUint16(extensionSupportedCurves)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, curve := range m.supportedCurves {
+ exts.AddUint16(uint16(curve))
+ }
+ })
+ })
+ }
+ if len(m.supportedPoints) > 0 {
+ // RFC 4492, Section 5.1.2
+ exts.AddUint16(extensionSupportedPoints)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.supportedPoints)
+ })
+ })
+ }
+ if m.ticketSupported {
+ // RFC 5077, Section 3.2
+ exts.AddUint16(extensionSessionTicket)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.sessionTicket)
+ })
+ }
+ if len(m.supportedSignatureAlgorithms) > 0 {
+ // RFC 5246, Section 7.4.1.4.1
+ exts.AddUint16(extensionSignatureAlgorithms)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sigAlgo := range m.supportedSignatureAlgorithms {
+ exts.AddUint16(uint16(sigAlgo))
+ }
+ })
+ })
+ }
+ if len(m.supportedSignatureAlgorithmsCert) > 0 {
+ // RFC 8446, Section 4.2.3
+ exts.AddUint16(extensionSignatureAlgorithmsCert)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
+ exts.AddUint16(uint16(sigAlgo))
+ }
+ })
+ })
+ }
+ if m.secureRenegotiationSupported {
+ // RFC 5746, Section 3.2
+ exts.AddUint16(extensionRenegotiationInfo)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocols) > 0 {
+ // RFC 7301, Section 3.1
+ exts.AddUint16(extensionALPN)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, proto := range m.alpnProtocols {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(proto))
+ })
+ }
+ })
+ })
+ }
+ if m.scts {
+ // RFC 6962, Section 3.3.1
+ exts.AddUint16(extensionSCT)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if len(m.supportedVersions) > 0 {
+ // RFC 8446, Section 4.2.1
+ exts.AddUint16(extensionSupportedVersions)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, vers := range m.supportedVersions {
+ exts.AddUint16(vers)
+ }
+ })
+ })
+ }
+ if len(m.cookie) > 0 {
+ // RFC 8446, Section 4.2.2
+ exts.AddUint16(extensionCookie)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.cookie)
+ })
+ })
+ }
+ if len(m.keyShares) > 0 {
+ // RFC 8446, Section 4.2.8
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, ks := range m.keyShares {
+ exts.AddUint16(uint16(ks.group))
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(ks.data)
+ })
+ }
+ })
+ })
+ }
+ if m.earlyData {
+ // RFC 8446, Section 4.2.10
+ exts.AddUint16(extensionEarlyData)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if len(m.pskModes) > 0 {
+ // RFC 8446, Section 4.2.9
+ exts.AddUint16(extensionPSKModes)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.pskModes)
+ })
+ })
+ }
+ if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
+ // RFC 8446, Section 4.2.11
+ exts.AddUint16(extensionPreSharedKey)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, psk := range m.pskIdentities {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(psk.label)
+ })
+ exts.AddUint32(psk.obfuscatedTicketAge)
+ }
+ })
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, binder := range m.pskBinders {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(binder)
+ })
+ }
+ })
+ })
+ }
+ extBytes, err := exts.Bytes()
+ if err != nil {
+ return nil, err
}
var b cryptobyte.Builder
b.AddBytes(m.compressionMethods)
})
- // If extensions aren't present, omit them.
- var extensionsPresent bool
- bWithoutExtensions := *b
-
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- if len(m.serverName) > 0 {
- // RFC 6066, Section 3
- b.AddUint16(extensionServerName)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8(0) // name_type = host_name
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(m.serverName))
- })
- })
- })
- }
- if m.ocspStapling {
- // RFC 4366, Section 3.6
- b.AddUint16(extensionStatusRequest)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8(1) // status_type = ocsp
- b.AddUint16(0) // empty responder_id_list
- b.AddUint16(0) // empty request_extensions
- })
- }
- if len(m.supportedCurves) > 0 {
- // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
- b.AddUint16(extensionSupportedCurves)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, curve := range m.supportedCurves {
- b.AddUint16(uint16(curve))
- }
- })
- })
- }
- if len(m.supportedPoints) > 0 {
- // RFC 4492, Section 5.1.2
- b.AddUint16(extensionSupportedPoints)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.supportedPoints)
- })
- })
- }
- if m.ticketSupported {
- // RFC 5077, Section 3.2
- b.AddUint16(extensionSessionTicket)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.sessionTicket)
- })
- }
- if len(m.supportedSignatureAlgorithms) > 0 {
- // RFC 5246, Section 7.4.1.4.1
- b.AddUint16(extensionSignatureAlgorithms)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sigAlgo := range m.supportedSignatureAlgorithms {
- b.AddUint16(uint16(sigAlgo))
- }
- })
- })
- }
- if len(m.supportedSignatureAlgorithmsCert) > 0 {
- // RFC 8446, Section 4.2.3
- b.AddUint16(extensionSignatureAlgorithmsCert)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
- b.AddUint16(uint16(sigAlgo))
- }
- })
- })
- }
- if m.secureRenegotiationSupported {
- // RFC 5746, Section 3.2
- b.AddUint16(extensionRenegotiationInfo)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.secureRenegotiation)
- })
- })
- }
- if len(m.alpnProtocols) > 0 {
- // RFC 7301, Section 3.1
- b.AddUint16(extensionALPN)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, proto := range m.alpnProtocols {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(proto))
- })
- }
- })
- })
- }
- if m.scts {
- // RFC 6962, Section 3.3.1
- b.AddUint16(extensionSCT)
- b.AddUint16(0) // empty extension_data
- }
- if len(m.supportedVersions) > 0 {
- // RFC 8446, Section 4.2.1
- b.AddUint16(extensionSupportedVersions)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, vers := range m.supportedVersions {
- b.AddUint16(vers)
- }
- })
- })
- }
- if len(m.cookie) > 0 {
- // RFC 8446, Section 4.2.2
- b.AddUint16(extensionCookie)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.cookie)
- })
- })
- }
- if len(m.keyShares) > 0 {
- // RFC 8446, Section 4.2.8
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, ks := range m.keyShares {
- b.AddUint16(uint16(ks.group))
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(ks.data)
- })
- }
- })
- })
- }
- if m.earlyData {
- // RFC 8446, Section 4.2.10
- b.AddUint16(extensionEarlyData)
- b.AddUint16(0) // empty extension_data
- }
- if len(m.pskModes) > 0 {
- // RFC 8446, Section 4.2.9
- b.AddUint16(extensionPSKModes)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.pskModes)
- })
- })
- }
- if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
- // RFC 8446, Section 4.2.11
- b.AddUint16(extensionPreSharedKey)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, psk := range m.pskIdentities {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(psk.label)
- })
- b.AddUint32(psk.obfuscatedTicketAge)
- }
- })
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, binder := range m.pskBinders {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(binder)
- })
- }
- })
- })
- }
-
- extensionsPresent = len(b.BytesOrPanic()) > 2
- })
-
- if !extensionsPresent {
- *b = bWithoutExtensions
+ if len(extBytes) > 0 {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(extBytes)
+ })
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
// marshalWithoutBinders returns the ClientHello through the
// PreSharedKeyExtension.identities field, according to RFC 8446, Section
// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
-func (m *clientHelloMsg) marshalWithoutBinders() []byte {
+func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen := 2 // uint16 length prefix
for _, binder := range m.pskBinders {
bindersLen += 1 // uint8 length prefix
bindersLen += len(binder)
}
- fullMessage := m.marshal()
- return fullMessage[:len(fullMessage)-bindersLen]
+ fullMessage, err := m.marshal()
+ if err != nil {
+ return nil, err
+ }
+ return fullMessage[:len(fullMessage)-bindersLen], nil
}
// updateBinders updates the m.pskBinders field, if necessary updating the
// cached marshaled representation. The supplied binders must have the same
// length as the current m.pskBinders.
-func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
+func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
- panic("tls: internal error: pskBinders length mismatch")
+ return errors.New("tls: internal error: pskBinders length mismatch")
}
for i := range m.pskBinders {
if len(pskBinders[i]) != len(m.pskBinders[i]) {
- panic("tls: internal error: pskBinders length mismatch")
+ return errors.New("tls: internal error: pskBinders length mismatch")
}
}
m.pskBinders = pskBinders
if m.raw != nil {
- lenWithoutBinders := len(m.marshalWithoutBinders())
+ helloBytes, err := m.marshalWithoutBinders()
+ if err != nil {
+ return err
+ }
+ lenWithoutBinders := len(helloBytes)
b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
}
})
if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
- panic("tls: internal error: failed to update binders")
+ return errors.New("tls: internal error: failed to update binders")
}
}
+
+ return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
selectedGroup CurveID
}
-func (m *serverHelloMsg) marshal() []byte {
+func (m *serverHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
+ }
+
+ var exts cryptobyte.Builder
+ if m.ocspStapling {
+ exts.AddUint16(extensionStatusRequest)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if m.ticketSupported {
+ exts.AddUint16(extensionSessionTicket)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if m.secureRenegotiationSupported {
+ exts.AddUint16(extensionRenegotiationInfo)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocol) > 0 {
+ exts.AddUint16(extensionALPN)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(m.alpnProtocol))
+ })
+ })
+ })
+ }
+ if len(m.scts) > 0 {
+ exts.AddUint16(extensionSCT)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sct := range m.scts {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(sct)
+ })
+ }
+ })
+ })
+ }
+ if m.supportedVersion != 0 {
+ exts.AddUint16(extensionSupportedVersions)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(m.supportedVersion)
+ })
+ }
+ if m.serverShare.group != 0 {
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(uint16(m.serverShare.group))
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.serverShare.data)
+ })
+ })
+ }
+ if m.selectedIdentityPresent {
+ exts.AddUint16(extensionPreSharedKey)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(m.selectedIdentity)
+ })
+ }
+
+ if len(m.cookie) > 0 {
+ exts.AddUint16(extensionCookie)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.cookie)
+ })
+ })
+ }
+ if m.selectedGroup != 0 {
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(uint16(m.selectedGroup))
+ })
+ }
+ if len(m.supportedPoints) > 0 {
+ exts.AddUint16(extensionSupportedPoints)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.supportedPoints)
+ })
+ })
+ }
+
+ extBytes, err := exts.Bytes()
+ if err != nil {
+ return nil, err
}
var b cryptobyte.Builder
b.AddUint16(m.cipherSuite)
b.AddUint8(m.compressionMethod)
- // If extensions aren't present, omit them.
- var extensionsPresent bool
- bWithoutExtensions := *b
-
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- if m.ocspStapling {
- b.AddUint16(extensionStatusRequest)
- b.AddUint16(0) // empty extension_data
- }
- if m.ticketSupported {
- b.AddUint16(extensionSessionTicket)
- b.AddUint16(0) // empty extension_data
- }
- if m.secureRenegotiationSupported {
- b.AddUint16(extensionRenegotiationInfo)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.secureRenegotiation)
- })
- })
- }
- if len(m.alpnProtocol) > 0 {
- b.AddUint16(extensionALPN)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(m.alpnProtocol))
- })
- })
- })
- }
- if len(m.scts) > 0 {
- b.AddUint16(extensionSCT)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sct := range m.scts {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(sct)
- })
- }
- })
- })
- }
- if m.supportedVersion != 0 {
- b.AddUint16(extensionSupportedVersions)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(m.supportedVersion)
- })
- }
- if m.serverShare.group != 0 {
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(uint16(m.serverShare.group))
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.serverShare.data)
- })
- })
- }
- if m.selectedIdentityPresent {
- b.AddUint16(extensionPreSharedKey)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(m.selectedIdentity)
- })
- }
-
- if len(m.cookie) > 0 {
- b.AddUint16(extensionCookie)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.cookie)
- })
- })
- }
- if m.selectedGroup != 0 {
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(uint16(m.selectedGroup))
- })
- }
- if len(m.supportedPoints) > 0 {
- b.AddUint16(extensionSupportedPoints)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.supportedPoints)
- })
- })
- }
-
- extensionsPresent = len(b.BytesOrPanic()) > 2
- })
-
- if !extensionsPresent {
- *b = bWithoutExtensions
+ if len(extBytes) > 0 {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(extBytes)
+ })
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
alpnProtocol string
}
-func (m *encryptedExtensionsMsg) marshal() []byte {
+func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
type endOfEarlyDataMsg struct{}
-func (m *endOfEarlyDataMsg) marshal() []byte {
+func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeEndOfEarlyData
- return x
+ return x, nil
}
func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
updateRequested bool
}
-func (m *keyUpdateMsg) marshal() []byte {
+func (m *keyUpdateMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
maxEarlyData uint32
}
-func (m *newSessionTicketMsgTLS13) marshal() []byte {
+func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
certificateAuthorities [][]byte
}
-func (m *certificateRequestMsgTLS13) marshal() []byte {
+func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
certificates [][]byte
}
-func (m *certificateMsg) marshal() (x []byte) {
+func (m *certificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var i int
}
length := 3 + 3*len(m.certificates) + i
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeCertificate
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
}
m.raw = x
- return
+ return m.raw, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
scts bool
}
-func (m *certificateMsgTLS13) marshal() []byte {
+func (m *certificateMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
marshalCertificate(b, certificate)
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
key []byte
}
-func (m *serverKeyExchangeMsg) marshal() []byte {
+func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
length := len(m.key)
x := make([]byte, length+4)
copy(x[4:], m.key)
m.raw = x
- return x
+ return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
response []byte
}
-func (m *certificateStatusMsg) marshal() []byte {
+func (m *certificateStatusMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{}
-func (m *serverHelloDoneMsg) marshal() []byte {
+func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeServerHelloDone
- return x
+ return x, nil
}
func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
ciphertext []byte
}
-func (m *clientKeyExchangeMsg) marshal() []byte {
+func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
length := len(m.ciphertext)
x := make([]byte, length+4)
copy(x[4:], m.ciphertext)
m.raw = x
- return x
+ return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
verifyData []byte
}
-func (m *finishedMsg) marshal() []byte {
+func (m *finishedMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
b.AddBytes(m.verifyData)
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *finishedMsg) unmarshal(data []byte) bool {
certificateAuthorities [][]byte
}
-func (m *certificateRequestMsg) marshal() (x []byte) {
+func (m *certificateRequestMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
// See RFC 4346, Section 7.4.4.
length += 2 + 2*len(m.supportedSignatureAlgorithms)
}
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeCertificateRequest
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
}
m.raw = x
- return
+ return m.raw, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
signature []byte
}
-func (m *certificateVerifyMsg) marshal() (x []byte) {
+func (m *certificateVerifyMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
ticket []byte
}
-func (m *newSessionTicketMsg) marshal() (x []byte) {
+func (m *newSessionTicketMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeNewSessionTicket
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
m.raw = x
- return
+ return m.raw, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
type helloRequestMsg struct {
}
-func (*helloRequestMsg) marshal() []byte {
- return []byte{typeHelloRequest, 0, 0, 0}
+func (*helloRequestMsg) marshal() ([]byte, error) {
+ return []byte{typeHelloRequest, 0, 0, 0}, nil
}
func (*helloRequestMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
+
+type transcriptHash interface {
+ Write([]byte) (int, error)
+}
+
+// transcriptMsg is a helper used to marshal and hash messages which typically
+// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
+func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
+ data, err := msg.marshal()
+ if err != nil {
+ return err
+ }
+ h.Write(data)
+ return nil
+}
&certificateMsgTLS13{},
}
+func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
+ t.Helper()
+ b, err := msg.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return b
+}
+
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(time.Now().UnixNano()))
}
m1 := v.Interface().(handshakeMessage)
- marshaled := m1.marshal()
+ marshaled := mustMarshal(t, m1)
m2 := iface.(handshakeMessage)
if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42}
- serverHello := serverHelloMsg{
+ serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{sct},
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
// not be zero length.
var random [32]byte
- serverHello := serverHelloMsg{
+ serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{nil},
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
- msg, err := c.readHandshake()
+ // clientHelloMsg is included in the transcript, but we haven't initialized
+ // it yet. The respective handshake functions will record it themselves.
+ msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
hs.hello.ticketSupported = hs.sessionState.usedOldKey
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
- hs.finishedHash.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
- hs.finishedHash.Write(certStatus.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
return err
}
if skx != nil {
- hs.finishedHash.Write(skx.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
- hs.finishedHash.Write(certReq.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
- hs.finishedHash.Write(helloDone.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
pub = c.peerCertificates[0].PublicKey
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
- hs.finishedHash.Write(ckx.marshal())
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
- hs.finishedHash.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
+ return err
+ }
}
hs.finishedHash.discardHandshakeBuffer()
return err
}
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
return errors.New("tls: client's Finished message is incorrect")
}
- hs.finishedHash.Write(clientFinished.marshal())
+ if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
+ return err
+ }
+
copy(out, verify)
return nil
}
masterSecret: hs.masterSecret,
certificates: certsFromClient,
}
- var err error
- m.ticket, err = c.encryptTicket(state.marshal())
+ stateBytes, err := state.marshal()
+ if err != nil {
+ return err
+ }
+ m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
- hs.finishedHash.Write(m.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
- if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
- hs.finishedHash.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
testClientHelloFailure(t, serverConfig, m, "")
}
+// testFatal is a hack to prevent the compiler from complaining that there is a
+// call to t.Fatal from a non-test goroutine
+func testFatal(t *testing.T, err error) {
+ t.Helper()
+ t.Fatal(err)
+}
+
func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) {
c, s := localPipe(t)
go func() {
if ch, ok := m.(*clientHelloMsg); ok {
cli.vers = ch.vers
}
- cli.writeRecord(recordTypeHandshake, m.marshal())
+ if _, err := cli.writeHandshakeRecord(m, nil); err != nil {
+ testFatal(t, err)
+ }
c.Close()
}()
ctx := context.Background()
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
buf := make([]byte, 1024)
n, err := c.Read(buf)
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
- reply, err := cli.readHandshake()
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
+ reply, err := cli.readHandshake(nil)
c.Close()
if err != nil {
replyChan <- err
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
- reply, err := cli.readHandshake()
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
+ reply, err := cli.readHandshake(nil)
c.Close()
if err != nil {
replyChan <- err
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
c.Close()
}()
conn := Server(s, serverConfig)
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
- transcript.Write(hs.clientHello.marshalWithoutBinders())
+ clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
+ if err != nil {
+ c.sendAlert(alertInternalError)
+ return err
+ }
+ transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
}
hs.sentDummyCCS = true
- _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
- return err
+ return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
- hs.transcript.Write(hs.clientHello.marshal())
+ if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
+ return err
+ }
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
selectedGroup: selectedGroup,
}
- hs.transcript.Write(helloRetryRequest.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return err
}
return err
}
- msg, err := c.readHandshake()
+ // clientHelloMsg is not included in the transcript.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
- hs.transcript.Write(hs.clientHello.marshal())
- hs.transcript.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
- hs.transcript.Write(encryptedExtensions.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err
}
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
- hs.transcript.Write(certReq.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
- hs.transcript.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
}
certVerifyMsg.signature = sig
- hs.transcript.Write(certVerifyMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
- hs.transcript.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
- hs.transcript.Write(finishedMsg.marshal())
+ if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
+ return err
+ }
if !hs.shouldSendSessionTickets() {
return nil
SignedCertificateTimestamps: c.scts,
},
}
- var err error
- m.label, err = c.encryptTicket(state.marshal())
+ stateBytes, err := state.marshal()
+ if err != nil {
+ c.sendAlert(alertInternalError)
+ return err
+ }
+ m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
- if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.transcript.Write(certMsg.marshal())
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
}
if len(certMsg.certificate.Certificate) != 0 {
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
- hs.transcript.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, hs.transcript); err != nil {
+ return err
+ }
}
// If we waited until the client certificates to send session tickets, we
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
- msg, err := c.readHandshake()
+ // finishedMsg is not included in the transcript.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
"crypto/ecdh"
"crypto/hmac"
"errors"
+ "fmt"
"hash"
"io"
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
+ hkdfLabelBytes, err := hkdfLabel.Bytes()
+ if err != nil {
+ // Rather than calling BytesOrPanic, we explicitly handle this error, in
+ // order to provide a reasonable error message. It should be basically
+ // impossible for this to panic, and routing errors back through the
+ // tree rooted in this function is quite painful. The labels are fixed
+ // size, and the context is either a fixed-length computed hash, or
+ // parsed from a field which has the same length limitation. As such, an
+ // error here is likely to only be caused during development.
+ //
+ // NOTE: another reasonable approach here might be to return a
+ // randomized slice if we encounter an error, which would break the
+ // connection, but avoid panicking. This would perhaps be safer but
+ // significantly more confusing to users.
+ panic(fmt.Errorf("failed to construct HKDF label: %s", err))
+ }
out := make([]byte, length)
- n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out)
+ n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
usedOldKey bool
}
-func (m *sessionState) marshal() []byte {
+func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
})
}
})
- return b.BytesOrPanic()
+ return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
}
-func (m *sessionStateTLS13) marshal() []byte {
+func (m *sessionStateTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(0) // revision
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
- return b.BytesOrPanic()
+ return b.Bytes()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {