From 9854fc3e86e1ba6d3a186bcadec814e73c562c78 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Mon, 21 Oct 2024 16:29:23 +0200 Subject: [PATCH] crypto/internal/mlkem768: add -768 suffix to all exported identifiers In preparation for introducing ML-KEM-1024. Aside from the constants at the top, all other changes were automated. Change-Id: I0fafce9a776c7b0b9179be1c858709cabf60e80f Reviewed-on: https://go-review.googlesource.com/c/go/+/621981 Auto-Submit: Filippo Valsorda Reviewed-by: Roland Shoemaker Reviewed-by: Russ Cox Reviewed-by: Daniel McCarney LUCI-TryBot-Result: Go LUCI --- src/crypto/internal/mlkem768/mlkem768.go | 108 +++++++++--------- src/crypto/internal/mlkem768/mlkem768_test.go | 32 +++--- src/crypto/tls/handshake_client.go | 2 +- src/crypto/tls/handshake_client_tls13.go | 2 +- src/crypto/tls/handshake_server_tls13.go | 2 +- src/crypto/tls/key_schedule.go | 6 +- src/crypto/tls/key_schedule_test.go | 2 +- 7 files changed, 76 insertions(+), 78 deletions(-) diff --git a/src/crypto/internal/mlkem768/mlkem768.go b/src/crypto/internal/mlkem768/mlkem768.go index 830841f738..1e46c46df9 100644 --- a/src/crypto/internal/mlkem768/mlkem768.go +++ b/src/crypto/internal/mlkem768/mlkem768.go @@ -33,35 +33,33 @@ const ( n = 256 q = 3329 - log2q = 12 - - // ML-KEM-768 parameters. The code makes assumptions based on these values, - // they can't be changed blindly. - k = 3 - η = 2 - du = 10 - dv = 4 - // encodingSizeX is the byte size of a ringElement or nttElement encoded // by ByteEncode_X (FIPS 203, Algorithm 5). - encodingSize12 = n * log2q / 8 - encodingSize10 = n * du / 8 - encodingSize4 = n * dv / 8 + encodingSize12 = n * 12 / 8 + encodingSize10 = n * 10 / 8 + encodingSize4 = n * 4 / 8 encodingSize1 = n * 1 / 8 - messageSize = encodingSize1 + messageSize = encodingSize1 + + SharedKeySize = 32 + SeedSize = 32 + 32 +) + +// ML-KEM-768 parameters. +const ( + k = 3 + decryptionKeySize = k * encodingSize12 encryptionKeySize = k*encodingSize12 + 32 - CiphertextSize = k*encodingSize10 + encodingSize4 - EncapsulationKeySize = encryptionKeySize - SharedKeySize = 32 - SeedSize = 32 + 32 + CiphertextSize768 = k*encodingSize10 + encodingSize4 + EncapsulationKeySize768 = encryptionKeySize ) -// A DecapsulationKey is the secret key used to decapsulate a shared key from a +// A DecapsulationKey768 is the secret key used to decapsulate a shared key from a // ciphertext. It includes various precomputed values. -type DecapsulationKey struct { +type DecapsulationKey768 struct { d [32]byte // decapsulation key seed z [32]byte // implicit rejection sampling seed @@ -75,7 +73,7 @@ type DecapsulationKey struct { // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form. // // The decapsulation key must be kept secret. -func (dk *DecapsulationKey) Bytes() []byte { +func (dk *DecapsulationKey768) Bytes() []byte { var b [SeedSize]byte copy(b[:], dk.d[:]) copy(b[32:], dk.z[:]) @@ -84,30 +82,30 @@ func (dk *DecapsulationKey) Bytes() []byte { // EncapsulationKey returns the public encapsulation key necessary to produce // ciphertexts. -func (dk *DecapsulationKey) EncapsulationKey() *EncapsulationKey { - return &EncapsulationKey{ +func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 { + return &EncapsulationKey768{ ρ: dk.ρ, h: dk.h, encryptionKey: dk.encryptionKey, } } -// An EncapsulationKey is the public key used to produce ciphertexts to be -// decapsulated by the corresponding [DecapsulationKey]. -type EncapsulationKey struct { +// An EncapsulationKey768 is the public key used to produce ciphertexts to be +// decapsulated by the corresponding [DecapsulationKey768]. +type EncapsulationKey768 struct { ρ [32]byte // sampleNTT seed for A h [32]byte // H(ek) encryptionKey } // Bytes returns the encapsulation key as a byte slice. -func (ek *EncapsulationKey) Bytes() []byte { +func (ek *EncapsulationKey768) Bytes() []byte { // The actual logic is in a separate function to outline this allocation. - b := make([]byte, 0, EncapsulationKeySize) + b := make([]byte, 0, EncapsulationKeySize768) return ek.bytes(b) } -func (ek *EncapsulationKey) bytes(b []byte) []byte { +func (ek *EncapsulationKey768) bytes(b []byte) []byte { for i := range ek.t { b = polyByteEncode(b, ek.t[i]) } @@ -126,15 +124,15 @@ type decryptionKey struct { s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize]) } -// GenerateKey generates a new decapsulation key, drawing random bytes from +// GenerateKey768 generates a new decapsulation key, drawing random bytes from // crypto/rand. The decapsulation key must be kept secret. -func GenerateKey() (*DecapsulationKey, error) { +func GenerateKey768() (*DecapsulationKey768, error) { // The actual logic is in a separate function to outline this allocation. - dk := &DecapsulationKey{} + dk := &DecapsulationKey768{} return generateKey(dk), nil } -func generateKey(dk *DecapsulationKey) *DecapsulationKey { +func generateKey(dk *DecapsulationKey768) *DecapsulationKey768 { var d [32]byte rand.Read(d[:]) var z [32]byte @@ -142,15 +140,15 @@ func generateKey(dk *DecapsulationKey) *DecapsulationKey { return kemKeyGen(dk, &d, &z) } -// NewDecapsulationKey parses a decapsulation key from a 64-byte +// NewDecapsulationKey768 parses a decapsulation key from a 64-byte // seed in the "d || z" form. The seed must be uniformly random. -func NewDecapsulationKey(seed []byte) (*DecapsulationKey, error) { +func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) { // The actual logic is in a separate function to outline this allocation. - dk := &DecapsulationKey{} + dk := &DecapsulationKey768{} return newKeyFromSeed(dk, seed) } -func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) { +func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) { if len(seed) != SeedSize { return nil, errors.New("mlkem768: invalid seed length") } @@ -164,9 +162,9 @@ func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save // copies and allocations. -func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { +func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) *DecapsulationKey768 { if dk == nil { - dk = &DecapsulationKey{} + dk = &DecapsulationKey768{} } dk.d = *d dk.z = *z @@ -217,13 +215,13 @@ func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { // encapsulation key, drawing random bytes from crypto/rand. // // The shared key must be kept secret. -func (ek *EncapsulationKey) Encapsulate() (ciphertext, sharedKey []byte) { +func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) { // The actual logic is in a separate function to outline this allocation. - var cc [CiphertextSize]byte + var cc [CiphertextSize768]byte return ek.encapsulate(&cc) } -func (ek *EncapsulationKey) encapsulate(cc *[CiphertextSize]byte) (ciphertext, sharedKey []byte) { +func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) { var m [messageSize]byte rand.Read(m[:]) // Note that the modulus check (step 2 of the encapsulation key check from @@ -234,9 +232,9 @@ func (ek *EncapsulationKey) encapsulate(cc *[CiphertextSize]byte) (ciphertext, s // kemEncaps generates a shared key and an associated ciphertext. // // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17. -func kemEncaps(cc *[CiphertextSize]byte, ek *EncapsulationKey, m *[messageSize]byte) (c, K []byte) { +func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (c, K []byte) { if cc == nil { - cc = &[CiphertextSize]byte{} + cc = &[CiphertextSize768]byte{} } g := sha3.New512() @@ -248,11 +246,11 @@ func kemEncaps(cc *[CiphertextSize]byte, ek *EncapsulationKey, m *[messageSize]b return c, K } -// NewEncapsulationKey parses an encapsulation key from its encoded form. -// If the encapsulation key is not valid, NewEncapsulationKey returns an error. -func NewEncapsulationKey(encapsulationKey []byte) (*EncapsulationKey, error) { +// NewEncapsulationKey768 parses an encapsulation key from its encoded form. +// If the encapsulation key is not valid, NewEncapsulationKey768 returns an error. +func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) { // The actual logic is in a separate function to outline this allocation. - ek := &EncapsulationKey{} + ek := &EncapsulationKey768{} return parseEK(ek, encapsulationKey) } @@ -260,7 +258,7 @@ func NewEncapsulationKey(encapsulationKey []byte) (*EncapsulationKey, error) { // // It implements the initial stages of K-PKE.Encrypt according to FIPS 203, // Algorithm 14. -func parseEK(ek *EncapsulationKey, ekPKE []byte) (*EncapsulationKey, error) { +func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) { if len(ekPKE) != encryptionKeySize { return nil, errors.New("mlkem768: invalid encapsulation key length") } @@ -290,7 +288,7 @@ func parseEK(ek *EncapsulationKey, ekPKE []byte) (*EncapsulationKey, error) { // // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the // computation of t and AT is done in parseEK. -func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte { +func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte { var N byte r, e1 := make([]nttElement, k), make([]ringElement, k) for i := range r { @@ -333,11 +331,11 @@ func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byt // If the ciphertext is not valid, Decapsulate returns an error. // // The shared key must be kept secret. -func (dk *DecapsulationKey) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) { - if len(ciphertext) != CiphertextSize { +func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) { + if len(ciphertext) != CiphertextSize768 { return nil, errors.New("mlkem768: invalid ciphertext length") } - c := (*[CiphertextSize]byte)(ciphertext) + c := (*[CiphertextSize768]byte)(ciphertext) // Note that the hash check (step 3 of the decapsulation input check from // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always // validly generated by ML-KEM.KeyGen_internal. @@ -347,7 +345,7 @@ func (dk *DecapsulationKey) Decapsulate(ciphertext []byte) (sharedKey []byte, er // kemDecaps produces a shared key from a ciphertext. // // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18. -func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { +func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) { m := pkeDecrypt(&dk.decryptionKey, c) g := sha3.New512() g.Write(m[:]) @@ -359,7 +357,7 @@ func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { J.Write(c[:]) Kout := make([]byte, SharedKeySize) J.Read(Kout) - var cc [CiphertextSize]byte + var cc [CiphertextSize768]byte c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r) subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime) @@ -370,7 +368,7 @@ func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { // // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15, // although s is retained from kemKeyGen. -func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte { +func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte { u := make([]ringElement, k) for i := range u { b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)]) diff --git a/src/crypto/internal/mlkem768/mlkem768_test.go b/src/crypto/internal/mlkem768/mlkem768_test.go index 295aa95d0a..58dc138810 100644 --- a/src/crypto/internal/mlkem768/mlkem768_test.go +++ b/src/crypto/internal/mlkem768/mlkem768_test.go @@ -16,7 +16,7 @@ import ( ) func TestRoundTrip(t *testing.T) { - dk, err := GenerateKey() + dk, err := GenerateKey768() if err != nil { t.Fatal(err) } @@ -29,7 +29,7 @@ func TestRoundTrip(t *testing.T) { t.Fail() } - dk1, err := GenerateKey() + dk1, err := GenerateKey768() if err != nil { t.Fatal(err) } @@ -50,7 +50,7 @@ func TestRoundTrip(t *testing.T) { } func TestBadLengths(t *testing.T) { - dk, err := GenerateKey() + dk, err := GenerateKey768() if err != nil { t.Fatal(err) } @@ -59,14 +59,14 @@ func TestBadLengths(t *testing.T) { c, _ := ek.Encapsulate() for i := 0; i < len(ekBytes)-1; i++ { - if _, err := NewEncapsulationKey(ekBytes[:i]); err == nil { + if _, err := NewEncapsulationKey768(ekBytes[:i]); err == nil { t.Errorf("expected error for ek length %d", i) } } ekLong := ekBytes for i := 0; i < 100; i++ { ekLong = append(ekLong, 0) - if _, err := NewEncapsulationKey(ekLong); err == nil { + if _, err := NewEncapsulationKey768(ekLong); err == nil { t.Errorf("expected error for ek length %d", len(ekLong)) } } @@ -105,11 +105,11 @@ func TestAccumulated(t *testing.T) { o := sha3.NewShake128() seed := make([]byte, SeedSize) var msg [messageSize]byte - ct1 := make([]byte, CiphertextSize) + ct1 := make([]byte, CiphertextSize768) for i := 0; i < n; i++ { s.Read(seed) - dk, err := NewDecapsulationKey(seed) + dk, err := NewDecapsulationKey768(seed) if err != nil { t.Fatal(err) } @@ -146,7 +146,7 @@ func TestAccumulated(t *testing.T) { var sink byte func BenchmarkKeyGen(b *testing.B) { - var dk DecapsulationKey + var dk DecapsulationKey768 var d, z [32]byte rand.Read(d[:]) rand.Read(z[:]) @@ -162,15 +162,15 @@ func BenchmarkEncaps(b *testing.B) { rand.Read(seed) var m [messageSize]byte rand.Read(m[:]) - dk, err := NewDecapsulationKey(seed) + dk, err := NewDecapsulationKey768(seed) if err != nil { b.Fatal(err) } ekBytes := dk.EncapsulationKey().Bytes() - var c [CiphertextSize]byte + var c [CiphertextSize768]byte b.ResetTimer() for i := 0; i < b.N; i++ { - ek, err := NewEncapsulationKey(ekBytes) + ek, err := NewEncapsulationKey768(ekBytes) if err != nil { b.Fatal(err) } @@ -180,7 +180,7 @@ func BenchmarkEncaps(b *testing.B) { } func BenchmarkDecaps(b *testing.B) { - dk, err := GenerateKey() + dk, err := GenerateKey768() if err != nil { b.Fatal(err) } @@ -188,13 +188,13 @@ func BenchmarkDecaps(b *testing.B) { c, _ := ek.Encapsulate() b.ResetTimer() for i := 0; i < b.N; i++ { - K := kemDecaps(dk, (*[CiphertextSize]byte)(c)) + K := kemDecaps(dk, (*[CiphertextSize768]byte)(c)) sink ^= K[0] } } func BenchmarkRoundTrip(b *testing.B) { - dk, err := GenerateKey() + dk, err := GenerateKey768() if err != nil { b.Fatal(err) } @@ -206,7 +206,7 @@ func BenchmarkRoundTrip(b *testing.B) { } b.Run("Alice", func(b *testing.B) { for i := 0; i < b.N; i++ { - dkS, err := GenerateKey() + dkS, err := GenerateKey768() if err != nil { b.Fatal(err) } @@ -222,7 +222,7 @@ func BenchmarkRoundTrip(b *testing.B) { }) b.Run("Bob", func(b *testing.B) { for i := 0; i < b.N; i++ { - ek, err := NewEncapsulationKey(ekBytes) + ek, err := NewEncapsulationKey768(ekBytes) if err != nil { b.Fatal(err) } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 8965ad6eaf..1c14476d60 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -164,7 +164,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon if _, err := io.ReadFull(config.rand(), seed); err != nil { return nil, nil, nil, err } - keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey(seed) + keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey768(seed) if err != nil { return nil, nil, nil, err } diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index cdef806ab0..fbec7431a1 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -481,7 +481,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { ecdhePeerData := hs.serverHello.serverShare.data if hs.serverHello.serverShare.group == x25519Kyber768Draft00 { - if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize { + if len(ecdhePeerData) != x25519PublicKeySize+mlkem768.CiphertextSize768 { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 29add50d6e..3591aa1f11 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -223,7 +223,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { ecdhData := clientKeyShare.data if selectedGroup == x25519Kyber768Draft00 { ecdhGroup = X25519 - if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize { + if len(ecdhData) != x25519PublicKeySize+mlkem768.EncapsulationKeySize768 { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid Kyber client key share") } diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go index 3bbfc1b435..8377807ba5 100644 --- a/src/crypto/tls/key_schedule.go +++ b/src/crypto/tls/key_schedule.go @@ -54,11 +54,11 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcrip type keySharePrivateKeys struct { curveID CurveID ecdhe *ecdh.PrivateKey - kyber *mlkem768.DecapsulationKey + kyber *mlkem768.DecapsulationKey768 } // kyberDecapsulate implements decapsulation according to Kyber Round 3. -func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { +func kyberDecapsulate(dk *mlkem768.DecapsulationKey768, c []byte) ([]byte, error) { K, err := dk.Decapsulate(c) if err != nil { return nil, err @@ -68,7 +68,7 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { // kyberEncapsulate implements encapsulation according to Kyber Round 3. func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { - k, err := mlkem768.NewEncapsulationKey(ek) + k, err := mlkem768.NewEncapsulationKey768(ek) if err != nil { return nil, nil, err } diff --git a/src/crypto/tls/key_schedule_test.go b/src/crypto/tls/key_schedule_test.go index 32532770d4..766370ff21 100644 --- a/src/crypto/tls/key_schedule_test.go +++ b/src/crypto/tls/key_schedule_test.go @@ -120,7 +120,7 @@ func TestTrafficKey(t *testing.T) { } func TestKyberEncapsulate(t *testing.T) { - dk, err := mlkem768.GenerateKey() + dk, err := mlkem768.GenerateKey768() if err != nil { t.Fatal(err) } -- 2.48.1