]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add server-side ECH
authorRoland Shoemaker <roland@golang.org>
Wed, 30 Oct 2024 03:22:27 +0000 (20:22 -0700)
committerGopher Robot <gobot@golang.org>
Thu, 21 Nov 2024 22:50:04 +0000 (22:50 +0000)
Adds support for server-side ECH.

We make a couple of implementation decisions that are not completely
in-line with the spec. In particular, we don't enforce that the SNI
matches the ECHConfig public_name, and we implement a hybrid
shared/backend mode (rather than shared or split mode, as described in
Section 7). Both of these match the behavior of BoringSSL.

The hybrid server mode will either act as a shared mode server, where-in
the server accepts "outer" client hellos and unwraps them before
processing the "inner" hello, or accepts bare "inner" hellos initially.
This lets the server operate either transparently as a shared mode
server, or a backend server, in Section 7 terminology. This seems like
the best implementation choice for a TLS library.

Fixes #68500

Change-Id: Ife69db7c1886610742e95e76b0ca92587e6d7ed4
Reviewed-on: https://go-review.googlesource.com/c/go/+/623576
Reviewed-by: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
14 files changed:
api/next/68500.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/crypto/tls/68500.md [new file with mode: 0644]
src/crypto/tls/bogo_config.json
src/crypto/tls/bogo_shim_test.go
src/crypto/tls/common.go
src/crypto/tls/ech.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/handshake_server.go
src/crypto/tls/handshake_server_test.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/tls_test.go

diff --git a/api/next/68500.txt b/api/next/68500.txt
new file mode 100644 (file)
index 0000000..6c979c4
--- /dev/null
@@ -0,0 +1,5 @@
+pkg crypto/tls, type Config struct, EncryptedClientHelloKeys []EncryptedClientHelloKey #68500
+pkg crypto/tls, type EncryptedClientHelloKey struct #68500
+pkg crypto/tls, type EncryptedClientHelloKey struct, Config []uint8 #68500
+pkg crypto/tls, type EncryptedClientHelloKey struct, PrivateKey []uint8 #68500
+pkg crypto/tls, type EncryptedClientHelloKey struct, SendAsRetry bool #68500
diff --git a/doc/next/6-stdlib/99-minor/crypto/tls/68500.md b/doc/next/6-stdlib/99-minor/crypto/tls/68500.md
new file mode 100644 (file)
index 0000000..f1618ca
--- /dev/null
@@ -0,0 +1,2 @@
+The TLS server now supports Encrypted Client Hello (ECH). This feature can be
+enabled by populating the [Config.EncryptedClientHelloKeys] field.
\ No newline at end of file
index 2363dd5d659accd9560c560b97bbd48bfd92b4d2..cfd95792acc3dc856e2e04d73f6329d363f271bf 100644 (file)
@@ -12,7 +12,7 @@
 
         "TLS-ECH-Client-Reject-ResumeInnerSession-TLS12": "We won't attempt to negotiate 1.2 if ECH is enabled (we could possibly test this if we had the ability to indicate not to send ECH on resumption?)",
 
-        "TLS-ECH-Client-Reject-EarlyDataRejected": "We don't support switiching out ECH configs with this level of granularity",
+        "TLS-ECH-Client-Reject-EarlyDataRejected": "Go does not support early (0-RTT) data",
 
         "TLS-ECH-Client-NoNPN": "We don't support NPN",
 
         "TLS-ECH-Client-NoSupportedConfigs": "We don't support fallback to cleartext when there are no valid ECH configs",
         "TLS-ECH-Client-SkipInvalidPublicName": "We don't support fallback to cleartext when there are no valid ECH configs",
 
+        "TLS-ECH-Server-EarlyData": "Go does not support early (0-RTT) data",
+        "TLS-ECH-Server-EarlyDataRejected": "Go does not support early (0-RTT) data",
+
+        "CurveTest-Client-Kyber-TLS13": "Temporarily disabled since the curve ID is not exposed and it cannot be correctly configured",
+        "CurveTest-Server-Kyber-TLS13": "Temporarily disabled since the curve ID is not exposed and it cannot be correctly configured",
 
-        "*ECH-Server*": "no ECH server support",
         "SendV2ClientHello*": "We don't support SSLv2",
         "*QUIC*": "No QUIC support",
         "Compliance-fips*": "No FIPS",
         "EarlyData-UnexpectedHandshake-Server-TLS13": "TODO: first pass, this should be fixed",
         "EarlyData-CipherMismatch-Client-TLS13": "TODO: first pass, this should be fixed",
         "Resume-Server-UnofferedCipher-TLS13": "TODO: first pass, this should be fixed"
-    }
+    },
+    "AllCurves": [
+        23,
+        24,
+        25,
+        29
+    ]
 }
index bb78c6459b6d2fb5d0a755ce54f4d58856eebdbc..b294608b3ca857cce0c3f56de540190ea7997399 100644 (file)
@@ -76,6 +76,9 @@ var (
        onResumeExpectECHAccepted  = flag.Bool("on-resume-expect-ech-accept", false, "")
        _                          = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
        expectedServerName         = flag.String("expect-server-name", "", "")
+       echServerConfig            = flagStringSlice("ech-server-config", "")
+       echServerKey               = flagStringSlice("ech-server-key", "")
+       echServerRetryConfig       = flagStringSlice("ech-is-retry-config", "")
 
        expectSessionMiss = flag.Bool("expect-session-miss", false, "")
 
@@ -105,12 +108,12 @@ func flagStringSlice(name, usage string) *stringSlice {
        return f
 }
 
-func (saf stringSlice) String() string {
-       return strings.Join(saf, ",")
+func (saf *stringSlice) String() string {
+       return strings.Join(*saf, ",")
 }
 
-func (saf stringSlice) Set(s string) error {
-       saf = append(saf, s)
+func (saf *stringSlice) Set(s string) error {
+       *saf = append(*saf, s)
        return nil
 }
 
@@ -248,6 +251,29 @@ func bogoShim() {
                }
        }
 
+       if len(*echServerConfig) != 0 {
+               if len(*echServerConfig) != len(*echServerKey) || len(*echServerConfig) != len(*echServerRetryConfig) {
+                       log.Fatal("-ech-server-config, -ech-server-key, and -ech-is-retry-config mismatch")
+               }
+
+               for i, c := range *echServerConfig {
+                       configBytes, err := base64.StdEncoding.DecodeString(c)
+                       if err != nil {
+                               log.Fatalf("parse ech-server-config err: %s", err)
+                       }
+                       privBytes, err := base64.StdEncoding.DecodeString((*echServerKey)[i])
+                       if err != nil {
+                               log.Fatalf("parse ech-server-key err: %s", err)
+                       }
+
+                       cfg.EncryptedClientHelloKeys = append(cfg.EncryptedClientHelloKeys, EncryptedClientHelloKey{
+                               Config:      configBytes,
+                               PrivateKey:  privBytes,
+                               SendAsRetry: (*echServerRetryConfig)[i] == "1",
+                       })
+               }
+       }
+
        for i := 0; i < *resumeCount+1; i++ {
                if i > 0 && (*onResumeECHConfigListB64 != "") {
                        echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
@@ -446,8 +472,11 @@ func TestBogoSuite(t *testing.T) {
        // are present in the output. They are only checked if -bogo-filter
        // was not passed.
        assertResults := map[string]string{
-               "CurveTest-Client-Kyber-TLS13": "PASS",
-               "CurveTest-Server-Kyber-TLS13": "PASS",
+               // TODO: these tests are temporarily disabled, since we don't expose the
+               // necessary curve ID, and it's currently not possible to correctly
+               // configure it.
+               // "CurveTest-Client-Kyber-TLS13": "PASS",
+               // "CurveTest-Server-Kyber-TLS13": "PASS",
        }
 
        for name, result := range results.Tests {
index 1f73e50d24870b5816d481e367b71c19596424a8..56f2acf520f29159fd3b732ea1b73129ff54d1d1 100644 (file)
@@ -791,8 +791,10 @@ type Config struct {
 
        // EncryptedClientHelloConfigList is a serialized ECHConfigList. If
        // provided, clients will attempt to connect to servers using Encrypted
-       // Client Hello (ECH) using one of the provided ECHConfigs. Servers
-       // currently ignore this field.
+       // Client Hello (ECH) using one of the provided ECHConfigs.
+       //
+       // Servers do not use this field. In order to configure ECH for servers, see
+       // the EncryptedClientHelloKeys field.
        //
        // If the list contains no valid ECH configs, the handshake will fail
        // and return an error.
@@ -810,9 +812,11 @@ type Config struct {
        EncryptedClientHelloConfigList []byte
 
        // EncryptedClientHelloRejectionVerify, if not nil, is called when ECH is
-       // rejected, in order to verify the ECH provider certificate in the outer
-       // Client Hello. If it returns a non-nil error, the handshake is aborted and
-       // that error results.
+       // rejected by the remote server, in order to verify the ECH provider
+       // certificate in the outer Client Hello. If it returns a non-nil error, the
+       // handshake is aborted and that error results.
+       //
+       // On the server side this field is not used.
        //
        // Unlike VerifyPeerCertificate and VerifyConnection, normal certificate
        // verification will not be performed before calling
@@ -824,6 +828,20 @@ type Config struct {
        // when ECH is rejected, even if set, and InsecureSkipVerify is ignored.
        EncryptedClientHelloRejectionVerify func(ConnectionState) error
 
+       // EncryptedClientHelloKeys are the ECH keys to use when a client
+       // attempts ECH.
+       //
+       // If EncryptedClientHelloKeys is set, MinVersion, if set, must be
+       // VersionTLS13.
+       //
+       // If a client attempts ECH, but it is rejected by the server, the server
+       // will send a list of configs to retry based on the set of
+       // EncryptedClientHelloKeys which have the SendAsRetry field set.
+       //
+       // On the client side, this field is ignored. In order to configure ECH for
+       // clients, see the EncryptedClientHelloConfigList field.
+       EncryptedClientHelloKeys []EncryptedClientHelloKey
+
        // mutex protects sessionTicketKeys and autoSessionTicketKeys.
        mutex sync.RWMutex
        // sessionTicketKeys contains zero or more ticket keys. If set, it means
@@ -837,6 +855,24 @@ type Config struct {
        autoSessionTicketKeys []ticketKey
 }
 
+// EncryptedClientHelloKey holds a private key that is associated
+// with a specific ECH config known to a client.
+type EncryptedClientHelloKey struct {
+       // Config should be a marshalled ECHConfig associated with PrivateKey. This
+       // must match the config provided to clients byte-for-byte. The config
+       // should only specify the DHKEM(X25519, HKDF-SHA256) KEM ID (0x0020), the
+       // HKDF-SHA256 KDF ID (0x0001), and a subset of the following AEAD IDs:
+       // AES-128-GCM (0x0000), AES-256-GCM (0x0001), ChaCha20Poly1305 (0x0002).
+       Config []byte
+       // PrivateKey should be a marshalled private key. Currently, we expect
+       // this to be the output of [ecdh.PrivateKey.Bytes].
+       PrivateKey []byte
+       // SendAsRetry indicates if Config should be sent as part of the list of
+       // retry configs when ECH is requested by the client but rejected by the
+       // server.
+       SendAsRetry bool
+}
+
 const (
        // ticketKeyLifetime is how long a ticket key remains valid and can be used to
        // resume a client connection.
@@ -913,6 +949,7 @@ func (c *Config) Clone() *Config {
                KeyLogWriter:                        c.KeyLogWriter,
                EncryptedClientHelloConfigList:      c.EncryptedClientHelloConfigList,
                EncryptedClientHelloRejectionVerify: c.EncryptedClientHelloRejectionVerify,
+               EncryptedClientHelloKeys:            c.EncryptedClientHelloKeys,
                sessionTicketKeys:                   c.sessionTicketKeys,
                autoSessionTicketKeys:               c.autoSessionTicketKeys,
        }
index 7bf68589f8702c474da6eb5838e1f7ee493aab21..55d52179c2c823d5176ba777a99cb2e2b5c4e1e9 100644 (file)
@@ -5,13 +5,28 @@
 package tls
 
 import (
+       "bytes"
        "crypto/internal/hpke"
        "errors"
+       "fmt"
+       "slices"
        "strings"
 
        "golang.org/x/crypto/cryptobyte"
 )
 
+// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS.
+// We need this so that when we insert them into ECHConfigs the ordering
+// is stable.
+var sortedSupportedAEADs []uint16
+
+func init() {
+       for aeadID := range hpke.SupportedAEADs {
+               sortedSupportedAEADs = append(sortedSupportedAEADs, aeadID)
+       }
+       slices.Sort(sortedSupportedAEADs)
+}
+
 type echCipher struct {
        KDFID  uint16
        AEADID uint16
@@ -40,12 +55,77 @@ type echConfig struct {
 
 var errMalformedECHConfig = errors.New("tls: malformed ECHConfigList")
 
+func parseECHConfig(enc []byte) (skip bool, ec echConfig, err error) {
+       s := cryptobyte.String(enc)
+       ec.raw = []byte(enc)
+       if !s.ReadUint16(&ec.Version) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       if !s.ReadUint16(&ec.Length) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       if len(ec.raw) < int(ec.Length)+4 {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       ec.raw = ec.raw[:ec.Length+4]
+       if ec.Version != extensionEncryptedClientHello {
+               s.Skip(int(ec.Length))
+               return true, echConfig{}, nil
+       }
+       if !s.ReadUint8(&ec.ConfigID) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       if !s.ReadUint16(&ec.KemID) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       if !readUint16LengthPrefixed(&s, &ec.PublicKey) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       var cipherSuites cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&cipherSuites) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       for !cipherSuites.Empty() {
+               var c echCipher
+               if !cipherSuites.ReadUint16(&c.KDFID) {
+                       return false, echConfig{}, errMalformedECHConfig
+               }
+               if !cipherSuites.ReadUint16(&c.AEADID) {
+                       return false, echConfig{}, errMalformedECHConfig
+               }
+               ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
+       }
+       if !s.ReadUint8(&ec.MaxNameLength) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       var publicName cryptobyte.String
+       if !s.ReadUint8LengthPrefixed(&publicName) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       ec.PublicName = publicName
+       var extensions cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&extensions) {
+               return false, echConfig{}, errMalformedECHConfig
+       }
+       for !extensions.Empty() {
+               var e echExtension
+               if !extensions.ReadUint16(&e.Type) {
+                       return false, echConfig{}, errMalformedECHConfig
+               }
+               if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
+                       return false, echConfig{}, errMalformedECHConfig
+               }
+               ec.Extensions = append(ec.Extensions, e)
+       }
+
+       return false, ec, nil
+}
+
 // parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
 // slice of parsed ECHConfigs, in the same order they were parsed, or an error
 // if the list is malformed.
 func parseECHConfigList(data []byte) ([]echConfig, error) {
        s := cryptobyte.String(data)
-       // Skip the length prefix
        var length uint16
        if !s.ReadUint16(&length) {
                return nil, errMalformedECHConfig
@@ -55,69 +135,18 @@ func parseECHConfigList(data []byte) ([]echConfig, error) {
        }
        var configs []echConfig
        for len(s) > 0 {
-               var ec echConfig
-               ec.raw = []byte(s)
-               if !s.ReadUint16(&ec.Version) {
-                       return nil, errMalformedECHConfig
-               }
-               if !s.ReadUint16(&ec.Length) {
-                       return nil, errMalformedECHConfig
-               }
-               if len(ec.raw) < int(ec.Length)+4 {
-                       return nil, errMalformedECHConfig
-               }
-               ec.raw = ec.raw[:ec.Length+4]
-               if ec.Version != extensionEncryptedClientHello {
-                       s.Skip(int(ec.Length))
-                       continue
-               }
-               if !s.ReadUint8(&ec.ConfigID) {
-                       return nil, errMalformedECHConfig
-               }
-               if !s.ReadUint16(&ec.KemID) {
-                       return nil, errMalformedECHConfig
-               }
-               if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
-                       return nil, errMalformedECHConfig
-               }
-               var cipherSuites cryptobyte.String
-               if !s.ReadUint16LengthPrefixed(&cipherSuites) {
-                       return nil, errMalformedECHConfig
-               }
-               for !cipherSuites.Empty() {
-                       var c echCipher
-                       if !cipherSuites.ReadUint16(&c.KDFID) {
-                               return nil, errMalformedECHConfig
-                       }
-                       if !cipherSuites.ReadUint16(&c.AEADID) {
-                               return nil, errMalformedECHConfig
-                       }
-                       ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
-               }
-               if !s.ReadUint8(&ec.MaxNameLength) {
-                       return nil, errMalformedECHConfig
+               if len(s) < 4 {
+                       return nil, errors.New("tls: malformed ECHConfig")
                }
-               var publicName cryptobyte.String
-               if !s.ReadUint8LengthPrefixed(&publicName) {
-                       return nil, errMalformedECHConfig
+               configLen := uint16(s[2])<<8 | uint16(s[3])
+               skip, ec, err := parseECHConfig(s)
+               if err != nil {
+                       return nil, err
                }
-               ec.PublicName = publicName
-               var extensions cryptobyte.String
-               if !s.ReadUint16LengthPrefixed(&extensions) {
-                       return nil, errMalformedECHConfig
+               s = s[configLen+4:]
+               if !skip {
+                       configs = append(configs, ec)
                }
-               for !extensions.Empty() {
-                       var e echExtension
-                       if !extensions.ReadUint16(&e.Type) {
-                               return nil, errMalformedECHConfig
-                       }
-                       if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
-                               return nil, errMalformedECHConfig
-                       }
-                       ec.Extensions = append(ec.Extensions, e)
-               }
-
-               configs = append(configs, ec)
        }
        return configs, nil
 }
@@ -195,6 +224,175 @@ func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, e
        return append(h, make([]byte, paddingLen)...), nil
 }
 
+func skipUint8LengthPrefixed(s *cryptobyte.String) bool {
+       var skip uint8
+       if !s.ReadUint8(&skip) {
+               return false
+       }
+       return s.Skip(int(skip))
+}
+
+func skipUint16LengthPrefixed(s *cryptobyte.String) bool {
+       var skip uint16
+       if !s.ReadUint16(&skip) {
+               return false
+       }
+       return s.Skip(int(skip))
+}
+
+type rawExtension struct {
+       extType uint16
+       data    []byte
+}
+
+func extractRawExtensions(hello *clientHelloMsg) ([]rawExtension, error) {
+       s := cryptobyte.String(hello.original)
+       if !s.Skip(4+2+32) || // header, version, random
+               !skipUint8LengthPrefixed(&s) || // session ID
+               !skipUint16LengthPrefixed(&s) || // cipher suites
+               !skipUint8LengthPrefixed(&s) { // compression methods
+               return nil, errors.New("tls: malformed outer client hello")
+       }
+       var rawExtensions []rawExtension
+       var extensions cryptobyte.String
+       if !s.ReadUint16LengthPrefixed(&extensions) {
+               return nil, errors.New("tls: malformed outer client hello")
+       }
+
+       for !extensions.Empty() {
+               var extension uint16
+               var extData cryptobyte.String
+               if !extensions.ReadUint16(&extension) ||
+                       !extensions.ReadUint16LengthPrefixed(&extData) {
+                       return nil, errors.New("tls: invalid inner client hello")
+               }
+               rawExtensions = append(rawExtensions, rawExtension{extension, extData})
+       }
+       return rawExtensions, nil
+}
+
+func decodeInnerClientHello(outer *clientHelloMsg, encoded []byte) (*clientHelloMsg, error) {
+       // Reconstructing the inner client hello from its encoded form is somewhat
+       // complicated. It is missing its header (message type and length), session
+       // ID, and the extensions may be compressed. Since we need to put the
+       // extensions back in the same order as they were in the raw outer hello,
+       // and since we don't store the raw extensions, or the order we parsed them
+       // in, we need to reparse the raw extensions from the outer hello in order
+       // to properly insert them into the inner hello. This _should_ result in raw
+       // bytes which match the hello as it was generated by the client.
+       innerReader := cryptobyte.String(encoded)
+       var versionAndRandom, sessionID, cipherSuites, compressionMethods []byte
+       var extensions cryptobyte.String
+       if !innerReader.ReadBytes(&versionAndRandom, 2+32) ||
+               !readUint8LengthPrefixed(&innerReader, &sessionID) ||
+               len(sessionID) != 0 ||
+               !readUint16LengthPrefixed(&innerReader, &cipherSuites) ||
+               !readUint8LengthPrefixed(&innerReader, &compressionMethods) ||
+               !innerReader.ReadUint16LengthPrefixed(&extensions) {
+               return nil, errors.New("tls: invalid inner client hello")
+       }
+
+       // The specification says we must verify that the trailing padding is all
+       // zeros. This is kind of weird for TLS messages, where we generally just
+       // throw away any trailing garbage.
+       for _, p := range innerReader {
+               if p != 0 {
+                       return nil, errors.New("tls: invalid inner client hello")
+               }
+       }
+
+       rawOuterExts, err := extractRawExtensions(outer)
+       if err != nil {
+               return nil, err
+       }
+
+       recon := cryptobyte.NewBuilder(nil)
+       recon.AddUint8(typeClientHello)
+       recon.AddUint24LengthPrefixed(func(recon *cryptobyte.Builder) {
+               recon.AddBytes(versionAndRandom)
+               recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
+                       recon.AddBytes(outer.sessionId)
+               })
+               recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
+                       recon.AddBytes(cipherSuites)
+               })
+               recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
+                       recon.AddBytes(compressionMethods)
+               })
+               recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
+                       for !extensions.Empty() {
+                               var extension uint16
+                               var extData cryptobyte.String
+                               if !extensions.ReadUint16(&extension) ||
+                                       !extensions.ReadUint16LengthPrefixed(&extData) {
+                                       recon.SetError(errors.New("tls: invalid inner client hello"))
+                                       return
+                               }
+                               if extension == extensionECHOuterExtensions {
+                                       if !extData.ReadUint8LengthPrefixed(&extData) {
+                                               recon.SetError(errors.New("tls: invalid inner client hello"))
+                                               return
+                                       }
+                                       var i int
+                                       for !extData.Empty() {
+                                               var extType uint16
+                                               if !extData.ReadUint16(&extType) {
+                                                       recon.SetError(errors.New("tls: invalid inner client hello"))
+                                                       return
+                                               }
+                                               if extType == extensionEncryptedClientHello {
+                                                       recon.SetError(errors.New("tls: invalid outer extensions"))
+                                                       return
+                                               }
+                                               for ; i <= len(rawOuterExts); i++ {
+                                                       if i == len(rawOuterExts) {
+                                                               recon.SetError(errors.New("tls: invalid outer extensions"))
+                                                               return
+                                                       }
+                                                       if rawOuterExts[i].extType == extType {
+                                                               break
+                                                       }
+                                               }
+                                               recon.AddUint16(rawOuterExts[i].extType)
+                                               recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
+                                                       recon.AddBytes(rawOuterExts[i].data)
+                                               })
+                                       }
+                               } else {
+                                       recon.AddUint16(extension)
+                                       recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
+                                               recon.AddBytes(extData)
+                                       })
+                               }
+                       }
+               })
+       })
+
+       reconBytes, err := recon.Bytes()
+       if err != nil {
+               return nil, err
+       }
+       inner := &clientHelloMsg{}
+       if !inner.unmarshal(reconBytes) {
+               return nil, errors.New("tls: invalid reconstructed inner client hello")
+       }
+
+       if !bytes.Equal(inner.encryptedClientHello, []byte{uint8(innerECHExt)}) {
+               return nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
+       }
+
+       if len(inner.supportedVersions) != 1 || (len(inner.supportedVersions) >= 1 && inner.supportedVersions[0] != VersionTLS13) {
+               return nil, errors.New("tls: client sent encrypted_client_hello extension and offered incompatible versions")
+       }
+
+       return inner, nil
+}
+
+func decryptECHPayload(context *hpke.Receipient, hello, payload []byte) ([]byte, error) {
+       outerAAD := bytes.Replace(hello[4:], payload, make([]byte, len(payload)), 1)
+       return context.Open(outerAAD, payload)
+}
+
 func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payload []byte) ([]byte, error) {
        var b cryptobyte.Builder
        b.AddUint8(0) // outer
@@ -206,7 +404,7 @@ func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payl
        return b.Bytes()
 }
 
-func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echContext, useKey bool) error {
+func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echClientContext, useKey bool) error {
        var encapKey []byte
        if useKey {
                encapKey = ech.encapsulatedKey
@@ -281,3 +479,153 @@ type ECHRejectionError struct {
 func (e *ECHRejectionError) Error() string {
        return "tls: server rejected ECH"
 }
+
+var errMalformedECHExt = errors.New("tls: malformed encrypted_client_hello extension")
+
+type echExtType uint8
+
+const (
+       innerECHExt echExtType = 1
+       outerECHExt echExtType = 0
+)
+
+func parseECHExt(ext []byte) (echType echExtType, cs echCipher, configID uint8, encap []byte, payload []byte, err error) {
+       data := make([]byte, len(ext))
+       copy(data, ext)
+       s := cryptobyte.String(data)
+       var echInt uint8
+       if !s.ReadUint8(&echInt) {
+               err = errMalformedECHExt
+               return
+       }
+       echType = echExtType(echInt)
+       if echType == innerECHExt {
+               if !s.Empty() {
+                       err = errMalformedECHExt
+                       return
+               }
+               return echType, cs, 0, nil, nil, nil
+       }
+       if echType != outerECHExt {
+               err = errMalformedECHExt
+               return
+       }
+       if !s.ReadUint16(&cs.KDFID) {
+               err = errMalformedECHExt
+               return
+       }
+       if !s.ReadUint16(&cs.AEADID) {
+               err = errMalformedECHExt
+               return
+       }
+       if !s.ReadUint8(&configID) {
+               err = errMalformedECHExt
+               return
+       }
+       if !readUint16LengthPrefixed(&s, &encap) {
+               err = errMalformedECHExt
+               return
+       }
+       if !readUint16LengthPrefixed(&s, &payload) {
+               err = errMalformedECHExt
+               return
+       }
+
+       // NOTE: clone encap and payload so that mutating them does not mutate the
+       // raw extension bytes.
+       return echType, cs, configID, bytes.Clone(encap), bytes.Clone(payload), nil
+}
+
+func marshalEncryptedClientHelloConfigList(configs []EncryptedClientHelloKey) ([]byte, error) {
+       builder := cryptobyte.NewBuilder(nil)
+       builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
+               for _, c := range configs {
+                       builder.AddBytes(c.Config)
+               }
+       })
+       return builder.Bytes()
+}
+
+func (c *Conn) processECHClientHello(outer *clientHelloMsg) (*clientHelloMsg, *echServerContext, error) {
+       echType, echCiphersuite, configID, encap, payload, err := parseECHExt(outer.encryptedClientHello)
+       if err != nil {
+               c.sendAlert(alertDecodeError)
+               return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
+       }
+
+       if echType == innerECHExt {
+               return outer, &echServerContext{inner: true}, nil
+       }
+
+       if len(c.config.EncryptedClientHelloKeys) == 0 {
+               return outer, nil, nil
+       }
+
+       for _, echKey := range c.config.EncryptedClientHelloKeys {
+               skip, config, err := parseECHConfig(echKey.Config)
+               if err != nil || skip {
+                       c.sendAlert(alertInternalError)
+                       return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", err)
+               }
+               if skip {
+                       continue
+               }
+               echPriv, err := hpke.ParseHPKEPrivateKey(config.KemID, echKey.PrivateKey)
+               if err != nil {
+                       c.sendAlert(alertInternalError)
+                       return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", err)
+               }
+               info := append([]byte("tls ech\x00"), echKey.Config...)
+               hpkeContext, err := hpke.SetupReceipient(hpke.DHKEM_X25519_HKDF_SHA256, echCiphersuite.KDFID, echCiphersuite.AEADID, echPriv, info, encap)
+               if err != nil {
+                       // attempt next trial decryption
+                       continue
+               }
+
+               encodedInner, err := decryptECHPayload(hpkeContext, outer.original, payload)
+               if err != nil {
+                       // attempt next trial decryption
+                       continue
+               }
+
+               // NOTE: we do not enforce that the sent server_name matches the ECH
+               // configs PublicName, since this is not particularly important, and
+               // the client already had to know what it was in order to properly
+               // encrypt the payload. This is only a MAY in the spec, so we're not
+               // doing anything revolutionary.
+
+               echInner, err := decodeInnerClientHello(outer, encodedInner)
+               if err != nil {
+                       c.sendAlert(alertIllegalParameter)
+                       return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
+               }
+
+               c.echAccepted = true
+
+               return echInner, &echServerContext{
+                       hpkeContext: hpkeContext,
+                       configID:    configID,
+                       ciphersuite: echCiphersuite,
+               }, nil
+       }
+
+       return outer, nil, nil
+}
+
+func buildRetryConfigList(keys []EncryptedClientHelloKey) ([]byte, error) {
+       var atLeastOneRetryConfig bool
+       var retryBuilder cryptobyte.Builder
+       retryBuilder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+               for _, c := range keys {
+                       if !c.SendAsRetry {
+                               continue
+                       }
+                       atLeastOneRetryConfig = true
+                       b.AddBytes(c.Config)
+               }
+       })
+       if !atLeastOneRetryConfig {
+               return nil, nil
+       }
+       return retryBuilder.Bytes()
+}
index 2ee1136b790120acb20577ce34fffc77ec5c5586..548b5f0acdb1ae4036c74fa6dbfdea2875959ce5 100644 (file)
@@ -43,7 +43,7 @@ type clientHandshakeState struct {
 
 var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme
 
-func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echContext, error) {
+func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echClientContext, error) {
        config := c.config
        if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
                return nil, nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
@@ -201,7 +201,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
                hello.quicTransportParameters = p
        }
 
-       var ech *echContext
+       var ech *echClientContext
        if c.config.EncryptedClientHelloConfigList != nil {
                if c.config.MinVersion != 0 && c.config.MinVersion < VersionTLS13 {
                        return nil, nil, nil, errors.New("tls: MinVersion must be >= VersionTLS13 if EncryptedClientHelloConfigList is populated")
@@ -217,7 +217,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
                if echConfig == nil {
                        return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs")
                }
-               ech = &echContext{config: echConfig}
+               ech = &echClientContext{config: echConfig}
                hello.encryptedClientHello = []byte{1} // indicate inner hello
                // We need to explicitly set these 1.2 fields to nil, as we do not
                // marshal them when encoding the inner hello, otherwise transcripts
@@ -246,7 +246,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
        return hello, keyShareKeys, ech, nil
 }
 
-type echContext struct {
+type echClientContext struct {
        config          *echConfig
        hpkeContext     *hpke.Sender
        encapsulatedKey []byte
index 53f16651661d57aa1321205ce0da49cfac822ef7..3f4cadb675ea225aff1ad08991ec4b0aac83517c 100644 (file)
@@ -39,7 +39,7 @@ type clientHandshakeStateTLS13 struct {
        masterSecret  *tls13.MasterSecret
        trafficSecret []byte // client_application_traffic_secret_0
 
-       echContext *echContext
+       echContext *echClientContext
 }
 
 // handshake requires hs.c, hs.hello, hs.serverHello, hs.keyShareKeys, and,
@@ -105,7 +105,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
 
                        if hs.serverHello.encryptedClientHello != nil {
                                c.sendAlert(alertUnsupportedExtension)
-                               return errors.New("tls: unexpected encrypted_client_hello extension in server hello despite ECH being accepted")
+                               return errors.New("tls: unexpected encrypted client hello extension in server hello despite ECH being accepted")
                        }
 
                        if hs.hello.serverName == "" && hs.serverHello.serverNameAck {
@@ -288,7 +288,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
        } else if hs.serverHello.encryptedClientHello != nil {
                // Unsolicited ECH extension should be rejected
                c.sendAlert(alertUnsupportedExtension)
-               return errors.New("tls: unexpected ECH extension in serverHello")
+               return errors.New("tls: unexpected encrypted client hello extension in serverHello")
        }
 
        // The only HelloRetryRequest extensions we support are key_share and
@@ -604,7 +604,7 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
        }
        if hs.echContext != nil && !hs.echContext.echRejected && encryptedExtensions.echRetryConfigs != nil {
                c.sendAlert(alertUnsupportedExtension)
-               return errors.New("tls: server sent ECH retry configs after accepting ECH")
+               return errors.New("tls: server sent encrypted client hello retry configs after accepting encrypted client hello")
        }
 
        return nil
index 823caff603ef1e7e78431b2041643a7096df25b7..fa00d7b741100eac23a40a9096f5e44c3e1dcbff 100644 (file)
@@ -97,7 +97,7 @@ type clientHelloMsg struct {
        pskBinders                       [][]byte
        quicTransportParameters          []byte
        encryptedClientHello             []byte
-       // extensions are only populated on the server-side of a handshake
+       // extensions are only populated on the servers-ide of a handshake
        extensions []uint16
 }
 
@@ -662,6 +662,10 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
                                }
                                m.pskBinders = append(m.pskBinders, binder)
                        }
+               case extensionEncryptedClientHello:
+                       if !extData.ReadBytes(&m.encryptedClientHello, len(extData)) {
+                               return false
+                       }
                default:
                        // Ignore unknown extensions.
                        continue
index 2c360e6a50dfc560f32211f9baaff8bb95a55f4a..e4112bfc3e747523d0a05fc03b3a57b22f1fe69b 100644 (file)
@@ -232,6 +232,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        if rand.Intn(10) > 5 {
                m.earlyData = true
        }
+       if rand.Intn(10) > 5 {
+               m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
+       }
 
        return reflect.ValueOf(m)
 }
index 6fb1755a2f426149550152acd0010b65b6411da2..6fe962869118a4e7c7c8776c47d03a50b6b191ce 100644 (file)
@@ -41,7 +41,7 @@ type serverHandshakeState struct {
 
 // serverHandshake performs a TLS handshake as a server.
 func (c *Conn) serverHandshake(ctx context.Context) error {
-       clientHello, err := c.readClientHello(ctx)
+       clientHello, ech, err := c.readClientHello(ctx)
        if err != nil {
                return err
        }
@@ -51,6 +51,7 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
                        c:           c,
                        ctx:         ctx,
                        clientHello: clientHello,
+                       echContext:  ech,
                }
                return hs.handshake()
        }
@@ -131,17 +132,27 @@ func (hs *serverHandshakeState) handshake() error {
 }
 
 // readClientHello reads a ClientHello message and selects the protocol version.
-func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
+func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, *echServerContext, error) {
        // clientHelloMsg is included in the transcript, but we haven't initialized
        // it yet. The respective handshake functions will record it themselves.
        msg, err := c.readHandshake(nil)
        if err != nil {
-               return nil, err
+               return nil, nil, err
        }
        clientHello, ok := msg.(*clientHelloMsg)
        if !ok {
                c.sendAlert(alertUnexpectedMessage)
-               return nil, unexpectedMessageError(clientHello, msg)
+               return nil, nil, unexpectedMessageError(clientHello, msg)
+       }
+
+       // ECH processing has to be done before we do any other negotiation based on
+       // the contents of the client hello, since we may swap it out completely.
+       var ech *echServerContext
+       if len(clientHello.encryptedClientHello) != 0 {
+               clientHello, ech, err = c.processECHClientHello(clientHello)
+               if err != nil {
+                       return nil, nil, err
+               }
        }
 
        var configForClient *Config
@@ -150,7 +161,7 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
                chi := clientHelloInfo(ctx, c, clientHello)
                if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
                        c.sendAlert(alertInternalError)
-                       return nil, err
+                       return nil, nil, err
                } else if configForClient != nil {
                        c.config = configForClient
                }
@@ -164,18 +175,30 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
        c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
        if !ok {
                c.sendAlert(alertProtocolVersion)
-               return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
+               return nil, nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
        }
        c.haveVers = true
        c.in.version = c.vers
        c.out.version = c.vers
 
+       // This check reflects some odd specification implied behavior. Client-facing servers
+       // are supposed to reject hellos with outer ECH and inner ECH that offers 1.2, but
+       // backend servers are allowed to accept hellos with inner ECH that offer 1.2, since
+       // they cannot expect client-facing servers to behave properly. Since we act as both
+       // a client-facing and backend server, we only enforce 1.3 being negotiated if we
+       // saw a hello with outer ECH first. The spec probably should've made this an error,
+       // but it didn't, and this matches the boringssl behavior.
+       if c.vers != VersionTLS13 && (ech != nil && !ech.inner) {
+               c.sendAlert(alertIllegalParameter)
+               return nil, nil, errors.New("tls: Encrypted Client Hello cannot be used pre-TLS 1.3")
+       }
+
        if c.config.MinVersion == 0 && c.vers < VersionTLS12 {
                tls10server.Value() // ensure godebug is initialized
                tls10server.IncNonDefault()
        }
 
-       return clientHello, nil
+       return clientHello, ech, nil
 }
 
 func (hs *serverHandshakeState) processClientHello() error {
index 01eae15a6b98c09e22ff066b7d43e0448ee52e81..84b086f05159a7dd8f1016bd960e0a9dd83fa141 100644 (file)
@@ -54,12 +54,13 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
        }()
        ctx := context.Background()
        conn := Server(s, serverConfig)
-       ch, err := conn.readClientHello(ctx)
+       ch, ech, err := conn.readClientHello(ctx)
        if conn.vers == VersionTLS13 {
                hs := serverHandshakeStateTLS13{
                        c:           conn,
                        ctx:         ctx,
                        clientHello: ch,
+                       echContext:  ech,
                }
                if err == nil {
                        err = hs.processClientHello()
@@ -1518,7 +1519,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
        }()
        conn := Server(s, serverConfig)
        ctx := context.Background()
-       ch, err := conn.readClientHello(ctx)
+       ch, _, err := conn.readClientHello(ctx)
        hs := serverHandshakeState{
                c:           conn,
                ctx:         ctx,
index 64c6b1349cdb385f6c95fffe5a213b3e34039865..76521d8b4706dfa549f1820d03d239c9bcab2273 100644 (file)
@@ -9,8 +9,10 @@ import (
        "context"
        "crypto"
        "crypto/hmac"
+       "crypto/internal/fips140/hkdf"
        "crypto/internal/fips140/mlkem"
        "crypto/internal/fips140/tls13"
+       "crypto/internal/hpke"
        "crypto/rsa"
        "crypto/tls/internal/fips140tls"
        "errors"
@@ -26,6 +28,18 @@ import (
 // messages cause too much work in session ticket decryption attempts.
 const maxClientPSKIdentities = 5
 
+type echServerContext struct {
+       hpkeContext *hpke.Receipient
+       configID    uint8
+       ciphersuite echCipher
+       transcript  hash.Hash
+       // inner indicates that the initial client_hello we recieved contained an
+       // encrypted_client_hello extension that indicated it was an "inner" hello.
+       // We don't do any additional processing of the hello in this case, so all
+       // fields above are unset.
+       inner bool
+}
+
 type serverHandshakeStateTLS13 struct {
        c               *Conn
        ctx             context.Context
@@ -44,6 +58,7 @@ type serverHandshakeStateTLS13 struct {
        trafficSecret   []byte // client_application_traffic_secret_0
        transcript      hash.Hash
        clientFinished  []byte
+       echContext      *echServerContext
 }
 
 func (hs *serverHandshakeStateTLS13) handshake() error {
@@ -531,6 +546,22 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
                selectedGroup:     selectedGroup,
        }
 
+       if hs.echContext != nil {
+               // Compute the acceptance message.
+               helloRetryRequest.encryptedClientHello = make([]byte, 8)
+               confTranscript := cloneHash(hs.transcript, hs.suite.hash)
+               if err := transcriptMsg(helloRetryRequest, confTranscript); err != nil {
+                       return nil, err
+               }
+               acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
+                       hkdf.Extract(hs.suite.hash.New, hs.clientHello.random, nil),
+                       "hrr ech accept confirmation",
+                       confTranscript.Sum(nil),
+                       8,
+               )
+               helloRetryRequest.encryptedClientHello = acceptConfirmation
+       }
+
        if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
                return nil, err
        }
@@ -551,6 +582,45 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
                return nil, unexpectedMessageError(clientHello, msg)
        }
 
+       if hs.echContext != nil {
+               if len(clientHello.encryptedClientHello) == 0 {
+                       c.sendAlert(alertMissingExtension)
+                       return nil, errors.New("tls: second client hello missing encrypted client hello extension")
+               }
+
+               echType, echCiphersuite, configID, encap, payload, err := parseECHExt(clientHello.encryptedClientHello)
+               if err != nil {
+                       c.sendAlert(alertDecodeError)
+                       return nil, errors.New("tls: client sent invalid encrypted client hello extension")
+               }
+
+               if echType == outerECHExt && hs.echContext.inner || echType == innerECHExt && !hs.echContext.inner {
+                       c.sendAlert(alertDecodeError)
+                       return nil, errors.New("tls: unexpected switch in encrypted client hello extension type")
+               }
+
+               if echType == outerECHExt {
+                       if echCiphersuite != hs.echContext.ciphersuite || configID != hs.echContext.configID || len(encap) != 0 {
+                               c.sendAlert(alertIllegalParameter)
+                               return nil, errors.New("tls: second client hello encrypted client hello extension does not match")
+                       }
+
+                       encodedInner, err := decryptECHPayload(hs.echContext.hpkeContext, clientHello.original, payload)
+                       if err != nil {
+                               c.sendAlert(alertDecryptError)
+                               return nil, errors.New("tls: failed to decrypt second client hello encrypted client hello extension payload")
+                       }
+
+                       echInner, err := decodeInnerClientHello(clientHello, encodedInner)
+                       if err != nil {
+                               c.sendAlert(alertIllegalParameter)
+                               return nil, errors.New("tls: client sent invalid encrypted client hello extension")
+                       }
+
+                       clientHello = echInner
+               }
+       }
+
        if len(clientHello.keyShares) != 1 {
                c.sendAlert(alertIllegalParameter)
                return nil, errors.New("tls: client didn't send one key share in second ClientHello")
@@ -638,9 +708,27 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
 func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
        c := hs.c
 
+       if hs.echContext != nil {
+               copy(hs.hello.random[32-8:], make([]byte, 8))
+               echTranscript := cloneHash(hs.transcript, hs.suite.hash)
+               echTranscript.Write(hs.clientHello.original)
+               if err := transcriptMsg(hs.hello, echTranscript); err != nil {
+                       return err
+               }
+               // compute the acceptance message
+               acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
+                       hkdf.Extract(hs.suite.hash.New, hs.clientHello.random, nil),
+                       "ech accept confirmation",
+                       echTranscript.Sum(nil),
+                       8,
+               )
+               copy(hs.hello.random[32-8:], acceptConfirmation)
+       }
+
        if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
                return err
        }
+
        if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
                return err
        }
@@ -691,6 +779,16 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
                encryptedExtensions.earlyData = hs.earlyData
        }
 
+       // If client sent ECH extension, but we didn't accept it,
+       // send retry configs, if available.
+       if len(hs.c.config.EncryptedClientHelloKeys) > 0 && len(hs.clientHello.encryptedClientHello) > 0 && hs.echContext == nil {
+               encryptedExtensions.echRetryConfigs, err = buildRetryConfigList(hs.c.config.EncryptedClientHelloKeys)
+               if err != nil {
+                       c.sendAlert(alertInternalError)
+                       return err
+               }
+       }
+
        if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
                return err
        }
index a6c03361e92f7e0e4de6356c12ca6b6055661aa0..394d517bc4c75f23e325fe42b5009fb87a6e347c 100644 (file)
@@ -8,8 +8,10 @@ import (
        "bytes"
        "context"
        "crypto"
+       "crypto/ecdh"
        "crypto/ecdsa"
        "crypto/elliptic"
+       "crypto/internal/hpke"
        "crypto/rand"
        "crypto/x509"
        "crypto/x509/pkix"
@@ -29,6 +31,8 @@ import (
        "strings"
        "testing"
        "time"
+
+       "golang.org/x/crypto/cryptobyte"
 )
 
 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
@@ -880,6 +884,10 @@ func TestCloneNonFuncFields(t *testing.T) {
                        f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
                case "EncryptedClientHelloConfigList":
                        f.Set(reflect.ValueOf([]byte{'x'}))
+               case "EncryptedClientHelloKeys":
+                       f.Set(reflect.ValueOf([]EncryptedClientHelloKey{
+                               {Config: []byte{1}, PrivateKey: []byte{1}},
+                       }))
                case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
                        continue // these are unexported fields that are handled separately
                default:
@@ -2072,6 +2080,120 @@ func TestLargeCertMsg(t *testing.T) {
                },
        }
        if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
-               t.Fatalf("unexpected failure :%s", err)
+               t.Fatalf("unexpected failure: %s", err)
+       }
+}
+
+func TestECH(t *testing.T) {
+       k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+       if err != nil {
+               t.Fatal(err)
+       }
+       tmpl := &x509.Certificate{
+               SerialNumber: big.NewInt(1),
+               DNSNames:     []string{"public.example"},
+               NotBefore:    time.Now().Add(-time.Hour),
+               NotAfter:     time.Now().Add(time.Hour),
+       }
+       publicCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
+       if err != nil {
+               t.Fatal(err)
+       }
+       publicCert, err := x509.ParseCertificate(publicCertDER)
+       if err != nil {
+               t.Fatal(err)
+       }
+       tmpl.DNSNames[0] = "secret.example"
+       secretCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
+       if err != nil {
+               t.Fatal(err)
+       }
+       secretCert, err := x509.ParseCertificate(secretCertDER)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       marshalECHConfig := func(id uint8, pubKey []byte, publicName string, maxNameLen uint8) []byte {
+               builder := cryptobyte.NewBuilder(nil)
+               builder.AddUint16(extensionEncryptedClientHello)
+               builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
+                       builder.AddUint8(id)
+                       builder.AddUint16(hpke.DHKEM_X25519_HKDF_SHA256) // The only DHKEM we support
+                       builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
+                               builder.AddBytes(pubKey)
+                       })
+                       builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
+                               for _, aeadID := range sortedSupportedAEADs {
+                                       builder.AddUint16(hpke.KDF_HKDF_SHA256) // The only KDF we support
+                                       builder.AddUint16(aeadID)
+                               }
+                       })
+                       builder.AddUint8(maxNameLen)
+                       builder.AddUint8LengthPrefixed(func(builder *cryptobyte.Builder) {
+                               builder.AddBytes([]byte(publicName))
+                       })
+                       builder.AddUint16(0) // extensions
+               })
+
+               return builder.BytesOrPanic()
+       }
+
+       echKey, err := ecdh.X25519().GenerateKey(rand.Reader)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       echConfig := marshalECHConfig(123, echKey.PublicKey().Bytes(), "public.example", 32)
+
+       builder := cryptobyte.NewBuilder(nil)
+       builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
+               builder.AddBytes(echConfig)
+       })
+       echConfigList := builder.BytesOrPanic()
+
+       clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
+       clientConfig.InsecureSkipVerify = false
+       clientConfig.Rand = rand.Reader
+       clientConfig.Time = nil
+       clientConfig.MinVersion = VersionTLS13
+       clientConfig.ServerName = "secret.example"
+       clientConfig.RootCAs = x509.NewCertPool()
+       clientConfig.RootCAs.AddCert(secretCert)
+       clientConfig.RootCAs.AddCert(publicCert)
+       clientConfig.EncryptedClientHelloConfigList = echConfigList
+       serverConfig.InsecureSkipVerify = false
+       serverConfig.Rand = rand.Reader
+       serverConfig.Time = nil
+       serverConfig.MinVersion = VersionTLS13
+       serverConfig.ServerName = "public.example"
+       serverConfig.Certificates = []Certificate{
+               {Certificate: [][]byte{publicCertDER}, PrivateKey: k},
+               {Certificate: [][]byte{secretCertDER}, PrivateKey: k},
+       }
+       serverConfig.EncryptedClientHelloKeys = []EncryptedClientHelloKey{
+               {Config: echConfig, PrivateKey: echKey.Bytes(), SendAsRetry: true},
+       }
+
+       ss, cs, err := testHandshake(t, clientConfig, serverConfig)
+       if err != nil {
+               t.Fatalf("unexpected failure: %s", err)
+       }
+       if !ss.ECHAccepted {
+               t.Fatal("server ConnectionState shows ECH not accepted")
+       }
+       if !cs.ECHAccepted {
+               t.Fatal("client ConnectionState shows ECH not accepted")
+       }
+       if cs.ServerName != "secret.example" || ss.ServerName != "secret.example" {
+               t.Fatalf("unexpected ConnectionState.ServerName, want %q, got server:%q, client: %q", "secret.example", ss.ServerName, cs.ServerName)
+       }
+       if len(cs.VerifiedChains) != 1 {
+               t.Fatal("unexpect number of certificate chains")
+       }
+       if len(cs.VerifiedChains[0]) != 1 {
+               t.Fatal("unexpect number of certificates")
+       }
+       if !cs.VerifiedChains[0][0].Equal(secretCert) {
+               t.Fatal("unexpected certificate")
        }
 }