s, b := sliceForAppend(s, encodingSize10)
for i := 0; i < n; i += 4 {
var x uint64
- x |= uint64(compress(f[i+0], 10))
+ x |= uint64(compress(f[i], 10))
x |= uint64(compress(f[i+1], 10)) << 10
x |= uint64(compress(f[i+2], 10)) << 20
x |= uint64(compress(f[i+3], 10)) << 30
return f
}
+// ringCompressAndEncode appends an encoding of a ring element to s,
+// compressing each coefficient to d bits.
+//
+// It implements Compress, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode(s []byte, f ringElement, d uint8) []byte {
+ var b byte
+ var bIdx uint8
+ for i := 0; i < n; i++ {
+ c := compress(f[i], d)
+ var cIdx uint8
+ for cIdx < d {
+ b |= byte(c>>cIdx) << bIdx
+ bits := min(8-bIdx, d-cIdx)
+ bIdx += bits
+ cIdx += bits
+ if bIdx == 8 {
+ s = append(s, b)
+ b = 0
+ bIdx = 0
+ }
+ }
+ }
+ if bIdx != 0 {
+ panic("mlkem: internal error: bitsFilled != 0")
+ }
+ return s
+}
+
+// ringDecodeAndDecompress decodes an encoding of a ring element where
+// each d bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode, according to FIPS 203, Algorithm 6,
+// followed by Decompress, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress(b []byte, d uint8) ringElement {
+ var f ringElement
+ var bIdx uint8
+ for i := 0; i < n; i++ {
+ var c uint16
+ var cIdx uint8
+ for cIdx < d {
+ c |= uint16(b[0]>>bIdx) << cIdx
+ c &= (1 << d) - 1
+ bits := min(8-bIdx, d-cIdx)
+ bIdx += bits
+ cIdx += bits
+ if bIdx == 8 {
+ b = b[1:]
+ bIdx = 0
+ }
+ }
+ f[i] = fieldElement(decompress(c, d))
+ }
+ if len(b) != 0 {
+ panic("mlkem: internal error: leftover bytes")
+ }
+ return f
+}
+
+// ringCompressAndEncode5 appends a 160-byte encoding of a ring element to s,
+// compressing eight coefficients per five bytes.
+//
+// It implements Compress₅, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₅, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode5(s []byte, f ringElement) []byte {
+ return ringCompressAndEncode(s, f, 5)
+}
+
+// ringDecodeAndDecompress5 decodes a 160-byte encoding of a ring element where
+// each five bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₅, according to FIPS 203, Algorithm 6,
+// followed by Decompress₅, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress5(bb *[encodingSize5]byte) ringElement {
+ return ringDecodeAndDecompress(bb[:], 5)
+}
+
+// ringCompressAndEncode11 appends a 352-byte encoding of a ring element to s,
+// compressing eight coefficients per eleven bytes.
+//
+// It implements Compress₁₁, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₁₁, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode11(s []byte, f ringElement) []byte {
+ return ringCompressAndEncode(s, f, 11)
+}
+
+// ringDecodeAndDecompress11 decodes a 352-byte encoding of a ring element where
+// each eleven bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₁₁, according to FIPS 203, Algorithm 6,
+// followed by Decompress₁₁, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress11(bb *[encodingSize11]byte) ringElement {
+ return ringDecodeAndDecompress(bb[:], 11)
+}
+
// samplePolyCBD draws a ringElement from the special Dη distribution given a
// stream of random bytes generated by the PRF function, according to FIPS 203,
// Algorithm 8 and Definition 4.3.
package mlkem
import (
+ "bytes"
+ "crypto/rand"
"math/big"
+ mathrand "math/rand/v2"
"strconv"
"testing"
)
}
}
+func randomRingElement() ringElement {
+ var r ringElement
+ for i := range r {
+ r[i] = fieldElement(mathrand.IntN(q))
+ }
+ return r
+}
+
+func TestEncodeDecode(t *testing.T) {
+ f := randomRingElement()
+ b := make([]byte, 12*n/8)
+ rand.Read(b)
+
+ // Compare ringCompressAndEncode to ringCompressAndEncodeN.
+ e1 := ringCompressAndEncode(nil, f, 10)
+ e2 := ringCompressAndEncode10(nil, f)
+ if !bytes.Equal(e1, e2) {
+ t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2)
+ }
+ e1 = ringCompressAndEncode(nil, f, 4)
+ e2 = ringCompressAndEncode4(nil, f)
+ if !bytes.Equal(e1, e2) {
+ t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2)
+ }
+ e1 = ringCompressAndEncode(nil, f, 1)
+ e2 = ringCompressAndEncode1(nil, f)
+ if !bytes.Equal(e1, e2) {
+ t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2)
+ }
+
+ // Compare ringDecodeAndDecompress to ringDecodeAndDecompressN.
+ g1 := ringDecodeAndDecompress(b[:encodingSize10], 10)
+ g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
+ if g1 != g2 {
+ t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2)
+ }
+ g1 = ringDecodeAndDecompress(b[:encodingSize4], 4)
+ g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
+ if g1 != g2 {
+ t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2)
+ }
+ g1 = ringDecodeAndDecompress(b[:encodingSize1], 1)
+ g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
+ if g1 != g2 {
+ t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2)
+ }
+
+ // Round-trip ringCompressAndEncode and ringDecodeAndDecompress.
+ for d := 1; d < 12; d++ {
+ encodingSize := d * n / 8
+ g := ringDecodeAndDecompress(b[:encodingSize], uint8(d))
+ out := ringCompressAndEncode(nil, g, uint8(d))
+ if !bytes.Equal(out, b[:encodingSize]) {
+ t.Errorf("roundtrip failed for d = %d", d)
+ }
+ }
+
+ // Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN.
+ g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
+ out := ringCompressAndEncode10(nil, g)
+ if !bytes.Equal(out, b[:encodingSize10]) {
+ t.Errorf("roundtrip failed for specialized 10")
+ }
+ g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
+ out = ringCompressAndEncode4(nil, g)
+ if !bytes.Equal(out, b[:encodingSize4]) {
+ t.Errorf("roundtrip failed for specialized 4")
+ }
+ g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
+ out = ringCompressAndEncode1(nil, g)
+ if !bytes.Equal(out, b[:encodingSize1]) {
+ t.Errorf("roundtrip failed for specialized 1")
+ }
+}
+
func BitRev7(n uint8) uint8 {
if n>>7 != 0 {
panic("not 7 bits")
--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore
+
+package main
+
+import (
+ "flag"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "log"
+ "os"
+ "strings"
+)
+
+var replacements = map[string]string{
+ "k": "k1024",
+
+ "CiphertextSize768": "CiphertextSize1024",
+ "EncapsulationKeySize768": "EncapsulationKeySize1024",
+
+ "encryptionKey": "encryptionKey1024",
+ "decryptionKey": "decryptionKey1024",
+
+ "EncapsulationKey768": "EncapsulationKey1024",
+ "NewEncapsulationKey768": "NewEncapsulationKey1024",
+ "parseEK": "parseEK1024",
+
+ "kemEncaps": "kemEncaps1024",
+ "pkeEncrypt": "pkeEncrypt1024",
+
+ "DecapsulationKey768": "DecapsulationKey1024",
+ "NewDecapsulationKey768": "NewDecapsulationKey1024",
+ "newKeyFromSeed": "newKeyFromSeed1024",
+
+ "kemDecaps": "kemDecaps1024",
+ "pkeDecrypt": "pkeDecrypt1024",
+
+ "GenerateKey768": "GenerateKey1024",
+ "generateKey": "generateKey1024",
+
+ "kemKeyGen": "kemKeyGen1024",
+
+ "encodingSize4": "encodingSize5",
+ "encodingSize10": "encodingSize11",
+ "ringCompressAndEncode4": "ringCompressAndEncode5",
+ "ringCompressAndEncode10": "ringCompressAndEncode11",
+ "ringDecodeAndDecompress4": "ringDecodeAndDecompress5",
+ "ringDecodeAndDecompress10": "ringDecodeAndDecompress11",
+}
+
+func main() {
+ inputFile := flag.String("input", "", "")
+ outputFile := flag.String("output", "", "")
+ flag.Parse()
+
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, *inputFile, nil, parser.SkipObjectResolution|parser.ParseComments)
+ if err != nil {
+ log.Fatal(err)
+ }
+ cmap := ast.NewCommentMap(fset, f, f.Comments)
+
+ // Drop header comments.
+ cmap[ast.Node(f)] = nil
+
+ // Remove top-level consts used across the main and generated files.
+ var newDecls []ast.Decl
+ for _, decl := range f.Decls {
+ switch d := decl.(type) {
+ case *ast.GenDecl:
+ if d.Tok == token.CONST {
+ continue // Skip const declarations
+ }
+ if d.Tok == token.IMPORT {
+ cmap[decl] = nil // Drop pre-import comments.
+ }
+ }
+ newDecls = append(newDecls, decl)
+ }
+ f.Decls = newDecls
+
+ // Replace identifiers.
+ ast.Inspect(f, func(n ast.Node) bool {
+ switch x := n.(type) {
+ case *ast.Ident:
+ if replacement, ok := replacements[x.Name]; ok {
+ x.Name = replacement
+ }
+ }
+ return true
+ })
+
+ // Replace identifiers in comments.
+ for _, c := range f.Comments {
+ for _, l := range c.List {
+ for k, v := range replacements {
+ if k == "k" {
+ continue
+ }
+ l.Text = strings.ReplaceAll(l.Text, k, v)
+ }
+ }
+ }
+
+ out, err := os.Create(*outputFile)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer out.Close()
+
+ out.WriteString("// Code generated by generate1024.go. DO NOT EDIT.\n\n")
+
+ f.Comments = cmap.Filter(f).Comments()
+ err = format.Node(out, fset, f)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
--- /dev/null
+// Code generated by generate1024.go. DO NOT EDIT.
+
+package mlkem
+
+import (
+ "crypto/internal/fips/drbg"
+ "crypto/internal/fips/sha3"
+ "crypto/internal/fips/subtle"
+ "errors"
+)
+
+// A DecapsulationKey1024 is the secret key used to decapsulate a shared key from a
+// ciphertext. It includes various precomputed values.
+type DecapsulationKey1024 struct {
+ d [32]byte // decapsulation key seed
+ z [32]byte // implicit rejection sampling seed
+
+ ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
+ h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
+
+ encryptionKey1024
+ decryptionKey1024
+}
+
+// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
+//
+// The decapsulation key must be kept secret.
+func (dk *DecapsulationKey1024) Bytes() []byte {
+ var b [SeedSize]byte
+ copy(b[:], dk.d[:])
+ copy(b[32:], dk.z[:])
+ return b[:]
+}
+
+// EncapsulationKey returns the public encapsulation key necessary to produce
+// ciphertexts.
+func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
+ return &EncapsulationKey1024{
+ ρ: dk.ρ,
+ h: dk.h,
+ encryptionKey1024: dk.encryptionKey1024,
+ }
+}
+
+// An EncapsulationKey1024 is the public key used to produce ciphertexts to be
+// decapsulated by the corresponding [DecapsulationKey1024].
+type EncapsulationKey1024 struct {
+ ρ [32]byte // sampleNTT seed for A
+ h [32]byte // H(ek)
+ encryptionKey1024
+}
+
+// Bytes returns the encapsulation key as a byte slice.
+func (ek *EncapsulationKey1024) Bytes() []byte {
+ // The actual logic is in a separate function to outline this allocation.
+ b := make([]byte, 0, EncapsulationKeySize1024)
+ return ek.bytes(b)
+}
+
+func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
+ for i := range ek.t {
+ b = polyByteEncode(b, ek.t[i])
+ }
+ b = append(b, ek.ρ[:]...)
+ return b
+}
+
+// encryptionKey1024 is the parsed and expanded form of a PKE encryption key.
+type encryptionKey1024 struct {
+ t [k1024]nttElement // ByteDecode₁₂(ek[:384k])
+ a [k1024 * k1024]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
+}
+
+// decryptionKey1024 is the parsed and expanded form of a PKE decryption key.
+type decryptionKey1024 struct {
+ s [k1024]nttElement // ByteDecode₁₂(dk[:decryptionKey1024Size])
+}
+
+// GenerateKey1024 generates a new decapsulation key, drawing random bytes from
+// a DRBG. The decapsulation key must be kept secret.
+func GenerateKey1024() (*DecapsulationKey1024, error) {
+ // The actual logic is in a separate function to outline this allocation.
+ dk := &DecapsulationKey1024{}
+ return generateKey1024(dk), nil
+}
+
+func generateKey1024(dk *DecapsulationKey1024) *DecapsulationKey1024 {
+ var d [32]byte
+ drbg.Read(d[:])
+ var z [32]byte
+ drbg.Read(z[:])
+ return kemKeyGen1024(dk, &d, &z)
+}
+
+// NewDecapsulationKey1024 parses a decapsulation key from a 64-byte
+// seed in the "d || z" form. The seed must be uniformly random.
+func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
+ // The actual logic is in a separate function to outline this allocation.
+ dk := &DecapsulationKey1024{}
+ return newKeyFromSeed1024(dk, seed)
+}
+
+func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
+ if len(seed) != SeedSize {
+ return nil, errors.New("mlkem: invalid seed length")
+ }
+ d := (*[32]byte)(seed[:32])
+ z := (*[32]byte)(seed[32:])
+ return kemKeyGen1024(dk, d, z), nil
+}
+
+// kemKeyGen1024 generates a decapsulation key.
+//
+// It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
+// K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
+// copies and allocations.
+func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) *DecapsulationKey1024 {
+ if dk == nil {
+ dk = &DecapsulationKey1024{}
+ }
+ dk.d = *d
+ dk.z = *z
+
+ g := sha3.New512()
+ g.Write(d[:])
+ g.Write([]byte{k1024}) // Module dimension as a domain separator.
+ G := g.Sum(make([]byte, 0, 64))
+ ρ, σ := G[:32], G[32:]
+ dk.ρ = [32]byte(ρ)
+
+ A := &dk.a
+ for i := byte(0); i < k1024; i++ {
+ for j := byte(0); j < k1024; j++ {
+ A[i*k1024+j] = sampleNTT(ρ, j, i)
+ }
+ }
+
+ var N byte
+ s := &dk.s
+ for i := range s {
+ s[i] = ntt(samplePolyCBD(σ, N))
+ N++
+ }
+ e := make([]nttElement, k1024)
+ for i := range e {
+ e[i] = ntt(samplePolyCBD(σ, N))
+ N++
+ }
+
+ t := &dk.t
+ for i := range t { // t = A ◦ s + e
+ t[i] = e[i]
+ for j := range s {
+ t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
+ }
+ }
+
+ H := sha3.New256()
+ ek := dk.EncapsulationKey().Bytes()
+ H.Write(ek)
+ H.Sum(dk.h[:0])
+
+ return dk
+}
+
+// Encapsulate generates a shared key and an associated ciphertext from an
+// encapsulation key, drawing random bytes from a DRBG.
+//
+// The shared key must be kept secret.
+func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) {
+ // The actual logic is in a separate function to outline this allocation.
+ var cc [CiphertextSize1024]byte
+ return ek.encapsulate(&cc)
+}
+
+func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) {
+ var m [messageSize]byte
+ drbg.Read(m[:])
+ // Note that the modulus check (step 2 of the encapsulation key check from
+ // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK1024.
+ return kemEncaps1024(cc, ek, &m)
+}
+
+// kemEncaps1024 generates a shared key and an associated ciphertext.
+//
+// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
+func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) {
+ if cc == nil {
+ cc = &[CiphertextSize1024]byte{}
+ }
+
+ g := sha3.New512()
+ g.Write(m[:])
+ g.Write(ek.h[:])
+ G := g.Sum(nil)
+ K, r := G[:SharedKeySize], G[SharedKeySize:]
+ c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
+ return c, K
+}
+
+// NewEncapsulationKey1024 parses an encapsulation key from its encoded form.
+// If the encapsulation key is not valid, NewEncapsulationKey1024 returns an error.
+func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
+ // The actual logic is in a separate function to outline this allocation.
+ ek := &EncapsulationKey1024{}
+ return parseEK1024(ek, encapsulationKey)
+}
+
+// parseEK1024 parses an encryption key from its encoded form.
+//
+// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
+// Algorithm 14.
+func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
+ if len(ekPKE) != EncapsulationKeySize1024 {
+ return nil, errors.New("mlkem: invalid encapsulation key length")
+ }
+
+ h := sha3.New256()
+ h.Write(ekPKE)
+ h.Sum(ek.h[:0])
+
+ for i := range ek.t {
+ var err error
+ ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
+ if err != nil {
+ return nil, err
+ }
+ ekPKE = ekPKE[encodingSize12:]
+ }
+ copy(ek.ρ[:], ekPKE)
+
+ for i := byte(0); i < k1024; i++ {
+ for j := byte(0); j < k1024; j++ {
+ ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
+ }
+ }
+
+ return ek, nil
+}
+
+// pkeEncrypt1024 encrypt a plaintext message.
+//
+// It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
+// computation of t and AT is done in parseEK1024.
+func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
+ var N byte
+ r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
+ for i := range r {
+ r[i] = ntt(samplePolyCBD(rnd, N))
+ N++
+ }
+ for i := range e1 {
+ e1[i] = samplePolyCBD(rnd, N)
+ N++
+ }
+ e2 := samplePolyCBD(rnd, N)
+
+ u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1
+ for i := range u {
+ u[i] = e1[i]
+ for j := range r {
+ // Note that i and j are inverted, as we need the transposed of A.
+ u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
+ }
+ }
+
+ μ := ringDecodeAndDecompress1(m)
+
+ var vNTT nttElement // t⊺ ◦ r
+ for i := range ex.t {
+ vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
+ }
+ v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
+
+ c := cc[:0]
+ for _, f := range u {
+ c = ringCompressAndEncode11(c, f)
+ }
+ c = ringCompressAndEncode5(c, v)
+
+ return c
+}
+
+// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
+// If the ciphertext is not valid, Decapsulate returns an error.
+//
+// The shared key must be kept secret.
+func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
+ if len(ciphertext) != CiphertextSize1024 {
+ return nil, errors.New("mlkem: invalid ciphertext length")
+ }
+ c := (*[CiphertextSize1024]byte)(ciphertext)
+ // Note that the hash check (step 3 of the decapsulation input check from
+ // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
+ // validly generated by ML-KEM.KeyGen_internal.
+ return kemDecaps1024(dk, c), nil
+}
+
+// kemDecaps1024 produces a shared key from a ciphertext.
+//
+// It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
+func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
+ m := pkeDecrypt1024(&dk.decryptionKey1024, c)
+ g := sha3.New512()
+ g.Write(m[:])
+ g.Write(dk.h[:])
+ G := g.Sum(make([]byte, 0, 64))
+ Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
+ J := sha3.NewShake256()
+ J.Write(dk.z[:])
+ J.Write(c[:])
+ Kout := make([]byte, SharedKeySize)
+ J.Read(Kout)
+ var cc [CiphertextSize1024]byte
+ c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
+
+ subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
+ return Kout
+}
+
+// pkeDecrypt1024 decrypts a ciphertext.
+//
+// It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
+// although s is retained from kemKeyGen1024.
+func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
+ u := make([]ringElement, k1024)
+ for i := range u {
+ b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
+ u[i] = ringDecodeAndDecompress11(b)
+ }
+
+ b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
+ v := ringDecodeAndDecompress5(b)
+
+ var mask nttElement // s⊺ ◦ NTT(u)
+ for i := range dx.s {
+ mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
+ }
+ w := polySub(v, inverseNTT(mask))
+
+ return ringCompressAndEncode1(nil, w)
+}
// Package mlkem implements the quantum-resistant key encapsulation method
// ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
//
-// Only the recommended ML-KEM-768 parameter set is provided.
-//
// [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
package mlkem
//
// Reviewers unfamiliar with polynomials or linear algebra might find the
// background at https://words.filippo.io/kyber-math/ useful.
+//
+// This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
+// parameter set implementation is auto-generated from this file.
+//
+//go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
import (
"crypto/internal/fips/drbg"
// encodingSizeX is the byte size of a ringElement or nttElement encoded
// by ByteEncode_X (FIPS 203, Algorithm 5).
encodingSize12 = n * 12 / 8
+ encodingSize11 = n * 11 / 8
encodingSize10 = n * 10 / 8
+ encodingSize5 = n * 5 / 8
encodingSize4 = n * 4 / 8
encodingSize1 = n * 1 / 8
const (
k = 3
- decryptionKeySize = k * encodingSize12
- encryptionKeySize = k*encodingSize12 + 32
-
CiphertextSize768 = k*encodingSize10 + encodingSize4
- EncapsulationKeySize768 = encryptionKeySize
+ EncapsulationKeySize768 = k*encodingSize12 + 32
+)
+
+// ML-KEM-1024 parameters.
+const (
+ k1024 = 4
+
+ CiphertextSize1024 = k1024*encodingSize11 + encodingSize5
+ EncapsulationKeySize1024 = k1024*encodingSize12 + 32
)
// A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
// Algorithm 14.
func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
- if len(ekPKE) != encryptionKeySize {
+ if len(ekPKE) != EncapsulationKeySize768 {
return nil, errors.New("mlkem: invalid encapsulation key length")
}
"testing"
)
+type encapsulationKey interface {
+ Bytes() []byte
+ Encapsulate() ([]byte, []byte)
+}
+
+type decapsulationKey[E encapsulationKey] interface {
+ Bytes() []byte
+ Decapsulate([]byte) ([]byte, error)
+ EncapsulationKey() E
+}
+
func TestRoundTrip(t *testing.T) {
- dk, err := GenerateKey768()
+ t.Run("768", func(t *testing.T) {
+ testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
+ })
+ t.Run("1024", func(t *testing.T) {
+ testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
+ })
+}
+
+func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
+ t *testing.T, generateKey func() (D, error),
+ newEncapsulationKey func([]byte) (E, error),
+ newDecapsulationKey func([]byte) (D, error)) {
+ dk, err := generateKey()
if err != nil {
t.Fatal(err)
}
- c, Ke := dk.EncapsulationKey().Encapsulate()
+ ek := dk.EncapsulationKey()
+ c, Ke := ek.Encapsulate()
Kd, err := dk.Decapsulate(c)
if err != nil {
t.Fatal(err)
t.Fail()
}
- dk1, err := GenerateKey768()
+ ek1, err := newEncapsulationKey(ek.Bytes())
if err != nil {
t.Fatal(err)
}
- if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) {
+ if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
t.Fail()
}
- if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
+ dk1, err := newDecapsulationKey(dk.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
+ t.Fail()
+ }
+ c1, Ke1 := ek1.Encapsulate()
+ Kd1, err := dk1.Decapsulate(c1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(Ke1, Kd1) {
t.Fail()
}
- c1, Ke1 := dk.EncapsulationKey().Encapsulate()
- if bytes.Equal(c, c1) {
+ dk2, err := generateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
t.Fail()
}
- if bytes.Equal(Ke, Ke1) {
+ if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
+ t.Fail()
+ }
+
+ c2, Ke2 := dk.EncapsulationKey().Encapsulate()
+ if bytes.Equal(c, c2) {
+ t.Fail()
+ }
+ if bytes.Equal(Ke, Ke2) {
t.Fail()
}
}
func TestBadLengths(t *testing.T) {
- dk, err := GenerateKey768()
+ t.Run("768", func(t *testing.T) {
+ testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
+ })
+ t.Run("1024", func(t *testing.T) {
+ testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
+ })
+}
+
+func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
+ t *testing.T, generateKey func() (D, error),
+ newEncapsulationKey func([]byte) (E, error),
+ newDecapsulationKey func([]byte) (D, error)) {
+ dk, err := generateKey()
+ dkBytes := dk.Bytes()
if err != nil {
t.Fatal(err)
}
ekBytes := dk.EncapsulationKey().Bytes()
c, _ := ek.Encapsulate()
+ for i := 0; i < len(dkBytes)-1; i++ {
+ if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
+ t.Errorf("expected error for dk length %d", i)
+ }
+ }
+ dkLong := dkBytes
+ for i := 0; i < 100; i++ {
+ dkLong = append(dkLong, 0)
+ if _, err := newDecapsulationKey(dkLong); err == nil {
+ t.Errorf("expected error for dk length %d", len(dkLong))
+ }
+ }
+
for i := 0; i < len(ekBytes)-1; i++ {
- if _, err := NewEncapsulationKey768(ekBytes[:i]); err == nil {
+ if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
t.Errorf("expected error for ek length %d", i)
}
}
ekLong := ekBytes
for i := 0; i < 100; i++ {
ekLong = append(ekLong, 0)
- if _, err := NewEncapsulationKey768(ekLong); err == nil {
+ if _, err := newEncapsulationKey(ekLong); err == nil {
t.Errorf("expected error for ek length %d", len(ekLong))
}
}