From: Nan Deng Date: Thu, 23 May 2013 15:10:41 +0000 (-0400) Subject: crypto/rsa: implement PSS signatures. X-Git-Tag: go1.2rc2~1427 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=876455f3ba0e3ee66e177cf901ff5ea9c5aa9f07;p=gostls13.git crypto/rsa: implement PSS signatures. This change contains an implementation of the RSASSA-PSS signature algorithm described in RFC 3447. R=agl, agl CC=gobot, golang-dev, r https://golang.org/cl/9438043 --- diff --git a/src/pkg/crypto/rsa/pss.go b/src/pkg/crypto/rsa/pss.go new file mode 100644 index 0000000000..f9abec3949 --- /dev/null +++ b/src/pkg/crypto/rsa/pss.go @@ -0,0 +1,282 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rsa + +// This file implementes the PSS signature scheme [1]. +// +// [1] http://www.rsa.com/rsalabs/pkcs/files/h11300-wp-pkcs-1v2-2-rsa-cryptography-standard.pdf + +import ( + "bytes" + "crypto" + "errors" + "hash" + "io" + "math/big" +) + +func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { + // See [1], section 9.1.1 + hLen := hash.Size() + sLen := len(salt) + emLen := (emBits + 7) / 8 + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + + if len(mHash) != hLen { + return nil, errors.New("crypto/rsa: input must be hashed message") + } + + // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. + + if emLen < hLen+sLen+2 { + return nil, errors.New("crypto/rsa: encoding error") + } + + em := make([]byte, emLen) + db := em[:emLen-sLen-hLen-2+1+sLen] + h := em[emLen-sLen-hLen-2+1+sLen : emLen-1] + + // 4. Generate a random octet string salt of length sLen; if sLen = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + // + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length hLen. + + var prefix [8]byte + + hash.Write(prefix[:]) + hash.Write(mHash) + hash.Write(salt) + + h = hash.Sum(h[:0]) + hash.Reset() + + // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + + db[emLen-sLen-hLen-2] = 0x01 + copy(db[emLen-sLen-hLen-1:], salt) + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + + mgf1XOR(db, hash, h) + + // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in + // maskedDB to zero. + + db[0] &= (0xFF >> uint(8*emLen-emBits)) + + // 12. Let EM = maskedDB || H || 0xbc. + em[emLen-1] = 0xBC + + // 13. Output EM. + return em, nil +} + +func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" + // and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + hLen := hash.Size() + if hLen != len(mHash) { + return ErrVerification + } + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + emLen := (emBits + 7) / 8 + if emLen < hLen+sLen+2 { + return ErrVerification + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if em[len(em)-1] != 0xBC { + return ErrVerification + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and + // let H be the next hLen octets. + db := em[:emLen-hLen-1] + h := em[emLen-hLen-1 : len(em)-1] + + // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + if em[0]&(0xFF<> uint(8*emLen-emBits)) + + if sLen == PSSSaltLengthAuto { + FindSaltLength: + for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- { + switch db[emLen-hLen-sLen-2] { + case 1: + break FindSaltLength + case 0: + continue + default: + return ErrVerification + } + } + if sLen < 0 { + return ErrVerification + } + } else { + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero + // or if the octet at position emLen - hLen - sLen - 1 (the leftmost + // position is "position 1") does not have hexadecimal value 0x01, + // output "inconsistent" and stop. + for _, e := range db[:emLen-hLen-sLen-2] { + if e != 0x00 { + return ErrVerification + } + } + if db[emLen-hLen-sLen-2] != 0x01 { + return ErrVerification + } + } + + // 11. Let salt be the last sLen octets of DB. + salt := db[len(db)-sLen:] + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 13. Let H' = Hash(M'), an octet string of length hLen. + var prefix [8]byte + hash.Write(prefix[:]) + hash.Write(mHash) + hash.Write(salt) + + h0 := hash.Sum(nil) + + // 14. If H = H', output "consistent." Otherwise, output "inconsistent." + if !bytes.Equal(h0, h) { + return ErrVerification + } + return nil +} + +// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt. +// Note that hashed must be the result of hashing the input message using the +// given hash funcion. salt is a random sequence of bytes whose length will be +// later used to verify the signature. +func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) { + nBits := priv.N.BitLen() + em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New()) + if err != nil { + return + } + m := new(big.Int).SetBytes(em) + c, err := decrypt(rand, priv, m) + if err != nil { + return + } + s = make([]byte, (nBits+7)/8) + copyWithLeftPad(s, c.Bytes()) + return +} + +const ( + // PSSSaltLengthAuto causes the salt in a PSS signature to be as large + // as possible when signing, and to be auto-detected when verifying. + PSSSaltLengthAuto = 0 + // PSSSaltLengthEqualsHash causes the salt length to equal the length + // of the hash used in the signature. + PSSSaltLengthEqualsHash = -1 +) + +// PSSOptions contains options for creating and verifying PSS signatures. +type PSSOptions struct { + // SaltLength controls the length of the salt used in the PSS + // signature. It can either be a number of bytes, or one of the special + // PSSSaltLength constants. + SaltLength int +} + +func (opts *PSSOptions) saltLength() int { + if opts == nil { + return PSSSaltLengthAuto + } + return opts.SaltLength +} + +// SignPSS calculates the signature of hashed using RSASSA-PSS [1]. +// Note that hashed must be the result of hashing the input message using the +// given hash funcion. The opts argument may be nil, in which case sensible +// defaults are used. +func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) (s []byte, err error) { + saltLength := opts.saltLength() + switch saltLength { + case PSSSaltLengthAuto: + saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size() + case PSSSaltLengthEqualsHash: + saltLength = hash.Size() + } + + salt := make([]byte, saltLength) + if _, err = io.ReadFull(rand, salt); err != nil { + return + } + return signPSSWithSalt(rand, priv, hash, hashed, salt) +} + +// VerifyPSS verifies a PSS signature. +// hashed is the result of hashing the input message using the given hash +// function and sig is the signature. A valid signature is indicated by +// returning a nil error. The opts argument may be nil, in which case sensible +// defaults are used. +func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error { + return verifyPSS(pub, hash, hashed, sig, opts.saltLength()) +} + +// verifyPSS verifies a PSS signature with the given salt length. +func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error { + nBits := pub.N.BitLen() + if len(sig) != (nBits+7)/8 { + return ErrVerification + } + s := new(big.Int).SetBytes(sig) + m := encrypt(new(big.Int), pub, s) + emBits := nBits - 1 + emLen := (emBits + 7) / 8 + if emLen < len(m.Bytes()) { + return ErrVerification + } + em := make([]byte, emLen) + copyWithLeftPad(em, m.Bytes()) + if saltLen == PSSSaltLengthEqualsHash { + saltLen = hash.Size() + } + return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New()) +} diff --git a/src/pkg/crypto/rsa/pss_test.go b/src/pkg/crypto/rsa/pss_test.go new file mode 100644 index 0000000000..32e6fc39d2 --- /dev/null +++ b/src/pkg/crypto/rsa/pss_test.go @@ -0,0 +1,249 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rsa + +import ( + "bufio" + "bytes" + "compress/bzip2" + "crypto" + _ "crypto/md5" + "crypto/rand" + "crypto/sha1" + _ "crypto/sha256" + "encoding/hex" + "math/big" + "os" + "strconv" + "strings" + "testing" +) + +func TestEMSAPSS(t *testing.T) { + // Test vector in file pss-int.txt from: ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip + msg := []byte{ + 0x85, 0x9e, 0xef, 0x2f, 0xd7, 0x8a, 0xca, 0x00, 0x30, 0x8b, + 0xdc, 0x47, 0x11, 0x93, 0xbf, 0x55, 0xbf, 0x9d, 0x78, 0xdb, + 0x8f, 0x8a, 0x67, 0x2b, 0x48, 0x46, 0x34, 0xf3, 0xc9, 0xc2, + 0x6e, 0x64, 0x78, 0xae, 0x10, 0x26, 0x0f, 0xe0, 0xdd, 0x8c, + 0x08, 0x2e, 0x53, 0xa5, 0x29, 0x3a, 0xf2, 0x17, 0x3c, 0xd5, + 0x0c, 0x6d, 0x5d, 0x35, 0x4f, 0xeb, 0xf7, 0x8b, 0x26, 0x02, + 0x1c, 0x25, 0xc0, 0x27, 0x12, 0xe7, 0x8c, 0xd4, 0x69, 0x4c, + 0x9f, 0x46, 0x97, 0x77, 0xe4, 0x51, 0xe7, 0xf8, 0xe9, 0xe0, + 0x4c, 0xd3, 0x73, 0x9c, 0x6b, 0xbf, 0xed, 0xae, 0x48, 0x7f, + 0xb5, 0x56, 0x44, 0xe9, 0xca, 0x74, 0xff, 0x77, 0xa5, 0x3c, + 0xb7, 0x29, 0x80, 0x2f, 0x6e, 0xd4, 0xa5, 0xff, 0xa8, 0xba, + 0x15, 0x98, 0x90, 0xfc, + } + salt := []byte{ + 0xe3, 0xb5, 0xd5, 0xd0, 0x02, 0xc1, 0xbc, 0xe5, 0x0c, 0x2b, + 0x65, 0xef, 0x88, 0xa1, 0x88, 0xd8, 0x3b, 0xce, 0x7e, 0x61, + } + expected := []byte{ + 0x66, 0xe4, 0x67, 0x2e, 0x83, 0x6a, 0xd1, 0x21, 0xba, 0x24, + 0x4b, 0xed, 0x65, 0x76, 0xb8, 0x67, 0xd9, 0xa4, 0x47, 0xc2, + 0x8a, 0x6e, 0x66, 0xa5, 0xb8, 0x7d, 0xee, 0x7f, 0xbc, 0x7e, + 0x65, 0xaf, 0x50, 0x57, 0xf8, 0x6f, 0xae, 0x89, 0x84, 0xd9, + 0xba, 0x7f, 0x96, 0x9a, 0xd6, 0xfe, 0x02, 0xa4, 0xd7, 0x5f, + 0x74, 0x45, 0xfe, 0xfd, 0xd8, 0x5b, 0x6d, 0x3a, 0x47, 0x7c, + 0x28, 0xd2, 0x4b, 0xa1, 0xe3, 0x75, 0x6f, 0x79, 0x2d, 0xd1, + 0xdc, 0xe8, 0xca, 0x94, 0x44, 0x0e, 0xcb, 0x52, 0x79, 0xec, + 0xd3, 0x18, 0x3a, 0x31, 0x1f, 0xc8, 0x96, 0xda, 0x1c, 0xb3, + 0x93, 0x11, 0xaf, 0x37, 0xea, 0x4a, 0x75, 0xe2, 0x4b, 0xdb, + 0xfd, 0x5c, 0x1d, 0xa0, 0xde, 0x7c, 0xec, 0xdf, 0x1a, 0x89, + 0x6f, 0x9d, 0x8b, 0xc8, 0x16, 0xd9, 0x7c, 0xd7, 0xa2, 0xc4, + 0x3b, 0xad, 0x54, 0x6f, 0xbe, 0x8c, 0xfe, 0xbc, + } + + hash := sha1.New() + hash.Write(msg) + hashed := hash.Sum(nil) + + encoded, err := emsaPSSEncode(hashed, 1023, salt, sha1.New()) + if err != nil { + t.Errorf("Error from emsaPSSEncode: %s\n", err) + } + if !bytes.Equal(encoded, expected) { + t.Errorf("Bad encoding. got %x, want %x", encoded, expected) + } + + if err = emsaPSSVerify(hashed, encoded, 1023, len(salt), sha1.New()); err != nil { + t.Errorf("Bad verification: %s", err) + } +} + +// TestPSSGolden tests all the test vectors in pss-vect.txt from +// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip +func TestPSSGolden(t *testing.T) { + inFile, err := os.Open("testdata/pss-vect.txt.bz2") + if err != nil { + t.Fatalf("Failed to open input file: %s", err) + } + defer inFile.Close() + + // The pss-vect.txt file contains RSA keys and then a series of + // signatures. A goroutine is used to preprocess the input by merging + // lines, removing spaces in hex values and identifying the start of + // new keys and signature blocks. + const newKeyMarker = "START NEW KEY" + const newSignatureMarker = "START NEW SIGNATURE" + + values := make(chan string) + + go func() { + defer close(values) + scanner := bufio.NewScanner(bzip2.NewReader(inFile)) + var partialValue string + lastWasValue := true + + for scanner.Scan() { + line := scanner.Text() + switch { + case len(line) == 0: + if len(partialValue) > 0 { + values <- strings.Replace(partialValue, " ", "", -1) + partialValue = "" + lastWasValue = true + } + continue + case strings.HasPrefix(line, "# ======") && lastWasValue: + values <- newKeyMarker + lastWasValue = false + case strings.HasPrefix(line, "# ------") && lastWasValue: + values <- newSignatureMarker + lastWasValue = false + case strings.HasPrefix(line, "#"): + continue + default: + partialValue += line + } + } + if err := scanner.Err(); err != nil { + panic(err) + } + }() + + var key *PublicKey + var hashed []byte + hash := crypto.SHA1 + h := hash.New() + opts := &PSSOptions{ + SaltLength: PSSSaltLengthEqualsHash, + } + + for marker := range values { + switch marker { + case newKeyMarker: + key = new(PublicKey) + nHex, ok := <-values + if !ok { + continue + } + key.N = bigFromHex(nHex) + key.E = intFromHex(<-values) + // We don't care for d, p, q, dP, dQ or qInv. + for i := 0; i < 6; i++ { + <-values + } + case newSignatureMarker: + msg := fromHex(<-values) + <-values // skip salt + sig := fromHex(<-values) + + h.Reset() + h.Write(msg) + hashed = h.Sum(hashed[:0]) + + if err := VerifyPSS(key, hash, hashed, sig, opts); err != nil { + t.Error(err) + } + default: + t.Fatalf("unknown marker: " + marker) + } + } +} + +// TestPSSOpenSSL ensures that we can verify a PSS signature from OpenSSL with +// the default options. OpenSSL sets the salt length to be maximal. +func TestPSSOpenSSL(t *testing.T) { + hash := crypto.SHA256 + h := hash.New() + h.Write([]byte("testing")) + hashed := h.Sum(nil) + + // Generated with `echo -n testing | openssl dgst -sign key.pem -sigopt rsa_padding_mode:pss -sha256 > sig` + sig := []byte{ + 0x95, 0x59, 0x6f, 0xd3, 0x10, 0xa2, 0xe7, 0xa2, 0x92, 0x9d, + 0x4a, 0x07, 0x2e, 0x2b, 0x27, 0xcc, 0x06, 0xc2, 0x87, 0x2c, + 0x52, 0xf0, 0x4a, 0xcc, 0x05, 0x94, 0xf2, 0xc3, 0x2e, 0x20, + 0xd7, 0x3e, 0x66, 0x62, 0xb5, 0x95, 0x2b, 0xa3, 0x93, 0x9a, + 0x66, 0x64, 0x25, 0xe0, 0x74, 0x66, 0x8c, 0x3e, 0x92, 0xeb, + 0xc6, 0xe6, 0xc0, 0x44, 0xf3, 0xb4, 0xb4, 0x2e, 0x8c, 0x66, + 0x0a, 0x37, 0x9c, 0x69, + } + + if err := VerifyPSS(&rsaPrivateKey.PublicKey, hash, hashed, sig, nil); err != nil { + t.Error(err) + } +} + +func TestPSSSigning(t *testing.T) { + var saltLengthCombinations = []struct { + signSaltLength, verifySaltLength int + good bool + }{ + {PSSSaltLengthAuto, PSSSaltLengthAuto, true}, + {PSSSaltLengthEqualsHash, PSSSaltLengthAuto, true}, + {PSSSaltLengthEqualsHash, PSSSaltLengthEqualsHash, true}, + {PSSSaltLengthEqualsHash, 8, false}, + {PSSSaltLengthAuto, PSSSaltLengthEqualsHash, false}, + {8, 8, true}, + } + + hash := crypto.MD5 + h := hash.New() + h.Write([]byte("testing")) + hashed := h.Sum(nil) + var opts PSSOptions + + for i, test := range saltLengthCombinations { + opts.SaltLength = test.signSaltLength + sig, err := SignPSS(rand.Reader, rsaPrivateKey, hash, hashed, &opts) + if err != nil { + t.Errorf("#%d: error while signing: %s", i, err) + continue + } + + opts.SaltLength = test.verifySaltLength + err = VerifyPSS(&rsaPrivateKey.PublicKey, hash, hashed, sig, &opts) + if (err == nil) != test.good { + t.Errorf("#%d: bad result, wanted: %t, got: %s", i, test.good, err) + } + } +} + +func bigFromHex(hex string) *big.Int { + n, ok := new(big.Int).SetString(hex, 16) + if !ok { + panic("bad hex: " + hex) + } + return n +} + +func intFromHex(hex string) int { + i, err := strconv.ParseInt(hex, 16, 32) + if err != nil { + panic(err) + } + return int(i) +} + +func fromHex(hexStr string) []byte { + s, err := hex.DecodeString(hexStr) + if err != nil { + panic(err) + } + return s +} diff --git a/src/pkg/crypto/rsa/rsa_test.go b/src/pkg/crypto/rsa/rsa_test.go index ffd96e62f6..cf193c669f 100644 --- a/src/pkg/crypto/rsa/rsa_test.go +++ b/src/pkg/crypto/rsa/rsa_test.go @@ -120,8 +120,10 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { } func fromBase10(base10 string) *big.Int { - i := new(big.Int) - i.SetString(base10, 10) + i, ok := new(big.Int).SetString(base10, 10) + if !ok { + panic("bad number: " + base10) + } return i } diff --git a/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 b/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 new file mode 100644 index 0000000000..ad3da1ac4e Binary files /dev/null and b/src/pkg/crypto/rsa/testdata/pss-vect.txt.bz2 differ