limbs []uint
}
+// preallocTarget is the size in bits of the numbers used to implement the most
+// common and most performant RSA key size. It's also enough to cover some of
+// the operations of key sizes up to 4096.
+const preallocTarget = 2048
+const preallocLimbs = (preallocTarget + _W) / _W
+
+// newNat returns a new nat with a size of zero, just like new(nat), but with
+// the preallocated capacity to hold a number of up to preallocTarget bits.
+// newNat inlines, so the allocation can live on the stack.
+func newNat() *nat {
+ limbs := make([]uint, 0, preallocLimbs)
+ return &nat{limbs}
+}
+
// expand expands x to n limbs, leaving its value unchanged.
func (x *nat) expand(n int) *nat {
for len(x.limbs) > n {
return x
}
-// clone returns a new nat, with the same value and announced length as x.
-func (x *nat) clone() *nat {
- out := &nat{make([]uint, len(x.limbs))}
- copy(out.limbs, x.limbs)
- return out
+// set assigns x = y, optionally resizing x to the appropriate size.
+func (x *nat) set(y *nat) *nat {
+ x.reset(len(y.limbs))
+ copy(x.limbs, y.limbs)
+ return x
}
-// natFromBig creates a new natural number from a big.Int.
+// set assigns x = n, optionally resizing n to the appropriate size.
//
-// The announced length of the resulting nat is based on the actual bit size of
-// the input, ignoring leading zeroes.
-func natFromBig(x *big.Int) *nat {
- xLimbs := x.Bits()
- bitSize := bigBitLen(x)
+// The announced length of x is set based on the actual bit size of the input,
+// ignoring leading zeroes.
+func (x *nat) setBig(n *big.Int) *nat {
+ bitSize := bigBitLen(n)
requiredLimbs := (bitSize + _W - 1) / _W
+ x.reset(requiredLimbs)
- out := &nat{make([]uint, requiredLimbs)}
outI := 0
shift := 0
- for i := range xLimbs {
- xi := uint(xLimbs[i])
- out.limbs[outI] |= (xi << shift) & _MASK
+ limbs := n.Bits()
+ for i := range limbs {
+ xi := uint(limbs[i])
+ x.limbs[outI] |= (xi << shift) & _MASK
outI++
if outI == requiredLimbs {
- return out
+ return x
}
- out.limbs[outI] = xi >> (_W - shift)
+ x.limbs[outI] = xi >> (_W - shift)
shift++ // this assumes bits.UintSize - _W = 1
if shift == _W {
shift = 0
outI++
}
}
- return out
+ return x
}
// fillBytes sets bytes to x as a zero-extended big-endian byte slice.
return bytes
}
-// natFromBytes converts a slice of big-endian bytes into a nat.
+// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally
+// resizing n to the appropriate size.
//
-// The announced length of the output depends on the length of bytes. Unlike
+// The announced length of the output depends only on the length of b. Unlike
// big.Int, creating a nat will not remove leading zeros.
-func natFromBytes(bytes []byte) *nat {
- bitSize := len(bytes) * 8
+func (x *nat) setBytes(b []byte) *nat {
+ bitSize := len(b) * 8
requiredLimbs := (bitSize + _W - 1) / _W
+ x.reset(requiredLimbs)
- out := &nat{make([]uint, requiredLimbs)}
outI := 0
shift := 0
- for i := len(bytes) - 1; i >= 0; i-- {
- bi := bytes[i]
- out.limbs[outI] |= uint(bi) << shift
+ for i := len(b) - 1; i >= 0; i-- {
+ bi := b[i]
+ x.limbs[outI] |= uint(bi) << shift
shift += 8
if shift >= _W {
shift -= _W
- out.limbs[outI] &= _MASK
+ x.limbs[outI] &= _MASK
outI++
if shift > 0 {
- out.limbs[outI] = uint(bi) >> (8 - shift)
+ x.limbs[outI] = uint(bi) >> (8 - shift)
}
}
}
- return out
+ return x
}
// cmpEq returns 1 if x == y, and 0 otherwise.
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
func rr(m *modulus) *nat {
- rr := new(nat).expandFor(m)
+ rr := newNat().expandFor(m)
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
// most significant limb to 1. We then get to R*R by shifting left by _W
// n + 1 times.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *nat) shiftIn(y uint, m *modulus) *nat {
- d := new(nat).resetFor(m)
+ d := newNat().resetFor(m)
// Eliminate bounds checks in the loop.
size := len(m.nat.limbs)
func (x *nat) montgomeryRepresentation(m *modulus) *nat {
// A Montgomery multiplication (which computes a * b / R) by R * R works out
// to a multiplication by R, which takes the value out of the Montgomery domain.
- return x.montgomeryMul(x.clone(), m.RR, m)
+ return x.montgomeryMul(newNat().set(x), m.RR, m)
}
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
// By Montgomery multiplying with 1 not in Montgomery representation, we
// convert out back from Montgomery representation, because it works out to
// dividing by R.
- t0 := x.clone()
- t1 := new(nat).expandFor(m)
+ t0 := newNat().set(x)
+ t1 := newNat().expandFor(m)
t1.limbs[0] = 1
return x.montgomeryMul(t0, t1, m)
}
func (x *nat) modMul(y *nat, m *modulus) *nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
- xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m
- return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
+ xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+ return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
// exp calculates out = x^e mod m.
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement.
- table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1)
- table[0] = x.clone().montgomeryRepresentation(m)
+
+ table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1)
+ // newNat calls are unrolled so they are allocated on the stack.
+ newNat(), newNat(), newNat(), newNat(), newNat(),
+ newNat(), newNat(), newNat(), newNat(), newNat(),
+ newNat(), newNat(), newNat(), newNat(), newNat(),
+ }
+ table[0].set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
- table[i] = new(nat).expandFor(m)
table[i].montgomeryMul(table[i-1], table[0], m)
}
out.resetFor(m)
out.limbs[0] = 1
out.montgomeryRepresentation(m)
- t0 := new(nat).expandFor(m)
- t1 := new(nat).expandFor(m)
+ t0 := newNat().expandFor(m)
+ t1 := newNat().expandFor(m)
for _, b := range e {
for _, j := range []int{4, 0} {
// Square four times.
mLimbs[i] = _MASK
}
m := modulusFromNat(&nat{mLimbs})
- aPlusB := a.clone()
+ aPlusB := new(nat).set(a)
aPlusB.modAdd(b, m)
- bPlusA := b.clone()
+ bPlusA := new(nat).set(b)
bPlusA.modAdd(a, m)
return aPlusB.cmpEq(bPlusA) == 1
}
mLimbs[i] = _MASK
}
m := modulusFromNat(&nat{mLimbs})
- original := a.clone()
+ original := new(nat).set(a)
a.modSub(b, m)
a.modAdd(b, m)
return a.cmpEq(original) == 1
func testMontgomeryRoundtrip(a *nat) bool {
one := &nat{make([]uint, len(a.limbs))}
one.limbs[0] = 1
- aPlusOne := a.clone()
+ aPlusOne := new(nat).set(a)
aPlusOne.add(1, one)
m := modulusFromNat(aPlusOne)
- monty := a.clone()
+ monty := new(nat).set(a)
monty.montgomeryRepresentation(m)
- aAgain := monty.clone()
+ aAgain := new(nat).set(monty)
aAgain.montgomeryMul(monty, one, m)
return a.cmpEq(aAgain) == 1
}
func TestFromBig(t *testing.T) {
expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
theBig := new(big.Int).SetBytes(expected)
- actual := natFromBig(theBig).fillBytes(make([]byte, len(expected)))
+ actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected)))
if !bytes.Equal(actual, expected) {
t.Errorf("%+x != %+x", actual, expected)
}
func TestFillBytes(t *testing.T) {
xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
- x := natFromBytes(xBytes)
+ x := new(nat).setBytes(xBytes)
for l := 20; l >= len(xBytes); l-- {
buf := make([]byte, l)
rand.Read(buf)
if len(xBytes) == 0 {
return true
}
- actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
+ actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
if !bytes.Equal(actual, xBytes) {
t.Errorf("%+x != %+x", actual, xBytes)
return false
}}
for i, tt := range examples {
- m := modulusFromNat(natFromBytes(tt.m))
- got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
- if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 {
+ m := modulusFromNat(new(nat).setBytes(tt.m))
+ got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
+ if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 {
t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
}
}
// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
// limbs, if they are not, they fit in three. This can be a problem because
// modulus strips leading zeroes and nat does not.
- m := modulusFromNat(natFromBytes([]byte{
+ m := modulusFromNat(new(nat).setBytes([]byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
- x := natFromBytes([]byte{
+ x := new(nat).setBytes([]byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
x.expandFor(m) // must not panic for shrinking
}
func TestMod(t *testing.T) {
- m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
- x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+ m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
+ x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
out := new(nat)
out.mod(x, m)
- expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+ expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
if out.cmpEq(expected) != 1 {
t.Errorf("%+v != %+v", out, expected)
}
Dq: Dq,
Qinv: Qinv,
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
- n: modulusFromNat(natFromBig(N)),
- p: modulusFromNat(natFromBig(P)),
- q: modulusFromNat(natFromBig(Q)),
+ n: modulusFromNat(newNat().setBig(N)),
+ p: modulusFromNat(newNat().setBig(P)),
+ q: modulusFromNat(newNat().setBig(Q)),
},
}
return key, nil
func encrypt(pub *PublicKey, plaintext []byte) []byte {
boring.Unreachable()
- N := modulusFromNat(natFromBig(pub.N))
- m := natFromBytes(plaintext).expandFor(N)
+ N := modulusFromNat(newNat().setBig(pub.N))
+ m := newNat().setBytes(plaintext).expandFor(N)
e := intToBytes(pub.E)
out := make([]byte, modulusSize(N))
- return new(nat).exp(m, e, N).fillBytes(out)
+ return newNat().exp(m, e, N).fillBytes(out)
}
// intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
// in the future.
func (priv *PrivateKey) Precompute() {
if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
- priv.Precomputed.n = modulusFromNat(natFromBig(priv.N))
- priv.Precomputed.p = modulusFromNat(natFromBig(priv.Primes[0]))
- priv.Precomputed.q = modulusFromNat(natFromBig(priv.Primes[1]))
+ priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N))
+ priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0]))
+ priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1]))
}
// Fill in the backwards-compatibility *big.Int values.
N := priv.Precomputed.n
if N == nil {
- N = modulusFromNat(natFromBig(priv.N))
+ N = modulusFromNat(newNat().setBig(priv.N))
}
- c := natFromBytes(ciphertext).expandFor(N)
+ c := newNat().setBytes(ciphertext).expandFor(N)
if c.cmpGeq(N.nat) == 1 {
return nil, ErrDecryption
}
var m *nat
if priv.Precomputed.n == nil {
- m = new(nat).exp(c, priv.D.Bytes(), N)
+ m = newNat().exp(c, priv.D.Bytes(), N)
} else {
- t0 := new(nat)
+ t0 := newNat()
P, Q := priv.Precomputed.p, priv.Precomputed.q
// m = c ^ Dp mod p
- m = new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
+ m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
// m2 = c ^ Dq mod q
- m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
+ m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
// m = m - m2 mod p
m.modSub(t0.mod(m2, P), P)
// m = m * Qinv mod p
- m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P)
+ m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P)
// m = m * q mod N
m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
// m = m + m2 mod N
}
if check {
- c1 := new(nat).exp(m, intToBytes(priv.E), N)
+ c1 := newNat().exp(m, intToBytes(priv.E), N)
if c1.cmpEq(c) != 1 {
return nil, ErrDecryption
}