]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rsa: implement PSS signatures.
authorNan Deng <monnand@gmail.com>
Thu, 23 May 2013 15:10:41 +0000 (11:10 -0400)
committerAdam Langley <agl@golang.org>
Thu, 23 May 2013 15:10:41 +0000 (11:10 -0400)
This change contains an implementation of the RSASSA-PSS signature
algorithm described in RFC 3447.

R=agl, agl
CC=gobot, golang-dev, r
https://golang.org/cl/9438043

src/pkg/crypto/rsa/pss.go [new file with mode: 0644]
src/pkg/crypto/rsa/pss_test.go [new file with mode: 0644]
src/pkg/crypto/rsa/rsa_test.go
src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 [new file with mode: 0644]

diff --git a/src/pkg/crypto/rsa/pss.go b/src/pkg/crypto/rsa/pss.go
new file mode 100644 (file)
index 0000000..f9abec3
--- /dev/null
@@ -0,0 +1,282 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+// This file implementes the PSS signature scheme [1].
+//
+// [1] http://www.rsa.com/rsalabs/pkcs/files/h11300-wp-pkcs-1v2-2-rsa-cryptography-standard.pdf
+
+import (
+       "bytes"
+       "crypto"
+       "errors"
+       "hash"
+       "io"
+       "math/big"
+)
+
+func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
+       // See [1], section 9.1.1
+       hLen := hash.Size()
+       sLen := len(salt)
+       emLen := (emBits + 7) / 8
+
+       // 1.  If the length of M is greater than the input limitation for the
+       //     hash function (2^61 - 1 octets for SHA-1), output "message too
+       //     long" and stop.
+       //
+       // 2.  Let mHash = Hash(M), an octet string of length hLen.
+
+       if len(mHash) != hLen {
+               return nil, errors.New("crypto/rsa: input must be hashed message")
+       }
+
+       // 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
+
+       if emLen < hLen+sLen+2 {
+               return nil, errors.New("crypto/rsa: encoding error")
+       }
+
+       em := make([]byte, emLen)
+       db := em[:emLen-sLen-hLen-2+1+sLen]
+       h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
+
+       // 4.  Generate a random octet string salt of length sLen; if sLen = 0,
+       //     then salt is the empty string.
+       //
+       // 5.  Let
+       //       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
+       //
+       //     M' is an octet string of length 8 + hLen + sLen with eight
+       //     initial zero octets.
+       //
+       // 6.  Let H = Hash(M'), an octet string of length hLen.
+
+       var prefix [8]byte
+
+       hash.Write(prefix[:])
+       hash.Write(mHash)
+       hash.Write(salt)
+
+       h = hash.Sum(h[:0])
+       hash.Reset()
+
+       // 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
+       //     zero octets.  The length of PS may be 0.
+       //
+       // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
+       //     emLen - hLen - 1.
+
+       db[emLen-sLen-hLen-2] = 0x01
+       copy(db[emLen-sLen-hLen-1:], salt)
+
+       // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
+       //
+       // 10. Let maskedDB = DB \xor dbMask.
+
+       mgf1XOR(db, hash, h)
+
+       // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
+       //     maskedDB to zero.
+
+       db[0] &= (0xFF >> uint(8*emLen-emBits))
+
+       // 12. Let EM = maskedDB || H || 0xbc.
+       em[emLen-1] = 0xBC
+
+       // 13. Output EM.
+       return em, nil
+}
+
+func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
+       // 1.  If the length of M is greater than the input limitation for the
+       //     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
+       //     and stop.
+       //
+       // 2.  Let mHash = Hash(M), an octet string of length hLen.
+       hLen := hash.Size()
+       if hLen != len(mHash) {
+               return ErrVerification
+       }
+
+       // 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
+       emLen := (emBits + 7) / 8
+       if emLen < hLen+sLen+2 {
+               return ErrVerification
+       }
+
+       // 4.  If the rightmost octet of EM does not have hexadecimal value
+       //     0xbc, output "inconsistent" and stop.
+       if em[len(em)-1] != 0xBC {
+               return ErrVerification
+       }
+
+       // 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
+       //     let H be the next hLen octets.
+       db := em[:emLen-hLen-1]
+       h := em[emLen-hLen-1 : len(em)-1]
+
+       // 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
+       //     maskedDB are not all equal to zero, output "inconsistent" and
+       //     stop.
+       if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
+               return ErrVerification
+       }
+
+       // 7.  Let dbMask = MGF(H, emLen - hLen - 1).
+       //
+       // 8.  Let DB = maskedDB \xor dbMask.
+       mgf1XOR(db, hash, h)
+
+       // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
+       //     to zero.
+       db[0] &= (0xFF >> uint(8*emLen-emBits))
+
+       if sLen == PSSSaltLengthAuto {
+       FindSaltLength:
+               for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
+                       switch db[emLen-hLen-sLen-2] {
+                       case 1:
+                               break FindSaltLength
+                       case 0:
+                               continue
+                       default:
+                               return ErrVerification
+                       }
+               }
+               if sLen < 0 {
+                       return ErrVerification
+               }
+       } else {
+               // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
+               //     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
+               //     position is "position 1") does not have hexadecimal value 0x01,
+               //     output "inconsistent" and stop.
+               for _, e := range db[:emLen-hLen-sLen-2] {
+                       if e != 0x00 {
+                               return ErrVerification
+                       }
+               }
+               if db[emLen-hLen-sLen-2] != 0x01 {
+                       return ErrVerification
+               }
+       }
+
+       // 11.  Let salt be the last sLen octets of DB.
+       salt := db[len(db)-sLen:]
+
+       // 12.  Let
+       //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
+       //     M' is an octet string of length 8 + hLen + sLen with eight
+       //     initial zero octets.
+       //
+       // 13. Let H' = Hash(M'), an octet string of length hLen.
+       var prefix [8]byte
+       hash.Write(prefix[:])
+       hash.Write(mHash)
+       hash.Write(salt)
+
+       h0 := hash.Sum(nil)
+
+       // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
+       if !bytes.Equal(h0, h) {
+               return ErrVerification
+       }
+       return nil
+}
+
+// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
+// Note that hashed must be the result of hashing the input message using the
+// given hash funcion. salt is a random sequence of bytes whose length will be
+// later used to verify the signature.
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
+       nBits := priv.N.BitLen()
+       em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
+       if err != nil {
+               return
+       }
+       m := new(big.Int).SetBytes(em)
+       c, err := decrypt(rand, priv, m)
+       if err != nil {
+               return
+       }
+       s = make([]byte, (nBits+7)/8)
+       copyWithLeftPad(s, c.Bytes())
+       return
+}
+
+const (
+       // PSSSaltLengthAuto causes the salt in a PSS signature to be as large
+       // as possible when signing, and to be auto-detected when verifying.
+       PSSSaltLengthAuto = 0
+       // PSSSaltLengthEqualsHash causes the salt length to equal the length
+       // of the hash used in the signature.
+       PSSSaltLengthEqualsHash = -1
+)
+
+// PSSOptions contains options for creating and verifying PSS signatures.
+type PSSOptions struct {
+       // SaltLength controls the length of the salt used in the PSS
+       // signature. It can either be a number of bytes, or one of the special
+       // PSSSaltLength constants.
+       SaltLength int
+}
+
+func (opts *PSSOptions) saltLength() int {
+       if opts == nil {
+               return PSSSaltLengthAuto
+       }
+       return opts.SaltLength
+}
+
+// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
+// Note that hashed must be the result of hashing the input message using the
+// given hash funcion. The opts argument may be nil, in which case sensible
+// defaults are used.
+func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) (s []byte, err error) {
+       saltLength := opts.saltLength()
+       switch saltLength {
+       case PSSSaltLengthAuto:
+               saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
+       case PSSSaltLengthEqualsHash:
+               saltLength = hash.Size()
+       }
+
+       salt := make([]byte, saltLength)
+       if _, err = io.ReadFull(rand, salt); err != nil {
+               return
+       }
+       return signPSSWithSalt(rand, priv, hash, hashed, salt)
+}
+
+// VerifyPSS verifies a PSS signature.
+// hashed is the result of hashing the input message using the given hash
+// function and sig is the signature. A valid signature is indicated by
+// returning a nil error. The opts argument may be nil, in which case sensible
+// defaults are used.
+func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
+       return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
+}
+
+// verifyPSS verifies a PSS signature with the given salt length.
+func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
+       nBits := pub.N.BitLen()
+       if len(sig) != (nBits+7)/8 {
+               return ErrVerification
+       }
+       s := new(big.Int).SetBytes(sig)
+       m := encrypt(new(big.Int), pub, s)
+       emBits := nBits - 1
+       emLen := (emBits + 7) / 8
+       if emLen < len(m.Bytes()) {
+               return ErrVerification
+       }
+       em := make([]byte, emLen)
+       copyWithLeftPad(em, m.Bytes())
+       if saltLen == PSSSaltLengthEqualsHash {
+               saltLen = hash.Size()
+       }
+       return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
+}
diff --git a/src/pkg/crypto/rsa/pss_test.go b/src/pkg/crypto/rsa/pss_test.go
new file mode 100644 (file)
index 0000000..32e6fc3
--- /dev/null
@@ -0,0 +1,249 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa
+
+import (
+       "bufio"
+       "bytes"
+       "compress/bzip2"
+       "crypto"
+       _ "crypto/md5"
+       "crypto/rand"
+       "crypto/sha1"
+       _ "crypto/sha256"
+       "encoding/hex"
+       "math/big"
+       "os"
+       "strconv"
+       "strings"
+       "testing"
+)
+
+func TestEMSAPSS(t *testing.T) {
+       // Test vector in file pss-int.txt from: ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+       msg := []byte{
+               0x85, 0x9e, 0xef, 0x2f, 0xd7, 0x8a, 0xca, 0x00, 0x30, 0x8b,
+               0xdc, 0x47, 0x11, 0x93, 0xbf, 0x55, 0xbf, 0x9d, 0x78, 0xdb,
+               0x8f, 0x8a, 0x67, 0x2b, 0x48, 0x46, 0x34, 0xf3, 0xc9, 0xc2,
+               0x6e, 0x64, 0x78, 0xae, 0x10, 0x26, 0x0f, 0xe0, 0xdd, 0x8c,
+               0x08, 0x2e, 0x53, 0xa5, 0x29, 0x3a, 0xf2, 0x17, 0x3c, 0xd5,
+               0x0c, 0x6d, 0x5d, 0x35, 0x4f, 0xeb, 0xf7, 0x8b, 0x26, 0x02,
+               0x1c, 0x25, 0xc0, 0x27, 0x12, 0xe7, 0x8c, 0xd4, 0x69, 0x4c,
+               0x9f, 0x46, 0x97, 0x77, 0xe4, 0x51, 0xe7, 0xf8, 0xe9, 0xe0,
+               0x4c, 0xd3, 0x73, 0x9c, 0x6b, 0xbf, 0xed, 0xae, 0x48, 0x7f,
+               0xb5, 0x56, 0x44, 0xe9, 0xca, 0x74, 0xff, 0x77, 0xa5, 0x3c,
+               0xb7, 0x29, 0x80, 0x2f, 0x6e, 0xd4, 0xa5, 0xff, 0xa8, 0xba,
+               0x15, 0x98, 0x90, 0xfc,
+       }
+       salt := []byte{
+               0xe3, 0xb5, 0xd5, 0xd0, 0x02, 0xc1, 0xbc, 0xe5, 0x0c, 0x2b,
+               0x65, 0xef, 0x88, 0xa1, 0x88, 0xd8, 0x3b, 0xce, 0x7e, 0x61,
+       }
+       expected := []byte{
+               0x66, 0xe4, 0x67, 0x2e, 0x83, 0x6a, 0xd1, 0x21, 0xba, 0x24,
+               0x4b, 0xed, 0x65, 0x76, 0xb8, 0x67, 0xd9, 0xa4, 0x47, 0xc2,
+               0x8a, 0x6e, 0x66, 0xa5, 0xb8, 0x7d, 0xee, 0x7f, 0xbc, 0x7e,
+               0x65, 0xaf, 0x50, 0x57, 0xf8, 0x6f, 0xae, 0x89, 0x84, 0xd9,
+               0xba, 0x7f, 0x96, 0x9a, 0xd6, 0xfe, 0x02, 0xa4, 0xd7, 0x5f,
+               0x74, 0x45, 0xfe, 0xfd, 0xd8, 0x5b, 0x6d, 0x3a, 0x47, 0x7c,
+               0x28, 0xd2, 0x4b, 0xa1, 0xe3, 0x75, 0x6f, 0x79, 0x2d, 0xd1,
+               0xdc, 0xe8, 0xca, 0x94, 0x44, 0x0e, 0xcb, 0x52, 0x79, 0xec,
+               0xd3, 0x18, 0x3a, 0x31, 0x1f, 0xc8, 0x96, 0xda, 0x1c, 0xb3,
+               0x93, 0x11, 0xaf, 0x37, 0xea, 0x4a, 0x75, 0xe2, 0x4b, 0xdb,
+               0xfd, 0x5c, 0x1d, 0xa0, 0xde, 0x7c, 0xec, 0xdf, 0x1a, 0x89,
+               0x6f, 0x9d, 0x8b, 0xc8, 0x16, 0xd9, 0x7c, 0xd7, 0xa2, 0xc4,
+               0x3b, 0xad, 0x54, 0x6f, 0xbe, 0x8c, 0xfe, 0xbc,
+       }
+
+       hash := sha1.New()
+       hash.Write(msg)
+       hashed := hash.Sum(nil)
+
+       encoded, err := emsaPSSEncode(hashed, 1023, salt, sha1.New())
+       if err != nil {
+               t.Errorf("Error from emsaPSSEncode: %s\n", err)
+       }
+       if !bytes.Equal(encoded, expected) {
+               t.Errorf("Bad encoding. got %x, want %x", encoded, expected)
+       }
+
+       if err = emsaPSSVerify(hashed, encoded, 1023, len(salt), sha1.New()); err != nil {
+               t.Errorf("Bad verification: %s", err)
+       }
+}
+
+// TestPSSGolden tests all the test vectors in pss-vect.txt from
+// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+func TestPSSGolden(t *testing.T) {
+       inFile, err := os.Open("testdata/pss-vect.txt.bz2")
+       if err != nil {
+               t.Fatalf("Failed to open input file: %s", err)
+       }
+       defer inFile.Close()
+
+       // The pss-vect.txt file contains RSA keys and then a series of
+       // signatures. A goroutine is used to preprocess the input by merging
+       // lines, removing spaces in hex values and identifying the start of
+       // new keys and signature blocks.
+       const newKeyMarker = "START NEW KEY"
+       const newSignatureMarker = "START NEW SIGNATURE"
+
+       values := make(chan string)
+
+       go func() {
+               defer close(values)
+               scanner := bufio.NewScanner(bzip2.NewReader(inFile))
+               var partialValue string
+               lastWasValue := true
+
+               for scanner.Scan() {
+                       line := scanner.Text()
+                       switch {
+                       case len(line) == 0:
+                               if len(partialValue) > 0 {
+                                       values <- strings.Replace(partialValue, " ", "", -1)
+                                       partialValue = ""
+                                       lastWasValue = true
+                               }
+                               continue
+                       case strings.HasPrefix(line, "# ======") && lastWasValue:
+                               values <- newKeyMarker
+                               lastWasValue = false
+                       case strings.HasPrefix(line, "# ------") && lastWasValue:
+                               values <- newSignatureMarker
+                               lastWasValue = false
+                       case strings.HasPrefix(line, "#"):
+                               continue
+                       default:
+                               partialValue += line
+                       }
+               }
+               if err := scanner.Err(); err != nil {
+                       panic(err)
+               }
+       }()
+
+       var key *PublicKey
+       var hashed []byte
+       hash := crypto.SHA1
+       h := hash.New()
+       opts := &PSSOptions{
+               SaltLength: PSSSaltLengthEqualsHash,
+       }
+
+       for marker := range values {
+               switch marker {
+               case newKeyMarker:
+                       key = new(PublicKey)
+                       nHex, ok := <-values
+                       if !ok {
+                               continue
+                       }
+                       key.N = bigFromHex(nHex)
+                       key.E = intFromHex(<-values)
+                       // We don't care for d, p, q, dP, dQ or qInv.
+                       for i := 0; i < 6; i++ {
+                               <-values
+                       }
+               case newSignatureMarker:
+                       msg := fromHex(<-values)
+                       <-values // skip salt
+                       sig := fromHex(<-values)
+
+                       h.Reset()
+                       h.Write(msg)
+                       hashed = h.Sum(hashed[:0])
+
+                       if err := VerifyPSS(key, hash, hashed, sig, opts); err != nil {
+                               t.Error(err)
+                       }
+               default:
+                       t.Fatalf("unknown marker: " + marker)
+               }
+       }
+}
+
+// TestPSSOpenSSL ensures that we can verify a PSS signature from OpenSSL with
+// the default options. OpenSSL sets the salt length to be maximal.
+func TestPSSOpenSSL(t *testing.T) {
+       hash := crypto.SHA256
+       h := hash.New()
+       h.Write([]byte("testing"))
+       hashed := h.Sum(nil)
+
+       // Generated with `echo -n testing | openssl dgst -sign key.pem -sigopt rsa_padding_mode:pss -sha256 > sig`
+       sig := []byte{
+               0x95, 0x59, 0x6f, 0xd3, 0x10, 0xa2, 0xe7, 0xa2, 0x92, 0x9d,
+               0x4a, 0x07, 0x2e, 0x2b, 0x27, 0xcc, 0x06, 0xc2, 0x87, 0x2c,
+               0x52, 0xf0, 0x4a, 0xcc, 0x05, 0x94, 0xf2, 0xc3, 0x2e, 0x20,
+               0xd7, 0x3e, 0x66, 0x62, 0xb5, 0x95, 0x2b, 0xa3, 0x93, 0x9a,
+               0x66, 0x64, 0x25, 0xe0, 0x74, 0x66, 0x8c, 0x3e, 0x92, 0xeb,
+               0xc6, 0xe6, 0xc0, 0x44, 0xf3, 0xb4, 0xb4, 0x2e, 0x8c, 0x66,
+               0x0a, 0x37, 0x9c, 0x69,
+       }
+
+       if err := VerifyPSS(&rsaPrivateKey.PublicKey, hash, hashed, sig, nil); err != nil {
+               t.Error(err)
+       }
+}
+
+func TestPSSSigning(t *testing.T) {
+       var saltLengthCombinations = []struct {
+               signSaltLength, verifySaltLength int
+               good                             bool
+       }{
+               {PSSSaltLengthAuto, PSSSaltLengthAuto, true},
+               {PSSSaltLengthEqualsHash, PSSSaltLengthAuto, true},
+               {PSSSaltLengthEqualsHash, PSSSaltLengthEqualsHash, true},
+               {PSSSaltLengthEqualsHash, 8, false},
+               {PSSSaltLengthAuto, PSSSaltLengthEqualsHash, false},
+               {8, 8, true},
+       }
+
+       hash := crypto.MD5
+       h := hash.New()
+       h.Write([]byte("testing"))
+       hashed := h.Sum(nil)
+       var opts PSSOptions
+
+       for i, test := range saltLengthCombinations {
+               opts.SaltLength = test.signSaltLength
+               sig, err := SignPSS(rand.Reader, rsaPrivateKey, hash, hashed, &opts)
+               if err != nil {
+                       t.Errorf("#%d: error while signing: %s", i, err)
+                       continue
+               }
+
+               opts.SaltLength = test.verifySaltLength
+               err = VerifyPSS(&rsaPrivateKey.PublicKey, hash, hashed, sig, &opts)
+               if (err == nil) != test.good {
+                       t.Errorf("#%d: bad result, wanted: %t, got: %s", i, test.good, err)
+               }
+       }
+}
+
+func bigFromHex(hex string) *big.Int {
+       n, ok := new(big.Int).SetString(hex, 16)
+       if !ok {
+               panic("bad hex: " + hex)
+       }
+       return n
+}
+
+func intFromHex(hex string) int {
+       i, err := strconv.ParseInt(hex, 16, 32)
+       if err != nil {
+               panic(err)
+       }
+       return int(i)
+}
+
+func fromHex(hexStr string) []byte {
+       s, err := hex.DecodeString(hexStr)
+       if err != nil {
+               panic(err)
+       }
+       return s
+}
index ffd96e62f64b72ad5cf74d67ae43e117a6c725a1..cf193c669f3d7e452cbc525226af28a230d5e46a 100644 (file)
@@ -120,8 +120,10 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) {
 }
 
 func fromBase10(base10 string) *big.Int {
-       i := new(big.Int)
-       i.SetString(base10, 10)
+       i, ok := new(big.Int).SetString(base10, 10)
+       if !ok {
+               panic("bad number: " + base10)
+       }
        return i
 }
 
diff --git a/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 b/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2
new file mode 100644 (file)
index 0000000..ad3da1a
Binary files /dev/null and b/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 differ