From 58a2db181b7cb2d51e462b6ea9c0026bba520055 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 23 Oct 2022 21:48:29 +0200 Subject: [PATCH] crypto/rsa: allocate nats on the stack for RSA 2048 MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit With a small tweak and the help of the inliner, we preallocate enough nat backing space to do RSA-2048 on the stack. We keep the length of the preallocated slices at zero so they don't silently mask missing expandFor calls. Surprisingly enough, this doesn't move the CPU benchmark needle much, but probably reduces GC pressure on larger applications. name old time/op new time/op delta DecryptPKCS1v15/2048-8 1.25ms ± 0% 1.22ms ± 1% -1.68% (p=0.000 n=10+9) DecryptPKCS1v15/3072-8 3.78ms ± 0% 3.73ms ± 1% -1.33% (p=0.000 n=9+10) DecryptPKCS1v15/4096-8 8.62ms ± 0% 8.45ms ± 1% -1.98% (p=0.000 n=8+10) EncryptPKCS1v15/2048-8 140µs ± 1% 136µs ± 0% -2.43% (p=0.000 n=9+9) DecryptOAEP/2048-8 1.25ms ± 0% 1.24ms ± 0% -0.83% (p=0.000 n=8+10) EncryptOAEP/2048-8 140µs ± 0% 137µs ± 0% -1.82% (p=0.000 n=8+10) SignPKCS1v15/2048-8 1.29ms ± 0% 1.29ms ± 1% ~ (p=0.574 n=8+8) VerifyPKCS1v15/2048-8 139µs ± 0% 136µs ± 0% -2.12% (p=0.000 n=9+10) SignPSS/2048-8 1.30ms ± 0% 1.28ms ± 0% -0.96% (p=0.000 n=8+10) VerifyPSS/2048-8 140µs ± 0% 137µs ± 0% -1.99% (p=0.000 n=10+8) name old alloc/op new alloc/op delta DecryptPKCS1v15/2048-8 15.0kB ± 0% 0.5kB ± 0% -96.58% (p=0.000 n=10+10) DecryptPKCS1v15/3072-8 24.6kB ± 0% 3.3kB ± 0% -86.74% (p=0.000 n=10+10) DecryptPKCS1v15/4096-8 38.9kB ± 0% 4.5kB ± 0% -88.50% (p=0.000 n=10+10) EncryptPKCS1v15/2048-8 18.0kB ± 0% 1.2kB ± 0% -93.48% (p=0.000 n=10+10) DecryptOAEP/2048-8 15.2kB ± 0% 0.7kB ± 0% -95.10% (p=0.000 n=10+10) EncryptOAEP/2048-8 18.2kB ± 0% 1.4kB ± 0% -92.29% (p=0.000 n=10+10) SignPKCS1v15/2048-8 21.9kB ± 0% 0.8kB ± 0% -96.50% (p=0.000 n=10+10) VerifyPKCS1v15/2048-8 17.7kB ± 0% 0.9kB ± 0% -94.85% (p=0.000 n=10+10) SignPSS/2048-8 22.3kB ± 0% 1.2kB ± 0% -94.77% (p=0.000 n=10+10) VerifyPSS/2048-8 17.9kB ± 0% 1.1kB ± 0% -93.75% (p=0.000 n=10+10) name old allocs/op new allocs/op delta DecryptPKCS1v15/2048-8 124 ± 0% 3 ± 0% -97.58% (p=0.000 n=10+10) DecryptPKCS1v15/3072-8 140 ± 0% 9 ± 0% -93.57% (p=0.000 n=10+10) DecryptPKCS1v15/4096-8 158 ± 0% 9 ± 0% -94.30% (p=0.000 n=10+10) EncryptPKCS1v15/2048-8 80.0 ± 0% 7.0 ± 0% -91.25% (p=0.000 n=10+10) DecryptOAEP/2048-8 130 ± 0% 9 ± 0% -93.08% (p=0.000 n=10+10) EncryptOAEP/2048-8 86.0 ± 0% 13.0 ± 0% -84.88% (p=0.000 n=10+10) SignPKCS1v15/2048-8 162 ± 0% 4 ± 0% -97.53% (p=0.000 n=10+10) VerifyPKCS1v15/2048-8 79.0 ± 0% 6.0 ± 0% -92.41% (p=0.000 n=10+10) SignPSS/2048-8 167 ± 0% 9 ± 0% -94.61% (p=0.000 n=10+10) VerifyPSS/2048-8 84.0 ± 0% 11.0 ± 0% -86.90% (p=0.000 n=10+10) Change-Id: I511a2f5f6f596bbec68a0a411e83a9d04080d72a Reviewed-on: https://go-review.googlesource.com/c/go/+/445021 Run-TryBot: Filippo Valsorda Reviewed-by: Joedian Reid Reviewed-by: Roland Shoemaker TryBot-Result: Gopher Robot --- src/crypto/rsa/nat.go | 102 ++++++++++++++++++++++--------------- src/crypto/rsa/nat_test.go | 34 ++++++------- src/crypto/rsa/rsa.go | 34 ++++++------- 3 files changed, 95 insertions(+), 75 deletions(-) diff --git a/src/crypto/rsa/nat.go b/src/crypto/rsa/nat.go index 61b4ffec48..5398d10606 100644 --- a/src/crypto/rsa/nat.go +++ b/src/crypto/rsa/nat.go @@ -69,6 +69,20 @@ type nat struct { limbs []uint } +// preallocTarget is the size in bits of the numbers used to implement the most +// common and most performant RSA key size. It's also enough to cover some of +// the operations of key sizes up to 4096. +const preallocTarget = 2048 +const preallocLimbs = (preallocTarget + _W) / _W + +// newNat returns a new nat with a size of zero, just like new(nat), but with +// the preallocated capacity to hold a number of up to preallocTarget bits. +// newNat inlines, so the allocation can live on the stack. +func newNat() *nat { + limbs := make([]uint, 0, preallocLimbs) + return &nat{limbs} +} + // expand expands x to n limbs, leaving its value unchanged. func (x *nat) expand(n int) *nat { for len(x.limbs) > n { @@ -104,40 +118,40 @@ func (x *nat) reset(n int) *nat { return x } -// clone returns a new nat, with the same value and announced length as x. -func (x *nat) clone() *nat { - out := &nat{make([]uint, len(x.limbs))} - copy(out.limbs, x.limbs) - return out +// set assigns x = y, optionally resizing x to the appropriate size. +func (x *nat) set(y *nat) *nat { + x.reset(len(y.limbs)) + copy(x.limbs, y.limbs) + return x } -// natFromBig creates a new natural number from a big.Int. +// set assigns x = n, optionally resizing n to the appropriate size. // -// The announced length of the resulting nat is based on the actual bit size of -// the input, ignoring leading zeroes. -func natFromBig(x *big.Int) *nat { - xLimbs := x.Bits() - bitSize := bigBitLen(x) +// The announced length of x is set based on the actual bit size of the input, +// ignoring leading zeroes. +func (x *nat) setBig(n *big.Int) *nat { + bitSize := bigBitLen(n) requiredLimbs := (bitSize + _W - 1) / _W + x.reset(requiredLimbs) - out := &nat{make([]uint, requiredLimbs)} outI := 0 shift := 0 - for i := range xLimbs { - xi := uint(xLimbs[i]) - out.limbs[outI] |= (xi << shift) & _MASK + limbs := n.Bits() + for i := range limbs { + xi := uint(limbs[i]) + x.limbs[outI] |= (xi << shift) & _MASK outI++ if outI == requiredLimbs { - return out + return x } - out.limbs[outI] = xi >> (_W - shift) + x.limbs[outI] = xi >> (_W - shift) shift++ // this assumes bits.UintSize - _W = 1 if shift == _W { shift = 0 outI++ } } - return out + return x } // fillBytes sets bytes to x as a zero-extended big-endian byte slice. @@ -175,31 +189,32 @@ func (x *nat) fillBytes(bytes []byte) []byte { return bytes } -// natFromBytes converts a slice of big-endian bytes into a nat. +// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally +// resizing n to the appropriate size. // -// The announced length of the output depends on the length of bytes. Unlike +// The announced length of the output depends only on the length of b. Unlike // big.Int, creating a nat will not remove leading zeros. -func natFromBytes(bytes []byte) *nat { - bitSize := len(bytes) * 8 +func (x *nat) setBytes(b []byte) *nat { + bitSize := len(b) * 8 requiredLimbs := (bitSize + _W - 1) / _W + x.reset(requiredLimbs) - out := &nat{make([]uint, requiredLimbs)} outI := 0 shift := 0 - for i := len(bytes) - 1; i >= 0; i-- { - bi := bytes[i] - out.limbs[outI] |= uint(bi) << shift + for i := len(b) - 1; i >= 0; i-- { + bi := b[i] + x.limbs[outI] |= uint(bi) << shift shift += 8 if shift >= _W { shift -= _W - out.limbs[outI] &= _MASK + x.limbs[outI] &= _MASK outI++ if shift > 0 { - out.limbs[outI] = uint(bi) >> (8 - shift) + x.limbs[outI] = uint(bi) >> (8 - shift) } } } - return out + return x } // cmpEq returns 1 if x == y, and 0 otherwise. @@ -306,7 +321,7 @@ type modulus struct { // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). func rr(m *modulus) *nat { - rr := new(nat).expandFor(m) + rr := newNat().expandFor(m) // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the // most significant limb to 1. We then get to R*R by shifting left by _W // n + 1 times. @@ -387,7 +402,7 @@ func modulusSize(m *modulus) int { // // This assumes that x is already reduced mod m, and that y < 2^_W. func (x *nat) shiftIn(y uint, m *modulus) *nat { - d := new(nat).resetFor(m) + d := newNat().resetFor(m) // Eliminate bounds checks in the loop. size := len(m.nat.limbs) @@ -528,7 +543,7 @@ func (x *nat) modAdd(y *nat, m *modulus) *nat { func (x *nat) montgomeryRepresentation(m *modulus) *nat { // A Montgomery multiplication (which computes a * b / R) by R * R works out // to a multiplication by R, which takes the value out of the Montgomery domain. - return x.montgomeryMul(x.clone(), m.RR, m) + return x.montgomeryMul(newNat().set(x), m.RR, m) } // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and @@ -539,8 +554,8 @@ func (x *nat) montgomeryReduction(m *modulus) *nat { // By Montgomery multiplying with 1 not in Montgomery representation, we // convert out back from Montgomery representation, because it works out to // dividing by R. - t0 := x.clone() - t1 := new(nat).expandFor(m) + t0 := newNat().set(x) + t1 := newNat().expandFor(m) t1.limbs[0] = 1 return x.montgomeryMul(t0, t1, m) } @@ -599,8 +614,8 @@ func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat { func (x *nat) modMul(y *nat, m *modulus) *nat { // A Montgomery multiplication by a value out of the Montgomery domain // takes the result out of Montgomery representation. - xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m - return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m + xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m + return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m } // exp calculates out = x^e mod m. @@ -611,18 +626,23 @@ func (out *nat) exp(x *nat, e []byte, m *modulus) *nat { // We use a 4 bit window. For our RSA workload, 4 bit windows are faster // than 2 bit windows, but use an extra 12 nats worth of scratch space. // Using bit sizes that don't divide 8 are more complex to implement. - table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1) - table[0] = x.clone().montgomeryRepresentation(m) + + table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1) + // newNat calls are unrolled so they are allocated on the stack. + newNat(), newNat(), newNat(), newNat(), newNat(), + newNat(), newNat(), newNat(), newNat(), newNat(), + newNat(), newNat(), newNat(), newNat(), newNat(), + } + table[0].set(x).montgomeryRepresentation(m) for i := 1; i < len(table); i++ { - table[i] = new(nat).expandFor(m) table[i].montgomeryMul(table[i-1], table[0], m) } out.resetFor(m) out.limbs[0] = 1 out.montgomeryRepresentation(m) - t0 := new(nat).expandFor(m) - t1 := new(nat).expandFor(m) + t0 := newNat().expandFor(m) + t1 := newNat().expandFor(m) for _, b := range e { for _, j := range []int{4, 0} { // Square four times. diff --git a/src/crypto/rsa/nat_test.go b/src/crypto/rsa/nat_test.go index 3e6eb10f61..d72ba119e3 100644 --- a/src/crypto/rsa/nat_test.go +++ b/src/crypto/rsa/nat_test.go @@ -30,9 +30,9 @@ func testModAddCommutative(a *nat, b *nat) bool { mLimbs[i] = _MASK } m := modulusFromNat(&nat{mLimbs}) - aPlusB := a.clone() + aPlusB := new(nat).set(a) aPlusB.modAdd(b, m) - bPlusA := b.clone() + bPlusA := new(nat).set(b) bPlusA.modAdd(a, m) return aPlusB.cmpEq(bPlusA) == 1 } @@ -50,7 +50,7 @@ func testModSubThenAddIdentity(a *nat, b *nat) bool { mLimbs[i] = _MASK } m := modulusFromNat(&nat{mLimbs}) - original := a.clone() + original := new(nat).set(a) a.modSub(b, m) a.modAdd(b, m) return a.cmpEq(original) == 1 @@ -66,12 +66,12 @@ func TestModSubThenAddIdentity(t *testing.T) { func testMontgomeryRoundtrip(a *nat) bool { one := &nat{make([]uint, len(a.limbs))} one.limbs[0] = 1 - aPlusOne := a.clone() + aPlusOne := new(nat).set(a) aPlusOne.add(1, one) m := modulusFromNat(aPlusOne) - monty := a.clone() + monty := new(nat).set(a) monty.montgomeryRepresentation(m) - aAgain := monty.clone() + aAgain := new(nat).set(monty) aAgain.montgomeryMul(monty, one, m) return a.cmpEq(aAgain) == 1 } @@ -86,7 +86,7 @@ func TestMontgomeryRoundtrip(t *testing.T) { func TestFromBig(t *testing.T) { expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} theBig := new(big.Int).SetBytes(expected) - actual := natFromBig(theBig).fillBytes(make([]byte, len(expected))) + actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected))) if !bytes.Equal(actual, expected) { t.Errorf("%+x != %+x", actual, expected) } @@ -94,7 +94,7 @@ func TestFromBig(t *testing.T) { func TestFillBytes(t *testing.T) { xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - x := natFromBytes(xBytes) + x := new(nat).setBytes(xBytes) for l := 20; l >= len(xBytes); l-- { buf := make([]byte, l) rand.Read(buf) @@ -122,7 +122,7 @@ func TestFromBytes(t *testing.T) { if len(xBytes) == 0 { return true } - actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes))) + actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes))) if !bytes.Equal(actual, xBytes) { t.Errorf("%+x != %+x", actual, xBytes) return false @@ -169,9 +169,9 @@ func TestShiftIn(t *testing.T) { }} for i, tt := range examples { - m := modulusFromNat(natFromBytes(tt.m)) - got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m) - if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 { + m := modulusFromNat(new(nat).setBytes(tt.m)) + got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m) + if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 { t.Errorf("%d: got %x, expected %x", i, got, tt.expected) } } @@ -182,10 +182,10 @@ func TestModulusAndNatSizes(t *testing.T) { // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two // limbs, if they are not, they fit in three. This can be a problem because // modulus strips leading zeroes and nat does not. - m := modulusFromNat(natFromBytes([]byte{ + m := modulusFromNat(new(nat).setBytes([]byte{ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) - x := natFromBytes([]byte{ + x := new(nat).setBytes([]byte{ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) x.expandFor(m) // must not panic for shrinking @@ -224,11 +224,11 @@ func TestExpand(t *testing.T) { } func TestMod(t *testing.T) { - m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})) - x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) + m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})) + x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) out := new(nat) out.mod(x, m) - expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) + expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) if out.cmpEq(expected) != 1 { t.Errorf("%+v != %+v", out, expected) } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index c71d35ab5e..55b2331527 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -314,9 +314,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey Dq: Dq, Qinv: Qinv, CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute - n: modulusFromNat(natFromBig(N)), - p: modulusFromNat(natFromBig(P)), - q: modulusFromNat(natFromBig(Q)), + n: modulusFromNat(newNat().setBig(N)), + p: modulusFromNat(newNat().setBig(P)), + q: modulusFromNat(newNat().setBig(Q)), }, } return key, nil @@ -454,12 +454,12 @@ var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key siz func encrypt(pub *PublicKey, plaintext []byte) []byte { boring.Unreachable() - N := modulusFromNat(natFromBig(pub.N)) - m := natFromBytes(plaintext).expandFor(N) + N := modulusFromNat(newNat().setBig(pub.N)) + m := newNat().setBytes(plaintext).expandFor(N) e := intToBytes(pub.E) out := make([]byte, modulusSize(N)) - return new(nat).exp(m, e, N).fillBytes(out) + return newNat().exp(m, e, N).fillBytes(out) } // intToBytes returns i as a big-endian slice of bytes with no leading zeroes, @@ -553,9 +553,9 @@ var ErrVerification = errors.New("crypto/rsa: verification error") // in the future. func (priv *PrivateKey) Precompute() { if priv.Precomputed.n == nil && len(priv.Primes) == 2 { - priv.Precomputed.n = modulusFromNat(natFromBig(priv.N)) - priv.Precomputed.p = modulusFromNat(natFromBig(priv.Primes[0])) - priv.Precomputed.q = modulusFromNat(natFromBig(priv.Primes[1])) + priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N)) + priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0])) + priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1])) } // Fill in the backwards-compatibility *big.Int values. @@ -600,9 +600,9 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { N := priv.Precomputed.n if N == nil { - N = modulusFromNat(natFromBig(priv.N)) + N = modulusFromNat(newNat().setBig(priv.N)) } - c := natFromBytes(ciphertext).expandFor(N) + c := newNat().setBytes(ciphertext).expandFor(N) if c.cmpGeq(N.nat) == 1 { return nil, ErrDecryption } @@ -612,18 +612,18 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { var m *nat if priv.Precomputed.n == nil { - m = new(nat).exp(c, priv.D.Bytes(), N) + m = newNat().exp(c, priv.D.Bytes(), N) } else { - t0 := new(nat) + t0 := newNat() P, Q := priv.Precomputed.p, priv.Precomputed.q // m = c ^ Dp mod p - m = new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) + m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) // m2 = c ^ Dq mod q - m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) + m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) // m = m - m2 mod p m.modSub(t0.mod(m2, P), P) // m = m * Qinv mod p - m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P) + m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P) // m = m * q mod N m.expandFor(N).modMul(t0.mod(Q.nat, N), N) // m = m + m2 mod N @@ -631,7 +631,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { } if check { - c1 := new(nat).exp(m, intToBytes(priv.E), N) + c1 := newNat().exp(m, intToBytes(priv.E), N) if c1.cmpEq(c) != 1 { return nil, ErrDecryption } -- 2.50.0