]> Cypherpunks repositories - gostls13.git/commitdiff
Add RSA PKCS#1 v1.5 support.
authorAdam Langley <agl@golang.org>
Fri, 30 Oct 2009 00:38:25 +0000 (17:38 -0700)
committerAdam Langley <agl@golang.org>
Fri, 30 Oct 2009 00:38:25 +0000 (17:38 -0700)
R=go-dev
APPROVED=rsc
DELTA=407  (400 added, 0 deleted, 7 changed)
OCL=36007
CL=36146

src/pkg/crypto/rsa/Makefile
src/pkg/crypto/rsa/pkcs1v15.go [new file with mode: 0644]
src/pkg/crypto/rsa/pkcs1v15_test.go [new file with mode: 0644]
src/pkg/crypto/rsa/rsa.go
src/pkg/crypto/rsa/rsa_test.go

index dd501dfad4ebb9f3a9b8069c74f1183c5e80f1d5..ef0bd7c496b0ee80fe05b59ec3ce3776e5314b39 100644 (file)
@@ -7,5 +7,6 @@ include $(GOROOT)/src/Make.$(GOARCH)
 TARG=crypto/rsa
 GOFILES=\
        rsa.go\
+       pkcs1v15.go\
 
 include $(GOROOT)/src/Make.pkg
diff --git a/src/pkg/crypto/rsa/pkcs1v15.go b/src/pkg/crypto/rsa/pkcs1v15.go
new file mode 100644 (file)
index 0000000..9fb4584
--- /dev/null
@@ -0,0 +1,137 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+import (
+       "bytes";
+       big "gmp";
+       "io";
+       "os";
+)
+
+// This file implements encryption and decryption using PKCS#1 v1.5 padding.
+
+// EncryptPKCS1v15 encrypts the given message with RSA and the padding scheme from PKCS#1 v1.5.
+// The message must be no longer than the length of the public modulus minus 11 bytes.
+// WARNING: use of this function to encrypt plaintexts other than session keys
+// is dangerous. Use RSA OAEP in new protocols.
+func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, err os.Error) {
+       k := (pub.N.Len() + 7)/8;
+       if len(msg) > k-11 {
+               err = MessageTooLongError{};
+               return;
+       }
+
+       // EM = 0x02 || PS || 0x00 || M
+       em := make([]byte, k-1);
+       em[0] = 2;
+       ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):len(em)];
+       err = nonZeroRandomBytes(ps, rand);
+       if err != nil {
+               return;
+       }
+       em[len(em)-len(msg)-1] = 0;
+       bytes.Copy(mm, msg);
+
+       m := new(big.Int).SetBytes(em);
+       c := encrypt(new(big.Int), pub, m);
+       out = c.Bytes();
+       return;
+}
+
+// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
+// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks.
+func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out []byte, err os.Error) {
+       valid, out, err := decryptPKCS1v15(rand, priv, ciphertext);
+       if err == nil && valid == 0 {
+               err = DecryptionError{};
+       }
+
+       return;
+}
+
+// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS#1 v1.5.
+// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks.
+// It returns an error if the ciphertext is the wrong length or if the
+// ciphertext is greater than the public modulus. Otherwise, no error is
+// returned. If the padding is valid, the resulting plaintext message is copied
+// into key. Otherwise, key is unchanged. These alternatives occur in constant
+// time. It is intended that the user of this function generate a random
+// session key beforehand and continue the protocol with the resulting value.
+// This will remove any possibility that an attacker can learn any information
+// about the plaintext.
+// See ``Chosen Ciphertext Attacks Against Protocols Based on the RSA
+// Encryption Standard PKCS #1'', Daniel Bleichenbacher, Advances in Cryptology
+// (Crypto '98),
+func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) (err os.Error) {
+       k := (priv.N.Len() + 7)/8;
+       if k-(len(key)+3+8) < 0 {
+               err = DecryptionError{};
+               return;
+       }
+
+       valid, msg, err := decryptPKCS1v15(rand, priv, ciphertext);
+       if err != nil {
+               return;
+       }
+
+       valid &= constantTimeEq(int32(len(msg)), int32(len(key)));
+       constantTimeCopy(valid, key, msg);
+       return;
+}
+
+func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, msg []byte, err os.Error) {
+       k := (priv.N.Len() + 7)/8;
+       if k < 11 {
+               err = DecryptionError{};
+               return;
+       }
+
+       c := new(big.Int).SetBytes(ciphertext);
+       m, err := decrypt(rand, priv, c);
+       if err != nil {
+               return;
+       }
+
+       em := leftPad(m.Bytes(), k);
+       firstByteIsZero := constantTimeByteEq(em[0], 0);
+       secondByteIsTwo := constantTimeByteEq(em[1], 2);
+
+       // The remainder of the plaintext must be a string of non-zero random
+       // octets, followed by a 0, followed by the message.
+       //   lookingForIndex: 1 iff we are still looking for the zero.
+       //   index: the offset of the first zero byte.
+       var lookingForIndex, index int;
+       lookingForIndex = 1;
+
+       for i := 2; i < len(em); i++ {
+               equals0 := constantTimeByteEq(em[i], 0);
+               index = constantTimeSelect(lookingForIndex & equals0, i, index);
+               lookingForIndex = constantTimeSelect(equals0, 0, lookingForIndex);
+       }
+
+       valid = firstByteIsZero & secondByteIsTwo & (^lookingForIndex & 1);
+       msg = em[index+1 : len(em)];
+       return;
+}
+
+// nonZeroRandomBytes fills the given slice with non-zero random octets.
+func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) {
+       _, err = io.ReadFull(rand, s);
+       if err != nil {
+               return;
+       }
+
+       for i := 0; i < len(s); i++ {
+               for s[i] == 0 {
+                       _, err = rand.Read(s[i:i+1]);
+                       if err != nil {
+                               return;
+                       }
+               }
+       }
+
+       return;
+}
diff --git a/src/pkg/crypto/rsa/pkcs1v15_test.go b/src/pkg/crypto/rsa/pkcs1v15_test.go
new file mode 100644 (file)
index 0000000..a062bc4
--- /dev/null
@@ -0,0 +1,182 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+import (
+       "bytes";
+       "encoding/base64";
+       big "gmp";
+       "os";
+       "io";
+       "strings";
+       "testing";
+       "testing/quick";
+)
+
+func decodeBase64(in string) []byte {
+       out := make([]byte, base64.StdEncoding.DecodedLen(len(in)));
+       n, err := base64.StdEncoding.Decode(strings.Bytes(in), out);
+       if err != nil {
+               return nil;
+       }
+       return out[0:n];
+}
+
+type DecryptPKCS1v15Test struct {
+       in, out string;
+}
+
+// These test vectors were generated with `openssl rsautl -pkcs -encrypt`
+var decryptPKCS1v15Tests = []DecryptPKCS1v15Test{
+       DecryptPKCS1v15Test{
+               "gIcUIoVkD6ATMBk/u/nlCZCCWRKdkfjCgFdo35VpRXLduiKXhNz1XupLLzTXAybEq15juc+EgY5o0DHv/nt3yg==",
+               "x",
+       },
+       DecryptPKCS1v15Test{
+               "Y7TOCSqofGhkRb+jaVRLzK8xw2cSo1IVES19utzv6hwvx+M8kFsoWQm5DzBeJCZTCVDPkTpavUuEbgp8hnUGDw==",
+               "testing.",
+       },
+       DecryptPKCS1v15Test{
+               "arReP9DJtEVyV2Dg3dDp4c/PSk1O6lxkoJ8HcFupoRorBZG+7+1fDAwT1olNddFnQMjmkb8vxwmNMoTAT/BFjQ==",
+               "testing.\n",
+       },
+       DecryptPKCS1v15Test{
+               "WtaBXIoGC54+vH0NH0CHHE+dRDOsMc/6BrfFu2lEqcKL9+uDuWaf+Xj9mrbQCjjZcpQuX733zyok/jsnqe/Ftw==",
+               "01234567890123456789012345678901234567890123456789012",
+       },
+}
+
+func TestDecryptPKCS1v15(t *testing.T) {
+       for i, test := range decryptPKCS1v15Tests {
+               out, err := DecryptPKCS1v15(nil, rsaPrivateKey, decodeBase64(test.in));
+               if err != nil {
+                       t.Errorf("#%d error decrypting", i);
+               }
+               want := strings.Bytes(test.out);
+               if bytes.Compare(out, want) != 0 {
+                       t.Errorf("#%d got:%#v want:%#v", i, out, want);
+               }
+       }
+}
+
+func TestEncryptPKCS1v15(t *testing.T) {
+       urandom, err := os.Open("/dev/urandom", os.O_RDONLY, 0);
+       if err != nil {
+               t.Errorf("Failed to open /dev/urandom");
+       }
+       k := (rsaPrivateKey.N.Len() + 7)/8;
+
+       tryEncryptDecrypt := func(in []byte, blind bool) bool {
+               if len(in) > k-11 {
+                       in = in[0 : k-11];
+               }
+
+               ciphertext, err := EncryptPKCS1v15(urandom, &rsaPrivateKey.PublicKey, in);
+               if err != nil {
+                       t.Errorf("error encrypting: %s", err);
+                       return false;
+               }
+
+               var rand io.Reader;
+               if !blind {
+                       rand = nil;
+               } else {
+                       rand = urandom;
+               }
+               plaintext, err := DecryptPKCS1v15(rand, rsaPrivateKey, ciphertext);
+               if err != nil {
+                       t.Errorf("error decrypting: %s", err);
+                       return false;
+               }
+
+               if bytes.Compare(plaintext, in) != 0 {
+                       t.Errorf("output mismatch: %#v %#v", plaintext, in);
+                       return false;
+               }
+               return true;
+       };
+
+       quick.Check(tryEncryptDecrypt, nil);
+}
+
+// These test vectors were generated with `openssl rsautl -pkcs -encrypt`
+var decryptPKCS1v15SessionKeyTests = []DecryptPKCS1v15Test{
+       DecryptPKCS1v15Test{
+               "e6ukkae6Gykq0fKzYwULpZehX+UPXYzMoB5mHQUDEiclRbOTqas4Y0E6nwns1BBpdvEJcilhl5zsox/6DtGsYg==",
+               "1234",
+       },
+       DecryptPKCS1v15Test{
+               "Dtis4uk/q/LQGGqGk97P59K03hkCIVFMEFZRgVWOAAhxgYpCRG0MX2adptt92l67IqMki6iVQyyt0TtX3IdtEw==",
+               "FAIL",
+       },
+       DecryptPKCS1v15Test{
+               "LIyFyCYCptPxrvTxpol8F3M7ZivlMsf53zs0vHRAv+rDIh2YsHS69ePMoPMe3TkOMZ3NupiL3takPxIs1sK+dw==",
+               "abcd",
+       },
+       DecryptPKCS1v15Test{
+               "bafnobel46bKy76JzqU/RIVOH0uAYvzUtauKmIidKgM0sMlvobYVAVQPeUQ/oTGjbIZ1v/6Gyi5AO4DtHruGdw==",
+               "FAIL",
+       },
+}
+
+func TestEncryptPKCS1v15SessionKey(t *testing.T) {
+       for i, test := range decryptPKCS1v15SessionKeyTests {
+               key := strings.Bytes("FAIL");
+               err := DecryptPKCS1v15SessionKey(nil, rsaPrivateKey, decodeBase64(test.in), key);
+               if err != nil {
+                       t.Errorf("#%d error decrypting", i);
+               }
+               want := strings.Bytes(test.out);
+               if bytes.Compare(key, want) != 0 {
+                       t.Errorf("#%d got:%#v want:%#v", i, key, want);
+               }
+       }
+}
+
+func TestNonZeroRandomBytes(t *testing.T) {
+       urandom, err := os.Open("/dev/urandom", os.O_RDONLY, 0);
+       if err != nil {
+               t.Errorf("Failed to open /dev/urandom");
+       }
+
+       b := make([]byte, 512);
+       err = nonZeroRandomBytes(b, urandom);
+       if err != nil {
+               t.Errorf("returned error: %s", err);
+       }
+       for _, b := range b {
+               if b == 0 {
+                       t.Errorf("Zero octet found");
+                       return;
+               }
+       }
+}
+
+func bigFromString(s string) *big.Int {
+       ret := new(big.Int);
+       ret.SetString(s, 10);
+       return ret;
+}
+
+// In order to generate new test vectors you'll need the PEM form of this key:
+// -----BEGIN RSA PRIVATE KEY-----
+// MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0
+// fd7Ai2KW5ToIwzFofvJcS/STa6HA5gQenRUCAwEAAQJBAIq9amn00aS0h/CrjXqu
+// /ThglAXJmZhOMPVn4eiu7/ROixi9sex436MaVeMqSNf7Ex9a8fRNfWss7Sqd9eWu
+// RTUCIQDasvGASLqmjeffBNLTXV2A5g4t+kLVCpsEIZAycV5GswIhANEPLmax0ME/
+// EO+ZJ79TJKN5yiGBRsv5yvx5UiHxajEXAiAhAol5N4EUyq6I9w1rYdhPMGpLfk7A
+// IU2snfRJ6Nq2CQIgFrPsWRCkV+gOYcajD17rEqmuLrdIRexpg8N1DOSXoJ8CIGlS
+// tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V
+// -----END RSA PRIVATE KEY-----
+
+var rsaPrivateKey = &PrivateKey{
+       PublicKey: PublicKey{
+               N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"),
+               E: 65537,
+       },
+       D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"),
+       P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"),
+       Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"),
+}
index 0873ff544921f403175db668735c07a8a63ac347..adc9c8f021ea263cb9b8d0657adf165635ba62ad 100644 (file)
@@ -313,14 +313,14 @@ func constantTimeCompare(x, y []byte) int {
        return constantTimeByteEq(v, 0);
 }
 
-// constantTimeSelect returns a if mask is 1 and b if mask is 0.
-// Its behaviour is undefined if mask takes any other value.
-func constantTimeSelect(mask, a, b int) int {
-       return ^(mask-1)&a | (mask-1)&b;
+// constantTimeSelect returns a if v is 1 and b if v is 0.
+// Its behaviour is undefined if v takes any other value.
+func constantTimeSelect(v, a, b int) int {
+       return ^(v-1)&a | (v-1)&b;
 }
 
 // constantTimeByteEq returns 1 if a == b and 0 otherwise.
-func constantTimeByteEq(a, b uint8) (mask int) {
+func constantTimeByteEq(a, b uint8) int {
        x := ^(a^b);
        x &= x>>4;
        x &= x>>2;
@@ -329,6 +329,29 @@ func constantTimeByteEq(a, b uint8) (mask int) {
        return int(x);
 }
 
+// constantTimeEq returns 1 if a == b and 0 otherwise.
+func constantTimeEq(a, b int32) int {
+       x := ^(a^b);
+       x &= x>>16;
+       x &= x>>8;
+       x &= x>>4;
+       x &= x>>2;
+       x &= x>>1;
+
+       return int(x&1);
+}
+
+// constantTimeCopy copies the contents of y into x iff v == 1. If v == 0, x is left unchanged.
+// Its behaviour is undefined if v takes any other value.
+func constantTimeCopy(v int, x, y []byte) {
+       xmask := byte(v - 1);
+       ymask := byte(^(v - 1));
+       for i := 0; i < len(x); i++ {
+               x[i] = x[i] & xmask | y[i] & ymask;
+       }
+       return;
+}
+
 // decrypt performs an RSA decryption, resulting in a plaintext integer. If a
 // random source is given, RSA blinding is used.
 func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) {
@@ -345,8 +368,9 @@ func decrypt(rand io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.E
                // which equals mr mod n. The factor of r can then be removed
                // by multipling by the multiplicative inverse of r.
 
-               r, err := randomNumber(rand, priv.N);
-               if err != nil {
+               r, err1 := randomNumber(rand, priv.N);
+               if err1 != nil {
+                       err = err1;
                        return;
                }
                ir = modInverse(r, priv.N);
index a30982c940963c8a4430b9686294a7717d74902e..df0d160ddea84cb44ad5145086a4f8d6faf9ae28 100644 (file)
@@ -10,6 +10,7 @@ import (
        big     "gmp";
                "os";
                "testing";
+               "testing/quick";
 )
 
 func TestKeyGeneration(t *testing.T) {
@@ -140,12 +141,67 @@ var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{
        TestConstantTimeByteEqStruct{0xff, 0xfe, 0},
 }
 
+func ByteEq(a, b uint8) int {
+       if a == b {
+               return 1;
+       }
+       return 0;
+}
+
 func TestConstantTimeByteEq(t *testing.T) {
        for i, test := range testConstandTimeByteEqData {
                if r := constantTimeByteEq(test.a, test.b); r != test.out {
                        t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out);
                }
        }
+       err := quick.CheckEqual(constantTimeByteEq, ByteEq, nil);
+       if err != nil {
+               t.Error(err);
+       }
+}
+
+func Eq(a, b int32) int {
+       if a == b {
+               return 1;
+       }
+       return 0;
+}
+
+func TestConstantTimeEq(t *testing.T) {
+       err := quick.CheckEqual(constantTimeEq, Eq, nil);
+       if err != nil {
+               t.Error(err);
+       }
+}
+
+func Copy(v int, x, y []byte) []byte {
+       if len(x) > len(y) {
+               x = x[0:len(y)];
+       } else {
+               y = y[0:len(x)];
+       }
+       if v == 1 {
+               bytes.Copy(x, y);
+       }
+       return x;
+}
+
+func constantTimeCopyWrapper(v int, x, y []byte) []byte {
+       if len(x) > len(y) {
+               x = x[0:len(y)];
+       } else {
+               y = y[0:len(x)];
+       }
+       v &= 1;
+       constantTimeCopy(v, x, y);
+       return x;
+}
+
+func TestConstantTimeCopy(t *testing.T) {
+       err := quick.CheckEqual(constantTimeCopyWrapper, Copy, nil);
+       if err != nil {
+               t.Error(err);
+       }
 }
 
 // testEncryptOAEPData contains a subset of the vectors from RSA's "Test vectors for RSA-OAEP".