]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.boringcrypto] crypto/internal/boring: make accesses to RSA types with finalizers...
authorFilippo Valsorda <filippo@golang.org>
Mon, 2 Mar 2020 20:52:56 +0000 (15:52 -0500)
committerFilippo Valsorda <filippo@golang.org>
Tue, 3 Mar 2020 20:42:07 +0000 (20:42 +0000)
RSA key types have a finalizer that will free the underlying C value
when the Go one is garbage collected. It's important that the finalizer
doesn't run while a cgo call is using the underlying C value, so they
require runtime.KeepAlive calls after each use.

This is error prone, so replace it with a closure that provides access
to the underlying C value and then automatically calls KeepAlive.

AES, HMAC, and ECDSA also need KeepAlives, but they have much fewer call
sites, so avoid the complexity for now.

Change-Id: I6d6f38297cd1cf384a1639974d9739a939cbdbcc
Reviewed-on: https://go-review.googlesource.com/c/go/+/221822
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
src/crypto/internal/boring/rsa.go

index 1ebf6044ba0949d12a030adf5180f07a3c549ece..9f4f53e0d85c67db30546211b49a613f2b23d7ac 100644 (file)
@@ -45,7 +45,8 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)
 }
 
 type PublicKeyRSA struct {
-       key *C.GO_RSA
+       // _key MUST NOT be accessed directly. Instead, use the withKey method.
+       _key *C.GO_RSA
 }
 
 func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
@@ -57,21 +58,26 @@ func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
                !bigToBn(&key.e, E) {
                return nil, fail("BN_bin2bn")
        }
-       k := &PublicKeyRSA{key: key}
-       // Note: Because of the finalizer, any time k.key is passed to cgo,
-       // that call must be followed by a call to runtime.KeepAlive(k),
-       // to make sure k is not collected (and finalized) before the cgo
-       // call returns.
+       k := &PublicKeyRSA{_key: key}
        runtime.SetFinalizer(k, (*PublicKeyRSA).finalize)
        return k, nil
 }
 
 func (k *PublicKeyRSA) finalize() {
-       C._goboringcrypto_RSA_free(k.key)
+       C._goboringcrypto_RSA_free(k._key)
+}
+
+func (k *PublicKeyRSA) withKey(f func(*C.GO_RSA) C.int) C.int {
+       // Because of the finalizer, any time _key is passed to cgo, that call must
+       // be followed by a call to runtime.KeepAlive, to make sure k is not
+       // collected (and finalized) before the cgo call returns.
+       defer runtime.KeepAlive(k)
+       return f(k._key)
 }
 
 type PrivateKeyRSA struct {
-       key *C.GO_RSA
+       // _key MUST NOT be accessed directly. Instead, use the withKey method.
+       _key *C.GO_RSA
 }
 
 func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
@@ -89,20 +95,24 @@ func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, err
                !bigToBn(&key.iqmp, Qinv) {
                return nil, fail("BN_bin2bn")
        }
-       k := &PrivateKeyRSA{key: key}
-       // Note: Because of the finalizer, any time k.key is passed to cgo,
-       // that call must be followed by a call to runtime.KeepAlive(k),
-       // to make sure k is not collected (and finalized) before the cgo
-       // call returns.
+       k := &PrivateKeyRSA{_key: key}
        runtime.SetFinalizer(k, (*PrivateKeyRSA).finalize)
        return k, nil
 }
 
 func (k *PrivateKeyRSA) finalize() {
-       C._goboringcrypto_RSA_free(k.key)
+       C._goboringcrypto_RSA_free(k._key)
+}
+
+func (k *PrivateKeyRSA) withKey(f func(*C.GO_RSA) C.int) C.int {
+       // Because of the finalizer, any time _key is passed to cgo, that call must
+       // be followed by a call to runtime.KeepAlive, to make sure k is not
+       // collected (and finalized) before the cgo call returns.
+       defer runtime.KeepAlive(k)
+       return f(k._key)
 }
 
-func setupRSA(gokey interface{}, key *C.GO_RSA,
+func setupRSA(withKey func(func(*C.GO_RSA) C.int) C.int,
        padding C.int, h hash.Hash, label []byte, saltLen int, ch crypto.Hash,
        init func(*C.GO_EVP_PKEY_CTX) C.int) (pkey *C.GO_EVP_PKEY, ctx *C.GO_EVP_PKEY_CTX, err error) {
        defer func() {
@@ -122,12 +132,11 @@ func setupRSA(gokey interface{}, key *C.GO_RSA,
        if pkey == nil {
                return nil, nil, fail("EVP_PKEY_new")
        }
-       if C._goboringcrypto_EVP_PKEY_set1_RSA(pkey, key) == 0 {
+       if withKey(func(key *C.GO_RSA) C.int {
+               return C._goboringcrypto_EVP_PKEY_set1_RSA(pkey, key)
+       }) == 0 {
                return nil, nil, fail("EVP_PKEY_set1_RSA")
        }
-       // key is freed by the finalizer on gokey, which is a PrivateKeyRSA or a
-       // PublicKeyRSA. Ensure it doesn't run until after the cgo calls that use key.
-       runtime.KeepAlive(gokey)
        ctx = C._goboringcrypto_EVP_PKEY_CTX_new(pkey, nil)
        if ctx == nil {
                return nil, nil, fail("EVP_PKEY_CTX_new")
@@ -174,13 +183,13 @@ func setupRSA(gokey interface{}, key *C.GO_RSA,
        return pkey, ctx, nil
 }
 
-func cryptRSA(gokey interface{}, key *C.GO_RSA,
+func cryptRSA(withKey func(func(*C.GO_RSA) C.int) C.int,
        padding C.int, h hash.Hash, label []byte, saltLen int, ch crypto.Hash,
        init func(*C.GO_EVP_PKEY_CTX) C.int,
        crypt func(*C.GO_EVP_PKEY_CTX, *C.uint8_t, *C.size_t, *C.uint8_t, C.size_t) C.int,
        in []byte) ([]byte, error) {
 
-       pkey, ctx, err := setupRSA(gokey, key, padding, h, label, saltLen, ch, init)
+       pkey, ctx, err := setupRSA(withKey, padding, h, label, saltLen, ch, init)
        if err != nil {
                return nil, err
        }
@@ -199,27 +208,27 @@ func cryptRSA(gokey interface{}, key *C.GO_RSA,
 }
 
 func DecryptRSAOAEP(h hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
-       return cryptRSA(priv, priv.key, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, 0, 0, decryptInit, decrypt, ciphertext)
+       return cryptRSA(priv.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, 0, 0, decryptInit, decrypt, ciphertext)
 }
 
 func EncryptRSAOAEP(h hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
-       return cryptRSA(pub, pub.key, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, 0, 0, encryptInit, encrypt, msg)
+       return cryptRSA(pub.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, 0, 0, encryptInit, encrypt, msg)
 }
 
 func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
-       return cryptRSA(priv, priv.key, C.GO_RSA_PKCS1_PADDING, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
+       return cryptRSA(priv.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
 }
 
 func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
-       return cryptRSA(pub, pub.key, C.GO_RSA_PKCS1_PADDING, nil, nil, 0, 0, encryptInit, encrypt, msg)
+       return cryptRSA(pub.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, 0, 0, encryptInit, encrypt, msg)
 }
 
 func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
-       return cryptRSA(priv, priv.key, C.GO_RSA_NO_PADDING, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
+       return cryptRSA(priv.withKey, C.GO_RSA_NO_PADDING, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
 }
 
 func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
-       return cryptRSA(pub, pub.key, C.GO_RSA_NO_PADDING, nil, nil, 0, 0, encryptInit, encrypt, msg)
+       return cryptRSA(pub.withKey, C.GO_RSA_NO_PADDING, nil, nil, 0, 0, encryptInit, encrypt, msg)
 }
 
 // These dumb wrappers work around the fact that cgo functions cannot be used as values directly.
@@ -248,12 +257,15 @@ func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int)
        if saltLen == 0 {
                saltLen = -1
        }
-       out := make([]byte, C._goboringcrypto_RSA_size(priv.key))
+       var out []byte
        var outLen C.size_t
-       if C._goboringcrypto_RSA_sign_pss_mgf1(priv.key, &outLen, base(out), C.size_t(len(out)), base(hashed), C.size_t(len(hashed)), md, nil, C.int(saltLen)) == 0 {
+       if priv.withKey(func(key *C.GO_RSA) C.int {
+               out = make([]byte, C._goboringcrypto_RSA_size(key))
+               return C._goboringcrypto_RSA_sign_pss_mgf1(key, &outLen, base(out), C.size_t(len(out)),
+                       base(hashed), C.size_t(len(hashed)), md, nil, C.int(saltLen))
+       }) == 0 {
                return nil, fail("RSA_sign_pss_mgf1")
        }
-       runtime.KeepAlive(priv)
 
        return out[:outLen], nil
 }
@@ -266,22 +278,27 @@ func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen
        if saltLen == 0 {
                saltLen = -2 // auto-recover
        }
-       if C._goboringcrypto_RSA_verify_pss_mgf1(pub.key, base(hashed), C.size_t(len(hashed)), md, nil, C.int(saltLen), base(sig), C.size_t(len(sig))) == 0 {
+       if pub.withKey(func(key *C.GO_RSA) C.int {
+               return C._goboringcrypto_RSA_verify_pss_mgf1(key, base(hashed), C.size_t(len(hashed)),
+                       md, nil, C.int(saltLen), base(sig), C.size_t(len(sig)))
+       }) == 0 {
                return fail("RSA_verify_pss_mgf1")
        }
-       runtime.KeepAlive(pub)
        return nil
 }
 
 func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) {
-       out := make([]byte, C._goboringcrypto_RSA_size(priv.key))
        if h == 0 {
                // No hashing.
+               var out []byte
                var outLen C.size_t
-               if C._goboringcrypto_RSA_sign_raw(priv.key, &outLen, base(out), C.size_t(len(out)), base(hashed), C.size_t(len(hashed)), C.GO_RSA_PKCS1_PADDING) == 0 {
+               if priv.withKey(func(key *C.GO_RSA) C.int {
+                       out = make([]byte, C._goboringcrypto_RSA_size(key))
+                       return C._goboringcrypto_RSA_sign_raw(key, &outLen, base(out), C.size_t(len(out)),
+                               base(hashed), C.size_t(len(hashed)), C.GO_RSA_PKCS1_PADDING)
+               }) == 0 {
                        return nil, fail("RSA_sign_raw")
                }
-               runtime.KeepAlive(priv)
                return out[:outLen], nil
        }
 
@@ -290,16 +307,22 @@ func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte,
                return nil, errors.New("crypto/rsa: unsupported hash function: " + strconv.Itoa(int(h)))
        }
        nid := C._goboringcrypto_EVP_MD_type(md)
+       var out []byte
        var outLen C.uint
-       if C._goboringcrypto_RSA_sign(nid, base(hashed), C.uint(len(hashed)), base(out), &outLen, priv.key) == 0 {
+       if priv.withKey(func(key *C.GO_RSA) C.int {
+               out = make([]byte, C._goboringcrypto_RSA_size(key))
+               return C._goboringcrypto_RSA_sign(nid, base(hashed), C.uint(len(hashed)),
+                       base(out), &outLen, key)
+       }) == 0 {
                return nil, fail("RSA_sign")
        }
-       runtime.KeepAlive(priv)
        return out[:outLen], nil
 }
 
 func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error {
-       size := int(C._goboringcrypto_RSA_size(pub.key))
+       size := int(pub.withKey(func(key *C.GO_RSA) C.int {
+               return C.int(C._goboringcrypto_RSA_size(key))
+       }))
        if len(sig) < size {
                // BoringCrypto requires sig to be same size as RSA key, so pad with leading zeros.
                zsig := make([]byte, size)
@@ -309,13 +332,15 @@ func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) err
        if h == 0 {
                var outLen C.size_t
                out := make([]byte, size)
-               if C._goboringcrypto_RSA_verify_raw(pub.key, &outLen, base(out), C.size_t(len(out)), base(sig), C.size_t(len(sig)), C.GO_RSA_PKCS1_PADDING) == 0 {
+               if pub.withKey(func(key *C.GO_RSA) C.int {
+                       return C._goboringcrypto_RSA_verify_raw(key, &outLen, base(out),
+                               C.size_t(len(out)), base(sig), C.size_t(len(sig)), C.GO_RSA_PKCS1_PADDING)
+               }) == 0 {
                        return fail("RSA_verify")
                }
                if subtle.ConstantTimeCompare(hashed, out[:outLen]) != 1 {
                        return fail("RSA_verify")
                }
-               runtime.KeepAlive(pub)
                return nil
        }
        md := cryptoHashToMD(h)
@@ -323,9 +348,11 @@ func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) err
                return errors.New("crypto/rsa: unsupported hash function")
        }
        nid := C._goboringcrypto_EVP_MD_type(md)
-       if C._goboringcrypto_RSA_verify(nid, base(hashed), C.size_t(len(hashed)), base(sig), C.size_t(len(sig)), pub.key) == 0 {
+       if pub.withKey(func(key *C.GO_RSA) C.int {
+               return C._goboringcrypto_RSA_verify(nid, base(hashed), C.size_t(len(hashed)),
+                       base(sig), C.size_t(len(sig)), key)
+       }) == 0 {
                return fail("RSA_verify")
        }
-       runtime.KeepAlive(pub)
        return nil
 }