"crypto/sha256"
"encoding/hex"
"fmt"
- "io"
"os"
)
// a buffer that contains a random key. Thus, if the RSA result isn't
// well-formed, the implementation uses a random key in constant time.
func ExampleDecryptPKCS1v15SessionKey() {
- // crypto/rand.Reader is a good source of entropy for blinding the RSA
- // operation.
- rng := rand.Reader
-
// The hybrid scheme should use at least a 16-byte symmetric key. Here
// we read the random key that will be used if the RSA decryption isn't
// well-formed.
key := make([]byte, 32)
- if _, err := io.ReadFull(rng, key); err != nil {
+ if _, err := rand.Read(key); err != nil {
panic("RNG failure")
}
rsaCiphertext, _ := hex.DecodeString("aabbccddeeff")
- if err := rsa.DecryptPKCS1v15SessionKey(rng, rsaPrivateKey, rsaCiphertext, key); err != nil {
+ if err := rsa.DecryptPKCS1v15SessionKey(nil, rsaPrivateKey, rsaCiphertext, key); err != nil {
// Any errors that result will be “public” – meaning that they
// can be determined without any secret information. (For
// instance, if the length of key is impossible given the RSA
}
func ExampleSignPKCS1v15() {
- // crypto/rand.Reader is a good source of entropy for blinding the RSA
- // operation.
- rng := rand.Reader
-
message := []byte("message to be signed")
// Only small messages can be signed directly; thus the hash of a
// of writing (2016).
hashed := sha256.Sum256(message)
- signature, err := rsa.SignPKCS1v15(rng, rsaPrivateKey, crypto.SHA256, hashed[:])
+ signature, err := rsa.SignPKCS1v15(nil, rsaPrivateKey, crypto.SHA256, hashed[:])
if err != nil {
fmt.Fprintf(os.Stderr, "Error from signing: %s\n", err)
return
ciphertext, _ := hex.DecodeString("4d1ee10e8f286390258c51a5e80802844c3e6358ad6690b7285218a7c7ed7fc3a4c7b950fbd04d4b0239cc060dcc7065ca6f84c1756deb71ca5685cadbb82be025e16449b905c568a19c088a1abfad54bf7ecc67a7df39943ec511091a34c0f2348d04e058fcff4d55644de3cd1d580791d4524b92f3e91695582e6e340a1c50b6c6d78e80b4e42c5b4d45e479b492de42bbd39cc642ebb80226bb5200020d501b24a37bcc2ec7f34e596b4fd6b063de4858dbf5a4e3dd18e262eda0ec2d19dbd8e890d672b63d368768360b20c0b6b8592a438fa275e5fa7f60bef0dd39673fd3989cc54d2cb80c08fcd19dacbc265ee1c6014616b0e04ea0328c2a04e73460")
label := []byte("orders")
- // crypto/rand.Reader is a good source of entropy for blinding the RSA
- // operation.
- rng := rand.Reader
-
- plaintext, err := rsa.DecryptOAEP(sha256.New(), rng, test2048Key, ciphertext, label)
+ plaintext, err := rsa.DecryptOAEP(sha256.New(), nil, test2048Key, ciphertext, label)
if err != nil {
fmt.Fprintf(os.Stderr, "Error from decryption: %s\n", err)
return
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+import (
+ "math/big"
+ "math/bits"
+)
+
+const (
+ // _W is the number of bits we use for our limbs.
+ _W = bits.UintSize - 1
+ // _MASK selects _W bits from a full machine word.
+ _MASK = (1 << _W) - 1
+)
+
+// choice represents a constant-time boolean. The value of choice is always
+// either 1 or 0. We use an int instead of bool in order to make decisions in
+// constant time by turning it into a mask.
+type choice uint
+
+func not(c choice) choice { return 1 ^ c }
+
+const yes = choice(1)
+const no = choice(0)
+
+// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
+// function does not depend on its inputs. If on is any value besides 1 or 0,
+// the result is undefined.
+func ctSelect(on choice, x, y uint) uint {
+ // When on == 1, mask is 0b111..., otherwise mask is 0b000...
+ mask := -uint(on)
+ // When mask is all zeros, we just have y, otherwise, y cancels with itself.
+ return y ^ (mask & (y ^ x))
+}
+
+// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
+// function does not depend on its inputs.
+func ctEq(x, y uint) choice {
+ // If x != y, then either x - y or y - x will generate a carry.
+ _, c1 := bits.Sub(x, y, 0)
+ _, c2 := bits.Sub(y, x, 0)
+ return not(choice(c1 | c2))
+}
+
+// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
+// function does not depend on its inputs.
+func ctGeq(x, y uint) choice {
+ // If x < y, then x - y generates a carry.
+ _, carry := bits.Sub(x, y, 0)
+ return not(choice(carry))
+}
+
+// nat represents an arbitrary natural number
+//
+// Each nat has an announced length, which is the number of limbs it has stored.
+// Operations on this number are allowed to leak this length, but will not leak
+// any information about the values contained in those limbs.
+type nat struct {
+ // limbs is a little-endian representation in base 2^W with
+ // W = bits.UintSize - 1. The top bit is always unset between operations.
+ //
+ // The top bit is left unset to optimize Montgomery multiplication, in the
+ // inner loop of exponentiation. Using fully saturated limbs would leave us
+ // working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
+ // and thus time.
+ limbs []uint
+}
+
+// expand expands x to n limbs, leaving its value unchanged.
+func (x *nat) expand(n int) *nat {
+ for len(x.limbs) > n {
+ if x.limbs[len(x.limbs)-1] != 0 {
+ panic("rsa: internal error: shrinking nat")
+ }
+ x.limbs = x.limbs[:len(x.limbs)-1]
+ }
+ if cap(x.limbs) < n {
+ newLimbs := make([]uint, n)
+ copy(newLimbs, x.limbs)
+ x.limbs = newLimbs
+ return x
+ }
+ extraLimbs := x.limbs[len(x.limbs):n]
+ for i := range extraLimbs {
+ extraLimbs[i] = 0
+ }
+ x.limbs = x.limbs[:n]
+ return x
+}
+
+// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
+func (x *nat) reset(n int) *nat {
+ if cap(x.limbs) < n {
+ x.limbs = make([]uint, n)
+ return x
+ }
+ for i := range x.limbs {
+ x.limbs[i] = 0
+ }
+ x.limbs = x.limbs[:n]
+ return x
+}
+
+// clone returns a new nat, with the same value and announced length as x.
+func (x *nat) clone() *nat {
+ out := &nat{make([]uint, len(x.limbs))}
+ copy(out.limbs, x.limbs)
+ return out
+}
+
+// natFromBig creates a new natural number from a big.Int.
+//
+// The announced length of the resulting nat is based on the actual bit size of
+// the input, ignoring leading zeroes.
+func natFromBig(x *big.Int) *nat {
+ xLimbs := x.Bits()
+ bitSize := bigBitLen(x)
+ requiredLimbs := (bitSize + _W - 1) / _W
+
+ out := &nat{make([]uint, requiredLimbs)}
+ outI := 0
+ shift := 0
+ for i := range xLimbs {
+ xi := uint(xLimbs[i])
+ out.limbs[outI] |= (xi << shift) & _MASK
+ outI++
+ if outI == requiredLimbs {
+ return out
+ }
+ out.limbs[outI] = xi >> (_W - shift)
+ shift++ // this assumes bits.UintSize - _W = 1
+ if shift == _W {
+ shift = 0
+ outI++
+ }
+ }
+ return out
+}
+
+// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
+//
+// If bytes is not long enough to contain the number or at least len(x.limbs)-1
+// limbs, or has zero length, fillBytes will panic.
+func (x *nat) fillBytes(bytes []byte) []byte {
+ if len(bytes) == 0 {
+ panic("nat: fillBytes invoked with too small buffer")
+ }
+ for i := range bytes {
+ bytes[i] = 0
+ }
+ shift := 0
+ outI := len(bytes) - 1
+ for i, limb := range x.limbs {
+ remainingBits := _W
+ for remainingBits >= 8 {
+ bytes[outI] |= byte(limb) << shift
+ consumed := 8 - shift
+ limb >>= consumed
+ remainingBits -= consumed
+ shift = 0
+ outI--
+ if outI < 0 {
+ if limb != 0 || i < len(x.limbs)-1 {
+ panic("nat: fillBytes invoked with too small buffer")
+ }
+ return bytes
+ }
+ }
+ bytes[outI] = byte(limb)
+ shift = remainingBits
+ }
+ return bytes
+}
+
+// natFromBytes converts a slice of big-endian bytes into a nat.
+//
+// The announced length of the output depends on the length of bytes. Unlike
+// big.Int, creating a nat will not remove leading zeros.
+func natFromBytes(bytes []byte) *nat {
+ bitSize := len(bytes) * 8
+ requiredLimbs := (bitSize + _W - 1) / _W
+
+ out := &nat{make([]uint, requiredLimbs)}
+ outI := 0
+ shift := 0
+ for i := len(bytes) - 1; i >= 0; i-- {
+ bi := bytes[i]
+ out.limbs[outI] |= uint(bi) << shift
+ shift += 8
+ if shift >= _W {
+ shift -= _W
+ out.limbs[outI] &= _MASK
+ outI++
+ if shift > 0 {
+ out.limbs[outI] = uint(bi) >> (8 - shift)
+ }
+ }
+ }
+ return out
+}
+
+// cmpEq returns 1 if x == y, and 0 otherwise.
+//
+// Both operands must have the same announced length.
+func (x *nat) cmpEq(y *nat) choice {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ equal := yes
+ for i := 0; i < size; i++ {
+ equal &= ctEq(xLimbs[i], yLimbs[i])
+ }
+ return equal
+}
+
+// cmpGeq returns 1 if x >= y, and 0 otherwise.
+//
+// Both operands must have the same announced length.
+func (x *nat) cmpGeq(y *nat) choice {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ var c uint
+ for i := 0; i < size; i++ {
+ c = (xLimbs[i] - yLimbs[i] - c) >> _W
+ }
+ // If there was a carry, then subtracting y underflowed, so
+ // x is not greater than or equal to y.
+ return not(choice(c))
+}
+
+// assign sets x <- y if on == 1, and does nothing otherwise.
+//
+// Both operands must have the same announced length.
+func (x *nat) assign(on choice, y *nat) *nat {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
+ }
+ return x
+}
+
+// add computes x += y if on == 1, and does nothing otherwise. It returns the
+// carry of the addition regardless of on.
+//
+// Both operands must have the same announced length.
+func (x *nat) add(on choice, y *nat) (c uint) {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ res := xLimbs[i] + yLimbs[i] + c
+ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
+ c = res >> _W
+ }
+ return
+}
+
+// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
+// borrow of the subtraction regardless of on.
+//
+// Both operands must have the same announced length.
+func (x *nat) sub(on choice, y *nat) (c uint) {
+ // Eliminate bounds checks in the loop.
+ size := len(x.limbs)
+ xLimbs := x.limbs[:size]
+ yLimbs := y.limbs[:size]
+
+ for i := 0; i < size; i++ {
+ res := xLimbs[i] - yLimbs[i] - c
+ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
+ c = res >> _W
+ }
+ return
+}
+
+// modulus is used for modular arithmetic, precomputing relevant constants.
+//
+// Moduli are assumed to be odd numbers. Moduli can also leak the exact
+// number of bits needed to store their value, and are stored without padding.
+//
+// Their actual value is still kept secret.
+type modulus struct {
+ // The underlying natural number for this modulus.
+ //
+ // This will be stored without any padding, and shouldn't alias with any
+ // other natural number being used.
+ nat *nat
+ leading int // number of leading zeros in the modulus
+ m0inv uint // -nat.limbs[0]⁻¹ mod _W
+}
+
+// minusInverseModW computes -x⁻¹ mod _W with x odd.
+//
+// This operation is used to precompute a constant involved in Montgomery
+// multiplication.
+func minusInverseModW(x uint) uint {
+ // Every iteration of this loop doubles the least-significant bits of
+ // correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
+ // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
+ // for 61 bits (and wastes only one iteration for 31 bits).
+ //
+ // See https://crypto.stackexchange.com/a/47496.
+ y := x
+ for i := 0; i < 5; i++ {
+ y = y * (2 - x*y)
+ }
+ return (1 << _W) - (y & _MASK)
+}
+
+// modulusFromNat creates a new modulus from a nat.
+//
+// The nat should be odd, nonzero, and the number of significant bits in the
+// number should be leakable. The nat shouldn't be reused.
+func modulusFromNat(nat *nat) *modulus {
+ m := &modulus{}
+ m.nat = nat
+ size := len(m.nat.limbs)
+ for m.nat.limbs[size-1] == 0 {
+ size--
+ }
+ m.nat.limbs = m.nat.limbs[:size]
+ m.leading = _W - bitLen(m.nat.limbs[size-1])
+ m.m0inv = minusInverseModW(m.nat.limbs[0])
+ return m
+}
+
+// bitLen is a version of bits.Len that only leaks the bit length of n, but not
+// its value. bits.Len and bits.LeadingZeros use a lookup table for the
+// low-order bits on some architectures.
+func bitLen(n uint) int {
+ var len int
+ // We assume, here and elsewhere, that comparison to zero is constant time
+ // with respect to different non-zero values.
+ for n != 0 {
+ len++
+ n >>= 1
+ }
+ return len
+}
+
+// bigBitLen is a version of big.Int.BitLen that only leaks the bit length of x,
+// but not its value. big.Int.BitLen uses bits.Len.
+func bigBitLen(x *big.Int) int {
+ xLimbs := x.Bits()
+ fullLimbs := len(xLimbs) - 1
+ topLimb := uint(xLimbs[len(xLimbs)-1])
+ return fullLimbs*bits.UintSize + bitLen(topLimb)
+}
+
+// modulusSize returns the size of m in bytes.
+func modulusSize(m *modulus) int {
+ bits := len(m.nat.limbs)*_W - int(m.leading)
+ return (bits + 7) / 8
+}
+
+// shiftIn calculates x = x << _W + y mod m.
+//
+// This assumes that x is already reduced mod m, and that y < 2^_W.
+func (x *nat) shiftIn(y uint, m *modulus) *nat {
+ d := new(nat).resetFor(m)
+
+ // Eliminate bounds checks in the loop.
+ size := len(m.nat.limbs)
+ xLimbs := x.limbs[:size]
+ dLimbs := d.limbs[:size]
+ mLimbs := m.nat.limbs[:size]
+
+ // Each iteration of this loop computes x = 2x + b mod m, where b is a bit
+ // from y. Effectively, it left-shifts x and adds y one bit at a time,
+ // reducing it every time.
+ //
+ // To do the reduction, each iteration computes both 2x + b and 2x + b - m.
+ // The next iteration (and finally the return line) will use either result
+ // based on whether the subtraction underflowed.
+ needSubtraction := no
+ for i := _W - 1; i >= 0; i-- {
+ carry := (y >> i) & 1
+ var borrow uint
+ for i := 0; i < size; i++ {
+ l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
+
+ res := l<<1 + carry
+ xLimbs[i] = res & _MASK
+ carry = res >> _W
+
+ res = xLimbs[i] - mLimbs[i] - borrow
+ dLimbs[i] = res & _MASK
+ borrow = res >> _W
+ }
+ // See modAdd for how carry (aka overflow), borrow (aka underflow), and
+ // needSubtraction relate.
+ needSubtraction = ctEq(carry, borrow)
+ }
+ return x.assign(needSubtraction, d)
+}
+
+// mod calculates out = x mod m.
+//
+// This works regardless how large the value of x is.
+//
+// The output will be resized to the size of m and overwritten.
+func (out *nat) mod(x *nat, m *modulus) *nat {
+ out.resetFor(m)
+ // Working our way from the most significant to the least significant limb,
+ // we can insert each limb at the least significant position, shifting all
+ // previous limbs left by _W. This way each limb will get shifted by the
+ // correct number of bits. We can insert at least N - 1 limbs without
+ // overflowing m. After that, we need to reduce every time we shift.
+ i := len(x.limbs) - 1
+ // For the first N - 1 limbs we can skip the actual shifting and position
+ // them at the shifted position, which starts at min(N - 2, i).
+ start := len(m.nat.limbs) - 2
+ if i < start {
+ start = i
+ }
+ for j := start; j >= 0; j-- {
+ out.limbs[j] = x.limbs[i]
+ i--
+ }
+ // We shift in the remaining limbs, reducing modulo m each time.
+ for i >= 0 {
+ out.shiftIn(x.limbs[i], m)
+ i--
+ }
+ return out
+}
+
+// expandFor ensures out has the right size to work with operations modulo m.
+//
+// This assumes that out has as many or fewer limbs than m, or that the extra
+// limbs are all zero (which may happen when decoding a value that has leading
+// zeroes in its bytes representation that spill over the limb threshold).
+func (out *nat) expandFor(m *modulus) *nat {
+ return out.expand(len(m.nat.limbs))
+}
+
+// resetFor ensures out has the right size to work with operations modulo m.
+//
+// out is zeroed and may start at any size.
+func (out *nat) resetFor(m *modulus) *nat {
+ return out.reset(len(m.nat.limbs))
+}
+
+// modSub computes x = x - y mod m.
+//
+// The length of both operands must be the same as the modulus. Both operands
+// must already be reduced modulo m.
+func (x *nat) modSub(y *nat, m *modulus) *nat {
+ underflow := x.sub(yes, y)
+ // If the subtraction underflowed, add m.
+ x.add(choice(underflow), m.nat)
+ return x
+}
+
+// modAdd computes x = x + y mod m.
+//
+// The length of both operands must be the same as the modulus. Both operands
+// must already be reduced modulo m.
+func (x *nat) modAdd(y *nat, m *modulus) *nat {
+ overflow := x.add(yes, y)
+ underflow := not(x.cmpGeq(m.nat)) // x < m
+
+ // Three cases are possible:
+ //
+ // - overflow = 0, underflow = 0
+ //
+ // In this case, addition fits in our limbs, but we can still subtract away
+ // m without an underflow, so we need to perform the subtraction to reduce
+ // our result.
+ //
+ // - overflow = 0, underflow = 1
+ //
+ // The addition fits in our limbs, but we can't subtract m without
+ // underflowing. The result is already reduced.
+ //
+ // - overflow = 1, underflow = 1
+ //
+ // The addition does not fit in our limbs, and the subtraction's borrow
+ // would cancel out with the addition's carry. We need to subtract m to
+ // reduce our result.
+ //
+ // The overflow = 1, underflow = 0 case is not possible, because y is at
+ // most m - 1, and if adding m - 1 overflows, then subtracting m must
+ // necessarily underflow.
+ needSubtraction := ctEq(overflow, uint(underflow))
+
+ x.sub(needSubtraction, m.nat)
+ return x
+}
+
+// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs).
+//
+// Faster Montgomery multiplication replaces standard modular multiplication for
+// numbers in this representation.
+//
+// This assumes that x is already reduced mod m.
+func (x *nat) montgomeryRepresentation(m *modulus) *nat {
+ for i := 0; i < len(m.nat.limbs); i++ {
+ x.shiftIn(0, m) // x = x * 2^_W mod m
+ }
+ return x
+}
+
+// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
+// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
+//
+// All inputs should be the same length, not aliasing d, and already
+// reduced modulo m. d will be resized to the size of m and overwritten.
+func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
+ // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
+ // for a description of the algorithm.
+
+ // Eliminate bounds checks in the loop.
+ size := len(m.nat.limbs)
+ aLimbs := a.limbs[:size]
+ bLimbs := b.limbs[:size]
+ dLimbs := d.resetFor(m).limbs[:size]
+ mLimbs := m.nat.limbs[:size]
+
+ var overflow uint
+ for i := 0; i < size; i++ {
+ f := ((dLimbs[0] + aLimbs[i]*bLimbs[0]) * m.m0inv) & _MASK
+ carry := uint(0)
+ for j := 0; j < size; j++ {
+ // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
+ hi, lo := bits.Mul(aLimbs[i], bLimbs[j])
+ z_lo, c := bits.Add(dLimbs[j], lo, 0)
+ z_hi, _ := bits.Add(0, hi, c)
+ hi, lo = bits.Mul(f, mLimbs[j])
+ z_lo, c = bits.Add(z_lo, lo, 0)
+ z_hi, _ = bits.Add(z_hi, hi, c)
+ z_lo, c = bits.Add(z_lo, carry, 0)
+ z_hi, _ = bits.Add(z_hi, 0, c)
+ if j > 0 {
+ dLimbs[j-1] = z_lo & _MASK
+ }
+ carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
+ }
+ z := overflow + carry // z <= 2^(W+1) - 1
+ dLimbs[size-1] = z & _MASK
+ overflow = z >> _W // overflow <= 1
+ }
+ // See modAdd for how overflow, underflow, and needSubtraction relate.
+ underflow := not(d.cmpGeq(m.nat)) // d < m
+ needSubtraction := ctEq(overflow, uint(underflow))
+ d.sub(needSubtraction, m.nat)
+
+ return d
+}
+
+// modMul calculates x *= y mod m.
+//
+// x and y must already be reduced modulo m, they must share its announced
+// length, and they may not alias.
+func (x *nat) modMul(y *nat, m *modulus) *nat {
+ // A Montgomery multiplication by a value out of the Montgomery domain
+ // takes the result out of Montgomery representation.
+ xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m
+ return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
+}
+
+// exp calculates out = x^e mod m.
+//
+// The exponent e is represented in big-endian order. The output will be resized
+// to the size of m and overwritten. x must already be reduced modulo m.
+func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
+ // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
+ // than 2 bit windows, but use an extra 12 nats worth of scratch space.
+ // Using bit sizes that don't divide 8 are more complex to implement.
+ table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1)
+ table[0] = x.clone().montgomeryRepresentation(m)
+ for i := 1; i < len(table); i++ {
+ table[i] = new(nat).expandFor(m)
+ table[i].montgomeryMul(table[i-1], table[0], m)
+ }
+
+ out.resetFor(m)
+ out.limbs[0] = 1
+ out.montgomeryRepresentation(m)
+ t0 := new(nat).expandFor(m)
+ t1 := new(nat).expandFor(m)
+ for _, b := range e {
+ for _, j := range []int{4, 0} {
+ // Square four times.
+ t1.montgomeryMul(out, out, m)
+ out.montgomeryMul(t1, t1, m)
+ t1.montgomeryMul(out, out, m)
+ out.montgomeryMul(t1, t1, m)
+
+ // Select x^k in constant time from the table.
+ k := uint((b >> j) & 0b1111)
+ for i := range table {
+ t0.assign(ctEq(k, uint(i+1)), table[i])
+ }
+
+ // Multiply by x^k, discarding the result if k = 0.
+ t1.montgomeryMul(out, t0, m)
+ out.assign(not(ctEq(k, 0)), t1)
+ }
+ }
+
+ // By Montgomery multiplying with 1 not in Montgomery representation, we
+ // convert out back from Montgomery representation, because it works out to
+ // dividing by R.
+ t0.assign(yes, out)
+ t1.resetFor(m)
+ t1.limbs[0] = 1
+ out.montgomeryMul(t0, t1, m)
+
+ return out
+}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+import (
+ "bytes"
+ "math/big"
+ "math/bits"
+ "math/rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+)
+
+// Generate generates an even nat. It's used by testing/quick to produce random
+// *nat values for quick.Check invocations.
+func (*nat) Generate(r *rand.Rand, size int) reflect.Value {
+ limbs := make([]uint, size)
+ for i := 0; i < size; i++ {
+ limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
+ }
+ return reflect.ValueOf(&nat{limbs})
+}
+
+func testModAddCommutative(a *nat, b *nat) bool {
+ mLimbs := make([]uint, len(a.limbs))
+ for i := 0; i < len(mLimbs); i++ {
+ mLimbs[i] = _MASK
+ }
+ m := modulusFromNat(&nat{mLimbs})
+ aPlusB := a.clone()
+ aPlusB.modAdd(b, m)
+ bPlusA := b.clone()
+ bPlusA.modAdd(a, m)
+ return aPlusB.cmpEq(bPlusA) == 1
+}
+
+func TestModAddCommutative(t *testing.T) {
+ err := quick.Check(testModAddCommutative, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testModSubThenAddIdentity(a *nat, b *nat) bool {
+ mLimbs := make([]uint, len(a.limbs))
+ for i := 0; i < len(mLimbs); i++ {
+ mLimbs[i] = _MASK
+ }
+ m := modulusFromNat(&nat{mLimbs})
+ original := a.clone()
+ a.modSub(b, m)
+ a.modAdd(b, m)
+ return a.cmpEq(original) == 1
+}
+
+func TestModSubThenAddIdentity(t *testing.T) {
+ err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func testMontgomeryRoundtrip(a *nat) bool {
+ one := &nat{make([]uint, len(a.limbs))}
+ one.limbs[0] = 1
+ aPlusOne := a.clone()
+ aPlusOne.add(1, one)
+ m := modulusFromNat(aPlusOne)
+ monty := a.clone()
+ monty.montgomeryRepresentation(m)
+ aAgain := monty.clone()
+ aAgain.montgomeryMul(monty, one, m)
+ return a.cmpEq(aAgain) == 1
+}
+
+func TestMontgomeryRoundtrip(t *testing.T) {
+ err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestFromBig(t *testing.T) {
+ expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
+ theBig := new(big.Int).SetBytes(expected)
+ actual := natFromBig(theBig).fillBytes(make([]byte, len(expected)))
+ if !bytes.Equal(actual, expected) {
+ t.Errorf("%+x != %+x", actual, expected)
+ }
+}
+
+func TestFillBytes(t *testing.T) {
+ xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
+ x := natFromBytes(xBytes)
+ for l := 20; l >= len(xBytes); l-- {
+ buf := make([]byte, l)
+ rand.Read(buf)
+ actual := x.fillBytes(buf)
+ expected := make([]byte, l)
+ copy(expected[l-len(xBytes):], xBytes)
+ if !bytes.Equal(actual, expected) {
+ t.Errorf("%d: %+v != %+v", l, actual, expected)
+ }
+ }
+ for l := len(xBytes) - 1; l >= 0; l-- {
+ (func() {
+ defer func() {
+ if recover() == nil {
+ t.Errorf("%d: expected panic", l)
+ }
+ }()
+ x.fillBytes(make([]byte, l))
+ })()
+ }
+}
+
+func TestFromBytes(t *testing.T) {
+ f := func(xBytes []byte) bool {
+ if len(xBytes) == 0 {
+ return true
+ }
+ actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
+ if !bytes.Equal(actual, xBytes) {
+ t.Errorf("%+x != %+x", actual, xBytes)
+ return false
+ }
+ return true
+ }
+
+ err := quick.Check(f, &quick.Config{})
+ if err != nil {
+ t.Error(err)
+ }
+
+ f([]byte{0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
+ f(bytes.Repeat([]byte{0xFF}, _W))
+}
+
+func TestShiftIn(t *testing.T) {
+ if bits.UintSize != 64 {
+ t.Skip("examples are only valid in 64 bit")
+ }
+ examples := []struct {
+ m, x, expected []byte
+ y uint64
+ }{{
+ m: []byte{13},
+ x: []byte{0},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{7},
+ }, {
+ m: []byte{13},
+ x: []byte{7},
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{11},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: make([]byte, 9),
+ y: 0x7FFF_FFFF_FFFF_FFFF,
+ expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ }, {
+ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
+ x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ y: 0,
+ expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
+ }}
+
+ for i, tt := range examples {
+ m := modulusFromNat(natFromBytes(tt.m))
+ got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
+ if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
+ }
+ }
+}
+
+func TestModulusAndNatSizes(t *testing.T) {
+ // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
+ // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
+ // limbs, if they are not, they fit in three. This can be a problem because
+ // modulus strips leading zeroes and nat does not.
+ m := modulusFromNat(natFromBytes([]byte{
+ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
+ x := natFromBytes([]byte{
+ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
+ x.expandFor(m) // must not panic for shrinking
+}
+
+func TestExpand(t *testing.T) {
+ sliced := []uint{1, 2, 3, 4}
+ examples := []struct {
+ in []uint
+ n int
+ out []uint
+ }{{
+ []uint{1, 2},
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ sliced[:2],
+ 4,
+ []uint{1, 2, 0, 0},
+ }, {
+ []uint{1, 2},
+ 2,
+ []uint{1, 2},
+ }, {
+ []uint{1, 2, 0},
+ 2,
+ []uint{1, 2},
+ }}
+
+ for i, tt := range examples {
+ got := (&nat{tt.in}).expand(tt.n)
+ if len(got.limbs) != len(tt.out) || got.cmpEq(&nat{tt.out}) != 1 {
+ t.Errorf("%d: got %x, expected %x", i, got, tt.out)
+ }
+ }
+}
+
+func TestMod(t *testing.T) {
+ m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
+ x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+ out := new(nat)
+ out.mod(x, m)
+ expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+ if out.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func TestModSub(t *testing.T) {
+ m := modulusFromNat(&nat{[]uint{13}})
+ x := &nat{[]uint{6}}
+ y := &nat{[]uint{7}}
+ x.modSub(y, m)
+ expected := &nat{[]uint{12}}
+ if x.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.modSub(y, m)
+ expected = &nat{[]uint{5}}
+ if x.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestModAdd(t *testing.T) {
+ m := modulusFromNat(&nat{[]uint{13}})
+ x := &nat{[]uint{6}}
+ y := &nat{[]uint{7}}
+ x.modAdd(y, m)
+ expected := &nat{[]uint{0}}
+ if x.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+ x.modAdd(y, m)
+ expected = &nat{[]uint{7}}
+ if x.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", x, expected)
+ }
+}
+
+func TestExp(t *testing.T) {
+ m := modulusFromNat(&nat{[]uint{13}})
+ x := &nat{[]uint{3}}
+ out := &nat{[]uint{0}}
+ out.exp(x, []byte{12}, m)
+ expected := &nat{[]uint{1}}
+ if out.cmpEq(expected) != 1 {
+ t.Errorf("%+v != %+v", out, expected)
+ }
+}
+
+func makeBenchmarkModulus() *modulus {
+ m := make([]uint, 32)
+ for i := 0; i < 32; i++ {
+ m[i] = _MASK
+ }
+ return modulusFromNat(&nat{limbs: m})
+}
+
+func makeBenchmarkValue() *nat {
+ x := make([]uint, 32)
+ for i := 0; i < 32; i++ {
+ x[i] = _MASK - 1
+ }
+ return &nat{limbs: x}
+}
+
+func makeBenchmarkExponent() []byte {
+ e := make([]byte, 256)
+ for i := 0; i < 32; i++ {
+ e[i] = 0xFF
+ }
+ return e
+}
+
+func BenchmarkModAdd(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.modAdd(y, m)
+ }
+}
+
+func BenchmarkModSub(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.modSub(y, m)
+ }
+}
+
+func BenchmarkMontgomeryRepr(b *testing.B) {
+ x := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.montgomeryRepresentation(m)
+ }
+}
+
+func BenchmarkMontgomeryMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.montgomeryMul(x, y, m)
+ }
+}
+
+func BenchmarkModMul(b *testing.B) {
+ x := makeBenchmarkValue()
+ y := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ x.modMul(y, m)
+ }
+}
+
+func BenchmarkExpBig(b *testing.B) {
+ out := new(big.Int)
+ exponentBytes := makeBenchmarkExponent()
+ x := new(big.Int).SetBytes(exponentBytes)
+ e := new(big.Int).SetBytes(exponentBytes)
+ n := new(big.Int).SetBytes(exponentBytes)
+ one := new(big.Int).SetUint64(1)
+ n.Add(n, one)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.Exp(x, e, n)
+ }
+}
+
+func BenchmarkExp(b *testing.B) {
+ x := makeBenchmarkValue()
+ e := makeBenchmarkExponent()
+ out := makeBenchmarkValue()
+ m := makeBenchmarkModulus()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ out.exp(x, e, m)
+ }
+}
"crypto/subtle"
"errors"
"io"
- "math/big"
)
// This file implements encryption and decryption using PKCS #1 v1.5 padding.
return boring.EncryptRSANoPadding(bkey, em)
}
- m := new(big.Int).SetBytes(em)
- c := encrypt(new(big.Int), pub, m)
- return c.FillBytes(em), nil
+ return encrypt(pub, em), nil
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
-// If random != nil, it uses RSA blinding to avoid timing side-channel attacks.
+// The random parameter is legacy and ignored, and it can be as nil.
//
// Note that whether this function returns an error or not discloses secret
// information. If an attacker can cause this function to run repeatedly and
return out, nil
}
- valid, out, index, err := decryptPKCS1v15(random, priv, ciphertext)
+ valid, out, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return nil, err
}
}
// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS #1 v1.5.
-// If random != nil, it uses RSA blinding to avoid timing side-channel attacks.
+// The random parameter is legacy and ignored, and it can be as nil.
// It returns an error if the ciphertext is the wrong length or if the
// ciphertext is greater than the public modulus. Otherwise, no error is
// returned. If the padding is valid, the resulting plaintext message is copied
return ErrDecryption
}
- valid, em, index, err := decryptPKCS1v15(random, priv, ciphertext)
+ valid, em, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return err
}
return nil
}
-// decryptPKCS1v15 decrypts ciphertext using priv and blinds the operation if
-// random is not nil. It returns one or zero in valid that indicates whether the
-// plaintext was correctly structured. In either case, the plaintext is
-// returned in em so that it may be read independently of whether it was valid
-// in order to maintain constant memory access patterns. If the plaintext was
-// valid then index contains the index of the original message in em.
-func decryptPKCS1v15(random io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) {
+// decryptPKCS1v15 decrypts ciphertext using priv. It returns one or zero in
+// valid that indicates whether the plaintext was correctly structured.
+// In either case, the plaintext is returned in em so that it may be read
+// independently of whether it was valid in order to maintain constant memory
+// access patterns. If the plaintext was valid then index contains the index of
+// the original message in em, to allow constant time padding removal.
+func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) {
k := priv.Size()
if k < 11 {
err = ErrDecryption
return
}
} else {
- c := new(big.Int).SetBytes(ciphertext)
- var m *big.Int
- m, err = decrypt(random, priv, c)
+ em, err = decrypt(priv, ciphertext)
if err != nil {
return
}
- em = m.FillBytes(make([]byte, k))
}
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
// function. If hash is zero, hashed is signed directly. This isn't
// advisable except for interoperability.
//
-// If random is not nil then RSA blinding will be used to avoid timing
-// side-channel attacks.
+// The random parameter is legacy and ignored, and it can be as nil.
//
// This function is deterministic. Thus, if the set of possible
// messages is small, an attacker may be able to build a map from
copy(em[k-tLen:k-hashLen], prefix)
copy(em[k-hashLen:k], hashed)
- m := new(big.Int).SetBytes(em)
- c, err := decryptAndCheck(random, priv, m)
- if err != nil {
- return nil, err
- }
-
- return c.FillBytes(em), nil
+ return decryptAndCheck(priv, em)
}
// VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature.
return ErrVerification
}
- c := new(big.Int).SetBytes(sig)
- m := encrypt(new(big.Int), pub, c)
- em := m.FillBytes(make([]byte, k))
+ em := encrypt(pub, sig)
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
"errors"
"hash"
"io"
- "math/big"
)
// Per RFC 8017, Section 9.1
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
- emBits := priv.N.BitLen() - 1
+func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
+ emBits := bigBitLen(priv.N) - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return nil, err
return s, nil
}
- m := new(big.Int).SetBytes(em)
- c, err := decryptAndCheck(rand, priv, m)
- if err != nil {
- return nil, err
+ // RFC 8017: "Note that the octet length of EM will be one less than k if
+ // modBits - 1 is divisible by 8 and equal to k otherwise, where k is the
+ // length in octets of the RSA modulus n." 🙄
+ //
+ // This is extremely annoying, as all other encrypt and decrypt inputs are
+ // always the exact same size as the modulus. Since it only happens for
+ // weird modulus sizes, fix it by padding inefficiently.
+ if emLen, k := len(em), priv.Size(); emLen < k {
+ emNew := make([]byte, k)
+ copy(emNew[k-emLen:], em)
+ em = emNew
}
- s := make([]byte, priv.Size())
- return c.FillBytes(s), nil
+
+ return decryptAndCheck(priv, em)
}
const (
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
- saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
+ saltLength = (bigBitLen(priv.N)-1+7)/8 - 2 - hash.Size()
if saltLength < 0 {
return nil, ErrMessageTooLong
}
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
}
- return signPSSWithSalt(rand, priv, hash, digest, salt)
+ return signPSSWithSalt(priv, hash, digest, salt)
}
// VerifyPSS verifies a PSS signature.
if opts.saltLength() < PSSSaltLengthEqualsHash {
return invalidSaltLenErr
}
- s := new(big.Int).SetBytes(sig)
- m := encrypt(new(big.Int), pub, s)
- emBits := pub.N.BitLen() - 1
+
+ emBits := bigBitLen(pub.N) - 1
emLen := (emBits + 7) / 8
- if m.BitLen() > emLen*8 {
- return ErrVerification
+ em := encrypt(pub, sig)
+
+ // Like in signPSSWithSalt, deal with mismatches between emLen and the size
+ // of the modulus. The spec would have us wire emLen into the encoding
+ // function, but we'd rather always encode to the size of the modulus and
+ // then strip leading zeroes if necessary. This only happens for weird
+ // modulus sizes anyway.
+ for len(em) > emLen && len(em) > 0 {
+ if em[0] != 0 {
+ return ErrVerification
+ }
+ em = em[1:]
}
- em := m.FillBytes(make([]byte, emLen))
+
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
}
}
-func TestSignWithPSSSaltLengthAuto(t *testing.T) {
+func TestPSS513(t *testing.T) {
+ // See Issue 42741, and separately, RFC 8017: "Note that the octet length of
+ // EM will be one less than k if modBits - 1 is divisible by 8 and equal to
+ // k otherwise, where k is the length in octets of the RSA modulus n."
key, err := GenerateKey(rand.Reader, 513)
if err != nil {
t.Fatal(err)
if err != nil {
t.Fatal(err)
}
- if len(signature) == 0 {
- t.Fatal("empty signature returned")
+ err = VerifyPSS(&key.PublicKey, crypto.SHA256, digest[:], signature, nil)
+ if err != nil {
+ t.Error(err)
}
}
// over the public key primitive, the PrivateKey type implements the
// Decrypter and Signer interfaces from the crypto package.
//
-// The RSA operations in this package are not implemented using constant-time algorithms.
+// Operations in this package are implemented using constant-time algorithms,
+// except for [GenerateKey], [PrivateKey.Precompute], and [PrivateKey.Validate].
+// Every other operation only leaks the bit size of the involved values, which
+// all depend on the selected key size.
package rsa
import (
"crypto/internal/randutil"
"crypto/rand"
"crypto/subtle"
+ "encoding/binary"
"errors"
"hash"
"io"
"math/big"
)
-var bigZero = big.NewInt(0)
var bigOne = big.NewInt(1)
// A PublicKey represents the public part of an RSA key.
// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func (pub *PublicKey) Size() int {
- return (pub.N.BitLen() + 7) / 8
+ return (bigBitLen(pub.N) + 7) / 8
}
// Equal reports whether pub and x have the same value.
// be returned if the size of the salt is too large.
var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
-func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
+func encrypt(pub *PublicKey, plaintext []byte) []byte {
boring.Unreachable()
- e := big.NewInt(int64(pub.E))
- c.Exp(m, e, pub.N)
- return c
+
+ N := modulusFromNat(natFromBig(pub.N))
+ m := natFromBytes(plaintext).expandFor(N)
+
+ e := make([]byte, 8)
+ binary.BigEndian.PutUint64(e, uint64(pub.E))
+ for len(e) > 1 && e[0] == 0 {
+ e = e[1:]
+ }
+
+ out := make([]byte, modulusSize(N))
+ return new(nat).exp(m, e, N).fillBytes(out)
}
// EncryptOAEP encrypts the given message with RSA-OAEP.
return boring.EncryptRSANoPadding(bkey, em)
}
- m := new(big.Int)
- m.SetBytes(em)
- c := encrypt(new(big.Int), pub, m)
-
- out := make([]byte, k)
- return c.FillBytes(out), nil
+ return encrypt(pub, em), nil
}
// ErrDecryption represents a failure to decrypt a message.
}
}
-// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
-// random source is given, RSA blinding is used.
-func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
+// decrypt performs an RSA decryption of ciphertext into out.
+func decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
if len(priv.Primes) <= 2 {
boring.Unreachable()
}
- // TODO(agl): can we get away with reusing blinds?
- if c.Cmp(priv.N) > 0 {
- err = ErrDecryption
- return
+
+ N := modulusFromNat(natFromBig(priv.N))
+ c := natFromBytes(ciphertext).expandFor(N)
+ if c.cmpGeq(N.nat) == 1 {
+ return nil, ErrDecryption
}
if priv.N.Sign() == 0 {
return nil, ErrDecryption
}
- var ir *big.Int
- if random != nil {
- randutil.MaybeReadByte(random)
-
- // Blinding enabled. Blinding involves multiplying c by r^e.
- // Then the decryption operation performs (m^e * r^e)^d mod n
- // which equals mr mod n. The factor of r can then be removed
- // by multiplying by the multiplicative inverse of r.
-
- var r *big.Int
- ir = new(big.Int)
- for {
- r, err = rand.Int(random, priv.N)
- if err != nil {
- return
- }
- if r.Cmp(bigZero) == 0 {
- r = bigOne
- }
- ok := ir.ModInverse(r, priv.N)
- if ok != nil {
- break
- }
- }
- bigE := big.NewInt(int64(priv.E))
- rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0
- cCopy := new(big.Int).Set(c)
- cCopy.Mul(cCopy, rpowe)
- cCopy.Mod(cCopy, priv.N)
- c = cCopy
- }
-
+ // Note that because our private decryption exponents are stored as big.Int,
+ // we potentially leak the exact number of bits of these exponents. This
+ // isn't great, but should be fine.
if priv.Precomputed.Dp == nil {
- m = new(big.Int).Exp(c, priv.D, priv.N)
- } else {
- // We have the precalculated values needed for the CRT.
- m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
- m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
- m.Sub(m, m2)
- if m.Sign() < 0 {
- m.Add(m, priv.Primes[0])
- }
- m.Mul(m, priv.Precomputed.Qinv)
- m.Mod(m, priv.Primes[0])
- m.Mul(m, priv.Primes[1])
- m.Add(m, m2)
-
- for i, values := range priv.Precomputed.CRTValues {
- prime := priv.Primes[2+i]
- m2.Exp(c, values.Exp, prime)
- m2.Sub(m2, m)
- m2.Mul(m2, values.Coeff)
- m2.Mod(m2, prime)
- if m2.Sign() < 0 {
- m2.Add(m2, prime)
- }
- m2.Mul(m2, values.R)
- m.Add(m, m2)
- }
- }
-
- if ir != nil {
- // Unblind.
- m.Mul(m, ir)
- m.Mod(m, priv.N)
- }
-
- return
+ out := make([]byte, modulusSize(N))
+ return new(nat).exp(c, priv.D.Bytes(), N).fillBytes(out), nil
+ }
+
+ t0 := new(nat)
+ P := modulusFromNat(natFromBig(priv.Primes[0]))
+ Q := modulusFromNat(natFromBig(priv.Primes[1]))
+ // m = c ^ Dp mod p
+ m := new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
+ // m2 = c ^ Dq mod q
+ m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
+ // m = m - m2 mod p
+ m.modSub(t0.mod(m2, P), P)
+ // m = m * Qinv mod p
+ m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P)
+ // m = m * q mod N
+ m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
+ // m = m + m2 mod N
+ m.modAdd(m2.expandFor(N), N)
+
+ for i, values := range priv.Precomputed.CRTValues {
+ p := modulusFromNat(natFromBig(priv.Primes[2+i]))
+ // m2 = c ^ Exp mod p
+ m2.exp(t0.mod(c, p), values.Exp.Bytes(), p)
+ // m2 = m2 - m mod p
+ m2.modSub(t0.mod(m, p), p)
+ // m2 = m2 * Coeff mod p
+ m2.modMul(natFromBig(values.Coeff).expandFor(p), p)
+ // m2 = m2 * R mod N
+ R := natFromBig(values.R).expandFor(N)
+ m2.expandFor(N).modMul(R, N)
+ // m = m + m2 mod N
+ m.modAdd(m2, N)
+ }
+
+ out := make([]byte, modulusSize(N))
+ return m.fillBytes(out), nil
}
-func decryptAndCheck(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
- m, err = decrypt(random, priv, c)
+func decryptAndCheck(priv *PrivateKey, ciphertext []byte) (m []byte, err error) {
+ m, err = decrypt(priv, ciphertext)
if err != nil {
return nil, err
}
// In order to defend against errors in the CRT computation, m^e is
// calculated, which should match the original ciphertext.
- check := encrypt(new(big.Int), &priv.PublicKey, m)
- if c.Cmp(check) != 0 {
+ check := encrypt(&priv.PublicKey, m)
+ if subtle.ConstantTimeCompare(ciphertext, check) != 1 {
return nil, errors.New("rsa: internal error")
}
return m, nil
// Encryption and decryption of a given message must use the same hash function
// and sha256.New() is a reasonable choice.
//
-// The random parameter, if not nil, is used to blind the private-key operation
-// and avoid timing side-channel attacks. Blinding is purely internal to this
-// function – the random data need not match that used when encrypting.
+// The random parameter is legacy and ignored, and it can be as nil.
//
// The label parameter must match the value given when encrypting. See
// EncryptOAEP for details.
}
return out, nil
}
- c := new(big.Int).SetBytes(ciphertext)
- m, err := decrypt(random, priv, c)
+ em, err := decrypt(priv, ciphertext)
if err != nil {
return nil, err
}
lHash := hash.Sum(nil)
hash.Reset()
- // We probably leak the number of leading zeros.
- // It's not clear that we can do anything about this.
- em := m.FillBytes(make([]byte, k))
-
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
seed := em[1 : hash.Size()+1]
dec, err := DecryptPKCS1v15(nil, priv, enc)
if err != nil {
- t.Fatalf("DecryptPKCS1v15: %v", err)
- }
- if !bytes.Equal(dec, msg) {
- t.Errorf("got:%x want:%x (%+v)", dec, msg, priv)
- }
-
- dec, err = DecryptPKCS1v15(rand.Reader, priv, enc)
- if err != nil {
- t.Fatalf("DecryptPKCS1v15: %v", err)
+ t.Errorf("DecryptPKCS1v15: %v", err)
+ return
}
if !bytes.Equal(dec, msg) {
t.Errorf("got:%x want:%x (%+v)", dec, msg, priv)