From: Filippo Valsorda Date: Wed, 23 Oct 2024 09:36:56 +0000 (+0200) Subject: crypto/internal/fips/mlkem: implement ML-KEM-1024 X-Git-Tag: go1.24rc1~244 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=ed413f3fe018e2290b1ebd6cc0975b7e828e1a6c;p=gostls13.git crypto/internal/fips/mlkem: implement ML-KEM-1024 Decided to automatically duplicate the high-level code to avoid growing the ML-KEM-768 data structures. For #70122 Change-Id: I5c705b71ee1e23adba9113d5cf6b6e505c028967 Reviewed-on: https://go-review.googlesource.com/c/go/+/621983 Auto-Submit: Filippo Valsorda Reviewed-by: Roland Shoemaker Reviewed-by: Daniel McCarney LUCI-TryBot-Result: Go LUCI Reviewed-by: Dmitri Shuralyov --- diff --git a/src/crypto/internal/fips/mlkem/field.go b/src/crypto/internal/fips/mlkem/field.go index 9c22ec86f9..1532f031f2 100644 --- a/src/crypto/internal/fips/mlkem/field.go +++ b/src/crypto/internal/fips/mlkem/field.go @@ -263,7 +263,7 @@ func ringCompressAndEncode10(s []byte, f ringElement) []byte { s, b := sliceForAppend(s, encodingSize10) for i := 0; i < n; i += 4 { var x uint64 - x |= uint64(compress(f[i+0], 10)) + x |= uint64(compress(f[i], 10)) x |= uint64(compress(f[i+1], 10)) << 10 x |= uint64(compress(f[i+2], 10)) << 20 x |= uint64(compress(f[i+3], 10)) << 30 @@ -296,6 +296,101 @@ func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { return f } +// ringCompressAndEncode appends an encoding of a ring element to s, +// compressing each coefficient to d bits. +// +// It implements Compress, according to FIPS 203, Definition 4.7, +// followed by ByteEncode, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode(s []byte, f ringElement, d uint8) []byte { + var b byte + var bIdx uint8 + for i := 0; i < n; i++ { + c := compress(f[i], d) + var cIdx uint8 + for cIdx < d { + b |= byte(c>>cIdx) << bIdx + bits := min(8-bIdx, d-cIdx) + bIdx += bits + cIdx += bits + if bIdx == 8 { + s = append(s, b) + b = 0 + bIdx = 0 + } + } + } + if bIdx != 0 { + panic("mlkem: internal error: bitsFilled != 0") + } + return s +} + +// ringDecodeAndDecompress decodes an encoding of a ring element where +// each d bits are mapped to an equidistant distribution. +// +// It implements ByteDecode, according to FIPS 203, Algorithm 6, +// followed by Decompress, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress(b []byte, d uint8) ringElement { + var f ringElement + var bIdx uint8 + for i := 0; i < n; i++ { + var c uint16 + var cIdx uint8 + for cIdx < d { + c |= uint16(b[0]>>bIdx) << cIdx + c &= (1 << d) - 1 + bits := min(8-bIdx, d-cIdx) + bIdx += bits + cIdx += bits + if bIdx == 8 { + b = b[1:] + bIdx = 0 + } + } + f[i] = fieldElement(decompress(c, d)) + } + if len(b) != 0 { + panic("mlkem: internal error: leftover bytes") + } + return f +} + +// ringCompressAndEncode5 appends a 160-byte encoding of a ring element to s, +// compressing eight coefficients per five bytes. +// +// It implements Compress₅, according to FIPS 203, Definition 4.7, +// followed by ByteEncode₅, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode5(s []byte, f ringElement) []byte { + return ringCompressAndEncode(s, f, 5) +} + +// ringDecodeAndDecompress5 decodes a 160-byte encoding of a ring element where +// each five bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₅, according to FIPS 203, Algorithm 6, +// followed by Decompress₅, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress5(bb *[encodingSize5]byte) ringElement { + return ringDecodeAndDecompress(bb[:], 5) +} + +// ringCompressAndEncode11 appends a 352-byte encoding of a ring element to s, +// compressing eight coefficients per eleven bytes. +// +// It implements Compress₁₁, according to FIPS 203, Definition 4.7, +// followed by ByteEncode₁₁, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode11(s []byte, f ringElement) []byte { + return ringCompressAndEncode(s, f, 11) +} + +// ringDecodeAndDecompress11 decodes a 352-byte encoding of a ring element where +// each eleven bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₁₁, according to FIPS 203, Algorithm 6, +// followed by Decompress₁₁, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress11(bb *[encodingSize11]byte) ringElement { + return ringDecodeAndDecompress(bb[:], 11) +} + // samplePolyCBD draws a ringElement from the special Dη distribution given a // stream of random bytes generated by the PRF function, according to FIPS 203, // Algorithm 8 and Definition 4.3. diff --git a/src/crypto/internal/fips/mlkem/field_test.go b/src/crypto/internal/fips/mlkem/field_test.go index a842913627..3a6d983803 100644 --- a/src/crypto/internal/fips/mlkem/field_test.go +++ b/src/crypto/internal/fips/mlkem/field_test.go @@ -5,7 +5,10 @@ package mlkem import ( + "bytes" + "crypto/rand" "math/big" + mathrand "math/rand/v2" "strconv" "testing" ) @@ -151,6 +154,81 @@ func TestDecompress(t *testing.T) { } } +func randomRingElement() ringElement { + var r ringElement + for i := range r { + r[i] = fieldElement(mathrand.IntN(q)) + } + return r +} + +func TestEncodeDecode(t *testing.T) { + f := randomRingElement() + b := make([]byte, 12*n/8) + rand.Read(b) + + // Compare ringCompressAndEncode to ringCompressAndEncodeN. + e1 := ringCompressAndEncode(nil, f, 10) + e2 := ringCompressAndEncode10(nil, f) + if !bytes.Equal(e1, e2) { + t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2) + } + e1 = ringCompressAndEncode(nil, f, 4) + e2 = ringCompressAndEncode4(nil, f) + if !bytes.Equal(e1, e2) { + t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2) + } + e1 = ringCompressAndEncode(nil, f, 1) + e2 = ringCompressAndEncode1(nil, f) + if !bytes.Equal(e1, e2) { + t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2) + } + + // Compare ringDecodeAndDecompress to ringDecodeAndDecompressN. + g1 := ringDecodeAndDecompress(b[:encodingSize10], 10) + g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b)) + if g1 != g2 { + t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2) + } + g1 = ringDecodeAndDecompress(b[:encodingSize4], 4) + g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b)) + if g1 != g2 { + t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2) + } + g1 = ringDecodeAndDecompress(b[:encodingSize1], 1) + g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b)) + if g1 != g2 { + t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2) + } + + // Round-trip ringCompressAndEncode and ringDecodeAndDecompress. + for d := 1; d < 12; d++ { + encodingSize := d * n / 8 + g := ringDecodeAndDecompress(b[:encodingSize], uint8(d)) + out := ringCompressAndEncode(nil, g, uint8(d)) + if !bytes.Equal(out, b[:encodingSize]) { + t.Errorf("roundtrip failed for d = %d", d) + } + } + + // Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN. + g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b)) + out := ringCompressAndEncode10(nil, g) + if !bytes.Equal(out, b[:encodingSize10]) { + t.Errorf("roundtrip failed for specialized 10") + } + g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b)) + out = ringCompressAndEncode4(nil, g) + if !bytes.Equal(out, b[:encodingSize4]) { + t.Errorf("roundtrip failed for specialized 4") + } + g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b)) + out = ringCompressAndEncode1(nil, g) + if !bytes.Equal(out, b[:encodingSize1]) { + t.Errorf("roundtrip failed for specialized 1") + } +} + func BitRev7(n uint8) uint8 { if n>>7 != 0 { panic("not 7 bits") diff --git a/src/crypto/internal/fips/mlkem/generate1024.go b/src/crypto/internal/fips/mlkem/generate1024.go new file mode 100644 index 0000000000..7ed68debdb --- /dev/null +++ b/src/crypto/internal/fips/mlkem/generate1024.go @@ -0,0 +1,123 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore + +package main + +import ( + "flag" + "go/ast" + "go/format" + "go/parser" + "go/token" + "log" + "os" + "strings" +) + +var replacements = map[string]string{ + "k": "k1024", + + "CiphertextSize768": "CiphertextSize1024", + "EncapsulationKeySize768": "EncapsulationKeySize1024", + + "encryptionKey": "encryptionKey1024", + "decryptionKey": "decryptionKey1024", + + "EncapsulationKey768": "EncapsulationKey1024", + "NewEncapsulationKey768": "NewEncapsulationKey1024", + "parseEK": "parseEK1024", + + "kemEncaps": "kemEncaps1024", + "pkeEncrypt": "pkeEncrypt1024", + + "DecapsulationKey768": "DecapsulationKey1024", + "NewDecapsulationKey768": "NewDecapsulationKey1024", + "newKeyFromSeed": "newKeyFromSeed1024", + + "kemDecaps": "kemDecaps1024", + "pkeDecrypt": "pkeDecrypt1024", + + "GenerateKey768": "GenerateKey1024", + "generateKey": "generateKey1024", + + "kemKeyGen": "kemKeyGen1024", + + "encodingSize4": "encodingSize5", + "encodingSize10": "encodingSize11", + "ringCompressAndEncode4": "ringCompressAndEncode5", + "ringCompressAndEncode10": "ringCompressAndEncode11", + "ringDecodeAndDecompress4": "ringDecodeAndDecompress5", + "ringDecodeAndDecompress10": "ringDecodeAndDecompress11", +} + +func main() { + inputFile := flag.String("input", "", "") + outputFile := flag.String("output", "", "") + flag.Parse() + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, *inputFile, nil, parser.SkipObjectResolution|parser.ParseComments) + if err != nil { + log.Fatal(err) + } + cmap := ast.NewCommentMap(fset, f, f.Comments) + + // Drop header comments. + cmap[ast.Node(f)] = nil + + // Remove top-level consts used across the main and generated files. + var newDecls []ast.Decl + for _, decl := range f.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + if d.Tok == token.CONST { + continue // Skip const declarations + } + if d.Tok == token.IMPORT { + cmap[decl] = nil // Drop pre-import comments. + } + } + newDecls = append(newDecls, decl) + } + f.Decls = newDecls + + // Replace identifiers. + ast.Inspect(f, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.Ident: + if replacement, ok := replacements[x.Name]; ok { + x.Name = replacement + } + } + return true + }) + + // Replace identifiers in comments. + for _, c := range f.Comments { + for _, l := range c.List { + for k, v := range replacements { + if k == "k" { + continue + } + l.Text = strings.ReplaceAll(l.Text, k, v) + } + } + } + + out, err := os.Create(*outputFile) + if err != nil { + log.Fatal(err) + } + defer out.Close() + + out.WriteString("// Code generated by generate1024.go. DO NOT EDIT.\n\n") + + f.Comments = cmap.Filter(f).Comments() + err = format.Node(out, fset, f) + if err != nil { + log.Fatal(err) + } +} diff --git a/src/crypto/internal/fips/mlkem/mlkem1024.go b/src/crypto/internal/fips/mlkem/mlkem1024.go new file mode 100644 index 0000000000..c77dae8f74 --- /dev/null +++ b/src/crypto/internal/fips/mlkem/mlkem1024.go @@ -0,0 +1,342 @@ +// Code generated by generate1024.go. DO NOT EDIT. + +package mlkem + +import ( + "crypto/internal/fips/drbg" + "crypto/internal/fips/sha3" + "crypto/internal/fips/subtle" + "errors" +) + +// A DecapsulationKey1024 is the secret key used to decapsulate a shared key from a +// ciphertext. It includes various precomputed values. +type DecapsulationKey1024 struct { + d [32]byte // decapsulation key seed + z [32]byte // implicit rejection sampling seed + + ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key + h [32]byte // H(ek), stored for ML-KEM.Decaps_internal + + encryptionKey1024 + decryptionKey1024 +} + +// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form. +// +// The decapsulation key must be kept secret. +func (dk *DecapsulationKey1024) Bytes() []byte { + var b [SeedSize]byte + copy(b[:], dk.d[:]) + copy(b[32:], dk.z[:]) + return b[:] +} + +// EncapsulationKey returns the public encapsulation key necessary to produce +// ciphertexts. +func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 { + return &EncapsulationKey1024{ + ρ: dk.ρ, + h: dk.h, + encryptionKey1024: dk.encryptionKey1024, + } +} + +// An EncapsulationKey1024 is the public key used to produce ciphertexts to be +// decapsulated by the corresponding [DecapsulationKey1024]. +type EncapsulationKey1024 struct { + ρ [32]byte // sampleNTT seed for A + h [32]byte // H(ek) + encryptionKey1024 +} + +// Bytes returns the encapsulation key as a byte slice. +func (ek *EncapsulationKey1024) Bytes() []byte { + // The actual logic is in a separate function to outline this allocation. + b := make([]byte, 0, EncapsulationKeySize1024) + return ek.bytes(b) +} + +func (ek *EncapsulationKey1024) bytes(b []byte) []byte { + for i := range ek.t { + b = polyByteEncode(b, ek.t[i]) + } + b = append(b, ek.ρ[:]...) + return b +} + +// encryptionKey1024 is the parsed and expanded form of a PKE encryption key. +type encryptionKey1024 struct { + t [k1024]nttElement // ByteDecode₁₂(ek[:384k]) + a [k1024 * k1024]nttElement // A[i*k+j] = sampleNTT(ρ, j, i) +} + +// decryptionKey1024 is the parsed and expanded form of a PKE decryption key. +type decryptionKey1024 struct { + s [k1024]nttElement // ByteDecode₁₂(dk[:decryptionKey1024Size]) +} + +// GenerateKey1024 generates a new decapsulation key, drawing random bytes from +// a DRBG. The decapsulation key must be kept secret. +func GenerateKey1024() (*DecapsulationKey1024, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey1024{} + return generateKey1024(dk), nil +} + +func generateKey1024(dk *DecapsulationKey1024) *DecapsulationKey1024 { + var d [32]byte + drbg.Read(d[:]) + var z [32]byte + drbg.Read(z[:]) + return kemKeyGen1024(dk, &d, &z) +} + +// NewDecapsulationKey1024 parses a decapsulation key from a 64-byte +// seed in the "d || z" form. The seed must be uniformly random. +func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) { + // The actual logic is in a separate function to outline this allocation. + dk := &DecapsulationKey1024{} + return newKeyFromSeed1024(dk, seed) +} + +func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) { + if len(seed) != SeedSize { + return nil, errors.New("mlkem: invalid seed length") + } + d := (*[32]byte)(seed[:32]) + z := (*[32]byte)(seed[32:]) + return kemKeyGen1024(dk, d, z), nil +} + +// kemKeyGen1024 generates a decapsulation key. +// +// It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and +// K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save +// copies and allocations. +func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) *DecapsulationKey1024 { + if dk == nil { + dk = &DecapsulationKey1024{} + } + dk.d = *d + dk.z = *z + + g := sha3.New512() + g.Write(d[:]) + g.Write([]byte{k1024}) // Module dimension as a domain separator. + G := g.Sum(make([]byte, 0, 64)) + ρ, σ := G[:32], G[32:] + dk.ρ = [32]byte(ρ) + + A := &dk.a + for i := byte(0); i < k1024; i++ { + for j := byte(0); j < k1024; j++ { + A[i*k1024+j] = sampleNTT(ρ, j, i) + } + } + + var N byte + s := &dk.s + for i := range s { + s[i] = ntt(samplePolyCBD(σ, N)) + N++ + } + e := make([]nttElement, k1024) + for i := range e { + e[i] = ntt(samplePolyCBD(σ, N)) + N++ + } + + t := &dk.t + for i := range t { // t = A ◦ s + e + t[i] = e[i] + for j := range s { + t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j])) + } + } + + H := sha3.New256() + ek := dk.EncapsulationKey().Bytes() + H.Write(ek) + H.Sum(dk.h[:0]) + + return dk +} + +// Encapsulate generates a shared key and an associated ciphertext from an +// encapsulation key, drawing random bytes from a DRBG. +// +// The shared key must be kept secret. +func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) { + // The actual logic is in a separate function to outline this allocation. + var cc [CiphertextSize1024]byte + return ek.encapsulate(&cc) +} + +func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) { + var m [messageSize]byte + drbg.Read(m[:]) + // Note that the modulus check (step 2 of the encapsulation key check from + // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK1024. + return kemEncaps1024(cc, ek, &m) +} + +// kemEncaps1024 generates a shared key and an associated ciphertext. +// +// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17. +func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) { + if cc == nil { + cc = &[CiphertextSize1024]byte{} + } + + g := sha3.New512() + g.Write(m[:]) + g.Write(ek.h[:]) + G := g.Sum(nil) + K, r := G[:SharedKeySize], G[SharedKeySize:] + c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r) + return c, K +} + +// NewEncapsulationKey1024 parses an encapsulation key from its encoded form. +// If the encapsulation key is not valid, NewEncapsulationKey1024 returns an error. +func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) { + // The actual logic is in a separate function to outline this allocation. + ek := &EncapsulationKey1024{} + return parseEK1024(ek, encapsulationKey) +} + +// parseEK1024 parses an encryption key from its encoded form. +// +// It implements the initial stages of K-PKE.Encrypt according to FIPS 203, +// Algorithm 14. +func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) { + if len(ekPKE) != EncapsulationKeySize1024 { + return nil, errors.New("mlkem: invalid encapsulation key length") + } + + h := sha3.New256() + h.Write(ekPKE) + h.Sum(ek.h[:0]) + + for i := range ek.t { + var err error + ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) + if err != nil { + return nil, err + } + ekPKE = ekPKE[encodingSize12:] + } + copy(ek.ρ[:], ekPKE) + + for i := byte(0); i < k1024; i++ { + for j := byte(0); j < k1024; j++ { + ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i) + } + } + + return ek, nil +} + +// pkeEncrypt1024 encrypt a plaintext message. +// +// It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the +// computation of t and AT is done in parseEK1024. +func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte { + var N byte + r, e1 := make([]nttElement, k1024), make([]ringElement, k1024) + for i := range r { + r[i] = ntt(samplePolyCBD(rnd, N)) + N++ + } + for i := range e1 { + e1[i] = samplePolyCBD(rnd, N) + N++ + } + e2 := samplePolyCBD(rnd, N) + + u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1 + for i := range u { + u[i] = e1[i] + for j := range r { + // Note that i and j are inverted, as we need the transposed of A. + u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j]))) + } + } + + μ := ringDecodeAndDecompress1(m) + + var vNTT nttElement // t⊺ ◦ r + for i := range ex.t { + vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i])) + } + v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ) + + c := cc[:0] + for _, f := range u { + c = ringCompressAndEncode11(c, f) + } + c = ringCompressAndEncode5(c, v) + + return c +} + +// Decapsulate generates a shared key from a ciphertext and a decapsulation key. +// If the ciphertext is not valid, Decapsulate returns an error. +// +// The shared key must be kept secret. +func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) { + if len(ciphertext) != CiphertextSize1024 { + return nil, errors.New("mlkem: invalid ciphertext length") + } + c := (*[CiphertextSize1024]byte)(ciphertext) + // Note that the hash check (step 3 of the decapsulation input check from + // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always + // validly generated by ML-KEM.KeyGen_internal. + return kemDecaps1024(dk, c), nil +} + +// kemDecaps1024 produces a shared key from a ciphertext. +// +// It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18. +func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) { + m := pkeDecrypt1024(&dk.decryptionKey1024, c) + g := sha3.New512() + g.Write(m[:]) + g.Write(dk.h[:]) + G := g.Sum(make([]byte, 0, 64)) + Kprime, r := G[:SharedKeySize], G[SharedKeySize:] + J := sha3.NewShake256() + J.Write(dk.z[:]) + J.Write(c[:]) + Kout := make([]byte, SharedKeySize) + J.Read(Kout) + var cc [CiphertextSize1024]byte + c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r) + + subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime) + return Kout +} + +// pkeDecrypt1024 decrypts a ciphertext. +// +// It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15, +// although s is retained from kemKeyGen1024. +func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte { + u := make([]ringElement, k1024) + for i := range u { + b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)]) + u[i] = ringDecodeAndDecompress11(b) + } + + b := (*[encodingSize5]byte)(c[encodingSize11*k1024:]) + v := ringDecodeAndDecompress5(b) + + var mask nttElement // s⊺ ◦ NTT(u) + for i := range dx.s { + mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i]))) + } + w := polySub(v, inverseNTT(mask)) + + return ringCompressAndEncode1(nil, w) +} diff --git a/src/crypto/internal/fips/mlkem/mlkem768.go b/src/crypto/internal/fips/mlkem/mlkem768.go index afbd31abe5..8cd6fffbcd 100644 --- a/src/crypto/internal/fips/mlkem/mlkem768.go +++ b/src/crypto/internal/fips/mlkem/mlkem768.go @@ -5,8 +5,6 @@ // Package mlkem implements the quantum-resistant key encapsulation method // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203]. // -// Only the recommended ML-KEM-768 parameter set is provided. -// // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203 package mlkem @@ -19,6 +17,11 @@ package mlkem // // Reviewers unfamiliar with polynomials or linear algebra might find the // background at https://words.filippo.io/kyber-math/ useful. +// +// This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024 +// parameter set implementation is auto-generated from this file. +// +//go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go import ( "crypto/internal/fips/drbg" @@ -35,7 +38,9 @@ const ( // encodingSizeX is the byte size of a ringElement or nttElement encoded // by ByteEncode_X (FIPS 203, Algorithm 5). encodingSize12 = n * 12 / 8 + encodingSize11 = n * 11 / 8 encodingSize10 = n * 10 / 8 + encodingSize5 = n * 5 / 8 encodingSize4 = n * 4 / 8 encodingSize1 = n * 1 / 8 @@ -49,11 +54,16 @@ const ( const ( k = 3 - decryptionKeySize = k * encodingSize12 - encryptionKeySize = k*encodingSize12 + 32 - CiphertextSize768 = k*encodingSize10 + encodingSize4 - EncapsulationKeySize768 = encryptionKeySize + EncapsulationKeySize768 = k*encodingSize12 + 32 +) + +// ML-KEM-1024 parameters. +const ( + k1024 = 4 + + CiphertextSize1024 = k1024*encodingSize11 + encodingSize5 + EncapsulationKeySize1024 = k1024*encodingSize12 + 32 ) // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a @@ -258,7 +268,7 @@ func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, erro // It implements the initial stages of K-PKE.Encrypt according to FIPS 203, // Algorithm 14. func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) { - if len(ekPKE) != encryptionKeySize { + if len(ekPKE) != EncapsulationKeySize768 { return nil, errors.New("mlkem: invalid encapsulation key length") } diff --git a/src/crypto/internal/fips/mlkem/mlkem768_test.go b/src/crypto/internal/fips/mlkem/mlkem_test.go similarity index 64% rename from src/crypto/internal/fips/mlkem/mlkem768_test.go rename to src/crypto/internal/fips/mlkem/mlkem_test.go index 28d17fe81a..acd8f4821b 100644 --- a/src/crypto/internal/fips/mlkem/mlkem768_test.go +++ b/src/crypto/internal/fips/mlkem/mlkem_test.go @@ -14,12 +14,36 @@ import ( "testing" ) +type encapsulationKey interface { + Bytes() []byte + Encapsulate() ([]byte, []byte) +} + +type decapsulationKey[E encapsulationKey] interface { + Bytes() []byte + Decapsulate([]byte) ([]byte, error) + EncapsulationKey() E +} + func TestRoundTrip(t *testing.T) { - dk, err := GenerateKey768() + t.Run("768", func(t *testing.T) { + testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768) + }) + t.Run("1024", func(t *testing.T) { + testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024) + }) +} + +func testRoundTrip[E encapsulationKey, D decapsulationKey[E]]( + t *testing.T, generateKey func() (D, error), + newEncapsulationKey func([]byte) (E, error), + newDecapsulationKey func([]byte) (D, error)) { + dk, err := generateKey() if err != nil { t.Fatal(err) } - c, Ke := dk.EncapsulationKey().Encapsulate() + ek := dk.EncapsulationKey() + c, Ke := ek.Encapsulate() Kd, err := dk.Decapsulate(c) if err != nil { t.Fatal(err) @@ -28,28 +52,64 @@ func TestRoundTrip(t *testing.T) { t.Fail() } - dk1, err := GenerateKey768() + ek1, err := newEncapsulationKey(ek.Bytes()) if err != nil { t.Fatal(err) } - if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) { + if !bytes.Equal(ek.Bytes(), ek1.Bytes()) { t.Fail() } - if bytes.Equal(dk.Bytes(), dk1.Bytes()) { + dk1, err := newDecapsulationKey(dk.Bytes()) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dk.Bytes(), dk1.Bytes()) { + t.Fail() + } + c1, Ke1 := ek1.Encapsulate() + Kd1, err := dk1.Decapsulate(c1) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(Ke1, Kd1) { t.Fail() } - c1, Ke1 := dk.EncapsulationKey().Encapsulate() - if bytes.Equal(c, c1) { + dk2, err := generateKey() + if err != nil { + t.Fatal(err) + } + if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) { t.Fail() } - if bytes.Equal(Ke, Ke1) { + if bytes.Equal(dk.Bytes(), dk2.Bytes()) { + t.Fail() + } + + c2, Ke2 := dk.EncapsulationKey().Encapsulate() + if bytes.Equal(c, c2) { + t.Fail() + } + if bytes.Equal(Ke, Ke2) { t.Fail() } } func TestBadLengths(t *testing.T) { - dk, err := GenerateKey768() + t.Run("768", func(t *testing.T) { + testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768) + }) + t.Run("1024", func(t *testing.T) { + testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024) + }) +} + +func testBadLengths[E encapsulationKey, D decapsulationKey[E]]( + t *testing.T, generateKey func() (D, error), + newEncapsulationKey func([]byte) (E, error), + newDecapsulationKey func([]byte) (D, error)) { + dk, err := generateKey() + dkBytes := dk.Bytes() if err != nil { t.Fatal(err) } @@ -57,15 +117,28 @@ func TestBadLengths(t *testing.T) { ekBytes := dk.EncapsulationKey().Bytes() c, _ := ek.Encapsulate() + for i := 0; i < len(dkBytes)-1; i++ { + if _, err := newDecapsulationKey(dkBytes[:i]); err == nil { + t.Errorf("expected error for dk length %d", i) + } + } + dkLong := dkBytes + for i := 0; i < 100; i++ { + dkLong = append(dkLong, 0) + if _, err := newDecapsulationKey(dkLong); err == nil { + t.Errorf("expected error for dk length %d", len(dkLong)) + } + } + for i := 0; i < len(ekBytes)-1; i++ { - if _, err := NewEncapsulationKey768(ekBytes[:i]); err == nil { + if _, err := newEncapsulationKey(ekBytes[:i]); err == nil { t.Errorf("expected error for ek length %d", i) } } ekLong := ekBytes for i := 0; i < 100; i++ { ekLong = append(ekLong, 0) - if _, err := NewEncapsulationKey768(ekLong); err == nil { + if _, err := newEncapsulationKey(ekLong); err == nil { t.Errorf("expected error for ek length %d", len(ekLong)) } }