]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add SecP256r1/SecP384r1MLKEM1024 hybrid post-quantum key exchanges
authorFilippo Valsorda <filippo@golang.org>
Wed, 19 Nov 2025 16:32:42 +0000 (17:32 +0100)
committerGopher Robot <gobot@golang.org>
Wed, 26 Nov 2025 01:25:00 +0000 (17:25 -0800)
Fixes #71206

Change-Id: If3cf75261c56828b87ae6805bd2913f56a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/722140
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

16 files changed:
api/next/71206.txt [new file with mode: 0644]
doc/godebug.md
doc/next/6-stdlib/99-minor/crypto/tls/71206.md [new file with mode: 0644]
src/crypto/tls/bogo_config.json
src/crypto/tls/common.go
src/crypto/tls/common_string.go
src/crypto/tls/defaults.go
src/crypto/tls/defaults_fips140.go
src/crypto/tls/fips140_test.go
src/crypto/tls/handshake_client.go
src/crypto/tls/handshake_client_tls13.go
src/crypto/tls/handshake_server_tls13.go
src/crypto/tls/key_agreement.go
src/crypto/tls/key_schedule.go
src/crypto/tls/tls_test.go
src/internal/godebugs/table.go

diff --git a/api/next/71206.txt b/api/next/71206.txt
new file mode 100644 (file)
index 0000000..e29fe83
--- /dev/null
@@ -0,0 +1,4 @@
+pkg crypto/tls, const SecP256r1MLKEM768 = 4587 #71206
+pkg crypto/tls, const SecP256r1MLKEM768 CurveID #71206
+pkg crypto/tls, const SecP384r1MLKEM1024 = 4589 #71206
+pkg crypto/tls, const SecP384r1MLKEM1024 CurveID #71206
index 6163d134ce0d4f54e1bbb9bc0a979f2381c4f080..0d3354bc0fd1c03d1c354ad4f02d4c4052b4b1f2 100644 (file)
@@ -168,6 +168,10 @@ allows malformed hostnames containing colons outside of a bracketed IPv6 address
 The default `urlstrictcolons=1` rejects URLs such as `http://localhost:1:2` or `http://::1/`.
 Colons are permitted as part of a bracketed IPv6 address, such as `http://[::1]/`.
 
+Go 1.26 enabled two additional post-quantum key exchange mechanisms:
+SecP256r1MLKEM768 and SecP384r1MLKEM1024. The default can be reverted using the
+[`tlssecpmlkem` setting](/pkg/crypto/tls/#Config.CurvePreferences).
+
 Go 1.26 added a new `tracebacklabels` setting that controls the inclusion of
 goroutine labels set through the the `runtime/pprof` package. Setting `tracebacklabels=1`
 includes these key/value pairs in the goroutine status header of runtime
diff --git a/doc/next/6-stdlib/99-minor/crypto/tls/71206.md b/doc/next/6-stdlib/99-minor/crypto/tls/71206.md
new file mode 100644 (file)
index 0000000..2caaa80
--- /dev/null
@@ -0,0 +1,3 @@
+The hybrid [SecP256r1MLKEM768] and [SecP384r1MLKEM1024] post-quantum key
+exchanges are now enabled by default. They can be disabled by setting
+[Config.CurvePreferences] or with the `tlssecpmlkem=0` GODEBUG setting.
index ed3fc6ec3d6f5b0dd6faa1e2312e21c333456c06..a4664d6e6f823e2016d7aece2909859d6c6d8e46 100644 (file)
         24,
         25,
         29,
-        4588
+        4587,
+        4588,
+        4589
     ],
     "ErrorMap": {
         ":ECH_REJECTED:": ["tls: server rejected ECH"]
index d809624b880a00ada4f54bef187d08e526461ba8..993cfaf7c06ac5920a08ebb86e6b061ffba9892e 100644 (file)
@@ -145,19 +145,31 @@ const (
 type CurveID uint16
 
 const (
-       CurveP256      CurveID = 23
-       CurveP384      CurveID = 24
-       CurveP521      CurveID = 25
-       X25519         CurveID = 29
-       X25519MLKEM768 CurveID = 4588
+       CurveP256          CurveID = 23
+       CurveP384          CurveID = 24
+       CurveP521          CurveID = 25
+       X25519             CurveID = 29
+       X25519MLKEM768     CurveID = 4588
+       SecP256r1MLKEM768  CurveID = 4587
+       SecP384r1MLKEM1024 CurveID = 4589
 )
 
 func isTLS13OnlyKeyExchange(curve CurveID) bool {
-       return curve == X25519MLKEM768
+       switch curve {
+       case X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024:
+               return true
+       default:
+               return false
+       }
 }
 
 func isPQKeyExchange(curve CurveID) bool {
-       return curve == X25519MLKEM768
+       switch curve {
+       case X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024:
+               return true
+       default:
+               return false
+       }
 }
 
 // TLS 1.3 Key Share. See RFC 8446, Section 4.2.8.
@@ -787,6 +799,11 @@ type Config struct {
        // From Go 1.24, the default includes the [X25519MLKEM768] hybrid
        // post-quantum key exchange. To disable it, set CurvePreferences explicitly
        // or use the GODEBUG=tlsmlkem=0 environment variable.
+       //
+       // From Go 1.26, the default includes the [SecP256r1MLKEM768] and
+       // [SecP256r1MLKEM768] hybrid post-quantum key exchanges, too. To disable
+       // them, set CurvePreferences explicitly or use either the
+       // GODEBUG=tlsmlkem=0 or the GODEBUG=tlssecpmlkem=0 environment variable.
        CurvePreferences []CurveID
 
        // DynamicRecordSizingDisabled disables adaptive sizing of TLS records.
index e15dd48838b5308d0d4cfd53e68ecb544a307f72..1e868e7162d3e89101451c390abcf811ec360b39 100644 (file)
@@ -72,16 +72,19 @@ func _() {
        _ = x[CurveP521-25]
        _ = x[X25519-29]
        _ = x[X25519MLKEM768-4588]
+       _ = x[SecP256r1MLKEM768-4587]
+       _ = x[SecP384r1MLKEM1024-4589]
 }
 
 const (
        _CurveID_name_0 = "CurveP256CurveP384CurveP521"
        _CurveID_name_1 = "X25519"
-       _CurveID_name_2 = "X25519MLKEM768"
+       _CurveID_name_2 = "SecP256r1MLKEM768X25519MLKEM768SecP384r1MLKEM1024"
 )
 
 var (
        _CurveID_index_0 = [...]uint8{0, 9, 18, 27}
+       _CurveID_index_2 = [...]uint8{0, 17, 31, 49}
 )
 
 func (i CurveID) String() string {
@@ -91,8 +94,9 @@ func (i CurveID) String() string {
                return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]]
        case i == 29:
                return _CurveID_name_1
-       case i == 4588:
-               return _CurveID_name_2
+       case 4587 <= i && i <= 4589:
+               i -= 4587
+               return _CurveID_name_2[_CurveID_index_2[i]:_CurveID_index_2[i+1]]
        default:
                return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")"
        }
index 489a2750dff390823c1db1edf712502773cde28e..8de8d7e0934b07ed68c2b513fde7e6f263a6ad13 100644 (file)
@@ -14,14 +14,24 @@ import (
 // them to apply local policies.
 
 var tlsmlkem = godebug.New("tlsmlkem")
+var tlssecpmlkem = godebug.New("tlssecpmlkem")
 
 // defaultCurvePreferences is the default set of supported key exchanges, as
 // well as the preference order.
 func defaultCurvePreferences() []CurveID {
-       if tlsmlkem.Value() == "0" {
+       switch {
+       // tlsmlkem=0 restores the pre-Go 1.24 default.
+       case tlsmlkem.Value() == "0":
                return []CurveID{X25519, CurveP256, CurveP384, CurveP521}
+       // tlssecpmlkem=0 restores the pre-Go 1.26 default.
+       case tlssecpmlkem.Value() == "0":
+               return []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521}
+       default:
+               return []CurveID{
+                       X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024,
+                       X25519, CurveP256, CurveP384, CurveP521,
+               }
        }
-       return []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521}
 }
 
 // defaultSupportedSignatureAlgorithms returns the signature and hash algorithms that
index 00176795eba6b727ee159686e0c63f199a69b757..19132607938a26d5ea9d29395d1cb86c714aab6f 100644 (file)
@@ -32,6 +32,8 @@ var (
        }
        allowedCurvePreferencesFIPS = []CurveID{
                X25519MLKEM768,
+               SecP256r1MLKEM768,
+               SecP384r1MLKEM1024,
                CurveP256,
                CurveP384,
                CurveP521,
index 291a19f44cdaee37bff25b5e84a23918b5e4027b..96273c0fe0ea2fbaa4b996c6a57c6faf2daa24a8 100644 (file)
@@ -43,11 +43,15 @@ func isTLS13CipherSuite(id uint16) bool {
 }
 
 func generateKeyShare(group CurveID) keyShare {
-       key, err := generateECDHEKey(rand.Reader, group)
+       ke, err := keyExchangeForCurveID(group)
        if err != nil {
                panic(err)
        }
-       return keyShare{group: group, data: key.PublicKey().Bytes()}
+       _, shares, err := ke.keyShares(rand.Reader)
+       if err != nil {
+               panic(err)
+       }
+       return shares[0]
 }
 
 func TestFIPSServerProtocolVersion(t *testing.T) {
@@ -132,7 +136,7 @@ func isFIPSCurve(id CurveID) bool {
        switch id {
        case CurveP256, CurveP384, CurveP521:
                return true
-       case X25519MLKEM768:
+       case X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024:
                // Only for the native module.
                return !boring.Enabled
        case X25519:
index c739544db678e7f54ba2b538de2f3acc864e4f23..e1ddcb3f10689a724bf4777ea3be28248174dc9d 100644 (file)
@@ -11,7 +11,6 @@ import (
        "crypto/ecdsa"
        "crypto/ed25519"
        "crypto/hpke"
-       "crypto/internal/fips140/mlkem"
        "crypto/internal/fips140/tls13"
        "crypto/rsa"
        "crypto/subtle"
@@ -142,43 +141,21 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli
                if len(hello.supportedCurves) == 0 {
                        return nil, nil, nil, errors.New("tls: no supported elliptic curves for ECDHE")
                }
+               // Since the order is fixed, the first one is always the one to send a
+               // key share for. All the PQ hybrids sort first, and produce a fallback
+               // ECDH share.
                curveID := hello.supportedCurves[0]
-               keyShareKeys = &keySharePrivateKeys{curveID: curveID}
-               // Note that if X25519MLKEM768 is supported, it will be first because
-               // the preference order is fixed.
-               if curveID == X25519MLKEM768 {
-                       keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), X25519)
-                       if err != nil {
-                               return nil, nil, nil, err
-                       }
-                       seed := make([]byte, mlkem.SeedSize)
-                       if _, err := io.ReadFull(config.rand(), seed); err != nil {
-                               return nil, nil, nil, err
-                       }
-                       keyShareKeys.mlkem, err = mlkem.NewDecapsulationKey768(seed)
-                       if err != nil {
-                               return nil, nil, nil, err
-                       }
-                       mlkemEncapsulationKey := keyShareKeys.mlkem.EncapsulationKey().Bytes()
-                       x25519EphemeralKey := keyShareKeys.ecdhe.PublicKey().Bytes()
-                       hello.keyShares = []keyShare{
-                               {group: X25519MLKEM768, data: append(mlkemEncapsulationKey, x25519EphemeralKey...)},
-                       }
-                       // If both X25519MLKEM768 and X25519 are supported, we send both key
-                       // shares (as a fallback) and we reuse the same X25519 ephemeral
-                       // key, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
-                       if slices.Contains(hello.supportedCurves, X25519) {
-                               hello.keyShares = append(hello.keyShares, keyShare{group: X25519, data: x25519EphemeralKey})
-                       }
-               } else {
-                       if _, ok := curveForCurveID(curveID); !ok {
-                               return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
-                       }
-                       keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), curveID)
-                       if err != nil {
-                               return nil, nil, nil, err
-                       }
-                       hello.keyShares = []keyShare{{group: curveID, data: keyShareKeys.ecdhe.PublicKey().Bytes()}}
+               ke, err := keyExchangeForCurveID(curveID)
+               if err != nil {
+                       return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
+               }
+               keyShareKeys, hello.keyShares, err = ke.keyShares(config.rand())
+               if err != nil {
+                       return nil, nil, nil, err
+               }
+               // Only send the fallback ECDH share if the corresponding CurveID is enabled.
+               if len(hello.keyShares) == 2 && !slices.Contains(hello.supportedCurves, hello.keyShares[1].group) {
+                       hello.keyShares = hello.keyShares[:1]
                }
        }
 
index 7018bb2336b8f3f01cab57c14ad3abc23c64b590..2912d97f75e6118549abc8973919906fed6fc935 100644 (file)
@@ -10,7 +10,6 @@ import (
        "crypto"
        "crypto/hkdf"
        "crypto/hmac"
-       "crypto/internal/fips140/mlkem"
        "crypto/internal/fips140/tls13"
        "crypto/rsa"
        "crypto/subtle"
@@ -320,22 +319,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
                        c.sendAlert(alertIllegalParameter)
                        return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
                }
-               // Note: we don't support selecting X25519MLKEM768 in a HRR, because it
-               // is currently first in preference order, so if it's enabled we'll
-               // always send a key share for it.
-               //
-               // This will have to change once we support multiple hybrid KEMs.
-               if _, ok := curveForCurveID(curveID); !ok {
+               ke, err := keyExchangeForCurveID(curveID)
+               if err != nil {
                        c.sendAlert(alertInternalError)
                        return errors.New("tls: CurvePreferences includes unsupported curve")
                }
-               key, err := generateECDHEKey(c.config.rand(), curveID)
+               hs.keyShareKeys, hello.keyShares, err = ke.keyShares(c.config.rand())
                if err != nil {
                        c.sendAlert(alertInternalError)
                        return err
                }
-               hs.keyShareKeys = &keySharePrivateKeys{curveID: curveID, ecdhe: key}
-               hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
+               // Do not send the fallback ECDH key share in a HRR response.
+               hello.keyShares = hello.keyShares[:1]
        }
 
        if len(hello.pskIdentities) > 0 {
@@ -475,36 +470,16 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
 func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
        c := hs.c
 
-       ecdhePeerData := hs.serverHello.serverShare.data
-       if hs.serverHello.serverShare.group == X25519MLKEM768 {
-               if len(ecdhePeerData) != mlkem.CiphertextSize768+x25519PublicKeySize {
-                       c.sendAlert(alertIllegalParameter)
-                       return errors.New("tls: invalid server X25519MLKEM768 key share")
-               }
-               ecdhePeerData = hs.serverHello.serverShare.data[mlkem.CiphertextSize768:]
-       }
-       peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData)
+       ke, err := keyExchangeForCurveID(hs.serverHello.serverShare.group)
        if err != nil {
-               c.sendAlert(alertIllegalParameter)
-               return errors.New("tls: invalid server key share")
+               c.sendAlert(alertInternalError)
+               return err
        }
-       sharedKey, err := hs.keyShareKeys.ecdhe.ECDH(peerKey)
+       sharedKey, err := ke.clientSharedSecret(hs.keyShareKeys, hs.serverHello.serverShare.data)
        if err != nil {
                c.sendAlert(alertIllegalParameter)
                return errors.New("tls: invalid server key share")
        }
-       if hs.serverHello.serverShare.group == X25519MLKEM768 {
-               if hs.keyShareKeys.mlkem == nil {
-                       return c.sendAlert(alertInternalError)
-               }
-               ciphertext := hs.serverHello.serverShare.data[:mlkem.CiphertextSize768]
-               mlkemShared, err := hs.keyShareKeys.mlkem.Decapsulate(ciphertext)
-               if err != nil {
-                       c.sendAlert(alertIllegalParameter)
-                       return errors.New("tls: invalid X25519MLKEM768 server key share")
-               }
-               sharedKey = append(mlkemShared, sharedKey...)
-       }
        c.curveID = hs.serverHello.serverShare.group
 
        earlySecret := hs.earlySecret
index c227371aceb9426f5484712c5f7e007ffb4b4b1b..1182307936c2d4722afd15bba0e1be3dc950cd1e 100644 (file)
@@ -11,7 +11,6 @@ import (
        "crypto/hkdf"
        "crypto/hmac"
        "crypto/hpke"
-       "crypto/internal/fips140/mlkem"
        "crypto/internal/fips140/tls13"
        "crypto/rsa"
        "crypto/tls/internal/fips140tls"
@@ -246,55 +245,16 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
        }
        c.curveID = selectedGroup
 
-       ecdhGroup := selectedGroup
-       ecdhData := clientKeyShare.data
-       if selectedGroup == X25519MLKEM768 {
-               ecdhGroup = X25519
-               if len(ecdhData) != mlkem.EncapsulationKeySize768+x25519PublicKeySize {
-                       c.sendAlert(alertIllegalParameter)
-                       return errors.New("tls: invalid X25519MLKEM768 client key share")
-               }
-               ecdhData = ecdhData[mlkem.EncapsulationKeySize768:]
-       }
-       if _, ok := curveForCurveID(ecdhGroup); !ok {
-               c.sendAlert(alertInternalError)
-               return errors.New("tls: CurvePreferences includes unsupported curve")
-       }
-       key, err := generateECDHEKey(c.config.rand(), ecdhGroup)
+       ke, err := keyExchangeForCurveID(selectedGroup)
        if err != nil {
                c.sendAlert(alertInternalError)
-               return err
-       }
-       hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
-       peerKey, err := key.Curve().NewPublicKey(ecdhData)
-       if err != nil {
-               c.sendAlert(alertIllegalParameter)
-               return errors.New("tls: invalid client key share")
+               return errors.New("tls: CurvePreferences includes unsupported curve")
        }
-       hs.sharedKey, err = key.ECDH(peerKey)
+       hs.sharedKey, hs.hello.serverShare, err = ke.serverSharedSecret(c.config.rand(), clientKeyShare.data)
        if err != nil {
                c.sendAlert(alertIllegalParameter)
                return errors.New("tls: invalid client key share")
        }
-       if selectedGroup == X25519MLKEM768 {
-               k, err := mlkem.NewEncapsulationKey768(clientKeyShare.data[:mlkem.EncapsulationKeySize768])
-               if err != nil {
-                       c.sendAlert(alertIllegalParameter)
-                       return errors.New("tls: invalid X25519MLKEM768 client key share")
-               }
-               mlkemSharedSecret, ciphertext := k.Encapsulate()
-               // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For
-               // X25519MLKEM768, the shared secret is the concatenation of the ML-KEM
-               // shared secret and the X25519 shared secret. The shared secret is 64
-               // bytes (32 bytes for each part)."
-               hs.sharedKey = append(mlkemSharedSecret, hs.sharedKey...)
-               // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.2: "When the
-               // X25519MLKEM768 group is negotiated, the server's key exchange value
-               // is the concatenation of an ML-KEM ciphertext returned from
-               // encapsulation to the client's encapsulation key, and the server's
-               // ephemeral X25519 share."
-               hs.hello.serverShare.data = append(ciphertext, hs.hello.serverShare.data...)
-       }
 
        selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
        if err != nil {
index 88116f941e0b9a3312bf6c19937172b4071e27db..26f7bd2c520176431e286f5a67fff1d70aebc5ad 100644 (file)
@@ -159,17 +159,17 @@ func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint1
 type ecdheKeyAgreement struct {
        version uint16
        isRSA   bool
-       key     *ecdh.PrivateKey
 
        // ckx and preMasterSecret are generated in processServerKeyExchange
        // and returned in generateClientKeyExchange.
        ckx             *clientKeyExchangeMsg
        preMasterSecret []byte
 
-       // curveID and signatureAlgorithm are set by processServerKeyExchange and
-       // generateServerKeyExchange.
+       // curveID, signatureAlgorithm, and key are set by processServerKeyExchange
+       // and generateServerKeyExchange.
        curveID            CurveID
        signatureAlgorithm SignatureScheme
+       key                *ecdh.PrivateKey
 }
 
 func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
@@ -380,3 +380,29 @@ func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHel
 
        return ka.preMasterSecret, ka.ckx, nil
 }
+
+// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
+// according to RFC 8446, Section 4.2.8.2.
+func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
+       curve, ok := curveForCurveID(curveID)
+       if !ok {
+               return nil, errors.New("tls: internal error: unsupported curve")
+       }
+
+       return curve.GenerateKey(rand)
+}
+
+func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
+       switch id {
+       case X25519:
+               return ecdh.X25519(), true
+       case CurveP256:
+               return ecdh.P256(), true
+       case CurveP384:
+               return ecdh.P384(), true
+       case CurveP521:
+               return ecdh.P521(), true
+       default:
+               return nil, false
+       }
+}
index 1426a276bf2de0c0c3a8f892d453fd99ff757a2d..bfa22449c8717846f3a47f88ca46dbcf7bb05c5f 100644 (file)
@@ -5,10 +5,11 @@
 package tls
 
 import (
+       "crypto"
        "crypto/ecdh"
        "crypto/hmac"
-       "crypto/internal/fips140/mlkem"
        "crypto/internal/fips140/tls13"
+       "crypto/mlkem"
        "errors"
        "hash"
        "io"
@@ -50,35 +51,202 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcrip
 }
 
 type keySharePrivateKeys struct {
-       curveID CurveID
-       ecdhe   *ecdh.PrivateKey
-       mlkem   *mlkem.DecapsulationKey768
+       ecdhe *ecdh.PrivateKey
+       mlkem crypto.Decapsulator
 }
 
-const x25519PublicKeySize = 32
+// A keyExchange implements a TLS 1.3 KEM.
+type keyExchange interface {
+       // keyShares generates one or two key shares.
+       //
+       // The first one will match the id, the second (if present) reuses the
+       // traditional component of the requested hybrid, as allowed by
+       // draft-ietf-tls-hybrid-design-09, Section 3.2.
+       keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error)
 
-// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
-// according to RFC 8446, Section 4.2.8.2.
-func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
-       curve, ok := curveForCurveID(curveID)
-       if !ok {
-               return nil, errors.New("tls: internal error: unsupported curve")
-       }
+       // serverSharedSecret computes the shared secret and the server's key share.
+       serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error)
 
-       return curve.GenerateKey(rand)
+       // clientSharedSecret computes the shared secret given the server's key
+       // share and the keys generated by keyShares.
+       clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error)
 }
 
-func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
+func keyExchangeForCurveID(id CurveID) (keyExchange, error) {
+       newMLKEMPrivateKey768 := func(b []byte) (crypto.Decapsulator, error) {
+               return mlkem.NewDecapsulationKey768(b)
+       }
+       newMLKEMPrivateKey1024 := func(b []byte) (crypto.Decapsulator, error) {
+               return mlkem.NewDecapsulationKey1024(b)
+       }
+       newMLKEMPublicKey768 := func(b []byte) (crypto.Encapsulator, error) {
+               return mlkem.NewEncapsulationKey768(b)
+       }
+       newMLKEMPublicKey1024 := func(b []byte) (crypto.Encapsulator, error) {
+               return mlkem.NewEncapsulationKey1024(b)
+       }
        switch id {
        case X25519:
-               return ecdh.X25519(), true
+               return &ecdhKeyExchange{id, ecdh.X25519()}, nil
        case CurveP256:
-               return ecdh.P256(), true
+               return &ecdhKeyExchange{id, ecdh.P256()}, nil
        case CurveP384:
-               return ecdh.P384(), true
+               return &ecdhKeyExchange{id, ecdh.P384()}, nil
        case CurveP521:
-               return ecdh.P521(), true
+               return &ecdhKeyExchange{id, ecdh.P521()}, nil
+       case X25519MLKEM768:
+               return &hybridKeyExchange{id, ecdhKeyExchange{X25519, ecdh.X25519()},
+                       32, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768,
+                       newMLKEMPrivateKey768, newMLKEMPublicKey768}, nil
+       case SecP256r1MLKEM768:
+               return &hybridKeyExchange{id, ecdhKeyExchange{CurveP256, ecdh.P256()},
+                       65, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768,
+                       newMLKEMPrivateKey768, newMLKEMPublicKey768}, nil
+       case SecP384r1MLKEM1024:
+               return &hybridKeyExchange{id, ecdhKeyExchange{CurveP384, ecdh.P384()},
+                       97, mlkem.EncapsulationKeySize1024, mlkem.CiphertextSize1024,
+                       newMLKEMPrivateKey1024, newMLKEMPublicKey1024}, nil
        default:
-               return nil, false
+               return nil, errors.New("tls: unsupported key exchange")
+       }
+}
+
+type ecdhKeyExchange struct {
+       id    CurveID
+       curve ecdh.Curve
+}
+
+func (ke *ecdhKeyExchange) keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error) {
+       priv, err := ke.curve.GenerateKey(rand)
+       if err != nil {
+               return nil, nil, err
+       }
+       return &keySharePrivateKeys{ecdhe: priv}, []keyShare{{ke.id, priv.PublicKey().Bytes()}}, nil
+}
+
+func (ke *ecdhKeyExchange) serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error) {
+       key, err := ke.curve.GenerateKey(rand)
+       if err != nil {
+               return nil, keyShare{}, err
+       }
+       peerKey, err := ke.curve.NewPublicKey(clientKeyShare)
+       if err != nil {
+               return nil, keyShare{}, err
+       }
+       sharedKey, err := key.ECDH(peerKey)
+       if err != nil {
+               return nil, keyShare{}, err
+       }
+       return sharedKey, keyShare{ke.id, key.PublicKey().Bytes()}, nil
+}
+
+func (ke *ecdhKeyExchange) clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error) {
+       peerKey, err := ke.curve.NewPublicKey(serverKeyShare)
+       if err != nil {
+               return nil, err
+       }
+       sharedKey, err := priv.ecdhe.ECDH(peerKey)
+       if err != nil {
+               return nil, err
+       }
+       return sharedKey, nil
+}
+
+type hybridKeyExchange struct {
+       id   CurveID
+       ecdh ecdhKeyExchange
+
+       ecdhElementSize     int
+       mlkemPublicKeySize  int
+       mlkemCiphertextSize int
+
+       newMLKEMPrivateKey func([]byte) (crypto.Decapsulator, error)
+       newMLKEMPublicKey  func([]byte) (crypto.Encapsulator, error)
+}
+
+func (ke *hybridKeyExchange) keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error) {
+       priv, ecdhShares, err := ke.ecdh.keyShares(rand)
+       if err != nil {
+               return nil, nil, err
+       }
+       seed := make([]byte, mlkem.SeedSize)
+       if _, err := io.ReadFull(rand, seed); err != nil {
+               return nil, nil, err
+       }
+       priv.mlkem, err = ke.newMLKEMPrivateKey(seed)
+       if err != nil {
+               return nil, nil, err
+       }
+       var shareData []byte
+       // For X25519MLKEM768, the ML-KEM-768 encapsulation key comes first.
+       // For SecP256r1MLKEM768 and SecP384r1MLKEM1024, the ECDH share comes first.
+       // See draft-ietf-tls-ecdhe-mlkem-02, Section 4.1.
+       if ke.id == X25519MLKEM768 {
+               shareData = append(priv.mlkem.Encapsulator().Bytes(), ecdhShares[0].data...)
+       } else {
+               shareData = append(ecdhShares[0].data, priv.mlkem.Encapsulator().Bytes()...)
+       }
+       return priv, []keyShare{{ke.id, shareData}, ecdhShares[0]}, nil
+}
+
+func (ke *hybridKeyExchange) serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error) {
+       if len(clientKeyShare) != ke.ecdhElementSize+ke.mlkemPublicKeySize {
+               return nil, keyShare{}, errors.New("tls: invalid client key share length for hybrid key exchange")
+       }
+       var ecdhShareData, mlkemShareData []byte
+       if ke.id == X25519MLKEM768 {
+               mlkemShareData = clientKeyShare[:ke.mlkemPublicKeySize]
+               ecdhShareData = clientKeyShare[ke.mlkemPublicKeySize:]
+       } else {
+               ecdhShareData = clientKeyShare[:ke.ecdhElementSize]
+               mlkemShareData = clientKeyShare[ke.ecdhElementSize:]
+       }
+       ecdhSharedSecret, ks, err := ke.ecdh.serverSharedSecret(rand, ecdhShareData)
+       if err != nil {
+               return nil, keyShare{}, err
+       }
+       mlkemPeerKey, err := ke.newMLKEMPublicKey(mlkemShareData)
+       if err != nil {
+               return nil, keyShare{}, err
+       }
+       mlkemSharedSecret, mlkemKeyShare := mlkemPeerKey.Encapsulate()
+       var sharedKey []byte
+       if ke.id == X25519MLKEM768 {
+               sharedKey = append(mlkemSharedSecret, ecdhSharedSecret...)
+               ks.data = append(mlkemKeyShare, ks.data...)
+       } else {
+               sharedKey = append(ecdhSharedSecret, mlkemSharedSecret...)
+               ks.data = append(ks.data, mlkemKeyShare...)
+       }
+       ks.group = ke.id
+       return sharedKey, ks, nil
+}
+
+func (ke *hybridKeyExchange) clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error) {
+       if len(serverKeyShare) != ke.ecdhElementSize+ke.mlkemCiphertextSize {
+               return nil, errors.New("tls: invalid server key share length for hybrid key exchange")
+       }
+       var ecdhShareData, mlkemShareData []byte
+       if ke.id == X25519MLKEM768 {
+               mlkemShareData = serverKeyShare[:ke.mlkemCiphertextSize]
+               ecdhShareData = serverKeyShare[ke.mlkemCiphertextSize:]
+       } else {
+               ecdhShareData = serverKeyShare[:ke.ecdhElementSize]
+               mlkemShareData = serverKeyShare[ke.ecdhElementSize:]
+       }
+       ecdhSharedSecret, err := ke.ecdh.clientSharedSecret(priv, ecdhShareData)
+       if err != nil {
+               return nil, err
+       }
+       mlkemSharedSecret, err := priv.mlkem.Decapsulate(mlkemShareData)
+       if err != nil {
+               return nil, err
+       }
+       var sharedKey []byte
+       if ke.id == X25519MLKEM768 {
+               sharedKey = append(mlkemSharedSecret, ecdhSharedSecret...)
+       } else {
+               sharedKey = append(ecdhSharedSecret, mlkemSharedSecret...)
        }
+       return sharedKey, nil
 }
index 6905f53949933fa19763577b70d0a82eb8b3ca47..af2828fd8da7c5e64be67fdccdc064c85e001fe5 100644 (file)
@@ -11,6 +11,7 @@ import (
        "crypto/ecdh"
        "crypto/ecdsa"
        "crypto/elliptic"
+       "crypto/internal/boring"
        "crypto/rand"
        "crypto/tls/internal/fips140tls"
        "crypto/x509"
@@ -1964,84 +1965,134 @@ func testVerifyCertificates(t *testing.T, version uint16) {
 }
 
 func TestHandshakeMLKEM(t *testing.T) {
-       skipFIPS(t) // No X25519MLKEM768 in FIPS
+       if boring.Enabled && fips140tls.Required() {
+               t.Skip("ML-KEM not supported in BoringCrypto FIPS mode")
+       }
+       defaultWithPQ := []CurveID{X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024,
+               X25519, CurveP256, CurveP384, CurveP521}
+       defaultWithoutPQ := []CurveID{X25519, CurveP256, CurveP384, CurveP521}
        var tests = []struct {
-               name                string
-               clientConfig        func(*Config)
-               serverConfig        func(*Config)
-               preparation         func(*testing.T)
-               expectClientSupport bool
-               expectMLKEM         bool
-               expectHRR           bool
+               name           string
+               clientConfig   func(*Config)
+               serverConfig   func(*Config)
+               preparation    func(*testing.T)
+               expectClient   []CurveID
+               expectSelected CurveID
+               expectHRR      bool
        }{
                {
-                       name:                "Default",
-                       expectClientSupport: true,
-                       expectMLKEM:         true,
-                       expectHRR:           false,
+                       name:           "Default",
+                       expectClient:   defaultWithPQ,
+                       expectSelected: X25519MLKEM768,
                },
                {
                        name: "ClientCurvePreferences",
                        clientConfig: func(config *Config) {
                                config.CurvePreferences = []CurveID{X25519}
                        },
-                       expectClientSupport: false,
+                       expectClient:   []CurveID{X25519},
+                       expectSelected: X25519,
                },
                {
                        name: "ServerCurvePreferencesX25519",
                        serverConfig: func(config *Config) {
                                config.CurvePreferences = []CurveID{X25519}
                        },
-                       expectClientSupport: true,
-                       expectMLKEM:         false,
-                       expectHRR:           false,
+                       expectClient:   defaultWithPQ,
+                       expectSelected: X25519,
                },
                {
                        name: "ServerCurvePreferencesHRR",
                        serverConfig: func(config *Config) {
                                config.CurvePreferences = []CurveID{CurveP256}
                        },
-                       expectClientSupport: true,
-                       expectMLKEM:         false,
-                       expectHRR:           true,
+                       expectClient:   defaultWithPQ,
+                       expectSelected: CurveP256,
+                       expectHRR:      true,
+               },
+               {
+                       name: "SecP256r1MLKEM768-Only",
+                       clientConfig: func(config *Config) {
+                               config.CurvePreferences = []CurveID{SecP256r1MLKEM768}
+                       },
+                       expectClient:   []CurveID{SecP256r1MLKEM768},
+                       expectSelected: SecP256r1MLKEM768,
+               },
+               {
+                       name: "SecP256r1MLKEM768-HRR",
+                       serverConfig: func(config *Config) {
+                               config.CurvePreferences = []CurveID{SecP256r1MLKEM768, CurveP256}
+                       },
+                       expectClient:   defaultWithPQ,
+                       expectSelected: SecP256r1MLKEM768,
+                       expectHRR:      true,
+               },
+               {
+                       name: "SecP384r1MLKEM1024",
+                       clientConfig: func(config *Config) {
+                               config.CurvePreferences = []CurveID{SecP384r1MLKEM1024, CurveP384}
+                       },
+                       expectClient:   []CurveID{SecP384r1MLKEM1024, CurveP384},
+                       expectSelected: SecP384r1MLKEM1024,
+               },
+               {
+                       name: "CurveP256NoHRR",
+                       clientConfig: func(config *Config) {
+                               config.CurvePreferences = []CurveID{SecP256r1MLKEM768, CurveP256}
+                       },
+                       serverConfig: func(config *Config) {
+                               config.CurvePreferences = []CurveID{CurveP256}
+                       },
+                       expectClient:   []CurveID{SecP256r1MLKEM768, CurveP256},
+                       expectSelected: CurveP256,
                },
                {
                        name: "ClientMLKEMOnly",
                        clientConfig: func(config *Config) {
                                config.CurvePreferences = []CurveID{X25519MLKEM768}
                        },
-                       expectClientSupport: true,
-                       expectMLKEM:         true,
+                       expectClient:   []CurveID{X25519MLKEM768},
+                       expectSelected: X25519MLKEM768,
                },
                {
                        name: "ClientSortedCurvePreferences",
                        clientConfig: func(config *Config) {
                                config.CurvePreferences = []CurveID{CurveP256, X25519MLKEM768}
                        },
-                       expectClientSupport: true,
-                       expectMLKEM:         true,
+                       expectClient:   []CurveID{X25519MLKEM768, CurveP256},
+                       expectSelected: X25519MLKEM768,
                },
                {
                        name: "ClientTLSv12",
                        clientConfig: func(config *Config) {
                                config.MaxVersion = VersionTLS12
                        },
-                       expectClientSupport: false,
+                       expectClient:   defaultWithoutPQ,
+                       expectSelected: X25519,
                },
                {
                        name: "ServerTLSv12",
                        serverConfig: func(config *Config) {
                                config.MaxVersion = VersionTLS12
                        },
-                       expectClientSupport: true,
-                       expectMLKEM:         false,
+                       expectClient:   defaultWithPQ,
+                       expectSelected: X25519,
                },
                {
-                       name: "GODEBUG",
+                       name: "GODEBUG tlsmlkem=0",
                        preparation: func(t *testing.T) {
                                t.Setenv("GODEBUG", "tlsmlkem=0")
                        },
-                       expectClientSupport: false,
+                       expectClient:   defaultWithoutPQ,
+                       expectSelected: X25519,
+               },
+               {
+                       name: "GODEBUG tlssecpmlkem=0",
+                       preparation: func(t *testing.T) {
+                               t.Setenv("GODEBUG", "tlssecpmlkem=0")
+                       },
+                       expectClient:   []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521},
+                       expectSelected: X25519MLKEM768,
                },
        }
 
@@ -2049,6 +2100,9 @@ func TestHandshakeMLKEM(t *testing.T) {
        baseConfig.CurvePreferences = nil
        for _, test := range tests {
                t.Run(test.name, func(t *testing.T) {
+                       if fips140tls.Required() && test.expectSelected == X25519 {
+                               t.Skip("X25519 not supported in FIPS mode")
+                       }
                        if test.preparation != nil {
                                test.preparation(t)
                        } else {
@@ -2059,10 +2113,12 @@ func TestHandshakeMLKEM(t *testing.T) {
                                test.serverConfig(serverConfig)
                        }
                        serverConfig.GetConfigForClient = func(hello *ClientHelloInfo) (*Config, error) {
-                               if !test.expectClientSupport && slices.Contains(hello.SupportedCurves, X25519MLKEM768) {
-                                       return nil, errors.New("client supports X25519MLKEM768")
-                               } else if test.expectClientSupport && !slices.Contains(hello.SupportedCurves, X25519MLKEM768) {
-                                       return nil, errors.New("client does not support X25519MLKEM768")
+                               expectClient := slices.Clone(test.expectClient)
+                               expectClient = slices.DeleteFunc(expectClient, func(c CurveID) bool {
+                                       return fips140tls.Required() && c == X25519
+                               })
+                               if !slices.Equal(hello.SupportedCurves, expectClient) {
+                                       t.Errorf("got client curves %v, expected %v", hello.SupportedCurves, expectClient)
                                }
                                return nil, nil
                        }
@@ -2074,20 +2130,11 @@ func TestHandshakeMLKEM(t *testing.T) {
                        if err != nil {
                                t.Fatal(err)
                        }
-                       if test.expectMLKEM {
-                               if ss.CurveID != X25519MLKEM768 {
-                                       t.Errorf("got CurveID %v (server), expected %v", ss.CurveID, X25519MLKEM768)
-                               }
-                               if cs.CurveID != X25519MLKEM768 {
-                                       t.Errorf("got CurveID %v (client), expected %v", cs.CurveID, X25519MLKEM768)
-                               }
-                       } else {
-                               if ss.CurveID == X25519MLKEM768 {
-                                       t.Errorf("got CurveID %v (server), expected not X25519MLKEM768", ss.CurveID)
-                               }
-                               if cs.CurveID == X25519MLKEM768 {
-                                       t.Errorf("got CurveID %v (client), expected not X25519MLKEM768", cs.CurveID)
-                               }
+                       if ss.CurveID != test.expectSelected {
+                               t.Errorf("server selected curve %v, expected %v", ss.CurveID, test.expectSelected)
+                       }
+                       if cs.CurveID != test.expectSelected {
+                               t.Errorf("client selected curve %v, expected %v", cs.CurveID, test.expectSelected)
                        }
                        if test.expectHRR {
                                if !ss.HelloRetryRequest {
index 4939e6ff10981d18517fceee238542c7db14a920..f707fc34f2feda7971c7a6ff9f6b82c51ba8dd13 100644 (file)
@@ -64,6 +64,7 @@ var All = []Info{
        {Name: "tlsmaxrsasize", Package: "crypto/tls"},
        {Name: "tlsmlkem", Package: "crypto/tls", Changed: 24, Old: "0", Opaque: true},
        {Name: "tlsrsakex", Package: "crypto/tls", Changed: 22, Old: "1"},
+       {Name: "tlssecpmlkem", Package: "crypto/tls", Changed: 26, Old: "0", Opaque: true},
        {Name: "tlssha1", Package: "crypto/tls", Changed: 25, Old: "1"},
        {Name: "tlsunsafeekm", Package: "crypto/tls", Changed: 22, Old: "1"},
        {Name: "updatemaxprocs", Package: "runtime", Changed: 25, Old: "0"},