]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/mlkem: swap order of return values of Encapsulate
authorAlec Bakholdin <alecbakholdin@gmail.com>
Mon, 23 Dec 2024 01:36:59 +0000 (20:36 -0500)
committerGopher Robot <gobot@golang.org>
Thu, 26 Dec 2024 20:33:05 +0000 (12:33 -0800)
Per FIPS 203 (https://csrc.nist.gov/pubs/fips/203/final), the order of return values should be sharedKey, ciphertext. This commit simply swaps those return values and updates any consumers of the Encapsulate() method to respect the new order.

Fixes #70950

Change-Id: I2a0d605e3baf7fe69510d60d3d35bbac18f883c9
Reviewed-on: https://go-review.googlesource.com/c/go/+/638376
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Austin Clements <austin@google.com>
Auto-Submit: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
src/crypto/internal/fips140/mlkem/cast.go
src/crypto/internal/fips140/mlkem/mlkem1024.go
src/crypto/internal/fips140/mlkem/mlkem768.go
src/crypto/mlkem/mlkem1024.go
src/crypto/mlkem/mlkem768.go
src/crypto/mlkem/mlkem_test.go
src/crypto/tls/handshake_server_tls13.go

index d3ae84ec3f112284dbf349a7dd4ee1cbb73047bc..a432d1fdab0e2b8668b142b73109e70051bd8ac0 100644 (file)
@@ -40,7 +40,7 @@ func init() {
                dk := &DecapsulationKey768{}
                kemKeyGen(dk, d, z)
                ek := dk.EncapsulationKey()
-               c, Ke := ek.EncapsulateInternal(m)
+               Ke, c := ek.EncapsulateInternal(m)
                Kd, err := dk.Decapsulate(c)
                if err != nil {
                        return err
index 5aa3c69243b3465cbd6fe7a28fce3fb9c6490ebe..c924c382933c5f19bda70a8512ce78cc80a79fb8 100644 (file)
@@ -189,7 +189,7 @@ func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
 // the first operational use (if not exported before the first use)."
 func kemPCT1024(dk *DecapsulationKey1024) error {
        ek := dk.EncapsulationKey()
-       c, K := ek.Encapsulate()
+       K, c := ek.Encapsulate()
        K1, err := dk.Decapsulate(c)
        if err != nil {
                return err
@@ -204,13 +204,13 @@ func kemPCT1024(dk *DecapsulationKey1024) error {
 // encapsulation key, drawing random bytes from a DRBG.
 //
 // The shared key must be kept secret.
-func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
        // The actual logic is in a separate function to outline this allocation.
        var cc [CiphertextSize1024]byte
        return ek.encapsulate(&cc)
 }
 
-func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
        var m [messageSize]byte
        drbg.Read(m[:])
        // Note that the modulus check (step 2 of the encapsulation key check from
@@ -221,7 +221,7 @@ func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphe
 
 // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
 // use in tests.
-func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
        cc := &[CiphertextSize1024]byte{}
        return kemEncaps1024(cc, ek, m)
 }
@@ -229,14 +229,14 @@ func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (ciphertext, sh
 // kemEncaps1024 generates a shared key and an associated ciphertext.
 //
 // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
-func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) {
+func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
        g := sha3.New512()
        g.Write(m[:])
        g.Write(ek.h[:])
        G := g.Sum(nil)
        K, r := G[:SharedKeySize], G[SharedKeySize:]
        c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
-       return c, K
+       return K, c
 }
 
 // NewEncapsulationKey1024 parses an encapsulation key from its encoded form.
index 0c91ceadc4284e0ae41db94b43d8220cd85f2a1f..2c1cb5c33fcdd110ea13a894dca1a10e59418a62 100644 (file)
@@ -246,7 +246,7 @@ func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
 // the first operational use (if not exported before the first use)."
 func kemPCT(dk *DecapsulationKey768) error {
        ek := dk.EncapsulationKey()
-       c, K := ek.Encapsulate()
+       K, c := ek.Encapsulate()
        K1, err := dk.Decapsulate(c)
        if err != nil {
                return err
@@ -261,13 +261,13 @@ func kemPCT(dk *DecapsulationKey768) error {
 // encapsulation key, drawing random bytes from a DRBG.
 //
 // The shared key must be kept secret.
-func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
        // The actual logic is in a separate function to outline this allocation.
        var cc [CiphertextSize768]byte
        return ek.encapsulate(&cc)
 }
 
-func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
        var m [messageSize]byte
        drbg.Read(m[:])
        // Note that the modulus check (step 2 of the encapsulation key check from
@@ -278,7 +278,7 @@ func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphert
 
 // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
 // use in tests.
-func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
        cc := &[CiphertextSize768]byte{}
        return kemEncaps(cc, ek, m)
 }
@@ -286,14 +286,14 @@ func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sha
 // kemEncaps generates a shared key and an associated ciphertext.
 //
 // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
-func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (c, K []byte) {
+func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
        g := sha3.New512()
        g.Write(m[:])
        g.Write(ek.h[:])
        G := g.Sum(nil)
        K, r := G[:SharedKeySize], G[SharedKeySize:]
        c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
-       return c, K
+       return K, c
 }
 
 // NewEncapsulationKey768 parses an encapsulation key from its encoded form.
index 530aacf00ff7166298aa7d031699a4decfc75c38..05bad1eb2aa42c248c7caa2736bc2839589a570d 100644 (file)
@@ -91,6 +91,6 @@ func (ek *EncapsulationKey1024) Bytes() []byte {
 // encapsulation key, drawing random bytes from crypto/rand.
 //
 // The shared key must be kept secret.
-func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
        return ek.key.Encapsulate()
 }
index d6f5c94171e947ffef4b7c3f0740b9c7e9f38811..c367c551e6fc024cbc241ec636311ed06d1a8d6b 100644 (file)
@@ -101,6 +101,6 @@ func (ek *EncapsulationKey768) Bytes() []byte {
 // encapsulation key, drawing random bytes from crypto/rand.
 //
 // The shared key must be kept secret.
-func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) {
+func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
        return ek.key.Encapsulate()
 }
index ddc52dab975d3292ad12c20a0f5943e04b4761fa..207d6d48c3cb0be828fdc69946c2cf0734520045 100644 (file)
@@ -43,7 +43,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
                t.Fatal(err)
        }
        ek := dk.EncapsulationKey()
-       c, Ke := ek.Encapsulate()
+       Ke, c := ek.Encapsulate()
        Kd, err := dk.Decapsulate(c)
        if err != nil {
                t.Fatal(err)
@@ -66,7 +66,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
        if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
                t.Fail()
        }
-       c1, Ke1 := ek1.Encapsulate()
+       Ke1, c1 := ek1.Encapsulate()
        Kd1, err := dk1.Decapsulate(c1)
        if err != nil {
                t.Fatal(err)
@@ -86,7 +86,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
                t.Fail()
        }
 
-       c2, Ke2 := dk.EncapsulationKey().Encapsulate()
+       Ke2, c2 := dk.EncapsulationKey().Encapsulate()
        if bytes.Equal(c, c2) {
                t.Fail()
        }
@@ -115,7 +115,7 @@ func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
        }
        ek := dk.EncapsulationKey()
        ekBytes := dk.EncapsulationKey().Bytes()
-       c, _ := ek.Encapsulate()
+       _, c := ek.Encapsulate()
 
        for i := 0; i < len(dkBytes)-1; i++ {
                if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
@@ -189,7 +189,7 @@ func TestAccumulated(t *testing.T) {
                o.Write(ek.Bytes())
 
                s.Read(msg[:])
-               ct, k := ek.key.EncapsulateInternal(&msg)
+               k, ct := ek.key.EncapsulateInternal(&msg)
                o.Write(ct)
                o.Write(k)
 
@@ -244,7 +244,7 @@ func BenchmarkEncaps(b *testing.B) {
                if err != nil {
                        b.Fatal(err)
                }
-               c, K := ek.key.EncapsulateInternal(&m)
+               K, c := ek.key.EncapsulateInternal(&m)
                sink ^= c[0] ^ K[0]
        }
 }
@@ -255,7 +255,7 @@ func BenchmarkDecaps(b *testing.B) {
                b.Fatal(err)
        }
        ek := dk.EncapsulationKey()
-       c, _ := ek.Encapsulate()
+       _, c := ek.Encapsulate()
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
                K, _ := dk.Decapsulate(c)
@@ -270,7 +270,7 @@ func BenchmarkRoundTrip(b *testing.B) {
        }
        ek := dk.EncapsulationKey()
        ekBytes := ek.Bytes()
-       c, _ := ek.Encapsulate()
+       _, c := ek.Encapsulate()
        if err != nil {
                b.Fatal(err)
        }
@@ -296,7 +296,7 @@ func BenchmarkRoundTrip(b *testing.B) {
                        if err != nil {
                                b.Fatal(err)
                        }
-                       cS, Ks := ek.Encapsulate()
+                       Ks, cS := ek.Encapsulate()
                        if err != nil {
                                b.Fatal(err)
                        }
index 3552d89ba3bc6f85ce8c37c8140d870195eaf216..76fff6974e74030a9b28a066ea2249b992f195a1 100644 (file)
@@ -280,7 +280,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
                        c.sendAlert(alertIllegalParameter)
                        return errors.New("tls: invalid X25519MLKEM768 client key share")
                }
-               ciphertext, mlkemSharedSecret := k.Encapsulate()
+               mlkemSharedSecret, ciphertext := k.Encapsulate()
                // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For
                // X25519MLKEM768, the shared secret is the concatenation of the ML-KEM
                // shared secret and the X25519 shared secret. The shared secret is 64