]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rsa: allocate nats on the stack for RSA 2048
authorFilippo Valsorda <filippo@golang.org>
Sun, 23 Oct 2022 19:48:29 +0000 (21:48 +0200)
committerFilippo Valsorda <filippo@golang.org>
Sat, 19 Nov 2022 16:50:07 +0000 (16:50 +0000)
With a small tweak and the help of the inliner, we preallocate enough
nat backing space to do RSA-2048 on the stack.

We keep the length of the preallocated slices at zero so they don't
silently mask missing expandFor calls.

Surprisingly enough, this doesn't move the CPU benchmark needle much,
but probably reduces GC pressure on larger applications.

name                    old time/op    new time/op    delta
DecryptPKCS1v15/2048-8    1.25ms ± 0%    1.22ms ± 1%   -1.68%  (p=0.000 n=10+9)
DecryptPKCS1v15/3072-8    3.78ms ± 0%    3.73ms ± 1%   -1.33%  (p=0.000 n=9+10)
DecryptPKCS1v15/4096-8    8.62ms ± 0%    8.45ms ± 1%   -1.98%  (p=0.000 n=8+10)
EncryptPKCS1v15/2048-8     140µs ± 1%     136µs ± 0%   -2.43%  (p=0.000 n=9+9)
DecryptOAEP/2048-8        1.25ms ± 0%    1.24ms ± 0%   -0.83%  (p=0.000 n=8+10)
EncryptOAEP/2048-8         140µs ± 0%     137µs ± 0%   -1.82%  (p=0.000 n=8+10)
SignPKCS1v15/2048-8       1.29ms ± 0%    1.29ms ± 1%     ~     (p=0.574 n=8+8)
VerifyPKCS1v15/2048-8      139µs ± 0%     136µs ± 0%   -2.12%  (p=0.000 n=9+10)
SignPSS/2048-8            1.30ms ± 0%    1.28ms ± 0%   -0.96%  (p=0.000 n=8+10)
VerifyPSS/2048-8           140µs ± 0%     137µs ± 0%   -1.99%  (p=0.000 n=10+8)

name                    old alloc/op   new alloc/op   delta
DecryptPKCS1v15/2048-8    15.0kB ± 0%     0.5kB ± 0%  -96.58%  (p=0.000 n=10+10)
DecryptPKCS1v15/3072-8    24.6kB ± 0%     3.3kB ± 0%  -86.74%  (p=0.000 n=10+10)
DecryptPKCS1v15/4096-8    38.9kB ± 0%     4.5kB ± 0%  -88.50%  (p=0.000 n=10+10)
EncryptPKCS1v15/2048-8    18.0kB ± 0%     1.2kB ± 0%  -93.48%  (p=0.000 n=10+10)
DecryptOAEP/2048-8        15.2kB ± 0%     0.7kB ± 0%  -95.10%  (p=0.000 n=10+10)
EncryptOAEP/2048-8        18.2kB ± 0%     1.4kB ± 0%  -92.29%  (p=0.000 n=10+10)
SignPKCS1v15/2048-8       21.9kB ± 0%     0.8kB ± 0%  -96.50%  (p=0.000 n=10+10)
VerifyPKCS1v15/2048-8     17.7kB ± 0%     0.9kB ± 0%  -94.85%  (p=0.000 n=10+10)
SignPSS/2048-8            22.3kB ± 0%     1.2kB ± 0%  -94.77%  (p=0.000 n=10+10)
VerifyPSS/2048-8          17.9kB ± 0%     1.1kB ± 0%  -93.75%  (p=0.000 n=10+10)

name                    old allocs/op  new allocs/op  delta
DecryptPKCS1v15/2048-8       124 ± 0%         3 ± 0%  -97.58%  (p=0.000 n=10+10)
DecryptPKCS1v15/3072-8       140 ± 0%         9 ± 0%  -93.57%  (p=0.000 n=10+10)
DecryptPKCS1v15/4096-8       158 ± 0%         9 ± 0%  -94.30%  (p=0.000 n=10+10)
EncryptPKCS1v15/2048-8      80.0 ± 0%       7.0 ± 0%  -91.25%  (p=0.000 n=10+10)
DecryptOAEP/2048-8           130 ± 0%         9 ± 0%  -93.08%  (p=0.000 n=10+10)
EncryptOAEP/2048-8          86.0 ± 0%      13.0 ± 0%  -84.88%  (p=0.000 n=10+10)
SignPKCS1v15/2048-8          162 ± 0%         4 ± 0%  -97.53%  (p=0.000 n=10+10)
VerifyPKCS1v15/2048-8       79.0 ± 0%       6.0 ± 0%  -92.41%  (p=0.000 n=10+10)
SignPSS/2048-8               167 ± 0%         9 ± 0%  -94.61%  (p=0.000 n=10+10)
VerifyPSS/2048-8            84.0 ± 0%      11.0 ± 0%  -86.90%  (p=0.000 n=10+10)

Change-Id: I511a2f5f6f596bbec68a0a411e83a9d04080d72a
Reviewed-on: https://go-review.googlesource.com/c/go/+/445021
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Joedian Reid <joedian@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/crypto/rsa/nat.go
src/crypto/rsa/nat_test.go
src/crypto/rsa/rsa.go

index 61b4ffec488096c4aaee515d71226597e0101a9b..5398d10606760102fab38a06762cf1dbef6da7a5 100644 (file)
@@ -69,6 +69,20 @@ type nat struct {
        limbs []uint
 }
 
+// preallocTarget is the size in bits of the numbers used to implement the most
+// common and most performant RSA key size. It's also enough to cover some of
+// the operations of key sizes up to 4096.
+const preallocTarget = 2048
+const preallocLimbs = (preallocTarget + _W) / _W
+
+// newNat returns a new nat with a size of zero, just like new(nat), but with
+// the preallocated capacity to hold a number of up to preallocTarget bits.
+// newNat inlines, so the allocation can live on the stack.
+func newNat() *nat {
+       limbs := make([]uint, 0, preallocLimbs)
+       return &nat{limbs}
+}
+
 // expand expands x to n limbs, leaving its value unchanged.
 func (x *nat) expand(n int) *nat {
        for len(x.limbs) > n {
@@ -104,40 +118,40 @@ func (x *nat) reset(n int) *nat {
        return x
 }
 
-// clone returns a new nat, with the same value and announced length as x.
-func (x *nat) clone() *nat {
-       out := &nat{make([]uint, len(x.limbs))}
-       copy(out.limbs, x.limbs)
-       return out
+// set assigns x = y, optionally resizing x to the appropriate size.
+func (x *nat) set(y *nat) *nat {
+       x.reset(len(y.limbs))
+       copy(x.limbs, y.limbs)
+       return x
 }
 
-// natFromBig creates a new natural number from a big.Int.
+// set assigns x = n, optionally resizing n to the appropriate size.
 //
-// The announced length of the resulting nat is based on the actual bit size of
-// the input, ignoring leading zeroes.
-func natFromBig(x *big.Int) *nat {
-       xLimbs := x.Bits()
-       bitSize := bigBitLen(x)
+// The announced length of x is set based on the actual bit size of the input,
+// ignoring leading zeroes.
+func (x *nat) setBig(n *big.Int) *nat {
+       bitSize := bigBitLen(n)
        requiredLimbs := (bitSize + _W - 1) / _W
+       x.reset(requiredLimbs)
 
-       out := &nat{make([]uint, requiredLimbs)}
        outI := 0
        shift := 0
-       for i := range xLimbs {
-               xi := uint(xLimbs[i])
-               out.limbs[outI] |= (xi << shift) & _MASK
+       limbs := n.Bits()
+       for i := range limbs {
+               xi := uint(limbs[i])
+               x.limbs[outI] |= (xi << shift) & _MASK
                outI++
                if outI == requiredLimbs {
-                       return out
+                       return x
                }
-               out.limbs[outI] = xi >> (_W - shift)
+               x.limbs[outI] = xi >> (_W - shift)
                shift++ // this assumes bits.UintSize - _W = 1
                if shift == _W {
                        shift = 0
                        outI++
                }
        }
-       return out
+       return x
 }
 
 // fillBytes sets bytes to x as a zero-extended big-endian byte slice.
@@ -175,31 +189,32 @@ func (x *nat) fillBytes(bytes []byte) []byte {
        return bytes
 }
 
-// natFromBytes converts a slice of big-endian bytes into a nat.
+// setBytes assigns x = b, where b is a slice of big-endian bytes, optionally
+// resizing n to the appropriate size.
 //
-// The announced length of the output depends on the length of bytes. Unlike
+// The announced length of the output depends only on the length of b. Unlike
 // big.Int, creating a nat will not remove leading zeros.
-func natFromBytes(bytes []byte) *nat {
-       bitSize := len(bytes) * 8
+func (x *nat) setBytes(b []byte) *nat {
+       bitSize := len(b) * 8
        requiredLimbs := (bitSize + _W - 1) / _W
+       x.reset(requiredLimbs)
 
-       out := &nat{make([]uint, requiredLimbs)}
        outI := 0
        shift := 0
-       for i := len(bytes) - 1; i >= 0; i-- {
-               bi := bytes[i]
-               out.limbs[outI] |= uint(bi) << shift
+       for i := len(b) - 1; i >= 0; i-- {
+               bi := b[i]
+               x.limbs[outI] |= uint(bi) << shift
                shift += 8
                if shift >= _W {
                        shift -= _W
-                       out.limbs[outI] &= _MASK
+                       x.limbs[outI] &= _MASK
                        outI++
                        if shift > 0 {
-                               out.limbs[outI] = uint(bi) >> (8 - shift)
+                               x.limbs[outI] = uint(bi) >> (8 - shift)
                        }
                }
        }
-       return out
+       return x
 }
 
 // cmpEq returns 1 if x == y, and 0 otherwise.
@@ -306,7 +321,7 @@ type modulus struct {
 
 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
 func rr(m *modulus) *nat {
-       rr := new(nat).expandFor(m)
+       rr := newNat().expandFor(m)
        // R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
        // most significant limb to 1. We then get to R*R by shifting left by _W
        // n + 1 times.
@@ -387,7 +402,7 @@ func modulusSize(m *modulus) int {
 //
 // This assumes that x is already reduced mod m, and that y < 2^_W.
 func (x *nat) shiftIn(y uint, m *modulus) *nat {
-       d := new(nat).resetFor(m)
+       d := newNat().resetFor(m)
 
        // Eliminate bounds checks in the loop.
        size := len(m.nat.limbs)
@@ -528,7 +543,7 @@ func (x *nat) modAdd(y *nat, m *modulus) *nat {
 func (x *nat) montgomeryRepresentation(m *modulus) *nat {
        // A Montgomery multiplication (which computes a * b / R) by R * R works out
        // to a multiplication by R, which takes the value out of the Montgomery domain.
-       return x.montgomeryMul(x.clone(), m.RR, m)
+       return x.montgomeryMul(newNat().set(x), m.RR, m)
 }
 
 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
@@ -539,8 +554,8 @@ func (x *nat) montgomeryReduction(m *modulus) *nat {
        // By Montgomery multiplying with 1 not in Montgomery representation, we
        // convert out back from Montgomery representation, because it works out to
        // dividing by R.
-       t0 := x.clone()
-       t1 := new(nat).expandFor(m)
+       t0 := newNat().set(x)
+       t1 := newNat().expandFor(m)
        t1.limbs[0] = 1
        return x.montgomeryMul(t0, t1, m)
 }
@@ -599,8 +614,8 @@ func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat {
 func (x *nat) modMul(y *nat, m *modulus) *nat {
        // A Montgomery multiplication by a value out of the Montgomery domain
        // takes the result out of Montgomery representation.
-       xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m
-       return x.montgomeryMul(xR, y, m)            // x = xR * y / R mod m
+       xR := newNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
+       return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
 }
 
 // exp calculates out = x^e mod m.
@@ -611,18 +626,23 @@ func (out *nat) exp(x *nat, e []byte, m *modulus) *nat {
        // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
        // than 2 bit windows, but use an extra 12 nats worth of scratch space.
        // Using bit sizes that don't divide 8 are more complex to implement.
-       table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1)
-       table[0] = x.clone().montgomeryRepresentation(m)
+
+       table := [(1 << 4) - 1]*nat{ // table[i] = x ^ (i+1)
+               // newNat calls are unrolled so they are allocated on the stack.
+               newNat(), newNat(), newNat(), newNat(), newNat(),
+               newNat(), newNat(), newNat(), newNat(), newNat(),
+               newNat(), newNat(), newNat(), newNat(), newNat(),
+       }
+       table[0].set(x).montgomeryRepresentation(m)
        for i := 1; i < len(table); i++ {
-               table[i] = new(nat).expandFor(m)
                table[i].montgomeryMul(table[i-1], table[0], m)
        }
 
        out.resetFor(m)
        out.limbs[0] = 1
        out.montgomeryRepresentation(m)
-       t0 := new(nat).expandFor(m)
-       t1 := new(nat).expandFor(m)
+       t0 := newNat().expandFor(m)
+       t1 := newNat().expandFor(m)
        for _, b := range e {
                for _, j := range []int{4, 0} {
                        // Square four times.
index 3e6eb10f61a7fa1844323806cab9175bd8c70a65..d72ba119e3c0f626498927f6b6d533e88d569d0a 100644 (file)
@@ -30,9 +30,9 @@ func testModAddCommutative(a *nat, b *nat) bool {
                mLimbs[i] = _MASK
        }
        m := modulusFromNat(&nat{mLimbs})
-       aPlusB := a.clone()
+       aPlusB := new(nat).set(a)
        aPlusB.modAdd(b, m)
-       bPlusA := b.clone()
+       bPlusA := new(nat).set(b)
        bPlusA.modAdd(a, m)
        return aPlusB.cmpEq(bPlusA) == 1
 }
@@ -50,7 +50,7 @@ func testModSubThenAddIdentity(a *nat, b *nat) bool {
                mLimbs[i] = _MASK
        }
        m := modulusFromNat(&nat{mLimbs})
-       original := a.clone()
+       original := new(nat).set(a)
        a.modSub(b, m)
        a.modAdd(b, m)
        return a.cmpEq(original) == 1
@@ -66,12 +66,12 @@ func TestModSubThenAddIdentity(t *testing.T) {
 func testMontgomeryRoundtrip(a *nat) bool {
        one := &nat{make([]uint, len(a.limbs))}
        one.limbs[0] = 1
-       aPlusOne := a.clone()
+       aPlusOne := new(nat).set(a)
        aPlusOne.add(1, one)
        m := modulusFromNat(aPlusOne)
-       monty := a.clone()
+       monty := new(nat).set(a)
        monty.montgomeryRepresentation(m)
-       aAgain := monty.clone()
+       aAgain := new(nat).set(monty)
        aAgain.montgomeryMul(monty, one, m)
        return a.cmpEq(aAgain) == 1
 }
@@ -86,7 +86,7 @@ func TestMontgomeryRoundtrip(t *testing.T) {
 func TestFromBig(t *testing.T) {
        expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
        theBig := new(big.Int).SetBytes(expected)
-       actual := natFromBig(theBig).fillBytes(make([]byte, len(expected)))
+       actual := new(nat).setBig(theBig).fillBytes(make([]byte, len(expected)))
        if !bytes.Equal(actual, expected) {
                t.Errorf("%+x != %+x", actual, expected)
        }
@@ -94,7 +94,7 @@ func TestFromBig(t *testing.T) {
 
 func TestFillBytes(t *testing.T) {
        xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
-       x := natFromBytes(xBytes)
+       x := new(nat).setBytes(xBytes)
        for l := 20; l >= len(xBytes); l-- {
                buf := make([]byte, l)
                rand.Read(buf)
@@ -122,7 +122,7 @@ func TestFromBytes(t *testing.T) {
                if len(xBytes) == 0 {
                        return true
                }
-               actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
+               actual := new(nat).setBytes(xBytes).fillBytes(make([]byte, len(xBytes)))
                if !bytes.Equal(actual, xBytes) {
                        t.Errorf("%+x != %+x", actual, xBytes)
                        return false
@@ -169,9 +169,9 @@ func TestShiftIn(t *testing.T) {
        }}
 
        for i, tt := range examples {
-               m := modulusFromNat(natFromBytes(tt.m))
-               got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
-               if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 {
+               m := modulusFromNat(new(nat).setBytes(tt.m))
+               got := new(nat).setBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m)
+               if got.cmpEq(new(nat).setBytes(tt.expected).expandFor(m)) != 1 {
                        t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
                }
        }
@@ -182,10 +182,10 @@ func TestModulusAndNatSizes(t *testing.T) {
        // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
        // limbs, if they are not, they fit in three. This can be a problem because
        // modulus strips leading zeroes and nat does not.
-       m := modulusFromNat(natFromBytes([]byte{
+       m := modulusFromNat(new(nat).setBytes([]byte{
                0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
                0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}))
-       x := natFromBytes([]byte{
+       x := new(nat).setBytes([]byte{
                0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
                0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe})
        x.expandFor(m) // must not panic for shrinking
@@ -224,11 +224,11 @@ func TestExpand(t *testing.T) {
 }
 
 func TestMod(t *testing.T) {
-       m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
-       x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
+       m := modulusFromNat(new(nat).setBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}))
+       x := new(nat).setBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
        out := new(nat)
        out.mod(x, m)
-       expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
+       expected := new(nat).setBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
        if out.cmpEq(expected) != 1 {
                t.Errorf("%+v != %+v", out, expected)
        }
index c71d35ab5eaf523a53ba5a7edcefbedd4b933709..55b233152705c56ca19055e0c011eac43832a388 100644 (file)
@@ -314,9 +314,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey
                                Dq:        Dq,
                                Qinv:      Qinv,
                                CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
-                               n:         modulusFromNat(natFromBig(N)),
-                               p:         modulusFromNat(natFromBig(P)),
-                               q:         modulusFromNat(natFromBig(Q)),
+                               n:         modulusFromNat(newNat().setBig(N)),
+                               p:         modulusFromNat(newNat().setBig(P)),
+                               q:         modulusFromNat(newNat().setBig(Q)),
                        },
                }
                return key, nil
@@ -454,12 +454,12 @@ var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key siz
 func encrypt(pub *PublicKey, plaintext []byte) []byte {
        boring.Unreachable()
 
-       N := modulusFromNat(natFromBig(pub.N))
-       m := natFromBytes(plaintext).expandFor(N)
+       N := modulusFromNat(newNat().setBig(pub.N))
+       m := newNat().setBytes(plaintext).expandFor(N)
        e := intToBytes(pub.E)
 
        out := make([]byte, modulusSize(N))
-       return new(nat).exp(m, e, N).fillBytes(out)
+       return newNat().exp(m, e, N).fillBytes(out)
 }
 
 // intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
@@ -553,9 +553,9 @@ var ErrVerification = errors.New("crypto/rsa: verification error")
 // in the future.
 func (priv *PrivateKey) Precompute() {
        if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
-               priv.Precomputed.n = modulusFromNat(natFromBig(priv.N))
-               priv.Precomputed.p = modulusFromNat(natFromBig(priv.Primes[0]))
-               priv.Precomputed.q = modulusFromNat(natFromBig(priv.Primes[1]))
+               priv.Precomputed.n = modulusFromNat(newNat().setBig(priv.N))
+               priv.Precomputed.p = modulusFromNat(newNat().setBig(priv.Primes[0]))
+               priv.Precomputed.q = modulusFromNat(newNat().setBig(priv.Primes[1]))
        }
 
        // Fill in the backwards-compatibility *big.Int values.
@@ -600,9 +600,9 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
 
        N := priv.Precomputed.n
        if N == nil {
-               N = modulusFromNat(natFromBig(priv.N))
+               N = modulusFromNat(newNat().setBig(priv.N))
        }
-       c := natFromBytes(ciphertext).expandFor(N)
+       c := newNat().setBytes(ciphertext).expandFor(N)
        if c.cmpGeq(N.nat) == 1 {
                return nil, ErrDecryption
        }
@@ -612,18 +612,18 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
 
        var m *nat
        if priv.Precomputed.n == nil {
-               m = new(nat).exp(c, priv.D.Bytes(), N)
+               m = newNat().exp(c, priv.D.Bytes(), N)
        } else {
-               t0 := new(nat)
+               t0 := newNat()
                P, Q := priv.Precomputed.p, priv.Precomputed.q
                // m = c ^ Dp mod p
-               m = new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
+               m = newNat().exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P)
                // m2 = c ^ Dq mod q
-               m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
+               m2 := newNat().exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
                // m = m - m2 mod p
                m.modSub(t0.mod(m2, P), P)
                // m = m * Qinv mod p
-               m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P)
+               m.modMul(newNat().setBig(priv.Precomputed.Qinv).expandFor(P), P)
                // m = m * q mod N
                m.expandFor(N).modMul(t0.mod(Q.nat, N), N)
                // m = m + m2 mod N
@@ -631,7 +631,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
        }
 
        if check {
-               c1 := new(nat).exp(m, intToBytes(priv.E), N)
+               c1 := newNat().exp(m, intToBytes(priv.E), N)
                if c1.cmpEq(c) != 1 {
                        return nil, ErrDecryption
                }