]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/mlkem768: various performance optimizations
authorFilippo Valsorda <filippo@golang.org>
Mon, 15 Apr 2024 01:56:10 +0000 (03:56 +0200)
committerGopher Robot <gobot@golang.org>
Mon, 6 May 2024 15:27:18 +0000 (15:27 +0000)
goos: linux
goarch: amd64
pkg: crypto/internal/mlkem768
cpu: Intel(R) Core(TM) i5-7400 CPU @ 3.00GHz
                  │  c0a0ba254c  │              2aeb615fa6               │
                  │    sec/op    │   sec/op     vs base                  │
KeyGen-4             73.36µ ± 0%   67.38µ ± 1%   -8.15% (p=0.000 n=20)
Encaps-4            108.96µ ± 0%   99.56µ ± 1%   -8.63% (p=0.000 n=20)
Decaps-4            132.19µ ± 0%   96.85µ ± 0%  -26.74% (p=0.000 n=20)
RoundTrip/Alice-4    216.4µ ± 0%   173.1µ ± 0%  -20.01% (p=0.000 n=20)
RoundTrip/Bob-4      109.5µ ± 0%   100.5µ ± 0%   -8.19% (p=0.000 n=20)

Change-Id: I600116baa0b390bb83950a42c7693cd7806dba9a
Reviewed-on: https://go-review.googlesource.com/c/go/+/578797
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>

src/crypto/internal/mlkem768/mlkem768.go
src/crypto/internal/mlkem768/mlkem768_test.go

index c6b191c1ae219846882ab489505a66633e3dbf57..24bedea84f5baffc02d82f4ce98bea19052f6a88 100644 (file)
@@ -68,57 +68,123 @@ const (
        SeedSize             = 32 + 32
 )
 
-// GenerateKey generates an encapsulation key and a corresponding decapsulation
-// key, drawing random bytes from crypto/rand.
-//
-// The decapsulation key must be kept secret.
-func GenerateKey() (encapsulationKey, decapsulationKey []byte, err error) {
-       d := make([]byte, 32)
-       if _, err := rand.Read(d); err != nil {
-               return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
+// A DecapsulationKey is the secret key used to decapsulate a shared key from a
+// ciphertext. It includes various precomputed values.
+type DecapsulationKey struct {
+       dk [DecapsulationKeySize]byte
+       encryptionKey
+       decryptionKey
+}
+
+// Bytes returns the extended encoding of the decapsulation key, according to
+// FIPS 203 (DRAFT).
+func (dk *DecapsulationKey) Bytes() []byte {
+       var b [DecapsulationKeySize]byte
+       copy(b[:], dk.dk[:])
+       return b[:]
+}
+
+// EncapsulationKey returns the public encapsulation key necessary to produce
+// ciphertexts.
+func (dk *DecapsulationKey) EncapsulationKey() []byte {
+       var b [EncapsulationKeySize]byte
+       copy(b[:], dk.dk[decryptionKeySize:])
+       return b[:]
+}
+
+// encryptionKey is the parsed and expanded form of a PKE encryption key.
+type encryptionKey struct {
+       t [k]nttElement     // ByteDecode₁₂(ek[:384k])
+       A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
+}
+
+// decryptionKey is the parsed and expanded form of a PKE decryption key.
+type decryptionKey struct {
+       s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
+}
+
+// GenerateKey generates a new decapsulation key, drawing random bytes from
+// crypto/rand. The decapsulation key must be kept secret.
+func GenerateKey() (*DecapsulationKey, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       dk := &DecapsulationKey{}
+       return generateKey(dk)
+}
+
+func generateKey(dk *DecapsulationKey) (*DecapsulationKey, error) {
+       var d [32]byte
+       if _, err := rand.Read(d[:]); err != nil {
+               return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
        }
-       z := make([]byte, 32)
-       if _, err := rand.Read(z); err != nil {
-               return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
+       var z [32]byte
+       if _, err := rand.Read(z[:]); err != nil {
+               return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
        }
-       ek, dk := kemKeyGen(d, z)
-       return ek, dk, nil
+       return kemKeyGen(dk, &d, &z), nil
+}
+
+// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte
+// seed in the "d || z" form. The seed must be uniformly random.
+func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       dk := &DecapsulationKey{}
+       return newKeyFromSeed(dk, seed)
 }
 
-// NewKeyFromSeed deterministically generates an encapsulation key and a
-// corresponding decapsulation key from a 64-byte seed. The seed must be
-// uniformly random.
-func NewKeyFromSeed(seed []byte) (encapsulationKey, decapsulationKey []byte, err error) {
+func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) {
        if len(seed) != SeedSize {
-               return nil, nil, errors.New("mlkem768: invalid seed length")
+               return nil, errors.New("mlkem768: invalid seed length")
        }
-       ek, dk := kemKeyGen(seed[:32], seed[32:])
-       return ek, dk, nil
+       d := (*[32]byte)(seed[:32])
+       z := (*[32]byte)(seed[32:])
+       return kemKeyGen(dk, d, z), nil
 }
 
-// kemKeyGen generates an encapsulation key and a corresponding decapsulation key.
-//
-// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15.
-func kemKeyGen(d, z []byte) (ek, dk []byte) {
-       ekPKE, dkPKE := pkeKeyGen(d)
-       dk = make([]byte, 0, DecapsulationKeySize)
-       dk = append(dk, dkPKE...)
-       dk = append(dk, ekPKE...)
-       H := sha3.New256()
-       H.Write(ekPKE)
-       dk = H.Sum(dk)
-       dk = append(dk, z...)
-       return ekPKE, dk
+// NewKeyFromExtendedEncoding parses a decapsulation key from its FIPS 203
+// (DRAFT) extended encoding.
+func NewKeyFromExtendedEncoding(decapsulationKey []byte) (*DecapsulationKey, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       dk := &DecapsulationKey{}
+       return newKeyFromExtendedEncoding(dk, decapsulationKey)
 }
 
-// pkeKeyGen generates a key pair for the underlying PKE from a 32-byte random seed.
+func newKeyFromExtendedEncoding(dk *DecapsulationKey, dkBytes []byte) (*DecapsulationKey, error) {
+       if len(dkBytes) != DecapsulationKeySize {
+               return nil, errors.New("mlkem768: invalid decapsulation key length")
+       }
+
+       // Note that we don't check that H(ek) matches ekPKE, as that's not
+       // specified in FIPS 203 (DRAFT). This is one reason to prefer the seed
+       // private key format.
+       dk.dk = [DecapsulationKeySize]byte(dkBytes)
+
+       dkPKE := dkBytes[:decryptionKeySize]
+       if err := parseDK(&dk.decryptionKey, dkPKE); err != nil {
+               return nil, err
+       }
+
+       ekPKE := dkBytes[decryptionKeySize : decryptionKeySize+encryptionKeySize]
+       if err := parseEK(&dk.encryptionKey, ekPKE); err != nil {
+               return nil, err
+       }
+
+       return dk, nil
+}
+
+// kemKeyGen generates a decapsulation key.
 //
-// It implements K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12.
-func pkeKeyGen(d []byte) (ek, dk []byte) {
-       G := sha3.Sum512(d)
+// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15, and
+// K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. The two are merged
+// to save copies and allocations.
+func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey {
+       if dk == nil {
+               dk = &DecapsulationKey{}
+       }
+
+       G := sha3.Sum512(d[:])
        ρ, σ := G[:32], G[32:]
 
-       A := make([]nttElement, k*k)
+       A := &dk.A
        for i := byte(0); i < k; i++ {
                for j := byte(0); j < k; j++ {
                        // Note that this is consistent with Kyber round 3, rather than with
@@ -129,36 +195,51 @@ func pkeKeyGen(d []byte) (ek, dk []byte) {
        }
 
        var N byte
-       s, e := make([]nttElement, k), make([]nttElement, k)
+       s := &dk.s
        for i := range s {
                s[i] = ntt(samplePolyCBD(σ, N))
                N++
        }
+       e := make([]nttElement, k)
        for i := range e {
                e[i] = ntt(samplePolyCBD(σ, N))
                N++
        }
 
-       t := make([]nttElement, k) // A ◦ s + e
-       for i := range t {
+       t := &dk.t
+       for i := range t { // t = A ◦ s + e
                t[i] = e[i]
                for j := range s {
                        t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
                }
        }
 
-       ek = make([]byte, 0, encryptionKeySize)
+       // dkPKE ← ByteEncode₁₂(s)
+       // ekPKE ← ByteEncode₁₂(t) || ρ
+       // ek ← ekPKE
+       // dk ← dkPKE || ek || H(ek) || z
+       dkB := dk.dk[:0]
+
+       for i := range s {
+               dkB = polyByteEncode(dkB, s[i])
+       }
+
        for i := range t {
-               ek = polyByteEncode(ek, t[i])
+               dkB = polyByteEncode(dkB, t[i])
        }
-       ek = append(ek, ρ...)
+       dkB = append(dkB, ρ...)
 
-       dk = make([]byte, 0, decryptionKeySize)
-       for i := range s {
-               dk = polyByteEncode(dk, s[i])
+       H := sha3.New256()
+       H.Write(dkB[decryptionKeySize:])
+       dkB = H.Sum(dkB)
+
+       dkB = append(dkB, z[:]...)
+
+       if len(dkB) != len(dk.dk) {
+               panic("mlkem768: internal error: invalid decapsulation key size")
        }
 
-       return ek, dk
+       return dk
 }
 
 // Encapsulate generates a shared key and an associated ciphertext from an
@@ -167,65 +248,79 @@ func pkeKeyGen(d []byte) (ek, dk []byte) {
 //
 // The shared key must be kept secret.
 func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
+       // The actual logic is in a separate function to outline this allocation.
+       var cc [CiphertextSize]byte
+       return encapsulate(&cc, encapsulationKey)
+}
+
+func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
        if len(encapsulationKey) != EncapsulationKeySize {
                return nil, nil, errors.New("mlkem768: invalid encapsulation key length")
        }
-       m := make([]byte, messageSize)
-       if _, err := rand.Read(m); err != nil {
+       var m [messageSize]byte
+       if _, err := rand.Read(m[:]); err != nil {
                return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error())
        }
-       ciphertext, sharedKey, err = kemEncaps(encapsulationKey, m)
-       if err != nil {
-               return nil, nil, err
-       }
-       return ciphertext, sharedKey, nil
+       return kemEncaps(cc, encapsulationKey, &m)
 }
 
 // kemEncaps generates a shared key and an associated ciphertext.
 //
 // It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16.
-func kemEncaps(ek, m []byte) (c, K []byte, err error) {
-       H := sha3.Sum256(ek)
+func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) {
+       if cc == nil {
+               cc = &[CiphertextSize]byte{}
+       }
+
+       H := sha3.Sum256(ek[:])
        g := sha3.New512()
-       g.Write(m)
+       g.Write(m[:])
        g.Write(H[:])
        G := g.Sum(nil)
        K, r := G[:SharedKeySize], G[SharedKeySize:]
-       c, err = pkeEncrypt(ek, m, r)
-       return c, K, err
+       var ex encryptionKey
+       if err := parseEK(&ex, ek[:]); err != nil {
+               return nil, nil, err
+       }
+       c = pkeEncrypt(cc, &ex, m, r)
+       return c, K, nil
 }
 
-// pkeEncrypt encrypt a plaintext message. It expects ek (the encryption key) to
-// be 1184 bytes, and m (the message) and rnd (the randomness) to be 32 bytes.
+// parseEK parses an encryption key from its encoded form.
 //
-// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13.
-func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) {
-       if len(ek) != encryptionKeySize {
-               return nil, errors.New("mlkem768: invalid encryption key length")
-       }
-       if len(m) != messageSize {
-               return nil, errors.New("mlkem768: invalid messages length")
+// It implements the initial stages of K-PKE.Encrypt according to FIPS 203
+// (DRAFT), Algorithm 13.
+func parseEK(ex *encryptionKey, ekPKE []byte) error {
+       if len(ekPKE) != encryptionKeySize {
+               return errors.New("mlkem768: invalid encryption key length")
        }
 
-       t := make([]nttElement, k)
-       for i := range t {
+       for i := range ex.t {
                var err error
-               t[i], err = polyByteDecode[nttElement](ek[:encodingSize12])
+               ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
                if err != nil {
-                       return nil, err
+                       return err
                }
-               ek = ek[encodingSize12:]
+               ekPKE = ekPKE[encodingSize12:]
        }
-       ρ := ek
+       ρ := ekPKE
 
-       AT := make([]nttElement, k*k)
        for i := byte(0); i < k; i++ {
                for j := byte(0); j < k; j++ {
-                       // Note that i and j are inverted, as we need the transposed of A.
-                       AT[i*k+j] = sampleNTT(ρ, i, j)
+                       // See the note in pkeKeyGen about the order of the indices being
+                       // consistent with Kyber round 3.
+                       ex.A[i*k+j] = sampleNTT(ρ, j, i)
                }
        }
 
+       return nil
+}
+
+// pkeEncrypt encrypt a plaintext message.
+//
+// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13,
+// although the computation of t and AT is done in parseEK.
+func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
        var N byte
        r, e1 := make([]nttElement, k), make([]ringElement, k)
        for i := range r {
@@ -242,125 +337,107 @@ func pkeEncrypt(ek, m, rnd []byte) ([]byte, error) {
        for i := range u {
                u[i] = e1[i]
                for j := range r {
-                       u[i] = polyAdd(u[i], inverseNTT(nttMul(AT[i*k+j], r[j])))
+                       // Note that i and j are inverted, as we need the transposed of A.
+                       u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.A[j*k+i], r[j])))
                }
        }
 
-       μ, err := ringDecodeAndDecompress1(m)
-       if err != nil {
-               return nil, err
-       }
+       μ := ringDecodeAndDecompress1(m)
 
        var vNTT nttElement // t⊺ ◦ r
-       for i := range t {
-               vNTT = polyAdd(vNTT, nttMul(t[i], r[i]))
+       for i := range ex.t {
+               vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
        }
        v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
 
-       c := make([]byte, 0, CiphertextSize)
+       c := cc[:0]
        for _, f := range u {
                c = ringCompressAndEncode10(c, f)
        }
        c = ringCompressAndEncode4(c, v)
 
-       return c, nil
+       return c
 }
 
 // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
-// If the decapsulation key or the ciphertext are not valid, Decapsulate returns
-// an error.
+// If the ciphertext is not valid, Decapsulate returns an error.
 //
 // The shared key must be kept secret.
-func Decapsulate(decapsulationKey, ciphertext []byte) (sharedKey []byte, err error) {
-       if len(decapsulationKey) != DecapsulationKeySize {
-               return nil, errors.New("mlkem768: invalid decapsulation key length")
-       }
+func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) {
        if len(ciphertext) != CiphertextSize {
                return nil, errors.New("mlkem768: invalid ciphertext length")
        }
-       return kemDecaps(decapsulationKey, ciphertext)
+       c := (*[CiphertextSize]byte)(ciphertext)
+       return kemDecaps(dk, c), nil
 }
 
 // kemDecaps produces a shared key from a ciphertext.
 //
 // It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17.
-func kemDecaps(dk, c []byte) (K []byte, err error) {
-       dkPKE := dk[:decryptionKeySize]
-       ekPKE := dk[decryptionKeySize : decryptionKeySize+encryptionKeySize]
-       h := dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32]
-       z := dk[decryptionKeySize+encryptionKeySize+32:]
-
-       m, err := pkeDecrypt(dkPKE, c)
-       if err != nil {
-               // This is only reachable if the ciphertext or the decryption key are
-               // encoded incorrectly, so it leaks no information about the message.
-               return nil, err
-       }
+func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) {
+       h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32]
+       z := dk.dk[decryptionKeySize+encryptionKeySize+32:]
+
+       m := pkeDecrypt(&dk.decryptionKey, c)
        g := sha3.New512()
-       g.Write(m)
+       g.Write(m[:])
        g.Write(h)
        G := g.Sum(nil)
        Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
        J := sha3.NewShake256()
        J.Write(z)
-       J.Write(c)
+       J.Write(c[:])
        Kout := make([]byte, SharedKeySize)
        J.Read(Kout)
-       c1, err := pkeEncrypt(ekPKE, m, r)
-       if err != nil {
-               // Likewise, this is only reachable if the encryption key is encoded
-               // incorrectly, so it leaks no secret information through timing.
-               return nil, err
-       }
+       var cc [CiphertextSize]byte
+       c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
 
-       subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c, c1), Kout, Kprime)
-       return Kout, nil
+       subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
+       return Kout
 }
 
-// pkeDecrypt decrypts a ciphertext. It expects dk (the decryption key) to
-// be 1152 bytes, and c (the ciphertext) to be 1088 bytes.
+// parseDK parses a decryption key from its encoded form.
 //
-// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14.
-func pkeDecrypt(dk, c []byte) ([]byte, error) {
-       if len(dk) != decryptionKeySize {
-               return nil, errors.New("mlkem768: invalid decryption key length")
-       }
-       if len(c) != CiphertextSize {
-               return nil, errors.New("mlkem768: invalid ciphertext length")
+// It implements the computation of s from K-PKE.Decrypt according to FIPS 203
+// (DRAFT), Algorithm 14.
+func parseDK(dx *decryptionKey, dkPKE []byte) error {
+       if len(dkPKE) != decryptionKeySize {
+               return errors.New("mlkem768: invalid decryption key length")
        }
 
-       u := make([]ringElement, k)
-       for i := range u {
-               f, err := ringDecodeAndDecompress10(c[:encodingSize10])
+       for i := range dx.s {
+               f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12])
                if err != nil {
-                       return nil, err
+                       return err
                }
-               u[i] = f
-               c = c[encodingSize10:]
+               dx.s[i] = f
+               dkPKE = dkPKE[encodingSize12:]
        }
 
-       v, err := ringDecodeAndDecompress4(c)
-       if err != nil {
-               return nil, err
-       }
+       return nil
+}
 
-       s := make([]nttElement, k)
-       for i := range s {
-               f, err := polyByteDecode[nttElement](dk[:encodingSize12])
-               if err != nil {
-                       return nil, err
-               }
-               s[i] = f
-               dk = dk[encodingSize12:]
+// pkeDecrypt decrypts a ciphertext.
+//
+// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14,
+// although the computation of s is done in parseDK.
+func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte {
+       u := make([]ringElement, k)
+       for i := range u {
+               b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
+               u[i] = ringDecodeAndDecompress10(b)
        }
 
+       b := (*[encodingSize4]byte)(c[encodingSize10*k:])
+       v := ringDecodeAndDecompress4(b)
+
        var mask nttElement // s⊺ ◦ NTT(u)
-       for i := range s {
-               mask = polyAdd(mask, nttMul(s[i], ntt(u[i])))
+       for i := range dx.s {
+               mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
        }
        w := polySub(v, inverseNTT(mask))
 
-       return ringCompressAndEncode1(nil, w), nil
+       return ringCompressAndEncode1(nil, w)
 }
 
 // fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
@@ -397,7 +474,7 @@ const (
        barrettShift      = 24   // log₂(2¹² * 2¹²)
 )
 
-// fieldReduce reduces a value a < q² using Barrett reduction, to avoid
+// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid
 // potentially variable-time division.
 func fieldReduce(a uint32) fieldElement {
        quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift)
@@ -409,6 +486,21 @@ func fieldMul(a, b fieldElement) fieldElement {
        return fieldReduce(x)
 }
 
+// fieldMulSub returns a * (b - c). This operation is fused to save a
+// fieldReduceOnce after the subtraction.
+func fieldMulSub(a, b, c fieldElement) fieldElement {
+       x := uint32(a) * uint32(b-c+q)
+       return fieldReduce(x)
+}
+
+// fieldAddMul returns a * b + c * d. This operation is fused to save a
+// fieldReduceOnce and a fieldReduce.
+func fieldAddMul(a, b, c, d fieldElement) fieldElement {
+       x := uint32(a) * uint32(b)
+       x += uint32(c) * uint32(d)
+       return fieldReduce(x)
+}
+
 // compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to
 // FIPS 203 (DRAFT), Definition 4.5.
 func compress(x fieldElement, d uint8) uint16 {
@@ -558,17 +650,14 @@ func ringCompressAndEncode1(s []byte, f ringElement) []byte {
 //
 // It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5,
 // followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6.
-func ringDecodeAndDecompress1(b []byte) (ringElement, error) {
-       if len(b) != encodingSize1 {
-               return ringElement{}, errors.New("mlkem768: invalid message length")
-       }
+func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement {
        var f ringElement
        for i := range f {
                b_i := b[i/8] >> (i % 8) & 1
                const halfQ = (q + 1) / 2        // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3
                f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋
        }
-       return f, nil
+       return f
 }
 
 // ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s,
@@ -589,16 +678,13 @@ func ringCompressAndEncode4(s []byte, f ringElement) []byte {
 //
 // It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5,
 // followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6.
-func ringDecodeAndDecompress4(b []byte) (ringElement, error) {
-       if len(b) != encodingSize4 {
-               return ringElement{}, errors.New("mlkem768: invalid encoding length")
-       }
+func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement {
        var f ringElement
        for i := 0; i < n; i += 2 {
                f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4))
                f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4))
        }
-       return f, nil
+       return f
 }
 
 // ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s,
@@ -629,10 +715,8 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte {
 //
 // It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5,
 // followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6.
-func ringDecodeAndDecompress10(b []byte) (ringElement, error) {
-       if len(b) != encodingSize10 {
-               return ringElement{}, errors.New("mlkem768: invalid encoding length")
-       }
+func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
+       b := bb[:]
        var f ringElement
        for i := 0; i < n; i += 4 {
                x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32
@@ -642,7 +726,7 @@ func ringDecodeAndDecompress10(b []byte) (ringElement, error) {
                f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10))
                f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10))
        }
-       return f, nil
+       return f
 }
 
 // samplePolyCBD draws a ringElement from the special Dη distribution given a
@@ -681,11 +765,12 @@ var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637,
 // It implements MultiplyNTTs, according to FIPS 203 (DRAFT), Algorithm 10.
 func nttMul(f, g nttElement) nttElement {
        var h nttElement
-       for i := 0; i < 128; i++ {
-               a0, a1 := f[2*i], f[2*i+1]
-               b0, b1 := g[2*i], g[2*i+1]
-               h[2*i] = fieldAdd(fieldMul(a0, b0), fieldMul(fieldMul(a1, b1), gammas[i]))
-               h[2*i+1] = fieldAdd(fieldMul(a0, b1), fieldMul(a1, b0))
+       // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826.
+       for i := 0; i < 256; i += 2 {
+               a0, a1 := f[i], f[i+1]
+               b0, b1 := g[i], g[i+1]
+               h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2])
+               h[i+1] = fieldAddMul(a0, b1, a1, b0)
        }
        return h
 }
@@ -702,18 +787,12 @@ func ntt(f ringElement) nttElement {
                for start := 0; start < 256; start += 2 * len {
                        zeta := zetas[k]
                        k++
-                       for j := start; j < start+len; j += 2 {
-                               // Loop 2x unrolled for performance.
-                               {
-                                       t := fieldMul(zeta, f[j+len])
-                                       f[j+len] = fieldSub(f[j], t)
-                                       f[j] = fieldAdd(f[j], t)
-                               }
-                               {
-                                       t := fieldMul(zeta, f[j+1+len])
-                                       f[j+1+len] = fieldSub(f[j+1], t)
-                                       f[j+1] = fieldAdd(f[j+1], t)
-                               }
+                       // Bounds check elimination hint.
+                       f, flen := f[start:start+len], f[start+len:start+len+len]
+                       for j := 0; j < len; j++ {
+                               t := fieldMul(zeta, flen[j])
+                               flen[j] = fieldSub(f[j], t)
+                               f[j] = fieldAdd(f[j], t)
                        }
                }
        }
@@ -729,18 +808,12 @@ func inverseNTT(f nttElement) ringElement {
                for start := 0; start < 256; start += 2 * len {
                        zeta := zetas[k]
                        k--
-                       for j := start; j < start+len; j += 2 {
-                               // Loop 2x unrolled for performance.
-                               {
-                                       t := f[j]
-                                       f[j] = fieldAdd(t, f[j+len])
-                                       f[j+len] = fieldMul(zeta, fieldSub(f[j+len], t))
-                               }
-                               {
-                                       t := f[j+1]
-                                       f[j+1] = fieldAdd(t, f[j+1+len])
-                                       f[j+1+len] = fieldMul(zeta, fieldSub(f[j+1+len], t))
-                               }
+                       // Bounds check elimination hint.
+                       f, flen := f[start:start+len], f[start+len:start+len+len]
+                       for j := 0; j < len; j++ {
+                               t := f[j]
+                               f[j] = fieldAdd(t, flen[j])
+                               flen[j] = fieldMulSub(zeta, flen[j], t)
                        }
                }
        }
index 6e2ac769efcb71bd0fa52ca8cf8c99cea073deea..b91b42a424f8dae15916b0c1dd2667cd6df58dde 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto/rand"
        _ "embed"
        "encoding/hex"
+       "errors"
        "flag"
        "math/big"
        "strconv"
@@ -17,6 +18,16 @@ import (
        "golang.org/x/crypto/sha3"
 )
 
+func TestFieldReduce(t *testing.T) {
+       for a := uint32(0); a < 2*q*q; a++ {
+               got := fieldReduce(a)
+               exp := fieldElement(a % q)
+               if got != exp {
+                       t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
+               }
+       }
+}
+
 func TestFieldAdd(t *testing.T) {
        for a := fieldElement(0); a < q; a++ {
                for b := fieldElement(0); b < q; b++ {
@@ -188,11 +199,11 @@ func TestGammas(t *testing.T) {
 }
 
 func TestRoundTrip(t *testing.T) {
-       ek, dk, err := GenerateKey()
+       dk, err := GenerateKey()
        if err != nil {
                t.Fatal(err)
        }
-       c, Ke, err := Encapsulate(ek)
+       c, Ke, err := Encapsulate(dk.EncapsulationKey())
        if err != nil {
                t.Fatal(err)
        }
@@ -204,21 +215,21 @@ func TestRoundTrip(t *testing.T) {
                t.Fail()
        }
 
-       ek1, dk1, err := GenerateKey()
+       dk1, err := GenerateKey()
        if err != nil {
                t.Fatal(err)
        }
-       if bytes.Equal(ek, ek1) {
+       if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
                t.Fail()
        }
-       if bytes.Equal(dk, dk1) {
+       if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
                t.Fail()
        }
-       if bytes.Equal(dk[len(dk)-32:], dk1[len(dk)-32:]) {
+       if bytes.Equal(dk.Bytes()[EncapsulationKeySize-32:], dk1.Bytes()[EncapsulationKeySize-32:]) {
                t.Fail()
        }
 
-       c1, Ke1, err := Encapsulate(ek)
+       c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
        if err != nil {
                t.Fatal(err)
        }
@@ -231,10 +242,11 @@ func TestRoundTrip(t *testing.T) {
 }
 
 func TestBadLengths(t *testing.T) {
-       ek, dk, err := GenerateKey()
+       dk, err := GenerateKey()
        if err != nil {
                t.Fatal(err)
        }
+       ek := dk.EncapsulationKey()
 
        for i := 0; i < len(ek)-1; i++ {
                if _, _, err := Encapsulate(ek[:i]); err == nil {
@@ -254,15 +266,15 @@ func TestBadLengths(t *testing.T) {
                t.Fatal(err)
        }
 
-       for i := 0; i < len(dk)-1; i++ {
-               if _, err := Decapsulate(dk[:i], c); err == nil {
+       for i := 0; i < len(dk.Bytes())-1; i++ {
+               if _, err := NewKeyFromExtendedEncoding(dk.Bytes()[:i]); err == nil {
                        t.Errorf("expected error for dk length %d", i)
                }
        }
-       dkLong := dk
+       dkLong := dk.Bytes()
        for i := 0; i < 100; i++ {
                dkLong = append(dkLong, 0)
-               if _, err := Decapsulate(dkLong, c); err == nil {
+               if _, err := NewKeyFromExtendedEncoding(dkLong); err == nil {
                        t.Errorf("expected error for dk length %d", len(dkLong))
                }
        }
@@ -281,6 +293,29 @@ func TestBadLengths(t *testing.T) {
        }
 }
 
+func EncapsulateDerand(ek, m []byte) (c, K []byte, err error) {
+       if len(m) != messageSize {
+               return nil, nil, errors.New("bad message length")
+       }
+       return kemEncaps(nil, ek, (*[messageSize]byte)(m))
+}
+
+func DecapsulateFromBytes(dkBytes []byte, c []byte) ([]byte, error) {
+       dk, err := NewKeyFromExtendedEncoding(dkBytes)
+       if err != nil {
+               return nil, err
+       }
+       return Decapsulate(dk, c)
+}
+
+func GenerateKeyDerand(t testing.TB, d, z []byte) ([]byte, *DecapsulationKey) {
+       if len(d) != 32 || len(z) != 32 {
+               t.Fatal("bad length")
+       }
+       dk := kemKeyGen(nil, (*[32]byte)(d), (*[32]byte)(z))
+       return dk.EncapsulationKey(), dk
+}
+
 var millionFlag = flag.Bool("million", false, "run the million vector test")
 
 // TestPQCrystalsAccumulated accumulates the 10k vectors generated by the
@@ -308,19 +343,19 @@ func TestPQCrystalsAccumulated(t *testing.T) {
        for i := 0; i < n; i++ {
                s.Read(d)
                s.Read(z)
-               ek, dk := kemKeyGen(d, z)
+               ek, dk := GenerateKeyDerand(t, d, z)
                o.Write(ek)
-               o.Write(dk)
+               o.Write(dk.Bytes())
 
                s.Read(msg)
-               ct, k, err := kemEncaps(ek, msg)
+               ct, k, err := EncapsulateDerand(ek, msg)
                if err != nil {
                        t.Fatal(err)
                }
                o.Write(ct)
                o.Write(k)
 
-               kk, err := kemDecaps(dk, ct)
+               kk, err := Decapsulate(dk, ct)
                if err != nil {
                        t.Fatal(err)
                }
@@ -329,7 +364,7 @@ func TestPQCrystalsAccumulated(t *testing.T) {
                }
 
                s.Read(ct1)
-               k1, err := kemDecaps(dk, ct1)
+               k1, err := Decapsulate(dk, ct1)
                if err != nil {
                        t.Fatal(err)
                }
@@ -342,25 +377,17 @@ func TestPQCrystalsAccumulated(t *testing.T) {
        }
 }
 
-var sinkElement fieldElement
-
-func BenchmarkSampleNTT(b *testing.B) {
-       for i := 0; i < b.N; i++ {
-               sinkElement ^= sampleNTT(bytes.Repeat([]byte("A"), 32), '4', '2')[0]
-       }
-}
-
 var sink byte
 
 func BenchmarkKeyGen(b *testing.B) {
-       d := make([]byte, 32)
-       rand.Read(d)
-       z := make([]byte, 32)
-       rand.Read(z)
+       var dk DecapsulationKey
+       var d, z [32]byte
+       rand.Read(d[:])
+       rand.Read(z[:])
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
-               ek, dk := kemKeyGen(d, z)
-               sink ^= ek[0] ^ dk[0]
+               dk := kemKeyGen(&dk, &d, &z)
+               sink ^= dk.EncapsulationKey()[0]
        }
 }
 
@@ -369,12 +396,13 @@ func BenchmarkEncaps(b *testing.B) {
        rand.Read(d)
        z := make([]byte, 32)
        rand.Read(z)
-       m := make([]byte, 32)
-       rand.Read(m)
-       ek, _ := kemKeyGen(d, z)
+       var m [messageSize]byte
+       rand.Read(m[:])
+       ek, _ := GenerateKeyDerand(b, d, z)
+       var c [CiphertextSize]byte
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
-               c, K, err := kemEncaps(ek, m)
+               c, K, err := kemEncaps(&c, ek, &m)
                if err != nil {
                        b.Fatal(err)
                }
@@ -389,41 +417,42 @@ func BenchmarkDecaps(b *testing.B) {
        rand.Read(z)
        m := make([]byte, 32)
        rand.Read(m)
-       ek, dk := kemKeyGen(d, z)
-       c, _, err := kemEncaps(ek, m)
+       ek, dk := GenerateKeyDerand(b, d, z)
+       c, _, err := EncapsulateDerand(ek, m)
        if err != nil {
                b.Fatal(err)
        }
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
-               K, err := kemDecaps(dk, c)
-               if err != nil {
-                       b.Fatal(err)
-               }
+               K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
                sink ^= K[0]
        }
 }
 
 func BenchmarkRoundTrip(b *testing.B) {
-       ek, dk, err := GenerateKey()
+       dk, err := GenerateKey()
        if err != nil {
                b.Fatal(err)
        }
+       ek := dk.EncapsulationKey()
        c, _, err := Encapsulate(ek)
        if err != nil {
                b.Fatal(err)
        }
        b.Run("Alice", func(b *testing.B) {
                for i := 0; i < b.N; i++ {
-                       ekS, dkS, err := GenerateKey()
+                       dkS, err := GenerateKey()
                        if err != nil {
                                b.Fatal(err)
                        }
+                       ekS := dkS.EncapsulationKey()
+                       sink ^= ekS[0]
+
                        Ks, err := Decapsulate(dk, c)
                        if err != nil {
                                b.Fatal(err)
                        }
-                       sink ^= ekS[0] ^ dkS[0] ^ Ks[0]
+                       sink ^= Ks[0]
                }
        })
        b.Run("Bob", func(b *testing.B) {