]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/mlkem768: add EncapsulationKey type
authorFilippo Valsorda <filippo@golang.org>
Mon, 21 Oct 2024 12:30:46 +0000 (14:30 +0200)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 19:25:16 +0000 (19:25 +0000)
Change-Id: I3feacb044caa15ac9bbfc11f5d90bebf8a505510
Reviewed-on: https://go-review.googlesource.com/c/go/+/621980
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Reviewed-by: Russ Cox <rsc@golang.org>
src/crypto/internal/mlkem768/mlkem768.go
src/crypto/internal/mlkem768/mlkem768_test.go
src/crypto/tls/handshake_client.go
src/crypto/tls/key_schedule.go
src/crypto/tls/key_schedule_test.go

index f152e7682ee9616042e079818c98129ef4431fc1..830841f7380293d68afda4bbe58a181704089c73 100644 (file)
@@ -73,6 +73,8 @@ type DecapsulationKey struct {
 }
 
 // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
+//
+// The decapsulation key must be kept secret.
 func (dk *DecapsulationKey) Bytes() []byte {
        var b [SeedSize]byte
        copy(b[:], dk.d[:])
@@ -82,17 +84,34 @@ func (dk *DecapsulationKey) Bytes() []byte {
 
 // EncapsulationKey returns the public encapsulation key necessary to produce
 // ciphertexts.
-func (dk *DecapsulationKey) EncapsulationKey() []byte {
+func (dk *DecapsulationKey) EncapsulationKey() *EncapsulationKey {
+       return &EncapsulationKey{
+               ρ:             dk.ρ,
+               h:             dk.h,
+               encryptionKey: dk.encryptionKey,
+       }
+}
+
+// An EncapsulationKey is the public key used to produce ciphertexts to be
+// decapsulated by the corresponding [DecapsulationKey].
+type EncapsulationKey struct {
+       ρ [32]byte // sampleNTT seed for A
+       h [32]byte // H(ek)
+       encryptionKey
+}
+
+// Bytes returns the encapsulation key as a byte slice.
+func (ek *EncapsulationKey) Bytes() []byte {
        // The actual logic is in a separate function to outline this allocation.
        b := make([]byte, 0, EncapsulationKeySize)
-       return dk.encapsulationKey(b)
+       return ek.bytes(b)
 }
 
-func (dk *DecapsulationKey) encapsulationKey(b []byte) []byte {
-       for i := range dk.t {
-               b = polyByteEncode(b, dk.t[i])
+func (ek *EncapsulationKey) bytes(b []byte) []byte {
+       for i := range ek.t {
+               b = polyByteEncode(b, ek.t[i])
        }
-       b = append(b, dk.ρ[:]...)
+       b = append(b, ek.ρ[:]...)
        return b
 }
 
@@ -123,9 +142,9 @@ func generateKey(dk *DecapsulationKey) *DecapsulationKey {
        return kemKeyGen(dk, &d, &z)
 }
 
-// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte
+// NewDecapsulationKey parses a decapsulation key from a 64-byte
 // seed in the "d || z" form. The seed must be uniformly random.
-func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) {
+func NewDecapsulationKey(seed []byte) (*DecapsulationKey, error) {
        // The actual logic is in a separate function to outline this allocation.
        dk := &DecapsulationKey{}
        return newKeyFromSeed(dk, seed)
@@ -187,7 +206,7 @@ func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey {
        }
 
        H := sha3.New256()
-       ek := dk.EncapsulationKey()
+       ek := dk.EncapsulationKey().Bytes()
        H.Write(ek)
        H.Sum(dk.h[:0])
 
@@ -196,74 +215,75 @@ func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey {
 
 // Encapsulate generates a shared key and an associated ciphertext from an
 // encapsulation key, drawing random bytes from crypto/rand.
-// If the encapsulation key is not valid, Encapsulate returns an error.
 //
 // The shared key must be kept secret.
-func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
+func (ek *EncapsulationKey) Encapsulate() (ciphertext, sharedKey []byte) {
        // The actual logic is in a separate function to outline this allocation.
        var cc [CiphertextSize]byte
-       return encapsulate(&cc, encapsulationKey)
+       return ek.encapsulate(&cc)
 }
 
-func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
-       if len(encapsulationKey) != EncapsulationKeySize {
-               return nil, nil, errors.New("mlkem768: invalid encapsulation key length")
-       }
+func (ek *EncapsulationKey) encapsulate(cc *[CiphertextSize]byte) (ciphertext, sharedKey []byte) {
        var m [messageSize]byte
        rand.Read(m[:])
        // Note that the modulus check (step 2 of the encapsulation key check from
        // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
-       return kemEncaps(cc, encapsulationKey, &m)
+       return kemEncaps(cc, ek, &m)
 }
 
 // kemEncaps generates a shared key and an associated ciphertext.
 //
 // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
-func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) {
+func kemEncaps(cc *[CiphertextSize]byte, ek *EncapsulationKey, m *[messageSize]byte) (c, K []byte) {
        if cc == nil {
                cc = &[CiphertextSize]byte{}
        }
 
-       H := sha3.Sum256(ek[:])
        g := sha3.New512()
        g.Write(m[:])
-       g.Write(H[:])
+       g.Write(ek.h[:])
        G := g.Sum(nil)
        K, r := G[:SharedKeySize], G[SharedKeySize:]
-       var ex encryptionKey
-       if err := parseEK(&ex, ek[:]); err != nil {
-               return nil, nil, err
-       }
-       c = pkeEncrypt(cc, &ex, m, r)
-       return c, K, nil
+       c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
+       return c, K
+}
+
+// NewEncapsulationKey parses an encapsulation key from its encoded form.
+// If the encapsulation key is not valid, NewEncapsulationKey returns an error.
+func NewEncapsulationKey(encapsulationKey []byte) (*EncapsulationKey, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       ek := &EncapsulationKey{}
+       return parseEK(ek, encapsulationKey)
 }
 
 // parseEK parses an encryption key from its encoded form.
 //
 // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
 // Algorithm 14.
-func parseEK(ex *encryptionKey, ekPKE []byte) error {
+func parseEK(ek *EncapsulationKey, ekPKE []byte) (*EncapsulationKey, error) {
        if len(ekPKE) != encryptionKeySize {
-               return errors.New("mlkem768: invalid encryption key length")
+               return nil, errors.New("mlkem768: invalid encapsulation key length")
        }
 
-       for i := range ex.t {
+       ek.h = sha3.Sum256(ekPKE[:])
+
+       for i := range ek.t {
                var err error
-               ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
+               ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
                if err != nil {
-                       return err
+                       return nil, err
                }
                ekPKE = ekPKE[encodingSize12:]
        }
-       ρ := ekPKE
+       copy(ek.ρ[:], ekPKE)
 
        for i := byte(0); i < k; i++ {
                for j := byte(0); j < k; j++ {
-                       ex.a[i*k+j] = sampleNTT(ρ, j, i)
+                       ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
                }
        }
 
-       return nil
+       return ek, nil
 }
 
 // pkeEncrypt encrypt a plaintext message.
index 4775f77aeba3b84f9509a7aaa31d9b3e165f2443..295aa95d0adc35e71c97539d02a7c55bbc370eb2 100644 (file)
@@ -20,10 +20,7 @@ func TestRoundTrip(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       c, Ke, err := Encapsulate(dk.EncapsulationKey())
-       if err != nil {
-               t.Fatal(err)
-       }
+       c, Ke := dk.EncapsulationKey().Encapsulate()
        Kd, err := dk.Decapsulate(c)
        if err != nil {
                t.Fatal(err)
@@ -36,17 +33,14 @@ func TestRoundTrip(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
+       if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) {
                t.Fail()
        }
        if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
                t.Fail()
        }
 
-       c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
-       if err != nil {
-               t.Fatal(err)
-       }
+       c1, Ke1 := dk.EncapsulationKey().Encapsulate()
        if bytes.Equal(c, c1) {
                t.Fail()
        }
@@ -61,25 +55,22 @@ func TestBadLengths(t *testing.T) {
                t.Fatal(err)
        }
        ek := dk.EncapsulationKey()
+       ekBytes := dk.EncapsulationKey().Bytes()
+       c, _ := ek.Encapsulate()
 
-       for i := 0; i < len(ek)-1; i++ {
-               if _, _, err := Encapsulate(ek[:i]); err == nil {
+       for i := 0; i < len(ekBytes)-1; i++ {
+               if _, err := NewEncapsulationKey(ekBytes[:i]); err == nil {
                        t.Errorf("expected error for ek length %d", i)
                }
        }
-       ekLong := ek
+       ekLong := ekBytes
        for i := 0; i < 100; i++ {
                ekLong = append(ekLong, 0)
-               if _, _, err := Encapsulate(ekLong); err == nil {
+               if _, err := NewEncapsulationKey(ekLong); err == nil {
                        t.Errorf("expected error for ek length %d", len(ekLong))
                }
        }
 
-       c, _, err := Encapsulate(ek)
-       if err != nil {
-               t.Fatal(err)
-       }
-
        for i := 0; i < len(c)-1; i++ {
                if _, err := dk.Decapsulate(c[:i]); err == nil {
                        t.Errorf("expected error for c length %d", i)
@@ -118,18 +109,15 @@ func TestAccumulated(t *testing.T) {
 
        for i := 0; i < n; i++ {
                s.Read(seed)
-               dk, err := NewKeyFromSeed(seed)
+               dk, err := NewDecapsulationKey(seed)
                if err != nil {
                        t.Fatal(err)
                }
                ek := dk.EncapsulationKey()
-               o.Write(ek)
+               o.Write(ek.Bytes())
 
                s.Read(msg[:])
-               ct, k, err := kemEncaps(nil, ek, &msg)
-               if err != nil {
-                       t.Fatal(err)
-               }
+               ct, k := kemEncaps(nil, ek, &msg)
                o.Write(ct)
                o.Write(k)
 
@@ -165,7 +153,7 @@ func BenchmarkKeyGen(b *testing.B) {
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
                dk := kemKeyGen(&dk, &d, &z)
-               sink ^= dk.EncapsulationKey()[0]
+               sink ^= dk.EncapsulationKey().Bytes()[0]
        }
 }
 
@@ -174,18 +162,19 @@ func BenchmarkEncaps(b *testing.B) {
        rand.Read(seed)
        var m [messageSize]byte
        rand.Read(m[:])
-       dk, err := NewKeyFromSeed(seed)
+       dk, err := NewDecapsulationKey(seed)
        if err != nil {
                b.Fatal(err)
        }
-       ek := dk.EncapsulationKey()
+       ekBytes := dk.EncapsulationKey().Bytes()
        var c [CiphertextSize]byte
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
-               c, K, err := kemEncaps(&c, ek, &m)
+               ek, err := NewEncapsulationKey(ekBytes)
                if err != nil {
                        b.Fatal(err)
                }
+               c, K := kemEncaps(&c, ek, &m)
                sink ^= c[0] ^ K[0]
        }
 }
@@ -196,10 +185,7 @@ func BenchmarkDecaps(b *testing.B) {
                b.Fatal(err)
        }
        ek := dk.EncapsulationKey()
-       c, _, err := Encapsulate(ek)
-       if err != nil {
-               b.Fatal(err)
-       }
+       c, _ := ek.Encapsulate()
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
                K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
@@ -213,7 +199,8 @@ func BenchmarkRoundTrip(b *testing.B) {
                b.Fatal(err)
        }
        ek := dk.EncapsulationKey()
-       c, _, err := Encapsulate(ek)
+       ekBytes := ek.Bytes()
+       c, _ := ek.Encapsulate()
        if err != nil {
                b.Fatal(err)
        }
@@ -223,7 +210,7 @@ func BenchmarkRoundTrip(b *testing.B) {
                        if err != nil {
                                b.Fatal(err)
                        }
-                       ekS := dkS.EncapsulationKey()
+                       ekS := dkS.EncapsulationKey().Bytes()
                        sink ^= ekS[0]
 
                        Ks, err := dk.Decapsulate(c)
@@ -235,7 +222,11 @@ func BenchmarkRoundTrip(b *testing.B) {
        })
        b.Run("Bob", func(b *testing.B) {
                for i := 0; i < b.N; i++ {
-                       cS, Ks, err := Encapsulate(ek)
+                       ek, err := NewEncapsulationKey(ekBytes)
+                       if err != nil {
+                               b.Fatal(err)
+                       }
+                       cS, Ks := ek.Encapsulate()
                        if err != nil {
                                b.Fatal(err)
                        }
index f6bccc40bcb022018758960144d23de76a269b2a..8965ad6eafdb747a5bbd60e5e890e245f70088c5 100644 (file)
@@ -164,7 +164,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
                        if _, err := io.ReadFull(config.rand(), seed); err != nil {
                                return nil, nil, nil, err
                        }
-                       keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed)
+                       keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey(seed)
                        if err != nil {
                                return nil, nil, nil, err
                        }
@@ -174,7 +174,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon
                        // both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
                        hello.keyShares = []keyShare{
                                {group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(),
-                                       keyShareKeys.kyber.EncapsulationKey()...)},
+                                       keyShareKeys.kyber.EncapsulationKey().Bytes()...)},
                                {group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()},
                        }
                } else {
index e8ee9ce9c2febd5a8ca3e0791c44a48094f5b708..3bbfc1b435ffed6429c65e8b5b959e829d5fea56 100644 (file)
@@ -63,19 +63,20 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) {
        if err != nil {
                return nil, err
        }
-       return kyberSharedSecret(K, c), nil
+       return kyberSharedSecret(c, K), nil
 }
 
 // kyberEncapsulate implements encapsulation according to Kyber Round 3.
 func kyberEncapsulate(ek []byte) (c, ss []byte, err error) {
-       c, ss, err = mlkem768.Encapsulate(ek)
+       k, err := mlkem768.NewEncapsulationKey(ek)
        if err != nil {
                return nil, nil, err
        }
-       return c, kyberSharedSecret(ss, c), nil
+       c, ss = k.Encapsulate()
+       return c, kyberSharedSecret(c, ss), nil
 }
 
-func kyberSharedSecret(K, c []byte) []byte {
+func kyberSharedSecret(c, K []byte) []byte {
        // Package mlkem768 implements ML-KEM, which compared to Kyber removed a
        // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber.
        // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3.
index 095113ca179b774819722dd17b161d3d334828bd..32532770d416140e7f6ff2c48753fd1c38cdecad 100644 (file)
@@ -124,7 +124,7 @@ func TestKyberEncapsulate(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       ct, ss, err := kyberEncapsulate(dk.EncapsulationKey())
+       ct, ss, err := kyberEncapsulate(dk.EncapsulationKey().Bytes())
        if err != nil {
                t.Fatal(err)
        }