From: Filippo Valsorda Date: Mon, 15 Apr 2024 01:56:10 +0000 (+0200) Subject: crypto/internal/mlkem768: various performance optimizations X-Git-Tag: go1.23rc1~459 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=cc1659916d49fdfe93aa8879b7f0d0cfb50f017a;p=gostls13.git crypto/internal/mlkem768: various performance optimizations goos: linux goarch: amd64 pkg: crypto/internal/mlkem768 cpu: Intel(R) Core(TM) i5-7400 CPU @ 3.00GHz │ c0a0ba254c │ 2aeb615fa6 │ │ sec/op │ sec/op vs base │ KeyGen-4 73.36µ ± 0% 67.38µ ± 1% -8.15% (p=0.000 n=20) Encaps-4 108.96µ ± 0% 99.56µ ± 1% -8.63% (p=0.000 n=20) Decaps-4 132.19µ ± 0% 96.85µ ± 0% -26.74% (p=0.000 n=20) RoundTrip/Alice-4 216.4µ ± 0% 173.1µ ± 0% -20.01% (p=0.000 n=20) RoundTrip/Bob-4 109.5µ ± 0% 100.5µ ± 0% -8.19% (p=0.000 n=20) Change-Id: I600116baa0b390bb83950a42c7693cd7806dba9a Reviewed-on: https://go-review.googlesource.com/c/go/+/578797 Reviewed-by: Roland Shoemaker LUCI-TryBot-Result: Go LUCI Reviewed-by: Cherry Mui Auto-Submit: Filippo Valsorda --- diff --git a/src/crypto/internal/mlkem768/mlkem768.go b/src/crypto/internal/mlkem768/mlkem768.go index c6b191c1ae..24bedea84f 100644 --- a/src/crypto/internal/mlkem768/mlkem768.go +++ b/src/crypto/internal/mlkem768/mlkem768.go @@ -68,57 +68,123 @@ const ( SeedSize = 32 + 32 ) -// GenerateKey generates an encapsulation key and a corresponding decapsulation -// key, drawing random bytes from crypto/rand. -// -// The decapsulation key must be kept secret. -func GenerateKey() (encapsulationKey, decapsulationKey []byte, err error) { - d := make([]byte, 32) - if _, err := rand.Read(d); err != nil { - return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) +// A DecapsulationKey is the secret key used to decapsulate a shared key from a +// ciphertext. It includes various precomputed values. +type DecapsulationKey struct { + dk [DecapsulationKeySize]byte + encryptionKey + decryptionKey +} + +// Bytes returns the extended encoding of the decapsulation key, according to +// FIPS 203 (DRAFT). +func (dk *DecapsulationKey) Bytes() []byte { + var b [DecapsulationKeySize]byte + copy(b[:], dk.dk[:]) + return b[:] +} + +// EncapsulationKey returns the public encapsulation key necessary to produce +// ciphertexts. +func (dk *DecapsulationKey) EncapsulationKey() []byte { + var b [EncapsulationKeySize]byte + copy(b[:], dk.dk[decryptionKeySize:]) + return b[:] +} + +// encryptionKey is the parsed and expanded form of a PKE encryption key. +type encryptionKey struct { + t [k]nttElement // ByteDecode₁₂(ek[:384k]) + A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i) +} + +// decryptionKey is the parsed and expanded form of a PKE decryption key. +type decryptionKey struct { + s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize]) +} + +// GenerateKey generates a new decapsulation key, drawing random bytes from +// crypto/rand. The decapsulation key must be kept secret. +func GenerateKey() (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return generateKey(dk) +} + +func generateKey(dk *DecapsulationKey) (*DecapsulationKey, error) { + var d [32]byte + if _, err := rand.Read(d[:]); err != nil { + return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) } - z := make([]byte, 32) - if _, err := rand.Read(z); err != nil { - return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) + var z [32]byte + if _, err := rand.Read(z[:]); err != nil { + return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) } - ek, dk := kemKeyGen(d, z) - return ek, dk, nil + return kemKeyGen(dk, &d, &z), nil +} + +// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte +// seed in the "d || z" form. The seed must be uniformly random. +func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return newKeyFromSeed(dk, seed) } -// NewKeyFromSeed deterministically generates an encapsulation key and a -// corresponding decapsulation key from a 64-byte seed. The seed must be -// uniformly random. -func NewKeyFromSeed(seed []byte) (encapsulationKey, decapsulationKey []byte, err error) { +func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) { if len(seed) != SeedSize { - return nil, nil, errors.New("mlkem768: invalid seed length") + return nil, errors.New("mlkem768: invalid seed length") } - ek, dk := kemKeyGen(seed[:32], seed[32:]) - return ek, dk, nil + d := (*[32]byte)(seed[:32]) + z := (*[32]byte)(seed[32:]) + return kemKeyGen(dk, d, z), nil } -// kemKeyGen generates an encapsulation key and a corresponding decapsulation key. -// -// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15. -func kemKeyGen(d, z []byte) (ek, dk []byte) { - ekPKE, dkPKE := pkeKeyGen(d) - dk = make([]byte, 0, DecapsulationKeySize) - dk = append(dk, dkPKE...) - dk = append(dk, ekPKE...) - H := sha3.New256() - H.Write(ekPKE) - dk = H.Sum(dk) - dk = append(dk, z...) - return ekPKE, dk +// NewKeyFromExtendedEncoding parses a decapsulation key from its FIPS 203 +// (DRAFT) extended encoding. +func NewKeyFromExtendedEncoding(decapsulationKey []byte) (*DecapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey{} + return newKeyFromExtendedEncoding(dk, decapsulationKey) } -// pkeKeyGen generates a key pair for the underlying PKE from a 32-byte random seed. +func newKeyFromExtendedEncoding(dk *DecapsulationKey, dkBytes []byte) (*DecapsulationKey, error) { + if len(dkBytes) != DecapsulationKeySize { + return nil, errors.New("mlkem768: invalid decapsulation key length") + } + + // Note that we don't check that H(ek) matches ekPKE, as that's not + // specified in FIPS 203 (DRAFT). This is one reason to prefer the seed + // private key format. + dk.dk = [DecapsulationKeySize]byte(dkBytes) + + dkPKE := dkBytes[:decryptionKeySize] + if err := parseDK(&dk.decryptionKey, dkPKE); err != nil { + return nil, err + } + + ekPKE := dkBytes[decryptionKeySize : decryptionKeySize+encryptionKeySize] + if err := parseEK(&dk.encryptionKey, ekPKE); err != nil { + return nil, err + } + + return dk, nil +} + +// kemKeyGen generates a decapsulation key. // -// It implements K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. -func pkeKeyGen(d []byte) (ek, dk []byte) { - G := sha3.Sum512(d) +// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15, and +// K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. The two are merged +// to save copies and allocations. +func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { + if dk == nil { + dk = &DecapsulationKey{} + } + + G := sha3.Sum512(d[:]) ρ, σ := G[:32], G[32:] - A := make([]nttElement, k*k) + A := &dk.A for i := byte(0); i < k; i++ { for j := byte(0); j < k; j++ { // Note that this is consistent with Kyber round 3, rather than with @@ -129,36 +195,51 @@ func pkeKeyGen(d []byte) (ek, dk []byte) { } var N byte - s, e := make([]nttElement, k), make([]nttElement, k) + s := &dk.s for i := range s { s[i] = ntt(samplePolyCBD(σ, N)) N++ } + e := make([]nttElement, k) for i := range e { e[i] = ntt(samplePolyCBD(σ, N)) N++ } - t := make([]nttElement, k) // A ◦ s + e - for i := range t { + t := &dk.t + for i := range t { // t = A ◦ s + e t[i] = e[i] for j := range s { t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j])) } } - ek = make([]byte, 0, encryptionKeySize) + // dkPKE ← ByteEncode₁₂(s) + // ekPKE ← ByteEncode₁₂(t) || ρ + // ek ← ekPKE + // dk ← dkPKE || ek || H(ek) || z + dkB := dk.dk[:0] + + for i := range s { + dkB = polyByteEncode(dkB, s[i]) + } + for i := range t { - ek = polyByteEncode(ek, t[i]) + dkB = polyByteEncode(dkB, t[i]) } - ek = append(ek, ρ...) + dkB = append(dkB, ρ...) - dk = make([]byte, 0, decryptionKeySize) - for i := range s { - dk = polyByteEncode(dk, s[i]) + H := sha3.New256() + H.Write(dkB[decryptionKeySize:]) + dkB = H.Sum(dkB) + + dkB = append(dkB, z[:]...) + + if len(dkB) != len(dk.dk) { + panic("mlkem768: internal error: invalid decapsulation key size") } - return ek, dk + return dk } // Encapsulate generates a shared key and an associated ciphertext from an @@ -167,65 +248,79 @@ func pkeKeyGen(d []byte) (ek, dk []byte) { // // The shared key must be kept secret. func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { + // The actual logic is in a separate function to outline this allocation. + var cc [CiphertextSize]byte + return encapsulate(&cc, encapsulationKey) +} + +func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { if len(encapsulationKey) != EncapsulationKeySize { return nil, nil, errors.New("mlkem768: invalid encapsulation key length") } - m := make([]byte, messageSize) - if _, err := rand.Read(m); err != nil { + var m [messageSize]byte + if _, err := rand.Read(m[:]); err != nil { return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) } - ciphertext, sharedKey, err = kemEncaps(encapsulationKey, m) - if err != nil { - return nil, nil, err - } - return ciphertext, sharedKey, nil + return kemEncaps(cc, encapsulationKey, &m) } // kemEncaps generates a shared key and an associated ciphertext. // // It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16. -func kemEncaps(ek, m []byte) (c, K []byte, err error) { - H := sha3.Sum256(ek) +func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) { + if cc == nil { + cc = &[CiphertextSize]byte{} + } + + H := sha3.Sum256(ek[:]) g := sha3.New512() - g.Write(m) + g.Write(m[:]) g.Write(H[:]) G := g.Sum(nil) K, r := G[:SharedKeySize], G[SharedKeySize:] - c, err = pkeEncrypt(ek, m, r) - return c, K, err + var ex encryptionKey + if err := parseEK(&ex, ek[:]); err != nil { + return nil, nil, err + } + c = pkeEncrypt(cc, &ex, m, r) + return c, K, nil } -// pkeEncrypt encrypt a plaintext message. It expects ek (the encryption key) to -// be 1184 bytes, and m (the message) and rnd (the randomness) to be 32 bytes. +// parseEK parses an encryption key from its encoded form. // -// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13. -func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) { - if len(ek) != encryptionKeySize { - return nil, errors.New("mlkem768: invalid encryption key length") - } - if len(m) != messageSize { - return nil, errors.New("mlkem768: invalid messages length") +// It implements the initial stages of K-PKE.Encrypt according to FIPS 203 +// (DRAFT), Algorithm 13. +func parseEK(ex *encryptionKey, ekPKE []byte) error { + if len(ekPKE) != encryptionKeySize { + return errors.New("mlkem768: invalid encryption key length") } - t := make([]nttElement, k) - for i := range t { + for i := range ex.t { var err error - t[i], err = polyByteDecode[nttElement](ek[:encodingSize12]) + ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) if err != nil { - return nil, err + return err } - ek = ek[encodingSize12:] + ekPKE = ekPKE[encodingSize12:] } - ρ := ek + ρ := ekPKE - AT := make([]nttElement, k*k) for i := byte(0); i < k; i++ { for j := byte(0); j < k; j++ { - // Note that i and j are inverted, as we need the transposed of A. - AT[i*k+j] = sampleNTT(ρ, i, j) + // See the note in pkeKeyGen about the order of the indices being + // consistent with Kyber round 3. + ex.A[i*k+j] = sampleNTT(ρ, j, i) } } + return nil +} + +// pkeEncrypt encrypt a plaintext message. +// +// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13, +// although the computation of t and AT is done in parseEK. +func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte { var N byte r, e1 := make([]nttElement, k), make([]ringElement, k) for i := range r { @@ -242,125 +337,107 @@ func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) { for i := range u { u[i] = e1[i] for j := range r { - u[i] = polyAdd(u[i], inverseNTT(nttMul(AT[i*k+j], r[j]))) + // Note that i and j are inverted, as we need the transposed of A. + u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.A[j*k+i], r[j]))) } } - μ, err := ringDecodeAndDecompress1(m) - if err != nil { - return nil, err - } + μ := ringDecodeAndDecompress1(m) var vNTT nttElement // t⊺ ◦ r - for i := range t { - vNTT = polyAdd(vNTT, nttMul(t[i], r[i])) + for i := range ex.t { + vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i])) } v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ) - c := make([]byte, 0, CiphertextSize) + c := cc[:0] for _, f := range u { c = ringCompressAndEncode10(c, f) } c = ringCompressAndEncode4(c, v) - return c, nil + return c } // Decapsulate generates a shared key from a ciphertext and a decapsulation key. -// If the decapsulation key or the ciphertext are not valid, Decapsulate returns -// an error. +// If the ciphertext is not valid, Decapsulate returns an error. // // The shared key must be kept secret. -func Decapsulate(decapsulationKey, ciphertext []byte) (sharedKey []byte, err error) { - if len(decapsulationKey) != DecapsulationKeySize { - return nil, errors.New("mlkem768: invalid decapsulation key length") - } +func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) { if len(ciphertext) != CiphertextSize { return nil, errors.New("mlkem768: invalid ciphertext length") } - return kemDecaps(decapsulationKey, ciphertext) + c := (*[CiphertextSize]byte)(ciphertext) + return kemDecaps(dk, c), nil } // kemDecaps produces a shared key from a ciphertext. // // It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17. -func kemDecaps(dk, c []byte) (K []byte, err error) { - dkPKE := dk[:decryptionKeySize] - ekPKE := dk[decryptionKeySize : decryptionKeySize+encryptionKeySize] - h := dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32] - z := dk[decryptionKeySize+encryptionKeySize+32:] - - m, err := pkeDecrypt(dkPKE, c) - if err != nil { - // This is only reachable if the ciphertext or the decryption key are - // encoded incorrectly, so it leaks no information about the message. - return nil, err - } +func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { + h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32] + z := dk.dk[decryptionKeySize+encryptionKeySize+32:] + + m := pkeDecrypt(&dk.decryptionKey, c) g := sha3.New512() - g.Write(m) + g.Write(m[:]) g.Write(h) G := g.Sum(nil) Kprime, r := G[:SharedKeySize], G[SharedKeySize:] J := sha3.NewShake256() J.Write(z) - J.Write(c) + J.Write(c[:]) Kout := make([]byte, SharedKeySize) J.Read(Kout) - c1, err := pkeEncrypt(ekPKE, m, r) - if err != nil { - // Likewise, this is only reachable if the encryption key is encoded - // incorrectly, so it leaks no secret information through timing. - return nil, err - } + var cc [CiphertextSize]byte + c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r) - subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c, c1), Kout, Kprime) - return Kout, nil + subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime) + return Kout } -// pkeDecrypt decrypts a ciphertext. It expects dk (the decryption key) to -// be 1152 bytes, and c (the ciphertext) to be 1088 bytes. +// parseDK parses a decryption key from its encoded form. // -// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14. -func pkeDecrypt(dk, c []byte) ([]byte, error) { - if len(dk) != decryptionKeySize { - return nil, errors.New("mlkem768: invalid decryption key length") - } - if len(c) != CiphertextSize { - return nil, errors.New("mlkem768: invalid ciphertext length") +// It implements the computation of s from K-PKE.Decrypt according to FIPS 203 +// (DRAFT), Algorithm 14. +func parseDK(dx *decryptionKey, dkPKE []byte) error { + if len(dkPKE) != decryptionKeySize { + return errors.New("mlkem768: invalid decryption key length") } - u := make([]ringElement, k) - for i := range u { - f, err := ringDecodeAndDecompress10(c[:encodingSize10]) + for i := range dx.s { + f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12]) if err != nil { - return nil, err + return err } - u[i] = f - c = c[encodingSize10:] + dx.s[i] = f + dkPKE = dkPKE[encodingSize12:] } - v, err := ringDecodeAndDecompress4(c) - if err != nil { - return nil, err - } + return nil +} - s := make([]nttElement, k) - for i := range s { - f, err := polyByteDecode[nttElement](dk[:encodingSize12]) - if err != nil { - return nil, err - } - s[i] = f - dk = dk[encodingSize12:] +// pkeDecrypt decrypts a ciphertext. +// +// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14, +// although the computation of s is done in parseDK. +func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte { + u := make([]ringElement, k) + for i := range u { + b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)]) + u[i] = ringDecodeAndDecompress10(b) } + b := (*[encodingSize4]byte)(c[encodingSize10*k:]) + v := ringDecodeAndDecompress4(b) + var mask nttElement // s⊺ ◦ NTT(u) - for i := range s { - mask = polyAdd(mask, nttMul(s[i], ntt(u[i]))) + for i := range dx.s { + mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i]))) } w := polySub(v, inverseNTT(mask)) - return ringCompressAndEncode1(nil, w), nil + return ringCompressAndEncode1(nil, w) } // fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced. @@ -397,7 +474,7 @@ const ( barrettShift = 24 // log₂(2¹² * 2¹²) ) -// fieldReduce reduces a value a < q² using Barrett reduction, to avoid +// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid // potentially variable-time division. func fieldReduce(a uint32) fieldElement { quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift) @@ -409,6 +486,21 @@ func fieldMul(a, b fieldElement) fieldElement { return fieldReduce(x) } +// fieldMulSub returns a * (b - c). This operation is fused to save a +// fieldReduceOnce after the subtraction. +func fieldMulSub(a, b, c fieldElement) fieldElement { + x := uint32(a) * uint32(b-c+q) + return fieldReduce(x) +} + +// fieldAddMul returns a * b + c * d. This operation is fused to save a +// fieldReduceOnce and a fieldReduce. +func fieldAddMul(a, b, c, d fieldElement) fieldElement { + x := uint32(a) * uint32(b) + x += uint32(c) * uint32(d) + return fieldReduce(x) +} + // compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to // FIPS 203 (DRAFT), Definition 4.5. func compress(x fieldElement, d uint8) uint16 { @@ -558,17 +650,14 @@ func ringCompressAndEncode1(s []byte, f ringElement) []byte { // // It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5, // followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6. -func ringDecodeAndDecompress1(b []byte) (ringElement, error) { - if len(b) != encodingSize1 { - return ringElement{}, errors.New("mlkem768: invalid message length") - } +func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement { var f ringElement for i := range f { b_i := b[i/8] >> (i % 8) & 1 const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3 f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋ } - return f, nil + return f } // ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s, @@ -589,16 +678,13 @@ func ringCompressAndEncode4(s []byte, f ringElement) []byte { // // It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5, // followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6. -func ringDecodeAndDecompress4(b []byte) (ringElement, error) { - if len(b) != encodingSize4 { - return ringElement{}, errors.New("mlkem768: invalid encoding length") - } +func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement { var f ringElement for i := 0; i < n; i += 2 { f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4)) f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4)) } - return f, nil + return f } // ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s, @@ -629,10 +715,8 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte { // // It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5, // followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6. -func ringDecodeAndDecompress10(b []byte) (ringElement, error) { - if len(b) != encodingSize10 { - return ringElement{}, errors.New("mlkem768: invalid encoding length") - } +func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { + b := bb[:] var f ringElement for i := 0; i < n; i += 4 { x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 @@ -642,7 +726,7 @@ func ringDecodeAndDecompress10(b []byte) (ringElement, error) { f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10)) f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10)) } - return f, nil + return f } // samplePolyCBD draws a ringElement from the special Dη distribution given a @@ -681,11 +765,12 @@ var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, // It implements MultiplyNTTs, according to FIPS 203 (DRAFT), Algorithm 10. func nttMul(f, g nttElement) nttElement { var h nttElement - for i := 0; i < 128; i++ { - a0, a1 := f[2*i], f[2*i+1] - b0, b1 := g[2*i], g[2*i+1] - h[2*i] = fieldAdd(fieldMul(a0, b0), fieldMul(fieldMul(a1, b1), gammas[i])) - h[2*i+1] = fieldAdd(fieldMul(a0, b1), fieldMul(a1, b0)) + // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826. + for i := 0; i < 256; i += 2 { + a0, a1 := f[i], f[i+1] + b0, b1 := g[i], g[i+1] + h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2]) + h[i+1] = fieldAddMul(a0, b1, a1, b0) } return h } @@ -702,18 +787,12 @@ func ntt(f ringElement) nttElement { for start := 0; start < 256; start += 2 * len { zeta := zetas[k] k++ - for j := start; j < start+len; j += 2 { - // Loop 2x unrolled for performance. - { - t := fieldMul(zeta, f[j+len]) - f[j+len] = fieldSub(f[j], t) - f[j] = fieldAdd(f[j], t) - } - { - t := fieldMul(zeta, f[j+1+len]) - f[j+1+len] = fieldSub(f[j+1], t) - f[j+1] = fieldAdd(f[j+1], t) - } + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := fieldMul(zeta, flen[j]) + flen[j] = fieldSub(f[j], t) + f[j] = fieldAdd(f[j], t) } } } @@ -729,18 +808,12 @@ func inverseNTT(f nttElement) ringElement { for start := 0; start < 256; start += 2 * len { zeta := zetas[k] k-- - for j := start; j < start+len; j += 2 { - // Loop 2x unrolled for performance. - { - t := f[j] - f[j] = fieldAdd(t, f[j+len]) - f[j+len] = fieldMul(zeta, fieldSub(f[j+len], t)) - } - { - t := f[j+1] - f[j+1] = fieldAdd(t, f[j+1+len]) - f[j+1+len] = fieldMul(zeta, fieldSub(f[j+1+len], t)) - } + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := f[j] + f[j] = fieldAdd(t, flen[j]) + flen[j] = fieldMulSub(zeta, flen[j], t) } } } diff --git a/src/crypto/internal/mlkem768/mlkem768_test.go b/src/crypto/internal/mlkem768/mlkem768_test.go index 6e2ac769ef..b91b42a424 100644 --- a/src/crypto/internal/mlkem768/mlkem768_test.go +++ b/src/crypto/internal/mlkem768/mlkem768_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" _ "embed" "encoding/hex" + "errors" "flag" "math/big" "strconv" @@ -17,6 +18,16 @@ import ( "golang.org/x/crypto/sha3" ) +func TestFieldReduce(t *testing.T) { + for a := uint32(0); a < 2*q*q; a++ { + got := fieldReduce(a) + exp := fieldElement(a % q) + if got != exp { + t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp) + } + } +} + func TestFieldAdd(t *testing.T) { for a := fieldElement(0); a < q; a++ { for b := fieldElement(0); b < q; b++ { @@ -188,11 +199,11 @@ func TestGammas(t *testing.T) { } func TestRoundTrip(t *testing.T) { - ek, dk, err := GenerateKey() + dk, err := GenerateKey() if err != nil { t.Fatal(err) } - c, Ke, err := Encapsulate(ek) + c, Ke, err := Encapsulate(dk.EncapsulationKey()) if err != nil { t.Fatal(err) } @@ -204,21 +215,21 @@ func TestRoundTrip(t *testing.T) { t.Fail() } - ek1, dk1, err := GenerateKey() + dk1, err := GenerateKey() if err != nil { t.Fatal(err) } - if bytes.Equal(ek, ek1) { + if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) { t.Fail() } - if bytes.Equal(dk, dk1) { + if bytes.Equal(dk.Bytes(), dk1.Bytes()) { t.Fail() } - if bytes.Equal(dk[len(dk)-32:], dk1[len(dk)-32:]) { + if bytes.Equal(dk.Bytes()[EncapsulationKeySize-32:], dk1.Bytes()[EncapsulationKeySize-32:]) { t.Fail() } - c1, Ke1, err := Encapsulate(ek) + c1, Ke1, err := Encapsulate(dk.EncapsulationKey()) if err != nil { t.Fatal(err) } @@ -231,10 +242,11 @@ func TestRoundTrip(t *testing.T) { } func TestBadLengths(t *testing.T) { - ek, dk, err := GenerateKey() + dk, err := GenerateKey() if err != nil { t.Fatal(err) } + ek := dk.EncapsulationKey() for i := 0; i < len(ek)-1; i++ { if _, _, err := Encapsulate(ek[:i]); err == nil { @@ -254,15 +266,15 @@ func TestBadLengths(t *testing.T) { t.Fatal(err) } - for i := 0; i < len(dk)-1; i++ { - if _, err := Decapsulate(dk[:i], c); err == nil { + for i := 0; i < len(dk.Bytes())-1; i++ { + if _, err := NewKeyFromExtendedEncoding(dk.Bytes()[:i]); err == nil { t.Errorf("expected error for dk length %d", i) } } - dkLong := dk + dkLong := dk.Bytes() for i := 0; i < 100; i++ { dkLong = append(dkLong, 0) - if _, err := Decapsulate(dkLong, c); err == nil { + if _, err := NewKeyFromExtendedEncoding(dkLong); err == nil { t.Errorf("expected error for dk length %d", len(dkLong)) } } @@ -281,6 +293,29 @@ func TestBadLengths(t *testing.T) { } } +func EncapsulateDerand(ek, m []byte) (c, K []byte, err error) { + if len(m) != messageSize { + return nil, nil, errors.New("bad message length") + } + return kemEncaps(nil, ek, (*[messageSize]byte)(m)) +} + +func DecapsulateFromBytes(dkBytes []byte, c []byte) ([]byte, error) { + dk, err := NewKeyFromExtendedEncoding(dkBytes) + if err != nil { + return nil, err + } + return Decapsulate(dk, c) +} + +func GenerateKeyDerand(t testing.TB, d, z []byte) ([]byte, *DecapsulationKey) { + if len(d) != 32 || len(z) != 32 { + t.Fatal("bad length") + } + dk := kemKeyGen(nil, (*[32]byte)(d), (*[32]byte)(z)) + return dk.EncapsulationKey(), dk +} + var millionFlag = flag.Bool("million", false, "run the million vector test") // TestPQCrystalsAccumulated accumulates the 10k vectors generated by the @@ -308,19 +343,19 @@ func TestPQCrystalsAccumulated(t *testing.T) { for i := 0; i < n; i++ { s.Read(d) s.Read(z) - ek, dk := kemKeyGen(d, z) + ek, dk := GenerateKeyDerand(t, d, z) o.Write(ek) - o.Write(dk) + o.Write(dk.Bytes()) s.Read(msg) - ct, k, err := kemEncaps(ek, msg) + ct, k, err := EncapsulateDerand(ek, msg) if err != nil { t.Fatal(err) } o.Write(ct) o.Write(k) - kk, err := kemDecaps(dk, ct) + kk, err := Decapsulate(dk, ct) if err != nil { t.Fatal(err) } @@ -329,7 +364,7 @@ func TestPQCrystalsAccumulated(t *testing.T) { } s.Read(ct1) - k1, err := kemDecaps(dk, ct1) + k1, err := Decapsulate(dk, ct1) if err != nil { t.Fatal(err) } @@ -342,25 +377,17 @@ func TestPQCrystalsAccumulated(t *testing.T) { } } -var sinkElement fieldElement - -func BenchmarkSampleNTT(b *testing.B) { - for i := 0; i < b.N; i++ { - sinkElement ^= sampleNTT(bytes.Repeat([]byte("A"), 32), '4', '2')[0] - } -} - var sink byte func BenchmarkKeyGen(b *testing.B) { - d := make([]byte, 32) - rand.Read(d) - z := make([]byte, 32) - rand.Read(z) + var dk DecapsulationKey + var d, z [32]byte + rand.Read(d[:]) + rand.Read(z[:]) b.ResetTimer() for i := 0; i < b.N; i++ { - ek, dk := kemKeyGen(d, z) - sink ^= ek[0] ^ dk[0] + dk := kemKeyGen(&dk, &d, &z) + sink ^= dk.EncapsulationKey()[0] } } @@ -369,12 +396,13 @@ func BenchmarkEncaps(b *testing.B) { rand.Read(d) z := make([]byte, 32) rand.Read(z) - m := make([]byte, 32) - rand.Read(m) - ek, _ := kemKeyGen(d, z) + var m [messageSize]byte + rand.Read(m[:]) + ek, _ := GenerateKeyDerand(b, d, z) + var c [CiphertextSize]byte b.ResetTimer() for i := 0; i < b.N; i++ { - c, K, err := kemEncaps(ek, m) + c, K, err := kemEncaps(&c, ek, &m) if err != nil { b.Fatal(err) } @@ -389,41 +417,42 @@ func BenchmarkDecaps(b *testing.B) { rand.Read(z) m := make([]byte, 32) rand.Read(m) - ek, dk := kemKeyGen(d, z) - c, _, err := kemEncaps(ek, m) + ek, dk := GenerateKeyDerand(b, d, z) + c, _, err := EncapsulateDerand(ek, m) if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - K, err := kemDecaps(dk, c) - if err != nil { - b.Fatal(err) - } + K := kemDecaps(dk, (*[CiphertextSize]byte)(c)) sink ^= K[0] } } func BenchmarkRoundTrip(b *testing.B) { - ek, dk, err := GenerateKey() + dk, err := GenerateKey() if err != nil { b.Fatal(err) } + ek := dk.EncapsulationKey() c, _, err := Encapsulate(ek) if err != nil { b.Fatal(err) } b.Run("Alice", func(b *testing.B) { for i := 0; i < b.N; i++ { - ekS, dkS, err := GenerateKey() + dkS, err := GenerateKey() if err != nil { b.Fatal(err) } + ekS := dkS.EncapsulationKey() + sink ^= ekS[0] + Ks, err := Decapsulate(dk, c) if err != nil { b.Fatal(err) } - sink ^= ekS[0] ^ dkS[0] ^ Ks[0] + sink ^= Ks[0] } }) b.Run("Bob", func(b *testing.B) {