]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/fips/mlkem: implement ML-KEM-1024
authorFilippo Valsorda <filippo@golang.org>
Wed, 23 Oct 2024 09:36:56 +0000 (11:36 +0200)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 20:43:01 +0000 (20:43 +0000)
Decided to automatically duplicate the high-level code to avoid growing
the ML-KEM-768 data structures.

For #70122

Change-Id: I5c705b71ee1e23adba9113d5cf6b6e505c028967
Reviewed-on: https://go-review.googlesource.com/c/go/+/621983
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
src/crypto/internal/fips/mlkem/field.go
src/crypto/internal/fips/mlkem/field_test.go
src/crypto/internal/fips/mlkem/generate1024.go [new file with mode: 0644]
src/crypto/internal/fips/mlkem/mlkem1024.go [new file with mode: 0644]
src/crypto/internal/fips/mlkem/mlkem768.go
src/crypto/internal/fips/mlkem/mlkem_test.go [moved from src/crypto/internal/fips/mlkem/mlkem768_test.go with 64% similarity]

index 9c22ec86f997a1105a65ce1767269841b2575b37..1532f031f2b051f59135d2044314d24854b2171b 100644 (file)
@@ -263,7 +263,7 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte {
        s, b := sliceForAppend(s, encodingSize10)
        for i := 0; i < n; i += 4 {
                var x uint64
-               x |= uint64(compress(f[i+0], 10))
+               x |= uint64(compress(f[i], 10))
                x |= uint64(compress(f[i+1], 10)) << 10
                x |= uint64(compress(f[i+2], 10)) << 20
                x |= uint64(compress(f[i+3], 10)) << 30
@@ -296,6 +296,101 @@ func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
        return f
 }
 
+// ringCompressAndEncode appends an encoding of a ring element to s,
+// compressing each coefficient to d bits.
+//
+// It implements Compress, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode(s []byte, f ringElement, d uint8) []byte {
+       var b byte
+       var bIdx uint8
+       for i := 0; i < n; i++ {
+               c := compress(f[i], d)
+               var cIdx uint8
+               for cIdx < d {
+                       b |= byte(c>>cIdx) << bIdx
+                       bits := min(8-bIdx, d-cIdx)
+                       bIdx += bits
+                       cIdx += bits
+                       if bIdx == 8 {
+                               s = append(s, b)
+                               b = 0
+                               bIdx = 0
+                       }
+               }
+       }
+       if bIdx != 0 {
+               panic("mlkem: internal error: bitsFilled != 0")
+       }
+       return s
+}
+
+// ringDecodeAndDecompress decodes an encoding of a ring element where
+// each d bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode, according to FIPS 203, Algorithm 6,
+// followed by Decompress, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress(b []byte, d uint8) ringElement {
+       var f ringElement
+       var bIdx uint8
+       for i := 0; i < n; i++ {
+               var c uint16
+               var cIdx uint8
+               for cIdx < d {
+                       c |= uint16(b[0]>>bIdx) << cIdx
+                       c &= (1 << d) - 1
+                       bits := min(8-bIdx, d-cIdx)
+                       bIdx += bits
+                       cIdx += bits
+                       if bIdx == 8 {
+                               b = b[1:]
+                               bIdx = 0
+                       }
+               }
+               f[i] = fieldElement(decompress(c, d))
+       }
+       if len(b) != 0 {
+               panic("mlkem: internal error: leftover bytes")
+       }
+       return f
+}
+
+// ringCompressAndEncode5 appends a 160-byte encoding of a ring element to s,
+// compressing eight coefficients per five bytes.
+//
+// It implements Compress₅, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₅, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode5(s []byte, f ringElement) []byte {
+       return ringCompressAndEncode(s, f, 5)
+}
+
+// ringDecodeAndDecompress5 decodes a 160-byte encoding of a ring element where
+// each five bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₅, according to FIPS 203, Algorithm 6,
+// followed by Decompress₅, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress5(bb *[encodingSize5]byte) ringElement {
+       return ringDecodeAndDecompress(bb[:], 5)
+}
+
+// ringCompressAndEncode11 appends a 352-byte encoding of a ring element to s,
+// compressing eight coefficients per eleven bytes.
+//
+// It implements Compress₁₁, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₁₁, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode11(s []byte, f ringElement) []byte {
+       return ringCompressAndEncode(s, f, 11)
+}
+
+// ringDecodeAndDecompress11 decodes a 352-byte encoding of a ring element where
+// each eleven bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₁₁, according to FIPS 203, Algorithm 6,
+// followed by Decompress₁₁, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress11(bb *[encodingSize11]byte) ringElement {
+       return ringDecodeAndDecompress(bb[:], 11)
+}
+
 // samplePolyCBD draws a ringElement from the special Dη distribution given a
 // stream of random bytes generated by the PRF function, according to FIPS 203,
 // Algorithm 8 and Definition 4.3.
index a842913627b3eec9ba53c53715ea4f714a2de174..3a6d983803ce5dda3d68171c65c36a0c8d09fc45 100644 (file)
@@ -5,7 +5,10 @@
 package mlkem
 
 import (
+       "bytes"
+       "crypto/rand"
        "math/big"
+       mathrand "math/rand/v2"
        "strconv"
        "testing"
 )
@@ -151,6 +154,81 @@ func TestDecompress(t *testing.T) {
        }
 }
 
+func randomRingElement() ringElement {
+       var r ringElement
+       for i := range r {
+               r[i] = fieldElement(mathrand.IntN(q))
+       }
+       return r
+}
+
+func TestEncodeDecode(t *testing.T) {
+       f := randomRingElement()
+       b := make([]byte, 12*n/8)
+       rand.Read(b)
+
+       // Compare ringCompressAndEncode to ringCompressAndEncodeN.
+       e1 := ringCompressAndEncode(nil, f, 10)
+       e2 := ringCompressAndEncode10(nil, f)
+       if !bytes.Equal(e1, e2) {
+               t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2)
+       }
+       e1 = ringCompressAndEncode(nil, f, 4)
+       e2 = ringCompressAndEncode4(nil, f)
+       if !bytes.Equal(e1, e2) {
+               t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2)
+       }
+       e1 = ringCompressAndEncode(nil, f, 1)
+       e2 = ringCompressAndEncode1(nil, f)
+       if !bytes.Equal(e1, e2) {
+               t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2)
+       }
+
+       // Compare ringDecodeAndDecompress to ringDecodeAndDecompressN.
+       g1 := ringDecodeAndDecompress(b[:encodingSize10], 10)
+       g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
+       if g1 != g2 {
+               t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2)
+       }
+       g1 = ringDecodeAndDecompress(b[:encodingSize4], 4)
+       g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
+       if g1 != g2 {
+               t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2)
+       }
+       g1 = ringDecodeAndDecompress(b[:encodingSize1], 1)
+       g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
+       if g1 != g2 {
+               t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2)
+       }
+
+       // Round-trip ringCompressAndEncode and ringDecodeAndDecompress.
+       for d := 1; d < 12; d++ {
+               encodingSize := d * n / 8
+               g := ringDecodeAndDecompress(b[:encodingSize], uint8(d))
+               out := ringCompressAndEncode(nil, g, uint8(d))
+               if !bytes.Equal(out, b[:encodingSize]) {
+                       t.Errorf("roundtrip failed for d = %d", d)
+               }
+       }
+
+       // Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN.
+       g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
+       out := ringCompressAndEncode10(nil, g)
+       if !bytes.Equal(out, b[:encodingSize10]) {
+               t.Errorf("roundtrip failed for specialized 10")
+       }
+       g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
+       out = ringCompressAndEncode4(nil, g)
+       if !bytes.Equal(out, b[:encodingSize4]) {
+               t.Errorf("roundtrip failed for specialized 4")
+       }
+       g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
+       out = ringCompressAndEncode1(nil, g)
+       if !bytes.Equal(out, b[:encodingSize1]) {
+               t.Errorf("roundtrip failed for specialized 1")
+       }
+}
+
 func BitRev7(n uint8) uint8 {
        if n>>7 != 0 {
                panic("not 7 bits")
diff --git a/src/crypto/internal/fips/mlkem/generate1024.go b/src/crypto/internal/fips/mlkem/generate1024.go
new file mode 100644 (file)
index 0000000..7ed68de
--- /dev/null
@@ -0,0 +1,123 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore
+
+package main
+
+import (
+       "flag"
+       "go/ast"
+       "go/format"
+       "go/parser"
+       "go/token"
+       "log"
+       "os"
+       "strings"
+)
+
+var replacements = map[string]string{
+       "k": "k1024",
+
+       "CiphertextSize768":       "CiphertextSize1024",
+       "EncapsulationKeySize768": "EncapsulationKeySize1024",
+
+       "encryptionKey": "encryptionKey1024",
+       "decryptionKey": "decryptionKey1024",
+
+       "EncapsulationKey768":    "EncapsulationKey1024",
+       "NewEncapsulationKey768": "NewEncapsulationKey1024",
+       "parseEK":                "parseEK1024",
+
+       "kemEncaps":  "kemEncaps1024",
+       "pkeEncrypt": "pkeEncrypt1024",
+
+       "DecapsulationKey768":    "DecapsulationKey1024",
+       "NewDecapsulationKey768": "NewDecapsulationKey1024",
+       "newKeyFromSeed":         "newKeyFromSeed1024",
+
+       "kemDecaps":  "kemDecaps1024",
+       "pkeDecrypt": "pkeDecrypt1024",
+
+       "GenerateKey768": "GenerateKey1024",
+       "generateKey":    "generateKey1024",
+
+       "kemKeyGen": "kemKeyGen1024",
+
+       "encodingSize4":             "encodingSize5",
+       "encodingSize10":            "encodingSize11",
+       "ringCompressAndEncode4":    "ringCompressAndEncode5",
+       "ringCompressAndEncode10":   "ringCompressAndEncode11",
+       "ringDecodeAndDecompress4":  "ringDecodeAndDecompress5",
+       "ringDecodeAndDecompress10": "ringDecodeAndDecompress11",
+}
+
+func main() {
+       inputFile := flag.String("input", "", "")
+       outputFile := flag.String("output", "", "")
+       flag.Parse()
+
+       fset := token.NewFileSet()
+       f, err := parser.ParseFile(fset, *inputFile, nil, parser.SkipObjectResolution|parser.ParseComments)
+       if err != nil {
+               log.Fatal(err)
+       }
+       cmap := ast.NewCommentMap(fset, f, f.Comments)
+
+       // Drop header comments.
+       cmap[ast.Node(f)] = nil
+
+       // Remove top-level consts used across the main and generated files.
+       var newDecls []ast.Decl
+       for _, decl := range f.Decls {
+               switch d := decl.(type) {
+               case *ast.GenDecl:
+                       if d.Tok == token.CONST {
+                               continue // Skip const declarations
+                       }
+                       if d.Tok == token.IMPORT {
+                               cmap[decl] = nil // Drop pre-import comments.
+                       }
+               }
+               newDecls = append(newDecls, decl)
+       }
+       f.Decls = newDecls
+
+       // Replace identifiers.
+       ast.Inspect(f, func(n ast.Node) bool {
+               switch x := n.(type) {
+               case *ast.Ident:
+                       if replacement, ok := replacements[x.Name]; ok {
+                               x.Name = replacement
+                       }
+               }
+               return true
+       })
+
+       // Replace identifiers in comments.
+       for _, c := range f.Comments {
+               for _, l := range c.List {
+                       for k, v := range replacements {
+                               if k == "k" {
+                                       continue
+                               }
+                               l.Text = strings.ReplaceAll(l.Text, k, v)
+                       }
+               }
+       }
+
+       out, err := os.Create(*outputFile)
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer out.Close()
+
+       out.WriteString("// Code generated by generate1024.go. DO NOT EDIT.\n\n")
+
+       f.Comments = cmap.Filter(f).Comments()
+       err = format.Node(out, fset, f)
+       if err != nil {
+               log.Fatal(err)
+       }
+}
diff --git a/src/crypto/internal/fips/mlkem/mlkem1024.go b/src/crypto/internal/fips/mlkem/mlkem1024.go
new file mode 100644 (file)
index 0000000..c77dae8
--- /dev/null
@@ -0,0 +1,342 @@
+// Code generated by generate1024.go. DO NOT EDIT.
+
+package mlkem
+
+import (
+       "crypto/internal/fips/drbg"
+       "crypto/internal/fips/sha3"
+       "crypto/internal/fips/subtle"
+       "errors"
+)
+
+// A DecapsulationKey1024 is the secret key used to decapsulate a shared key from a
+// ciphertext. It includes various precomputed values.
+type DecapsulationKey1024 struct {
+       d [32]byte // decapsulation key seed
+       z [32]byte // implicit rejection sampling seed
+
+       ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
+       h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
+
+       encryptionKey1024
+       decryptionKey1024
+}
+
+// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
+//
+// The decapsulation key must be kept secret.
+func (dk *DecapsulationKey1024) Bytes() []byte {
+       var b [SeedSize]byte
+       copy(b[:], dk.d[:])
+       copy(b[32:], dk.z[:])
+       return b[:]
+}
+
+// EncapsulationKey returns the public encapsulation key necessary to produce
+// ciphertexts.
+func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
+       return &EncapsulationKey1024{
+               ρ:                 dk.ρ,
+               h:                 dk.h,
+               encryptionKey1024: dk.encryptionKey1024,
+       }
+}
+
+// An EncapsulationKey1024 is the public key used to produce ciphertexts to be
+// decapsulated by the corresponding [DecapsulationKey1024].
+type EncapsulationKey1024 struct {
+       ρ [32]byte // sampleNTT seed for A
+       h [32]byte // H(ek)
+       encryptionKey1024
+}
+
+// Bytes returns the encapsulation key as a byte slice.
+func (ek *EncapsulationKey1024) Bytes() []byte {
+       // The actual logic is in a separate function to outline this allocation.
+       b := make([]byte, 0, EncapsulationKeySize1024)
+       return ek.bytes(b)
+}
+
+func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
+       for i := range ek.t {
+               b = polyByteEncode(b, ek.t[i])
+       }
+       b = append(b, ek.ρ[:]...)
+       return b
+}
+
+// encryptionKey1024 is the parsed and expanded form of a PKE encryption key.
+type encryptionKey1024 struct {
+       t [k1024]nttElement         // ByteDecode₁₂(ek[:384k])
+       a [k1024 * k1024]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
+}
+
+// decryptionKey1024 is the parsed and expanded form of a PKE decryption key.
+type decryptionKey1024 struct {
+       s [k1024]nttElement // ByteDecode₁₂(dk[:decryptionKey1024Size])
+}
+
+// GenerateKey1024 generates a new decapsulation key, drawing random bytes from
+// a DRBG. The decapsulation key must be kept secret.
+func GenerateKey1024() (*DecapsulationKey1024, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       dk := &DecapsulationKey1024{}
+       return generateKey1024(dk), nil
+}
+
+func generateKey1024(dk *DecapsulationKey1024) *DecapsulationKey1024 {
+       var d [32]byte
+       drbg.Read(d[:])
+       var z [32]byte
+       drbg.Read(z[:])
+       return kemKeyGen1024(dk, &d, &z)
+}
+
+// NewDecapsulationKey1024 parses a decapsulation key from a 64-byte
+// seed in the "d || z" form. The seed must be uniformly random.
+func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       dk := &DecapsulationKey1024{}
+       return newKeyFromSeed1024(dk, seed)
+}
+
+func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
+       if len(seed) != SeedSize {
+               return nil, errors.New("mlkem: invalid seed length")
+       }
+       d := (*[32]byte)(seed[:32])
+       z := (*[32]byte)(seed[32:])
+       return kemKeyGen1024(dk, d, z), nil
+}
+
+// kemKeyGen1024 generates a decapsulation key.
+//
+// It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
+// K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
+// copies and allocations.
+func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) *DecapsulationKey1024 {
+       if dk == nil {
+               dk = &DecapsulationKey1024{}
+       }
+       dk.d = *d
+       dk.z = *z
+
+       g := sha3.New512()
+       g.Write(d[:])
+       g.Write([]byte{k1024}) // Module dimension as a domain separator.
+       G := g.Sum(make([]byte, 0, 64))
+       ρ, σ := G[:32], G[32:]
+       dk.ρ = [32]byte(ρ)
+
+       A := &dk.a
+       for i := byte(0); i < k1024; i++ {
+               for j := byte(0); j < k1024; j++ {
+                       A[i*k1024+j] = sampleNTT(ρ, j, i)
+               }
+       }
+
+       var N byte
+       s := &dk.s
+       for i := range s {
+               s[i] = ntt(samplePolyCBD(σ, N))
+               N++
+       }
+       e := make([]nttElement, k1024)
+       for i := range e {
+               e[i] = ntt(samplePolyCBD(σ, N))
+               N++
+       }
+
+       t := &dk.t
+       for i := range t { // t = A ◦ s + e
+               t[i] = e[i]
+               for j := range s {
+                       t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
+               }
+       }
+
+       H := sha3.New256()
+       ek := dk.EncapsulationKey().Bytes()
+       H.Write(ek)
+       H.Sum(dk.h[:0])
+
+       return dk
+}
+
+// Encapsulate generates a shared key and an associated ciphertext from an
+// encapsulation key, drawing random bytes from a DRBG.
+//
+// The shared key must be kept secret.
+func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) {
+       // The actual logic is in a separate function to outline this allocation.
+       var cc [CiphertextSize1024]byte
+       return ek.encapsulate(&cc)
+}
+
+func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) {
+       var m [messageSize]byte
+       drbg.Read(m[:])
+       // Note that the modulus check (step 2 of the encapsulation key check from
+       // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK1024.
+       return kemEncaps1024(cc, ek, &m)
+}
+
+// kemEncaps1024 generates a shared key and an associated ciphertext.
+//
+// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
+func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) {
+       if cc == nil {
+               cc = &[CiphertextSize1024]byte{}
+       }
+
+       g := sha3.New512()
+       g.Write(m[:])
+       g.Write(ek.h[:])
+       G := g.Sum(nil)
+       K, r := G[:SharedKeySize], G[SharedKeySize:]
+       c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
+       return c, K
+}
+
+// NewEncapsulationKey1024 parses an encapsulation key from its encoded form.
+// If the encapsulation key is not valid, NewEncapsulationKey1024 returns an error.
+func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
+       // The actual logic is in a separate function to outline this allocation.
+       ek := &EncapsulationKey1024{}
+       return parseEK1024(ek, encapsulationKey)
+}
+
+// parseEK1024 parses an encryption key from its encoded form.
+//
+// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
+// Algorithm 14.
+func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
+       if len(ekPKE) != EncapsulationKeySize1024 {
+               return nil, errors.New("mlkem: invalid encapsulation key length")
+       }
+
+       h := sha3.New256()
+       h.Write(ekPKE)
+       h.Sum(ek.h[:0])
+
+       for i := range ek.t {
+               var err error
+               ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
+               if err != nil {
+                       return nil, err
+               }
+               ekPKE = ekPKE[encodingSize12:]
+       }
+       copy(ek.ρ[:], ekPKE)
+
+       for i := byte(0); i < k1024; i++ {
+               for j := byte(0); j < k1024; j++ {
+                       ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
+               }
+       }
+
+       return ek, nil
+}
+
+// pkeEncrypt1024 encrypt a plaintext message.
+//
+// It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
+// computation of t and AT is done in parseEK1024.
+func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
+       var N byte
+       r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
+       for i := range r {
+               r[i] = ntt(samplePolyCBD(rnd, N))
+               N++
+       }
+       for i := range e1 {
+               e1[i] = samplePolyCBD(rnd, N)
+               N++
+       }
+       e2 := samplePolyCBD(rnd, N)
+
+       u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1
+       for i := range u {
+               u[i] = e1[i]
+               for j := range r {
+                       // Note that i and j are inverted, as we need the transposed of A.
+                       u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
+               }
+       }
+
+       μ := ringDecodeAndDecompress1(m)
+
+       var vNTT nttElement // t⊺ ◦ r
+       for i := range ex.t {
+               vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
+       }
+       v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
+
+       c := cc[:0]
+       for _, f := range u {
+               c = ringCompressAndEncode11(c, f)
+       }
+       c = ringCompressAndEncode5(c, v)
+
+       return c
+}
+
+// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
+// If the ciphertext is not valid, Decapsulate returns an error.
+//
+// The shared key must be kept secret.
+func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
+       if len(ciphertext) != CiphertextSize1024 {
+               return nil, errors.New("mlkem: invalid ciphertext length")
+       }
+       c := (*[CiphertextSize1024]byte)(ciphertext)
+       // Note that the hash check (step 3 of the decapsulation input check from
+       // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
+       // validly generated by ML-KEM.KeyGen_internal.
+       return kemDecaps1024(dk, c), nil
+}
+
+// kemDecaps1024 produces a shared key from a ciphertext.
+//
+// It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
+func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
+       m := pkeDecrypt1024(&dk.decryptionKey1024, c)
+       g := sha3.New512()
+       g.Write(m[:])
+       g.Write(dk.h[:])
+       G := g.Sum(make([]byte, 0, 64))
+       Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
+       J := sha3.NewShake256()
+       J.Write(dk.z[:])
+       J.Write(c[:])
+       Kout := make([]byte, SharedKeySize)
+       J.Read(Kout)
+       var cc [CiphertextSize1024]byte
+       c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
+
+       subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
+       return Kout
+}
+
+// pkeDecrypt1024 decrypts a ciphertext.
+//
+// It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
+// although s is retained from kemKeyGen1024.
+func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
+       u := make([]ringElement, k1024)
+       for i := range u {
+               b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
+               u[i] = ringDecodeAndDecompress11(b)
+       }
+
+       b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
+       v := ringDecodeAndDecompress5(b)
+
+       var mask nttElement // s⊺ ◦ NTT(u)
+       for i := range dx.s {
+               mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
+       }
+       w := polySub(v, inverseNTT(mask))
+
+       return ringCompressAndEncode1(nil, w)
+}
index afbd31abe5d75fac7e75c220a5fc5fe035223a27..8cd6fffbcd13acf979c974d50261be292077e689 100644 (file)
@@ -5,8 +5,6 @@
 // Package mlkem implements the quantum-resistant key encapsulation method
 // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
 //
-// Only the recommended ML-KEM-768 parameter set is provided.
-//
 // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
 package mlkem
 
@@ -19,6 +17,11 @@ package mlkem
 //
 // Reviewers unfamiliar with polynomials or linear algebra might find the
 // background at https://words.filippo.io/kyber-math/ useful.
+//
+// This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
+// parameter set implementation is auto-generated from this file.
+//
+//go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
 
 import (
        "crypto/internal/fips/drbg"
@@ -35,7 +38,9 @@ const (
        // encodingSizeX is the byte size of a ringElement or nttElement encoded
        // by ByteEncode_X (FIPS 203, Algorithm 5).
        encodingSize12 = n * 12 / 8
+       encodingSize11 = n * 11 / 8
        encodingSize10 = n * 10 / 8
+       encodingSize5  = n * 5 / 8
        encodingSize4  = n * 4 / 8
        encodingSize1  = n * 1 / 8
 
@@ -49,11 +54,16 @@ const (
 const (
        k = 3
 
-       decryptionKeySize = k * encodingSize12
-       encryptionKeySize = k*encodingSize12 + 32
-
        CiphertextSize768       = k*encodingSize10 + encodingSize4
-       EncapsulationKeySize768 = encryptionKeySize
+       EncapsulationKeySize768 = k*encodingSize12 + 32
+)
+
+// ML-KEM-1024 parameters.
+const (
+       k1024 = 4
+
+       CiphertextSize1024       = k1024*encodingSize11 + encodingSize5
+       EncapsulationKeySize1024 = k1024*encodingSize12 + 32
 )
 
 // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
@@ -258,7 +268,7 @@ func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, erro
 // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
 // Algorithm 14.
 func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
-       if len(ekPKE) != encryptionKeySize {
+       if len(ekPKE) != EncapsulationKeySize768 {
                return nil, errors.New("mlkem: invalid encapsulation key length")
        }
 
similarity index 64%
rename from src/crypto/internal/fips/mlkem/mlkem768_test.go
rename to src/crypto/internal/fips/mlkem/mlkem_test.go
index 28d17fe81a752ef55f3f71f8aa2f534556fcab32..acd8f4821b963afcb94c6d738bec30056a16a4ef 100644 (file)
@@ -14,12 +14,36 @@ import (
        "testing"
 )
 
+type encapsulationKey interface {
+       Bytes() []byte
+       Encapsulate() ([]byte, []byte)
+}
+
+type decapsulationKey[E encapsulationKey] interface {
+       Bytes() []byte
+       Decapsulate([]byte) ([]byte, error)
+       EncapsulationKey() E
+}
+
 func TestRoundTrip(t *testing.T) {
-       dk, err := GenerateKey768()
+       t.Run("768", func(t *testing.T) {
+               testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
+       })
+       t.Run("1024", func(t *testing.T) {
+               testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
+       })
+}
+
+func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
+       t *testing.T, generateKey func() (D, error),
+       newEncapsulationKey func([]byte) (E, error),
+       newDecapsulationKey func([]byte) (D, error)) {
+       dk, err := generateKey()
        if err != nil {
                t.Fatal(err)
        }
-       c, Ke := dk.EncapsulationKey().Encapsulate()
+       ek := dk.EncapsulationKey()
+       c, Ke := ek.Encapsulate()
        Kd, err := dk.Decapsulate(c)
        if err != nil {
                t.Fatal(err)
@@ -28,28 +52,64 @@ func TestRoundTrip(t *testing.T) {
                t.Fail()
        }
 
-       dk1, err := GenerateKey768()
+       ek1, err := newEncapsulationKey(ek.Bytes())
        if err != nil {
                t.Fatal(err)
        }
-       if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) {
+       if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
                t.Fail()
        }
-       if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
+       dk1, err := newDecapsulationKey(dk.Bytes())
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
+               t.Fail()
+       }
+       c1, Ke1 := ek1.Encapsulate()
+       Kd1, err := dk1.Decapsulate(c1)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !bytes.Equal(Ke1, Kd1) {
                t.Fail()
        }
 
-       c1, Ke1 := dk.EncapsulationKey().Encapsulate()
-       if bytes.Equal(c, c1) {
+       dk2, err := generateKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+       if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
                t.Fail()
        }
-       if bytes.Equal(Ke, Ke1) {
+       if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
+               t.Fail()
+       }
+
+       c2, Ke2 := dk.EncapsulationKey().Encapsulate()
+       if bytes.Equal(c, c2) {
+               t.Fail()
+       }
+       if bytes.Equal(Ke, Ke2) {
                t.Fail()
        }
 }
 
 func TestBadLengths(t *testing.T) {
-       dk, err := GenerateKey768()
+       t.Run("768", func(t *testing.T) {
+               testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
+       })
+       t.Run("1024", func(t *testing.T) {
+               testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
+       })
+}
+
+func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
+       t *testing.T, generateKey func() (D, error),
+       newEncapsulationKey func([]byte) (E, error),
+       newDecapsulationKey func([]byte) (D, error)) {
+       dk, err := generateKey()
+       dkBytes := dk.Bytes()
        if err != nil {
                t.Fatal(err)
        }
@@ -57,15 +117,28 @@ func TestBadLengths(t *testing.T) {
        ekBytes := dk.EncapsulationKey().Bytes()
        c, _ := ek.Encapsulate()
 
+       for i := 0; i < len(dkBytes)-1; i++ {
+               if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
+                       t.Errorf("expected error for dk length %d", i)
+               }
+       }
+       dkLong := dkBytes
+       for i := 0; i < 100; i++ {
+               dkLong = append(dkLong, 0)
+               if _, err := newDecapsulationKey(dkLong); err == nil {
+                       t.Errorf("expected error for dk length %d", len(dkLong))
+               }
+       }
+
        for i := 0; i < len(ekBytes)-1; i++ {
-               if _, err := NewEncapsulationKey768(ekBytes[:i]); err == nil {
+               if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
                        t.Errorf("expected error for ek length %d", i)
                }
        }
        ekLong := ekBytes
        for i := 0; i < 100; i++ {
                ekLong = append(ekLong, 0)
-               if _, err := NewEncapsulationKey768(ekLong); err == nil {
+               if _, err := newEncapsulationKey(ekLong); err == nil {
                        t.Errorf("expected error for ek length %d", len(ekLong))
                }
        }