--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mlkem768
+
+import (
+ "errors"
+ "internal/byteorder"
+
+ "golang.org/x/crypto/sha3"
+)
+
+// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
+type fieldElement uint16
+
+// fieldCheckReduced checks that a value a is < q.
+func fieldCheckReduced(a uint16) (fieldElement, error) {
+ if a >= q {
+ return 0, errors.New("unreduced field element")
+ }
+ return fieldElement(a), nil
+}
+
+// fieldReduceOnce reduces a value a < 2q.
+func fieldReduceOnce(a uint16) fieldElement {
+ x := a - q
+ // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set.
+ x += (x >> 15) * q
+ return fieldElement(x)
+}
+
+func fieldAdd(a, b fieldElement) fieldElement {
+ x := uint16(a + b)
+ return fieldReduceOnce(x)
+}
+
+func fieldSub(a, b fieldElement) fieldElement {
+ x := uint16(a - b + q)
+ return fieldReduceOnce(x)
+}
+
+const (
+ barrettMultiplier = 5039 // 2¹² * 2¹² / q
+ barrettShift = 24 // log₂(2¹² * 2¹²)
+)
+
+// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid
+// potentially variable-time division.
+func fieldReduce(a uint32) fieldElement {
+ quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift)
+ return fieldReduceOnce(uint16(a - quotient*q))
+}
+
+func fieldMul(a, b fieldElement) fieldElement {
+ x := uint32(a) * uint32(b)
+ return fieldReduce(x)
+}
+
+// fieldMulSub returns a * (b - c). This operation is fused to save a
+// fieldReduceOnce after the subtraction.
+func fieldMulSub(a, b, c fieldElement) fieldElement {
+ x := uint32(a) * uint32(b-c+q)
+ return fieldReduce(x)
+}
+
+// fieldAddMul returns a * b + c * d. This operation is fused to save a
+// fieldReduceOnce and a fieldReduce.
+func fieldAddMul(a, b, c, d fieldElement) fieldElement {
+ x := uint32(a) * uint32(b)
+ x += uint32(c) * uint32(d)
+ return fieldReduce(x)
+}
+
+// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to
+// FIPS 203, Definition 4.7.
+func compress(x fieldElement, d uint8) uint16 {
+ // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2
+ // rounding up (see FIPS 203, Section 2.3).
+
+ // Barrett reduction produces a quotient and a remainder in the range [0, 2q),
+ // such that dividend = quotient * q + remainder.
+ dividend := uint32(x) << d // x * 2ᵈ
+ quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift)
+ remainder := dividend - quotient*q
+
+ // Since the remainder is in the range [0, 2q), not [0, q), we need to
+ // portion it into three spans for rounding.
+ //
+ // [ 0, q/2 ) -> round to 0
+ // [ q/2, q + q/2 ) -> round to 1
+ // [ q + q/2, 2q ) -> round to 2
+ //
+ // We can convert that to the following logic: add 1 if remainder > q/2,
+ // then add 1 again if remainder > q + q/2.
+ //
+ // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top
+ // bit of the difference will be set.
+ quotient += (q/2 - remainder) >> 31 & 1
+ quotient += (q + q/2 - remainder) >> 31 & 1
+
+ // quotient might have overflowed at this point, so reduce it by masking.
+ var mask uint32 = (1 << d) - 1
+ return uint16(quotient & mask)
+}
+
+// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of
+// field elements, according to FIPS 203, Definition 4.8.
+func decompress(y uint16, d uint8) fieldElement {
+ // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2
+ // rounding up (see FIPS 203, Section 2.3).
+
+ dividend := uint32(y) * q
+ quotient := dividend >> d // (y * q) / 2ᵈ
+
+ // The d'th least-significant bit of the dividend (the most significant bit
+ // of the remainder) is 1 for the top half of the values that divide to the
+ // same quotient, which are the ones that round up.
+ quotient += dividend >> (d - 1) & 1
+
+ // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow.
+ return fieldElement(quotient)
+}
+
+// ringElement is a polynomial, an element of R_q, represented as an array
+// according to FIPS 203, Section 2.4.4.
+type ringElement [n]fieldElement
+
+// polyAdd adds two ringElements or nttElements.
+func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
+ for i := range s {
+ s[i] = fieldAdd(a[i], b[i])
+ }
+ return s
+}
+
+// polySub subtracts two ringElements or nttElements.
+func polySub[T ~[n]fieldElement](a, b T) (s T) {
+ for i := range s {
+ s[i] = fieldSub(a[i], b[i])
+ }
+ return s
+}
+
+// polyByteEncode appends the 384-byte encoding of f to b.
+//
+// It implements ByteEncode₁₂, according to FIPS 203, Algorithm 5.
+func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte {
+ out, B := sliceForAppend(b, encodingSize12)
+ for i := 0; i < n; i += 2 {
+ x := uint32(f[i]) | uint32(f[i+1])<<12
+ B[0] = uint8(x)
+ B[1] = uint8(x >> 8)
+ B[2] = uint8(x >> 16)
+ B = B[3:]
+ }
+ return out
+}
+
+// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that
+// all the coefficients are properly reduced. This fulfills the "Modulus check"
+// step of ML-KEM Encapsulation.
+//
+// It implements ByteDecode₁₂, according to FIPS 203, Algorithm 6.
+func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) {
+ if len(b) != encodingSize12 {
+ return T{}, errors.New("mlkem768: invalid encoding length")
+ }
+ var f T
+ for i := 0; i < n; i += 2 {
+ d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16
+ const mask12 = 0b1111_1111_1111
+ var err error
+ if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil {
+ return T{}, errors.New("mlkem768: invalid polynomial encoding")
+ }
+ if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil {
+ return T{}, errors.New("mlkem768: invalid polynomial encoding")
+ }
+ b = b[3:]
+ }
+ return f, nil
+}
+
+// sliceForAppend takes a slice and a requested number of bytes. It returns a
+// slice with the contents of the given slice followed by that many bytes and a
+// second slice that aliases into it and contains only the extra bytes. If the
+// original slice has sufficient capacity then no allocation is performed.
+func sliceForAppend(in []byte, n int) (head, tail []byte) {
+ if total := len(in) + n; cap(in) >= total {
+ head = in[:total]
+ } else {
+ head = make([]byte, total)
+ copy(head, in)
+ }
+ tail = head[len(in):]
+ return
+}
+
+// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s,
+// compressing one coefficients per bit.
+//
+// It implements Compress₁, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₁, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode1(s []byte, f ringElement) []byte {
+ s, b := sliceForAppend(s, encodingSize1)
+ for i := range b {
+ b[i] = 0
+ }
+ for i := range f {
+ b[i/8] |= uint8(compress(f[i], 1) << (i % 8))
+ }
+ return s
+}
+
+// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each
+// bit is mapped to 0 or ⌈q/2⌋.
+//
+// It implements ByteDecode₁, according to FIPS 203, Algorithm 6,
+// followed by Decompress₁, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement {
+ var f ringElement
+ for i := range f {
+ b_i := b[i/8] >> (i % 8) & 1
+ const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203, Section 2.3
+ f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋
+ }
+ return f
+}
+
+// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s,
+// compressing two coefficients per byte.
+//
+// It implements Compress₄, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₄, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode4(s []byte, f ringElement) []byte {
+ s, b := sliceForAppend(s, encodingSize4)
+ for i := 0; i < n; i += 2 {
+ b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4)
+ }
+ return s
+}
+
+// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where
+// each four bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₄, according to FIPS 203, Algorithm 6,
+// followed by Decompress₄, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement {
+ var f ringElement
+ for i := 0; i < n; i += 2 {
+ f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4))
+ f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4))
+ }
+ return f
+}
+
+// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s,
+// compressing four coefficients per five bytes.
+//
+// It implements Compress₁₀, according to FIPS 203, Definition 4.7,
+// followed by ByteEncode₁₀, according to FIPS 203, Algorithm 5.
+func ringCompressAndEncode10(s []byte, f ringElement) []byte {
+ s, b := sliceForAppend(s, encodingSize10)
+ for i := 0; i < n; i += 4 {
+ var x uint64
+ x |= uint64(compress(f[i+0], 10))
+ x |= uint64(compress(f[i+1], 10)) << 10
+ x |= uint64(compress(f[i+2], 10)) << 20
+ x |= uint64(compress(f[i+3], 10)) << 30
+ b[0] = uint8(x)
+ b[1] = uint8(x >> 8)
+ b[2] = uint8(x >> 16)
+ b[3] = uint8(x >> 24)
+ b[4] = uint8(x >> 32)
+ b = b[5:]
+ }
+ return s
+}
+
+// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where
+// each ten bits are mapped to an equidistant distribution.
+//
+// It implements ByteDecode₁₀, according to FIPS 203, Algorithm 6,
+// followed by Decompress₁₀, according to FIPS 203, Definition 4.8.
+func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
+ b := bb[:]
+ var f ringElement
+ for i := 0; i < n; i += 4 {
+ x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32
+ b = b[5:]
+ f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10))
+ f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10))
+ f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10))
+ f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10))
+ }
+ return f
+}
+
+// samplePolyCBD draws a ringElement from the special Dη distribution given a
+// stream of random bytes generated by the PRF function, according to FIPS 203,
+// Algorithm 8 and Definition 4.3.
+func samplePolyCBD(s []byte, b byte) ringElement {
+ prf := sha3.NewShake256()
+ prf.Write(s)
+ prf.Write([]byte{b})
+ B := make([]byte, 64*2) // η = 2
+ prf.Read(B)
+
+ // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds
+ // the first two and subtracts the last two.
+
+ var f ringElement
+ for i := 0; i < n; i += 2 {
+ b := B[i/2]
+ b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1
+ b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1
+ f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3))
+ f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7))
+ }
+ return f
+}
+
+// nttElement is an NTT representation, an element of T_q, represented as an
+// array according to FIPS 203, Section 2.4.4.
+type nttElement [n]fieldElement
+
+// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i, according to
+// FIPS 203, Appendix A (with negative values reduced to positive).
+var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175}
+
+// nttMul multiplies two nttElements.
+//
+// It implements MultiplyNTTs, according to FIPS 203, Algorithm 11.
+func nttMul(f, g nttElement) nttElement {
+ var h nttElement
+ // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826.
+ for i := 0; i < 256; i += 2 {
+ a0, a1 := f[i], f[i+1]
+ b0, b1 := g[i], g[i+1]
+ h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2])
+ h[i+1] = fieldAddMul(a0, b1, a1, b0)
+ }
+ return h
+}
+
+// zetas are the values ζ^BitRev7(k) mod q for each index k, according to FIPS
+// 203, Appendix A.
+var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154}
+
+// ntt maps a ringElement to its nttElement representation.
+//
+// It implements NTT, according to FIPS 203, Algorithm 9.
+func ntt(f ringElement) nttElement {
+ k := 1
+ for len := 128; len >= 2; len /= 2 {
+ for start := 0; start < 256; start += 2 * len {
+ zeta := zetas[k]
+ k++
+ // Bounds check elimination hint.
+ f, flen := f[start:start+len], f[start+len:start+len+len]
+ for j := 0; j < len; j++ {
+ t := fieldMul(zeta, flen[j])
+ flen[j] = fieldSub(f[j], t)
+ f[j] = fieldAdd(f[j], t)
+ }
+ }
+ }
+ return nttElement(f)
+}
+
+// inverseNTT maps a nttElement back to the ringElement it represents.
+//
+// It implements NTT⁻¹, according to FIPS 203, Algorithm 10.
+func inverseNTT(f nttElement) ringElement {
+ k := 127
+ for len := 2; len <= 128; len *= 2 {
+ for start := 0; start < 256; start += 2 * len {
+ zeta := zetas[k]
+ k--
+ // Bounds check elimination hint.
+ f, flen := f[start:start+len], f[start+len:start+len+len]
+ for j := 0; j < len; j++ {
+ t := f[j]
+ f[j] = fieldAdd(t, flen[j])
+ flen[j] = fieldMulSub(zeta, flen[j], t)
+ }
+ }
+ }
+ for i := range f {
+ f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q
+ }
+ return ringElement(f)
+}
+
+// sampleNTT draws a uniformly random nttElement from a stream of uniformly
+// random bytes generated by the XOF function, according to FIPS 203,
+// Algorithm 7.
+func sampleNTT(rho []byte, ii, jj byte) nttElement {
+ B := sha3.NewShake128()
+ B.Write(rho)
+ B.Write([]byte{ii, jj})
+
+ // SampleNTT essentially draws 12 bits at a time from r, interprets them in
+ // little-endian, and rejects values higher than q, until it drew 256
+ // values. (The rejection rate is approximately 19%.)
+ //
+ // To do this from a bytes stream, it draws three bytes at a time, and
+ // splits them into two uint16 appropriately masked.
+ //
+ // r₀ r₁ r₂
+ // |- - - - - - - -|- - - - - - - -|- - - - - - - -|
+ //
+ // Uint16(r₀ || r₁)
+ // |- - - - - - - - - - - - - - - -|
+ // |- - - - - - - - - - - -|
+ // d₁
+ //
+ // Uint16(r₁ || r₂)
+ // |- - - - - - - - - - - - - - - -|
+ // |- - - - - - - - - - - -|
+ // d₂
+ //
+ // Note that in little-endian, the rightmost bits are the most significant
+ // bits (dropped with a mask) and the leftmost bits are the least
+ // significant bits (dropped with a right shift).
+
+ var a nttElement
+ var j int // index into a
+ var buf [24]byte // buffered reads from B
+ off := len(buf) // index into buf, starts in a "buffer fully consumed" state
+ for {
+ if off >= len(buf) {
+ B.Read(buf[:])
+ off = 0
+ }
+ d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111
+ d2 := byteorder.LeUint16(buf[off+1:]) >> 4
+ off += 3
+ if d1 < q {
+ a[j] = fieldElement(d1)
+ j++
+ }
+ if j >= len(a) {
+ break
+ }
+ if d2 < q {
+ a[j] = fieldElement(d2)
+ j++
+ }
+ if j >= len(a) {
+ break
+ }
+ }
+ return a
+}
--- /dev/null
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mlkem768
+
+import (
+ "math/big"
+ "strconv"
+ "testing"
+)
+
+func TestFieldReduce(t *testing.T) {
+ for a := uint32(0); a < 2*q*q; a++ {
+ got := fieldReduce(a)
+ exp := fieldElement(a % q)
+ if got != exp {
+ t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
+ }
+ }
+}
+
+func TestFieldAdd(t *testing.T) {
+ for a := fieldElement(0); a < q; a++ {
+ for b := fieldElement(0); b < q; b++ {
+ got := fieldAdd(a, b)
+ exp := (a + b) % q
+ if got != exp {
+ t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
+ }
+ }
+ }
+}
+
+func TestFieldSub(t *testing.T) {
+ for a := fieldElement(0); a < q; a++ {
+ for b := fieldElement(0); b < q; b++ {
+ got := fieldSub(a, b)
+ exp := (a - b + q) % q
+ if got != exp {
+ t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
+ }
+ }
+ }
+}
+
+func TestFieldMul(t *testing.T) {
+ for a := fieldElement(0); a < q; a++ {
+ for b := fieldElement(0); b < q; b++ {
+ got := fieldMul(a, b)
+ exp := fieldElement((uint32(a) * uint32(b)) % q)
+ if got != exp {
+ t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
+ }
+ }
+ }
+}
+
+func TestDecompressCompress(t *testing.T) {
+ for _, bits := range []uint8{1, 4, 10} {
+ for a := uint16(0); a < 1<<bits; a++ {
+ f := decompress(a, bits)
+ if f >= q {
+ t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
+ }
+ got := compress(f, bits)
+ if got != a {
+ t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
+ }
+ }
+
+ for a := fieldElement(0); a < q; a++ {
+ c := compress(a, bits)
+ if c >= 1<<bits {
+ t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
+ }
+ got := decompress(c, bits)
+ diff := min(a-got, got-a, a-got+q, got-a+q)
+ ceil := q / (1 << bits)
+ if diff > fieldElement(ceil) {
+ t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
+ a, bits, bits, got, diff, ceil)
+ }
+ }
+ }
+}
+
+func CompressRat(x fieldElement, d uint8) uint16 {
+ if x >= q {
+ panic("x out of range")
+ }
+ if d <= 0 || d >= 12 {
+ panic("d out of range")
+ }
+
+ precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
+
+ // FloatString rounds halves away from 0, and our result should always be positive,
+ // so it should work as we expect. (There's no direct way to round a Rat.)
+ rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
+ if err != nil {
+ panic(err)
+ }
+
+ // If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
+ return uint16(rounded % (1 << d))
+}
+
+func TestCompress(t *testing.T) {
+ for d := 1; d < 12; d++ {
+ for n := 0; n < q; n++ {
+ expected := CompressRat(fieldElement(n), uint8(d))
+ result := compress(fieldElement(n), uint8(d))
+ if result != expected {
+ t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
+ }
+ }
+ }
+}
+
+func DecompressRat(y uint16, d uint8) fieldElement {
+ if y >= 1<<d {
+ panic("y out of range")
+ }
+ if d <= 0 || d >= 12 {
+ panic("d out of range")
+ }
+
+ precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y == (q * y) / 2ᵈ
+
+ // FloatString rounds halves away from 0, and our result should always be positive,
+ // so it should work as we expect. (There's no direct way to round a Rat.)
+ rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
+ if err != nil {
+ panic(err)
+ }
+
+ // If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
+ return fieldElement(rounded % q)
+}
+
+func TestDecompress(t *testing.T) {
+ for d := 1; d < 12; d++ {
+ for n := 0; n < (1 << d); n++ {
+ expected := DecompressRat(uint16(n), uint8(d))
+ result := decompress(uint16(n), uint8(d))
+ if result != expected {
+ t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
+ }
+ }
+ }
+}
+
+func BitRev7(n uint8) uint8 {
+ if n>>7 != 0 {
+ panic("not 7 bits")
+ }
+ var r uint8
+ r |= n >> 6 & 0b0000_0001
+ r |= n >> 4 & 0b0000_0010
+ r |= n >> 2 & 0b0000_0100
+ r |= n /**/ & 0b0000_1000
+ r |= n << 2 & 0b0001_0000
+ r |= n << 4 & 0b0010_0000
+ r |= n << 6 & 0b0100_0000
+ return r
+}
+
+func TestZetas(t *testing.T) {
+ ζ := big.NewInt(17)
+ q := big.NewInt(q)
+ for k, zeta := range zetas {
+ // ζ^BitRev7(k) mod q
+ exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
+ if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
+ t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
+ }
+ }
+}
+
+func TestGammas(t *testing.T) {
+ ζ := big.NewInt(17)
+ q := big.NewInt(q)
+ for k, gamma := range gammas {
+ // ζ^2BitRev7(i)+1
+ exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
+ if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
+ t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
+ }
+ }
+}
"crypto/rand"
"crypto/subtle"
"errors"
- "internal/byteorder"
"golang.org/x/crypto/sha3"
)
return ringCompressAndEncode1(nil, w)
}
-
-// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced.
-type fieldElement uint16
-
-// fieldCheckReduced checks that a value a is < q.
-func fieldCheckReduced(a uint16) (fieldElement, error) {
- if a >= q {
- return 0, errors.New("unreduced field element")
- }
- return fieldElement(a), nil
-}
-
-// fieldReduceOnce reduces a value a < 2q.
-func fieldReduceOnce(a uint16) fieldElement {
- x := a - q
- // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set.
- x += (x >> 15) * q
- return fieldElement(x)
-}
-
-func fieldAdd(a, b fieldElement) fieldElement {
- x := uint16(a + b)
- return fieldReduceOnce(x)
-}
-
-func fieldSub(a, b fieldElement) fieldElement {
- x := uint16(a - b + q)
- return fieldReduceOnce(x)
-}
-
-const (
- barrettMultiplier = 5039 // 2¹² * 2¹² / q
- barrettShift = 24 // log₂(2¹² * 2¹²)
-)
-
-// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid
-// potentially variable-time division.
-func fieldReduce(a uint32) fieldElement {
- quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift)
- return fieldReduceOnce(uint16(a - quotient*q))
-}
-
-func fieldMul(a, b fieldElement) fieldElement {
- x := uint32(a) * uint32(b)
- return fieldReduce(x)
-}
-
-// fieldMulSub returns a * (b - c). This operation is fused to save a
-// fieldReduceOnce after the subtraction.
-func fieldMulSub(a, b, c fieldElement) fieldElement {
- x := uint32(a) * uint32(b-c+q)
- return fieldReduce(x)
-}
-
-// fieldAddMul returns a * b + c * d. This operation is fused to save a
-// fieldReduceOnce and a fieldReduce.
-func fieldAddMul(a, b, c, d fieldElement) fieldElement {
- x := uint32(a) * uint32(b)
- x += uint32(c) * uint32(d)
- return fieldReduce(x)
-}
-
-// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to
-// FIPS 203, Definition 4.7.
-func compress(x fieldElement, d uint8) uint16 {
- // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2
- // rounding up (see FIPS 203, Section 2.3).
-
- // Barrett reduction produces a quotient and a remainder in the range [0, 2q),
- // such that dividend = quotient * q + remainder.
- dividend := uint32(x) << d // x * 2ᵈ
- quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift)
- remainder := dividend - quotient*q
-
- // Since the remainder is in the range [0, 2q), not [0, q), we need to
- // portion it into three spans for rounding.
- //
- // [ 0, q/2 ) -> round to 0
- // [ q/2, q + q/2 ) -> round to 1
- // [ q + q/2, 2q ) -> round to 2
- //
- // We can convert that to the following logic: add 1 if remainder > q/2,
- // then add 1 again if remainder > q + q/2.
- //
- // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top
- // bit of the difference will be set.
- quotient += (q/2 - remainder) >> 31 & 1
- quotient += (q + q/2 - remainder) >> 31 & 1
-
- // quotient might have overflowed at this point, so reduce it by masking.
- var mask uint32 = (1 << d) - 1
- return uint16(quotient & mask)
-}
-
-// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of
-// field elements, according to FIPS 203, Definition 4.8.
-func decompress(y uint16, d uint8) fieldElement {
- // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2
- // rounding up (see FIPS 203, Section 2.3).
-
- dividend := uint32(y) * q
- quotient := dividend >> d // (y * q) / 2ᵈ
-
- // The d'th least-significant bit of the dividend (the most significant bit
- // of the remainder) is 1 for the top half of the values that divide to the
- // same quotient, which are the ones that round up.
- quotient += dividend >> (d - 1) & 1
-
- // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow.
- return fieldElement(quotient)
-}
-
-// ringElement is a polynomial, an element of R_q, represented as an array
-// according to FIPS 203, Section 2.4.4.
-type ringElement [n]fieldElement
-
-// polyAdd adds two ringElements or nttElements.
-func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
- for i := range s {
- s[i] = fieldAdd(a[i], b[i])
- }
- return s
-}
-
-// polySub subtracts two ringElements or nttElements.
-func polySub[T ~[n]fieldElement](a, b T) (s T) {
- for i := range s {
- s[i] = fieldSub(a[i], b[i])
- }
- return s
-}
-
-// polyByteEncode appends the 384-byte encoding of f to b.
-//
-// It implements ByteEncode₁₂, according to FIPS 203, Algorithm 5.
-func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte {
- out, B := sliceForAppend(b, encodingSize12)
- for i := 0; i < n; i += 2 {
- x := uint32(f[i]) | uint32(f[i+1])<<12
- B[0] = uint8(x)
- B[1] = uint8(x >> 8)
- B[2] = uint8(x >> 16)
- B = B[3:]
- }
- return out
-}
-
-// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that
-// all the coefficients are properly reduced. This fulfills the "Modulus check"
-// step of ML-KEM Encapsulation.
-//
-// It implements ByteDecode₁₂, according to FIPS 203, Algorithm 6.
-func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) {
- if len(b) != encodingSize12 {
- return T{}, errors.New("mlkem768: invalid encoding length")
- }
- var f T
- for i := 0; i < n; i += 2 {
- d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16
- const mask12 = 0b1111_1111_1111
- var err error
- if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil {
- return T{}, errors.New("mlkem768: invalid polynomial encoding")
- }
- if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil {
- return T{}, errors.New("mlkem768: invalid polynomial encoding")
- }
- b = b[3:]
- }
- return f, nil
-}
-
-// sliceForAppend takes a slice and a requested number of bytes. It returns a
-// slice with the contents of the given slice followed by that many bytes and a
-// second slice that aliases into it and contains only the extra bytes. If the
-// original slice has sufficient capacity then no allocation is performed.
-func sliceForAppend(in []byte, n int) (head, tail []byte) {
- if total := len(in) + n; cap(in) >= total {
- head = in[:total]
- } else {
- head = make([]byte, total)
- copy(head, in)
- }
- tail = head[len(in):]
- return
-}
-
-// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s,
-// compressing one coefficients per bit.
-//
-// It implements Compress₁, according to FIPS 203, Definition 4.7,
-// followed by ByteEncode₁, according to FIPS 203, Algorithm 5.
-func ringCompressAndEncode1(s []byte, f ringElement) []byte {
- s, b := sliceForAppend(s, encodingSize1)
- for i := range b {
- b[i] = 0
- }
- for i := range f {
- b[i/8] |= uint8(compress(f[i], 1) << (i % 8))
- }
- return s
-}
-
-// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each
-// bit is mapped to 0 or ⌈q/2⌋.
-//
-// It implements ByteDecode₁, according to FIPS 203, Algorithm 6,
-// followed by Decompress₁, according to FIPS 203, Definition 4.8.
-func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement {
- var f ringElement
- for i := range f {
- b_i := b[i/8] >> (i % 8) & 1
- const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203, Section 2.3
- f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋
- }
- return f
-}
-
-// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s,
-// compressing two coefficients per byte.
-//
-// It implements Compress₄, according to FIPS 203, Definition 4.7,
-// followed by ByteEncode₄, according to FIPS 203, Algorithm 5.
-func ringCompressAndEncode4(s []byte, f ringElement) []byte {
- s, b := sliceForAppend(s, encodingSize4)
- for i := 0; i < n; i += 2 {
- b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4)
- }
- return s
-}
-
-// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where
-// each four bits are mapped to an equidistant distribution.
-//
-// It implements ByteDecode₄, according to FIPS 203, Algorithm 6,
-// followed by Decompress₄, according to FIPS 203, Definition 4.8.
-func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement {
- var f ringElement
- for i := 0; i < n; i += 2 {
- f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4))
- f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4))
- }
- return f
-}
-
-// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s,
-// compressing four coefficients per five bytes.
-//
-// It implements Compress₁₀, according to FIPS 203, Definition 4.7,
-// followed by ByteEncode₁₀, according to FIPS 203, Algorithm 5.
-func ringCompressAndEncode10(s []byte, f ringElement) []byte {
- s, b := sliceForAppend(s, encodingSize10)
- for i := 0; i < n; i += 4 {
- var x uint64
- x |= uint64(compress(f[i+0], 10))
- x |= uint64(compress(f[i+1], 10)) << 10
- x |= uint64(compress(f[i+2], 10)) << 20
- x |= uint64(compress(f[i+3], 10)) << 30
- b[0] = uint8(x)
- b[1] = uint8(x >> 8)
- b[2] = uint8(x >> 16)
- b[3] = uint8(x >> 24)
- b[4] = uint8(x >> 32)
- b = b[5:]
- }
- return s
-}
-
-// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where
-// each ten bits are mapped to an equidistant distribution.
-//
-// It implements ByteDecode₁₀, according to FIPS 203, Algorithm 6,
-// followed by Decompress₁₀, according to FIPS 203, Definition 4.8.
-func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement {
- b := bb[:]
- var f ringElement
- for i := 0; i < n; i += 4 {
- x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32
- b = b[5:]
- f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10))
- f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10))
- f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10))
- f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10))
- }
- return f
-}
-
-// samplePolyCBD draws a ringElement from the special Dη distribution given a
-// stream of random bytes generated by the PRF function, according to FIPS 203,
-// Algorithm 8 and Definition 4.3.
-func samplePolyCBD(s []byte, b byte) ringElement {
- prf := sha3.NewShake256()
- prf.Write(s)
- prf.Write([]byte{b})
- B := make([]byte, 64*η)
- prf.Read(B)
-
- // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds
- // the first two and subtracts the last two.
-
- var f ringElement
- for i := 0; i < n; i += 2 {
- b := B[i/2]
- b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1
- b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1
- f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3))
- f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7))
- }
- return f
-}
-
-// nttElement is an NTT representation, an element of T_q, represented as an
-// array according to FIPS 203, Section 2.4.4.
-type nttElement [n]fieldElement
-
-// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i, according to
-// FIPS 203, Appendix A (with negative values reduced to positive).
-var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175}
-
-// nttMul multiplies two nttElements.
-//
-// It implements MultiplyNTTs, according to FIPS 203, Algorithm 11.
-func nttMul(f, g nttElement) nttElement {
- var h nttElement
- // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826.
- for i := 0; i < 256; i += 2 {
- a0, a1 := f[i], f[i+1]
- b0, b1 := g[i], g[i+1]
- h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2])
- h[i+1] = fieldAddMul(a0, b1, a1, b0)
- }
- return h
-}
-
-// zetas are the values ζ^BitRev7(k) mod q for each index k, according to FIPS
-// 203, Appendix A.
-var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154}
-
-// ntt maps a ringElement to its nttElement representation.
-//
-// It implements NTT, according to FIPS 203, Algorithm 9.
-func ntt(f ringElement) nttElement {
- k := 1
- for len := 128; len >= 2; len /= 2 {
- for start := 0; start < 256; start += 2 * len {
- zeta := zetas[k]
- k++
- // Bounds check elimination hint.
- f, flen := f[start:start+len], f[start+len:start+len+len]
- for j := 0; j < len; j++ {
- t := fieldMul(zeta, flen[j])
- flen[j] = fieldSub(f[j], t)
- f[j] = fieldAdd(f[j], t)
- }
- }
- }
- return nttElement(f)
-}
-
-// inverseNTT maps a nttElement back to the ringElement it represents.
-//
-// It implements NTT⁻¹, according to FIPS 203, Algorithm 10.
-func inverseNTT(f nttElement) ringElement {
- k := 127
- for len := 2; len <= 128; len *= 2 {
- for start := 0; start < 256; start += 2 * len {
- zeta := zetas[k]
- k--
- // Bounds check elimination hint.
- f, flen := f[start:start+len], f[start+len:start+len+len]
- for j := 0; j < len; j++ {
- t := f[j]
- f[j] = fieldAdd(t, flen[j])
- flen[j] = fieldMulSub(zeta, flen[j], t)
- }
- }
- }
- for i := range f {
- f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q
- }
- return ringElement(f)
-}
-
-// sampleNTT draws a uniformly random nttElement from a stream of uniformly
-// random bytes generated by the XOF function, according to FIPS 203,
-// Algorithm 7.
-func sampleNTT(rho []byte, ii, jj byte) nttElement {
- B := sha3.NewShake128()
- B.Write(rho)
- B.Write([]byte{ii, jj})
-
- // SampleNTT essentially draws 12 bits at a time from r, interprets them in
- // little-endian, and rejects values higher than q, until it drew 256
- // values. (The rejection rate is approximately 19%.)
- //
- // To do this from a bytes stream, it draws three bytes at a time, and
- // splits them into two uint16 appropriately masked.
- //
- // r₀ r₁ r₂
- // |- - - - - - - -|- - - - - - - -|- - - - - - - -|
- //
- // Uint16(r₀ || r₁)
- // |- - - - - - - - - - - - - - - -|
- // |- - - - - - - - - - - -|
- // d₁
- //
- // Uint16(r₁ || r₂)
- // |- - - - - - - - - - - - - - - -|
- // |- - - - - - - - - - - -|
- // d₂
- //
- // Note that in little-endian, the rightmost bits are the most significant
- // bits (dropped with a mask) and the leftmost bits are the least
- // significant bits (dropped with a right shift).
-
- var a nttElement
- var j int // index into a
- var buf [24]byte // buffered reads from B
- off := len(buf) // index into buf, starts in a "buffer fully consumed" state
- for {
- if off >= len(buf) {
- B.Read(buf[:])
- off = 0
- }
- d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111
- d2 := byteorder.LeUint16(buf[off+1:]) >> 4
- off += 3
- if d1 < q {
- a[j] = fieldElement(d1)
- j++
- }
- if j >= len(a) {
- break
- }
- if d2 < q {
- a[j] = fieldElement(d2)
- j++
- }
- if j >= len(a) {
- break
- }
- }
- return a
-}
_ "embed"
"encoding/hex"
"flag"
- "math/big"
- "strconv"
"testing"
"golang.org/x/crypto/sha3"
)
-func TestFieldReduce(t *testing.T) {
- for a := uint32(0); a < 2*q*q; a++ {
- got := fieldReduce(a)
- exp := fieldElement(a % q)
- if got != exp {
- t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
- }
- }
-}
-
-func TestFieldAdd(t *testing.T) {
- for a := fieldElement(0); a < q; a++ {
- for b := fieldElement(0); b < q; b++ {
- got := fieldAdd(a, b)
- exp := (a + b) % q
- if got != exp {
- t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
- }
- }
- }
-}
-
-func TestFieldSub(t *testing.T) {
- for a := fieldElement(0); a < q; a++ {
- for b := fieldElement(0); b < q; b++ {
- got := fieldSub(a, b)
- exp := (a - b + q) % q
- if got != exp {
- t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
- }
- }
- }
-}
-
-func TestFieldMul(t *testing.T) {
- for a := fieldElement(0); a < q; a++ {
- for b := fieldElement(0); b < q; b++ {
- got := fieldMul(a, b)
- exp := fieldElement((uint32(a) * uint32(b)) % q)
- if got != exp {
- t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
- }
- }
- }
-}
-
-func TestDecompressCompress(t *testing.T) {
- for _, bits := range []uint8{1, 4, 10} {
- for a := uint16(0); a < 1<<bits; a++ {
- f := decompress(a, bits)
- if f >= q {
- t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
- }
- got := compress(f, bits)
- if got != a {
- t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
- }
- }
-
- for a := fieldElement(0); a < q; a++ {
- c := compress(a, bits)
- if c >= 1<<bits {
- t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
- }
- got := decompress(c, bits)
- diff := min(a-got, got-a, a-got+q, got-a+q)
- ceil := q / (1 << bits)
- if diff > fieldElement(ceil) {
- t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
- a, bits, bits, got, diff, ceil)
- }
- }
- }
-}
-
-func CompressRat(x fieldElement, d uint8) uint16 {
- if x >= q {
- panic("x out of range")
- }
- if d <= 0 || d >= 12 {
- panic("d out of range")
- }
-
- precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
-
- // FloatString rounds halves away from 0, and our result should always be positive,
- // so it should work as we expect. (There's no direct way to round a Rat.)
- rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
- if err != nil {
- panic(err)
- }
-
- // If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
- return uint16(rounded % (1 << d))
-}
-
-func TestCompress(t *testing.T) {
- for d := 1; d < 12; d++ {
- for n := 0; n < q; n++ {
- expected := CompressRat(fieldElement(n), uint8(d))
- result := compress(fieldElement(n), uint8(d))
- if result != expected {
- t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
- }
- }
- }
-}
-
-func DecompressRat(y uint16, d uint8) fieldElement {
- if y >= 1<<d {
- panic("y out of range")
- }
- if d <= 0 || d >= 12 {
- panic("d out of range")
- }
-
- precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y == (q * y) / 2ᵈ
-
- // FloatString rounds halves away from 0, and our result should always be positive,
- // so it should work as we expect. (There's no direct way to round a Rat.)
- rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
- if err != nil {
- panic(err)
- }
-
- // If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
- return fieldElement(rounded % q)
-}
-
-func TestDecompress(t *testing.T) {
- for d := 1; d < 12; d++ {
- for n := 0; n < (1 << d); n++ {
- expected := DecompressRat(uint16(n), uint8(d))
- result := decompress(uint16(n), uint8(d))
- if result != expected {
- t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
- }
- }
- }
-}
-
-func BitRev7(n uint8) uint8 {
- if n>>7 != 0 {
- panic("not 7 bits")
- }
- var r uint8
- r |= n >> 6 & 0b0000_0001
- r |= n >> 4 & 0b0000_0010
- r |= n >> 2 & 0b0000_0100
- r |= n /**/ & 0b0000_1000
- r |= n << 2 & 0b0001_0000
- r |= n << 4 & 0b0010_0000
- r |= n << 6 & 0b0100_0000
- return r
-}
-
-func TestZetas(t *testing.T) {
- ζ := big.NewInt(17)
- q := big.NewInt(q)
- for k, zeta := range zetas {
- // ζ^BitRev7(k) mod q
- exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
- if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
- t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
- }
- }
-}
-
-func TestGammas(t *testing.T) {
- ζ := big.NewInt(17)
- q := big.NewInt(q)
- for k, gamma := range gammas {
- // ζ^2BitRev7(i)+1
- exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
- if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
- t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
- }
- }
-}
-
func TestRoundTrip(t *testing.T) {
dk, err := GenerateKey()
if err != nil {