From 0568cda10aa9a3b26e519da15bcff18ab6ef974d Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 16 Oct 2024 14:50:22 +0200 Subject: [PATCH] crypto/internal/mlkem768: move field implementation to its own file Change-Id: Id2676f1fd446feda506a3f1d4fbdadffe87ecc95 Reviewed-on: https://go-review.googlesource.com/c/go/+/621978 Reviewed-by: Michael Knyszek Auto-Submit: Filippo Valsorda Reviewed-by: Russ Cox LUCI-TryBot-Result: Go LUCI Reviewed-by: Daniel McCarney --- src/crypto/internal/mlkem768/field.go | 456 ++++++++++++++++++ src/crypto/internal/mlkem768/field_test.go | 191 ++++++++ src/crypto/internal/mlkem768/mlkem768.go | 445 ----------------- src/crypto/internal/mlkem768/mlkem768_test.go | 182 ------- 4 files changed, 647 insertions(+), 627 deletions(-) create mode 100644 src/crypto/internal/mlkem768/field.go create mode 100644 src/crypto/internal/mlkem768/field_test.go diff --git a/src/crypto/internal/mlkem768/field.go b/src/crypto/internal/mlkem768/field.go new file mode 100644 index 0000000000..e0cde03354 --- /dev/null +++ b/src/crypto/internal/mlkem768/field.go @@ -0,0 +1,456 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mlkem768 + +import ( + "errors" + "internal/byteorder" + + "golang.org/x/crypto/sha3" +) + +// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced. +type fieldElement uint16 + +// fieldCheckReduced checks that a value a is < q. +func fieldCheckReduced(a uint16) (fieldElement, error) { + if a >= q { + return 0, errors.New("unreduced field element") + } + return fieldElement(a), nil +} + +// fieldReduceOnce reduces a value a < 2q. +func fieldReduceOnce(a uint16) fieldElement { + x := a - q + // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set. + x += (x >> 15) * q + return fieldElement(x) +} + +func fieldAdd(a, b fieldElement) fieldElement { + x := uint16(a + b) + return fieldReduceOnce(x) +} + +func fieldSub(a, b fieldElement) fieldElement { + x := uint16(a - b + q) + return fieldReduceOnce(x) +} + +const ( + barrettMultiplier = 5039 // 2¹² * 2¹² / q + barrettShift = 24 // log₂(2¹² * 2¹²) +) + +// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid +// potentially variable-time division. +func fieldReduce(a uint32) fieldElement { + quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift) + return fieldReduceOnce(uint16(a - quotient*q)) +} + +func fieldMul(a, b fieldElement) fieldElement { + x := uint32(a) * uint32(b) + return fieldReduce(x) +} + +// fieldMulSub returns a * (b - c). This operation is fused to save a +// fieldReduceOnce after the subtraction. +func fieldMulSub(a, b, c fieldElement) fieldElement { + x := uint32(a) * uint32(b-c+q) + return fieldReduce(x) +} + +// fieldAddMul returns a * b + c * d. This operation is fused to save a +// fieldReduceOnce and a fieldReduce. +func fieldAddMul(a, b, c, d fieldElement) fieldElement { + x := uint32(a) * uint32(b) + x += uint32(c) * uint32(d) + return fieldReduce(x) +} + +// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to +// FIPS 203, Definition 4.7. +func compress(x fieldElement, d uint8) uint16 { + // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2 + // rounding up (see FIPS 203, Section 2.3). + + // Barrett reduction produces a quotient and a remainder in the range [0, 2q), + // such that dividend = quotient * q + remainder. + dividend := uint32(x) << d // x * 2ᵈ + quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift) + remainder := dividend - quotient*q + + // Since the remainder is in the range [0, 2q), not [0, q), we need to + // portion it into three spans for rounding. + // + // [ 0, q/2 ) -> round to 0 + // [ q/2, q + q/2 ) -> round to 1 + // [ q + q/2, 2q ) -> round to 2 + // + // We can convert that to the following logic: add 1 if remainder > q/2, + // then add 1 again if remainder > q + q/2. + // + // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top + // bit of the difference will be set. + quotient += (q/2 - remainder) >> 31 & 1 + quotient += (q + q/2 - remainder) >> 31 & 1 + + // quotient might have overflowed at this point, so reduce it by masking. + var mask uint32 = (1 << d) - 1 + return uint16(quotient & mask) +} + +// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of +// field elements, according to FIPS 203, Definition 4.8. +func decompress(y uint16, d uint8) fieldElement { + // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2 + // rounding up (see FIPS 203, Section 2.3). + + dividend := uint32(y) * q + quotient := dividend >> d // (y * q) / 2ᵈ + + // The d'th least-significant bit of the dividend (the most significant bit + // of the remainder) is 1 for the top half of the values that divide to the + // same quotient, which are the ones that round up. + quotient += dividend >> (d - 1) & 1 + + // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow. + return fieldElement(quotient) +} + +// ringElement is a polynomial, an element of R_q, represented as an array +// according to FIPS 203, Section 2.4.4. +type ringElement [n]fieldElement + +// polyAdd adds two ringElements or nttElements. +func polyAdd[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldAdd(a[i], b[i]) + } + return s +} + +// polySub subtracts two ringElements or nttElements. +func polySub[T ~[n]fieldElement](a, b T) (s T) { + for i := range s { + s[i] = fieldSub(a[i], b[i]) + } + return s +} + +// polyByteEncode appends the 384-byte encoding of f to b. +// +// It implements ByteEncode₁₂, according to FIPS 203, Algorithm 5. +func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte { + out, B := sliceForAppend(b, encodingSize12) + for i := 0; i < n; i += 2 { + x := uint32(f[i]) | uint32(f[i+1])<<12 + B[0] = uint8(x) + B[1] = uint8(x >> 8) + B[2] = uint8(x >> 16) + B = B[3:] + } + return out +} + +// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that +// all the coefficients are properly reduced. This fulfills the "Modulus check" +// step of ML-KEM Encapsulation. +// +// It implements ByteDecode₁₂, according to FIPS 203, Algorithm 6. +func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) { + if len(b) != encodingSize12 { + return T{}, errors.New("mlkem768: invalid encoding length") + } + var f T + for i := 0; i < n; i += 2 { + d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 + const mask12 = 0b1111_1111_1111 + var err error + if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil { + return T{}, errors.New("mlkem768: invalid polynomial encoding") + } + if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil { + return T{}, errors.New("mlkem768: invalid polynomial encoding") + } + b = b[3:] + } + return f, nil +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s, +// compressing one coefficients per bit. +// +// It implements Compress₁, according to FIPS 203, Definition 4.7, +// followed by ByteEncode₁, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode1(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize1) + for i := range b { + b[i] = 0 + } + for i := range f { + b[i/8] |= uint8(compress(f[i], 1) << (i % 8)) + } + return s +} + +// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each +// bit is mapped to 0 or ⌈q/2⌋. +// +// It implements ByteDecode₁, according to FIPS 203, Algorithm 6, +// followed by Decompress₁, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement { + var f ringElement + for i := range f { + b_i := b[i/8] >> (i % 8) & 1 + const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203, Section 2.3 + f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋ + } + return f +} + +// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s, +// compressing two coefficients per byte. +// +// It implements Compress₄, according to FIPS 203, Definition 4.7, +// followed by ByteEncode₄, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode4(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize4) + for i := 0; i < n; i += 2 { + b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4) + } + return s +} + +// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where +// each four bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₄, according to FIPS 203, Algorithm 6, +// followed by Decompress₄, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement { + var f ringElement + for i := 0; i < n; i += 2 { + f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4)) + f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4)) + } + return f +} + +// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s, +// compressing four coefficients per five bytes. +// +// It implements Compress₁₀, according to FIPS 203, Definition 4.7, +// followed by ByteEncode₁₀, according to FIPS 203, Algorithm 5. +func ringCompressAndEncode10(s []byte, f ringElement) []byte { + s, b := sliceForAppend(s, encodingSize10) + for i := 0; i < n; i += 4 { + var x uint64 + x |= uint64(compress(f[i+0], 10)) + x |= uint64(compress(f[i+1], 10)) << 10 + x |= uint64(compress(f[i+2], 10)) << 20 + x |= uint64(compress(f[i+3], 10)) << 30 + b[0] = uint8(x) + b[1] = uint8(x >> 8) + b[2] = uint8(x >> 16) + b[3] = uint8(x >> 24) + b[4] = uint8(x >> 32) + b = b[5:] + } + return s +} + +// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where +// each ten bits are mapped to an equidistant distribution. +// +// It implements ByteDecode₁₀, according to FIPS 203, Algorithm 6, +// followed by Decompress₁₀, according to FIPS 203, Definition 4.8. +func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { + b := bb[:] + var f ringElement + for i := 0; i < n; i += 4 { + x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 + b = b[5:] + f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10)) + f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10)) + f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10)) + f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10)) + } + return f +} + +// samplePolyCBD draws a ringElement from the special Dη distribution given a +// stream of random bytes generated by the PRF function, according to FIPS 203, +// Algorithm 8 and Definition 4.3. +func samplePolyCBD(s []byte, b byte) ringElement { + prf := sha3.NewShake256() + prf.Write(s) + prf.Write([]byte{b}) + B := make([]byte, 64*2) // η = 2 + prf.Read(B) + + // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds + // the first two and subtracts the last two. + + var f ringElement + for i := 0; i < n; i += 2 { + b := B[i/2] + b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1 + b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1 + f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3)) + f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7)) + } + return f +} + +// nttElement is an NTT representation, an element of T_q, represented as an +// array according to FIPS 203, Section 2.4.4. +type nttElement [n]fieldElement + +// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i, according to +// FIPS 203, Appendix A (with negative values reduced to positive). +var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175} + +// nttMul multiplies two nttElements. +// +// It implements MultiplyNTTs, according to FIPS 203, Algorithm 11. +func nttMul(f, g nttElement) nttElement { + var h nttElement + // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826. + for i := 0; i < 256; i += 2 { + a0, a1 := f[i], f[i+1] + b0, b1 := g[i], g[i+1] + h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2]) + h[i+1] = fieldAddMul(a0, b1, a1, b0) + } + return h +} + +// zetas are the values ζ^BitRev7(k) mod q for each index k, according to FIPS +// 203, Appendix A. +var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154} + +// ntt maps a ringElement to its nttElement representation. +// +// It implements NTT, according to FIPS 203, Algorithm 9. +func ntt(f ringElement) nttElement { + k := 1 + for len := 128; len >= 2; len /= 2 { + for start := 0; start < 256; start += 2 * len { + zeta := zetas[k] + k++ + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := fieldMul(zeta, flen[j]) + flen[j] = fieldSub(f[j], t) + f[j] = fieldAdd(f[j], t) + } + } + } + return nttElement(f) +} + +// inverseNTT maps a nttElement back to the ringElement it represents. +// +// It implements NTT⁻¹, according to FIPS 203, Algorithm 10. +func inverseNTT(f nttElement) ringElement { + k := 127 + for len := 2; len <= 128; len *= 2 { + for start := 0; start < 256; start += 2 * len { + zeta := zetas[k] + k-- + // Bounds check elimination hint. + f, flen := f[start:start+len], f[start+len:start+len+len] + for j := 0; j < len; j++ { + t := f[j] + f[j] = fieldAdd(t, flen[j]) + flen[j] = fieldMulSub(zeta, flen[j], t) + } + } + } + for i := range f { + f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q + } + return ringElement(f) +} + +// sampleNTT draws a uniformly random nttElement from a stream of uniformly +// random bytes generated by the XOF function, according to FIPS 203, +// Algorithm 7. +func sampleNTT(rho []byte, ii, jj byte) nttElement { + B := sha3.NewShake128() + B.Write(rho) + B.Write([]byte{ii, jj}) + + // SampleNTT essentially draws 12 bits at a time from r, interprets them in + // little-endian, and rejects values higher than q, until it drew 256 + // values. (The rejection rate is approximately 19%.) + // + // To do this from a bytes stream, it draws three bytes at a time, and + // splits them into two uint16 appropriately masked. + // + // r₀ r₁ r₂ + // |- - - - - - - -|- - - - - - - -|- - - - - - - -| + // + // Uint16(r₀ || r₁) + // |- - - - - - - - - - - - - - - -| + // |- - - - - - - - - - - -| + // d₁ + // + // Uint16(r₁ || r₂) + // |- - - - - - - - - - - - - - - -| + // |- - - - - - - - - - - -| + // d₂ + // + // Note that in little-endian, the rightmost bits are the most significant + // bits (dropped with a mask) and the leftmost bits are the least + // significant bits (dropped with a right shift). + + var a nttElement + var j int // index into a + var buf [24]byte // buffered reads from B + off := len(buf) // index into buf, starts in a "buffer fully consumed" state + for { + if off >= len(buf) { + B.Read(buf[:]) + off = 0 + } + d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111 + d2 := byteorder.LeUint16(buf[off+1:]) >> 4 + off += 3 + if d1 < q { + a[j] = fieldElement(d1) + j++ + } + if j >= len(a) { + break + } + if d2 < q { + a[j] = fieldElement(d2) + j++ + } + if j >= len(a) { + break + } + } + return a +} diff --git a/src/crypto/internal/mlkem768/field_test.go b/src/crypto/internal/mlkem768/field_test.go new file mode 100644 index 0000000000..6350ebc82f --- /dev/null +++ b/src/crypto/internal/mlkem768/field_test.go @@ -0,0 +1,191 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mlkem768 + +import ( + "math/big" + "strconv" + "testing" +) + +func TestFieldReduce(t *testing.T) { + for a := uint32(0); a < 2*q*q; a++ { + got := fieldReduce(a) + exp := fieldElement(a % q) + if got != exp { + t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp) + } + } +} + +func TestFieldAdd(t *testing.T) { + for a := fieldElement(0); a < q; a++ { + for b := fieldElement(0); b < q; b++ { + got := fieldAdd(a, b) + exp := (a + b) % q + if got != exp { + t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestFieldSub(t *testing.T) { + for a := fieldElement(0); a < q; a++ { + for b := fieldElement(0); b < q; b++ { + got := fieldSub(a, b) + exp := (a - b + q) % q + if got != exp { + t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestFieldMul(t *testing.T) { + for a := fieldElement(0); a < q; a++ { + for b := fieldElement(0); b < q; b++ { + got := fieldMul(a, b) + exp := fieldElement((uint32(a) * uint32(b)) % q) + if got != exp { + t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp) + } + } + } +} + +func TestDecompressCompress(t *testing.T) { + for _, bits := range []uint8{1, 4, 10} { + for a := uint16(0); a < 1<= q { + t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f) + } + got := compress(f, bits) + if got != a { + t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got) + } + } + + for a := fieldElement(0); a < q; a++ { + c := compress(a, bits) + if c >= 1<= 2^bits", a, bits, c) + } + got := decompress(c, bits) + diff := min(a-got, got-a, a-got+q, got-a+q) + ceil := q / (1 << bits) + if diff > fieldElement(ceil) { + t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)", + a, bits, bits, got, diff, ceil) + } + } + } +} + +func CompressRat(x fieldElement, d uint8) uint16 { + if x >= q { + panic("x out of range") + } + if d <= 0 || d >= 12 { + panic("d out of range") + } + + precise := big.NewRat((1<= 1<= 12 { + panic("d out of range") + } + + precise := big.NewRat(q*int64(y), 1<>7 != 0 { + panic("not 7 bits") + } + var r uint8 + r |= n >> 6 & 0b0000_0001 + r |= n >> 4 & 0b0000_0010 + r |= n >> 2 & 0b0000_0100 + r |= n /**/ & 0b0000_1000 + r |= n << 2 & 0b0001_0000 + r |= n << 4 & 0b0010_0000 + r |= n << 6 & 0b0100_0000 + return r +} + +func TestZetas(t *testing.T) { + ζ := big.NewInt(17) + q := big.NewInt(q) + for k, zeta := range zetas { + // ζ^BitRev7(k) mod q + exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q) + if big.NewInt(int64(zeta)).Cmp(exp) != 0 { + t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp) + } + } +} + +func TestGammas(t *testing.T) { + ζ := big.NewInt(17) + q := big.NewInt(q) + for k, gamma := range gammas { + // ζ^2BitRev7(i)+1 + exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q) + if big.NewInt(int64(gamma)).Cmp(exp) != 0 { + t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp) + } + } +} diff --git a/src/crypto/internal/mlkem768/mlkem768.go b/src/crypto/internal/mlkem768/mlkem768.go index 0daf359446..45f4b78056 100644 --- a/src/crypto/internal/mlkem768/mlkem768.go +++ b/src/crypto/internal/mlkem768/mlkem768.go @@ -24,7 +24,6 @@ import ( "crypto/rand" "crypto/subtle" "errors" - "internal/byteorder" "golang.org/x/crypto/sha3" ) @@ -375,447 +374,3 @@ func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte { return ringCompressAndEncode1(nil, w) } - -// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced. -type fieldElement uint16 - -// fieldCheckReduced checks that a value a is < q. -func fieldCheckReduced(a uint16) (fieldElement, error) { - if a >= q { - return 0, errors.New("unreduced field element") - } - return fieldElement(a), nil -} - -// fieldReduceOnce reduces a value a < 2q. -func fieldReduceOnce(a uint16) fieldElement { - x := a - q - // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set. - x += (x >> 15) * q - return fieldElement(x) -} - -func fieldAdd(a, b fieldElement) fieldElement { - x := uint16(a + b) - return fieldReduceOnce(x) -} - -func fieldSub(a, b fieldElement) fieldElement { - x := uint16(a - b + q) - return fieldReduceOnce(x) -} - -const ( - barrettMultiplier = 5039 // 2¹² * 2¹² / q - barrettShift = 24 // log₂(2¹² * 2¹²) -) - -// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid -// potentially variable-time division. -func fieldReduce(a uint32) fieldElement { - quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift) - return fieldReduceOnce(uint16(a - quotient*q)) -} - -func fieldMul(a, b fieldElement) fieldElement { - x := uint32(a) * uint32(b) - return fieldReduce(x) -} - -// fieldMulSub returns a * (b - c). This operation is fused to save a -// fieldReduceOnce after the subtraction. -func fieldMulSub(a, b, c fieldElement) fieldElement { - x := uint32(a) * uint32(b-c+q) - return fieldReduce(x) -} - -// fieldAddMul returns a * b + c * d. This operation is fused to save a -// fieldReduceOnce and a fieldReduce. -func fieldAddMul(a, b, c, d fieldElement) fieldElement { - x := uint32(a) * uint32(b) - x += uint32(c) * uint32(d) - return fieldReduce(x) -} - -// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to -// FIPS 203, Definition 4.7. -func compress(x fieldElement, d uint8) uint16 { - // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2 - // rounding up (see FIPS 203, Section 2.3). - - // Barrett reduction produces a quotient and a remainder in the range [0, 2q), - // such that dividend = quotient * q + remainder. - dividend := uint32(x) << d // x * 2ᵈ - quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift) - remainder := dividend - quotient*q - - // Since the remainder is in the range [0, 2q), not [0, q), we need to - // portion it into three spans for rounding. - // - // [ 0, q/2 ) -> round to 0 - // [ q/2, q + q/2 ) -> round to 1 - // [ q + q/2, 2q ) -> round to 2 - // - // We can convert that to the following logic: add 1 if remainder > q/2, - // then add 1 again if remainder > q + q/2. - // - // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top - // bit of the difference will be set. - quotient += (q/2 - remainder) >> 31 & 1 - quotient += (q + q/2 - remainder) >> 31 & 1 - - // quotient might have overflowed at this point, so reduce it by masking. - var mask uint32 = (1 << d) - 1 - return uint16(quotient & mask) -} - -// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of -// field elements, according to FIPS 203, Definition 4.8. -func decompress(y uint16, d uint8) fieldElement { - // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2 - // rounding up (see FIPS 203, Section 2.3). - - dividend := uint32(y) * q - quotient := dividend >> d // (y * q) / 2ᵈ - - // The d'th least-significant bit of the dividend (the most significant bit - // of the remainder) is 1 for the top half of the values that divide to the - // same quotient, which are the ones that round up. - quotient += dividend >> (d - 1) & 1 - - // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow. - return fieldElement(quotient) -} - -// ringElement is a polynomial, an element of R_q, represented as an array -// according to FIPS 203, Section 2.4.4. -type ringElement [n]fieldElement - -// polyAdd adds two ringElements or nttElements. -func polyAdd[T ~[n]fieldElement](a, b T) (s T) { - for i := range s { - s[i] = fieldAdd(a[i], b[i]) - } - return s -} - -// polySub subtracts two ringElements or nttElements. -func polySub[T ~[n]fieldElement](a, b T) (s T) { - for i := range s { - s[i] = fieldSub(a[i], b[i]) - } - return s -} - -// polyByteEncode appends the 384-byte encoding of f to b. -// -// It implements ByteEncode₁₂, according to FIPS 203, Algorithm 5. -func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte { - out, B := sliceForAppend(b, encodingSize12) - for i := 0; i < n; i += 2 { - x := uint32(f[i]) | uint32(f[i+1])<<12 - B[0] = uint8(x) - B[1] = uint8(x >> 8) - B[2] = uint8(x >> 16) - B = B[3:] - } - return out -} - -// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that -// all the coefficients are properly reduced. This fulfills the "Modulus check" -// step of ML-KEM Encapsulation. -// -// It implements ByteDecode₁₂, according to FIPS 203, Algorithm 6. -func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) { - if len(b) != encodingSize12 { - return T{}, errors.New("mlkem768: invalid encoding length") - } - var f T - for i := 0; i < n; i += 2 { - d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 - const mask12 = 0b1111_1111_1111 - var err error - if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil { - return T{}, errors.New("mlkem768: invalid polynomial encoding") - } - if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil { - return T{}, errors.New("mlkem768: invalid polynomial encoding") - } - b = b[3:] - } - return f, nil -} - -// sliceForAppend takes a slice and a requested number of bytes. It returns a -// slice with the contents of the given slice followed by that many bytes and a -// second slice that aliases into it and contains only the extra bytes. If the -// original slice has sufficient capacity then no allocation is performed. -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -} - -// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s, -// compressing one coefficients per bit. -// -// It implements Compress₁, according to FIPS 203, Definition 4.7, -// followed by ByteEncode₁, according to FIPS 203, Algorithm 5. -func ringCompressAndEncode1(s []byte, f ringElement) []byte { - s, b := sliceForAppend(s, encodingSize1) - for i := range b { - b[i] = 0 - } - for i := range f { - b[i/8] |= uint8(compress(f[i], 1) << (i % 8)) - } - return s -} - -// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each -// bit is mapped to 0 or ⌈q/2⌋. -// -// It implements ByteDecode₁, according to FIPS 203, Algorithm 6, -// followed by Decompress₁, according to FIPS 203, Definition 4.8. -func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement { - var f ringElement - for i := range f { - b_i := b[i/8] >> (i % 8) & 1 - const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203, Section 2.3 - f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋ - } - return f -} - -// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s, -// compressing two coefficients per byte. -// -// It implements Compress₄, according to FIPS 203, Definition 4.7, -// followed by ByteEncode₄, according to FIPS 203, Algorithm 5. -func ringCompressAndEncode4(s []byte, f ringElement) []byte { - s, b := sliceForAppend(s, encodingSize4) - for i := 0; i < n; i += 2 { - b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4) - } - return s -} - -// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where -// each four bits are mapped to an equidistant distribution. -// -// It implements ByteDecode₄, according to FIPS 203, Algorithm 6, -// followed by Decompress₄, according to FIPS 203, Definition 4.8. -func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement { - var f ringElement - for i := 0; i < n; i += 2 { - f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4)) - f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4)) - } - return f -} - -// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s, -// compressing four coefficients per five bytes. -// -// It implements Compress₁₀, according to FIPS 203, Definition 4.7, -// followed by ByteEncode₁₀, according to FIPS 203, Algorithm 5. -func ringCompressAndEncode10(s []byte, f ringElement) []byte { - s, b := sliceForAppend(s, encodingSize10) - for i := 0; i < n; i += 4 { - var x uint64 - x |= uint64(compress(f[i+0], 10)) - x |= uint64(compress(f[i+1], 10)) << 10 - x |= uint64(compress(f[i+2], 10)) << 20 - x |= uint64(compress(f[i+3], 10)) << 30 - b[0] = uint8(x) - b[1] = uint8(x >> 8) - b[2] = uint8(x >> 16) - b[3] = uint8(x >> 24) - b[4] = uint8(x >> 32) - b = b[5:] - } - return s -} - -// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where -// each ten bits are mapped to an equidistant distribution. -// -// It implements ByteDecode₁₀, according to FIPS 203, Algorithm 6, -// followed by Decompress₁₀, according to FIPS 203, Definition 4.8. -func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { - b := bb[:] - var f ringElement - for i := 0; i < n; i += 4 { - x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 - b = b[5:] - f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10)) - f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10)) - f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10)) - f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10)) - } - return f -} - -// samplePolyCBD draws a ringElement from the special Dη distribution given a -// stream of random bytes generated by the PRF function, according to FIPS 203, -// Algorithm 8 and Definition 4.3. -func samplePolyCBD(s []byte, b byte) ringElement { - prf := sha3.NewShake256() - prf.Write(s) - prf.Write([]byte{b}) - B := make([]byte, 64*η) - prf.Read(B) - - // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds - // the first two and subtracts the last two. - - var f ringElement - for i := 0; i < n; i += 2 { - b := B[i/2] - b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1 - b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1 - f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3)) - f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7)) - } - return f -} - -// nttElement is an NTT representation, an element of T_q, represented as an -// array according to FIPS 203, Section 2.4.4. -type nttElement [n]fieldElement - -// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i, according to -// FIPS 203, Appendix A (with negative values reduced to positive). -var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175} - -// nttMul multiplies two nttElements. -// -// It implements MultiplyNTTs, according to FIPS 203, Algorithm 11. -func nttMul(f, g nttElement) nttElement { - var h nttElement - // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826. - for i := 0; i < 256; i += 2 { - a0, a1 := f[i], f[i+1] - b0, b1 := g[i], g[i+1] - h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2]) - h[i+1] = fieldAddMul(a0, b1, a1, b0) - } - return h -} - -// zetas are the values ζ^BitRev7(k) mod q for each index k, according to FIPS -// 203, Appendix A. -var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154} - -// ntt maps a ringElement to its nttElement representation. -// -// It implements NTT, according to FIPS 203, Algorithm 9. -func ntt(f ringElement) nttElement { - k := 1 - for len := 128; len >= 2; len /= 2 { - for start := 0; start < 256; start += 2 * len { - zeta := zetas[k] - k++ - // Bounds check elimination hint. - f, flen := f[start:start+len], f[start+len:start+len+len] - for j := 0; j < len; j++ { - t := fieldMul(zeta, flen[j]) - flen[j] = fieldSub(f[j], t) - f[j] = fieldAdd(f[j], t) - } - } - } - return nttElement(f) -} - -// inverseNTT maps a nttElement back to the ringElement it represents. -// -// It implements NTT⁻¹, according to FIPS 203, Algorithm 10. -func inverseNTT(f nttElement) ringElement { - k := 127 - for len := 2; len <= 128; len *= 2 { - for start := 0; start < 256; start += 2 * len { - zeta := zetas[k] - k-- - // Bounds check elimination hint. - f, flen := f[start:start+len], f[start+len:start+len+len] - for j := 0; j < len; j++ { - t := f[j] - f[j] = fieldAdd(t, flen[j]) - flen[j] = fieldMulSub(zeta, flen[j], t) - } - } - } - for i := range f { - f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q - } - return ringElement(f) -} - -// sampleNTT draws a uniformly random nttElement from a stream of uniformly -// random bytes generated by the XOF function, according to FIPS 203, -// Algorithm 7. -func sampleNTT(rho []byte, ii, jj byte) nttElement { - B := sha3.NewShake128() - B.Write(rho) - B.Write([]byte{ii, jj}) - - // SampleNTT essentially draws 12 bits at a time from r, interprets them in - // little-endian, and rejects values higher than q, until it drew 256 - // values. (The rejection rate is approximately 19%.) - // - // To do this from a bytes stream, it draws three bytes at a time, and - // splits them into two uint16 appropriately masked. - // - // r₀ r₁ r₂ - // |- - - - - - - -|- - - - - - - -|- - - - - - - -| - // - // Uint16(r₀ || r₁) - // |- - - - - - - - - - - - - - - -| - // |- - - - - - - - - - - -| - // d₁ - // - // Uint16(r₁ || r₂) - // |- - - - - - - - - - - - - - - -| - // |- - - - - - - - - - - -| - // d₂ - // - // Note that in little-endian, the rightmost bits are the most significant - // bits (dropped with a mask) and the leftmost bits are the least - // significant bits (dropped with a right shift). - - var a nttElement - var j int // index into a - var buf [24]byte // buffered reads from B - off := len(buf) // index into buf, starts in a "buffer fully consumed" state - for { - if off >= len(buf) { - B.Read(buf[:]) - off = 0 - } - d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111 - d2 := byteorder.LeUint16(buf[off+1:]) >> 4 - off += 3 - if d1 < q { - a[j] = fieldElement(d1) - j++ - } - if j >= len(a) { - break - } - if d2 < q { - a[j] = fieldElement(d2) - j++ - } - if j >= len(a) { - break - } - } - return a -} diff --git a/src/crypto/internal/mlkem768/mlkem768_test.go b/src/crypto/internal/mlkem768/mlkem768_test.go index 5d129e11df..4775f77aeb 100644 --- a/src/crypto/internal/mlkem768/mlkem768_test.go +++ b/src/crypto/internal/mlkem768/mlkem768_test.go @@ -10,193 +10,11 @@ import ( _ "embed" "encoding/hex" "flag" - "math/big" - "strconv" "testing" "golang.org/x/crypto/sha3" ) -func TestFieldReduce(t *testing.T) { - for a := uint32(0); a < 2*q*q; a++ { - got := fieldReduce(a) - exp := fieldElement(a % q) - if got != exp { - t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp) - } - } -} - -func TestFieldAdd(t *testing.T) { - for a := fieldElement(0); a < q; a++ { - for b := fieldElement(0); b < q; b++ { - got := fieldAdd(a, b) - exp := (a + b) % q - if got != exp { - t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp) - } - } - } -} - -func TestFieldSub(t *testing.T) { - for a := fieldElement(0); a < q; a++ { - for b := fieldElement(0); b < q; b++ { - got := fieldSub(a, b) - exp := (a - b + q) % q - if got != exp { - t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp) - } - } - } -} - -func TestFieldMul(t *testing.T) { - for a := fieldElement(0); a < q; a++ { - for b := fieldElement(0); b < q; b++ { - got := fieldMul(a, b) - exp := fieldElement((uint32(a) * uint32(b)) % q) - if got != exp { - t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp) - } - } - } -} - -func TestDecompressCompress(t *testing.T) { - for _, bits := range []uint8{1, 4, 10} { - for a := uint16(0); a < 1<= q { - t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f) - } - got := compress(f, bits) - if got != a { - t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got) - } - } - - for a := fieldElement(0); a < q; a++ { - c := compress(a, bits) - if c >= 1<= 2^bits", a, bits, c) - } - got := decompress(c, bits) - diff := min(a-got, got-a, a-got+q, got-a+q) - ceil := q / (1 << bits) - if diff > fieldElement(ceil) { - t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)", - a, bits, bits, got, diff, ceil) - } - } - } -} - -func CompressRat(x fieldElement, d uint8) uint16 { - if x >= q { - panic("x out of range") - } - if d <= 0 || d >= 12 { - panic("d out of range") - } - - precise := big.NewRat((1<= 1<= 12 { - panic("d out of range") - } - - precise := big.NewRat(q*int64(y), 1<>7 != 0 { - panic("not 7 bits") - } - var r uint8 - r |= n >> 6 & 0b0000_0001 - r |= n >> 4 & 0b0000_0010 - r |= n >> 2 & 0b0000_0100 - r |= n /**/ & 0b0000_1000 - r |= n << 2 & 0b0001_0000 - r |= n << 4 & 0b0010_0000 - r |= n << 6 & 0b0100_0000 - return r -} - -func TestZetas(t *testing.T) { - ζ := big.NewInt(17) - q := big.NewInt(q) - for k, zeta := range zetas { - // ζ^BitRev7(k) mod q - exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q) - if big.NewInt(int64(zeta)).Cmp(exp) != 0 { - t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp) - } - } -} - -func TestGammas(t *testing.T) { - ζ := big.NewInt(17) - q := big.NewInt(q) - for k, gamma := range gammas { - // ζ^2BitRev7(i)+1 - exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q) - if big.NewInt(int64(gamma)).Cmp(exp) != 0 { - t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp) - } - } -} - func TestRoundTrip(t *testing.T) { dk, err := GenerateKey() if err != nil { -- 2.48.1