"crypto/sha256"
"crypto/x509"
"encoding/pem"
+ "flag"
"fmt"
"math/big"
"strings"
}
}
+var allFlag = flag.Bool("all", false, "test all key sizes up to 2048")
+
+func TestEverything(t *testing.T) {
+ min := 32
+ max := 560 // any smaller than this and not all tests will run
+ if testing.Short() {
+ min = max
+ }
+ if *allFlag {
+ max = 2048
+ }
+ for size := min; size <= max; size++ {
+ t.Run(fmt.Sprintf("%d", size), func(t *testing.T) {
+ t.Parallel()
+ priv, err := GenerateKey(rand.Reader, size)
+ if err != nil {
+ t.Errorf("GenerateKey(%d): %v", size, err)
+ }
+ if bits := priv.N.BitLen(); bits != size {
+ t.Errorf("key too short (%d vs %d)", bits, size)
+ }
+ testEverything(t, priv)
+ })
+ }
+}
+
+func testEverything(t *testing.T, priv *PrivateKey) {
+ if err := priv.Validate(); err != nil {
+ t.Errorf("Validate() failed: %s", err)
+ }
+
+ msg := []byte("test")
+ enc, err := EncryptPKCS1v15(rand.Reader, &priv.PublicKey, msg)
+ if err == ErrMessageTooLong {
+ t.Log("key too small for EncryptPKCS1v15")
+ } else if err != nil {
+ t.Errorf("EncryptPKCS1v15: %v", err)
+ }
+ if err == nil {
+ dec, err := DecryptPKCS1v15(nil, priv, enc)
+ if err != nil {
+ t.Errorf("DecryptPKCS1v15: %v", err)
+ }
+ err = DecryptPKCS1v15SessionKey(nil, priv, enc, make([]byte, 4))
+ if err != nil {
+ t.Errorf("DecryptPKCS1v15SessionKey: %v", err)
+ }
+ if !bytes.Equal(dec, msg) {
+ t.Errorf("got:%x want:%x (%+v)", dec, msg, priv)
+ }
+ }
+
+ label := []byte("label")
+ enc, err = EncryptOAEP(sha256.New(), rand.Reader, &priv.PublicKey, msg, label)
+ if err == ErrMessageTooLong {
+ t.Log("key too small for EncryptOAEP")
+ } else if err != nil {
+ t.Errorf("EncryptOAEP: %v", err)
+ }
+ if err == nil {
+ dec, err := DecryptOAEP(sha256.New(), nil, priv, enc, label)
+ if err != nil {
+ t.Errorf("DecryptOAEP: %v", err)
+ }
+ if !bytes.Equal(dec, msg) {
+ t.Errorf("got:%x want:%x (%+v)", dec, msg, priv)
+ }
+ }
+
+ hash := sha256.Sum256(msg)
+ sig, err := SignPKCS1v15(nil, priv, crypto.SHA256, hash[:])
+ if err == ErrMessageTooLong {
+ t.Log("key too small for SignPKCS1v15")
+ } else if err != nil {
+ t.Errorf("SignPKCS1v15: %v", err)
+ }
+ if err == nil {
+ err = VerifyPKCS1v15(&priv.PublicKey, crypto.SHA256, hash[:], sig)
+ if err != nil {
+ t.Errorf("VerifyPKCS1v15: %v", err)
+ }
+ sig[1] ^= 0x80
+ err = VerifyPKCS1v15(&priv.PublicKey, crypto.SHA256, hash[:], sig)
+ if err == nil {
+ t.Errorf("VerifyPKCS1v15 success for tampered signature")
+ }
+ sig[1] ^= 0x80
+ hash[1] ^= 0x80
+ err = VerifyPKCS1v15(&priv.PublicKey, crypto.SHA256, hash[:], sig)
+ if err == nil {
+ t.Errorf("VerifyPKCS1v15 success for tampered message")
+ }
+ hash[1] ^= 0x80
+ }
+
+ opts := &PSSOptions{SaltLength: PSSSaltLengthAuto}
+ sig, err = SignPSS(rand.Reader, priv, crypto.SHA256, hash[:], opts)
+ if err == ErrMessageTooLong {
+ t.Log("key too small for SignPSS with PSSSaltLengthAuto")
+ } else if err != nil {
+ t.Errorf("SignPSS: %v", err)
+ }
+ if err == nil {
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err != nil {
+ t.Errorf("VerifyPSS: %v", err)
+ }
+ sig[1] ^= 0x80
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS success for tampered signature")
+ }
+ sig[1] ^= 0x80
+ hash[1] ^= 0x80
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS success for tampered message")
+ }
+ hash[1] ^= 0x80
+ }
+
+ opts.SaltLength = PSSSaltLengthEqualsHash
+ sig, err = SignPSS(rand.Reader, priv, crypto.SHA256, hash[:], opts)
+ if err == ErrMessageTooLong {
+ t.Log("key too small for SignPSS with PSSSaltLengthEqualsHash")
+ } else if err != nil {
+ t.Errorf("SignPSS: %v", err)
+ }
+ if err == nil {
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err != nil {
+ t.Errorf("VerifyPSS: %v", err)
+ }
+ sig[1] ^= 0x80
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS success for tampered signature")
+ }
+ sig[1] ^= 0x80
+ hash[1] ^= 0x80
+ err = VerifyPSS(&priv.PublicKey, crypto.SHA256, hash[:], sig, opts)
+ if err == nil {
+ t.Errorf("VerifyPSS success for tampered message")
+ }
+ hash[1] ^= 0x80
+ }
+}
+
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
func parseKey(s string) *PrivateKey {