]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: add (*Int).FillBytes
authorFilippo Valsorda <filippo@golang.org>
Tue, 28 Apr 2020 01:52:38 +0000 (21:52 -0400)
committerFilippo Valsorda <filippo@golang.org>
Tue, 5 May 2020 00:36:44 +0000 (00:36 +0000)
Replaced almost every use of Bytes with FillBytes.

Note that the approved proposal was for

    func (*Int) FillBytes(buf []byte)

while this implements

    func (*Int) FillBytes(buf []byte) []byte

because the latter was far nicer to use in all callsites.

Fixes #35833

Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a
Reviewed-on: https://go-review.googlesource.com/c/go/+/230397
Reviewed-by: Robert Griesemer <gri@golang.org>
src/crypto/elliptic/elliptic.go
src/crypto/rsa/pkcs1v15.go
src/crypto/rsa/pss.go
src/crypto/rsa/rsa.go
src/crypto/tls/key_schedule.go
src/crypto/x509/sec1.go
src/math/big/int.go
src/math/big/int_test.go
src/math/big/nat.go

index e2f71cdb63babf2fcd72bacff4acafdcd5264b60..bd5168c5fd842cf44b5db9c5498198a9760999ad 100644 (file)
@@ -277,7 +277,7 @@ var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
 func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
        N := curve.Params().N
        bitSize := N.BitLen()
-       byteLen := (bitSize + 7) >> 3
+       byteLen := (bitSize + 7) / 8
        priv = make([]byte, byteLen)
 
        for x == nil {
@@ -304,15 +304,14 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e
 
 // Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
 func Marshal(curve Curve, x, y *big.Int) []byte {
-       byteLen := (curve.Params().BitSize + 7) >> 3
+       byteLen := (curve.Params().BitSize + 7) / 8
 
        ret := make([]byte, 1+2*byteLen)
        ret[0] = 4 // uncompressed point
 
-       xBytes := x.Bytes()
-       copy(ret[1+byteLen-len(xBytes):], xBytes)
-       yBytes := y.Bytes()
-       copy(ret[1+2*byteLen-len(yBytes):], yBytes)
+       x.FillBytes(ret[1 : 1+byteLen])
+       y.FillBytes(ret[1+byteLen : 1+2*byteLen])
+
        return ret
 }
 
@@ -320,7 +319,7 @@ func Marshal(curve Curve, x, y *big.Int) []byte {
 // It is an error if the point is not in uncompressed form or is not on the curve.
 // On error, x = nil.
 func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
-       byteLen := (curve.Params().BitSize + 7) >> 3
+       byteLen := (curve.Params().BitSize + 7) / 8
        if len(data) != 1+2*byteLen {
                return
        }
index 499242ffc5b574196f51eb273925097c7270ec5f..3208119ae1ff4707111950cc77673e2f4813a4c8 100644 (file)
@@ -61,8 +61,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) ([]byte, error)
        m := new(big.Int).SetBytes(em)
        c := encrypt(new(big.Int), pub, m)
 
-       copyWithLeftPad(em, c.Bytes())
-       return em, nil
+       return c.FillBytes(em), nil
 }
 
 // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
@@ -150,7 +149,7 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid
                return
        }
 
-       em = leftPad(m.Bytes(), k)
+       em = m.FillBytes(make([]byte, k))
        firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
        secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
 
@@ -256,8 +255,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
                return nil, err
        }
 
-       copyWithLeftPad(em, c.Bytes())
-       return em, nil
+       return c.FillBytes(em), nil
 }
 
 // VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
@@ -286,7 +284,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
 
        c := new(big.Int).SetBytes(sig)
        m := encrypt(new(big.Int), pub, c)
-       em := leftPad(m.Bytes(), k)
+       em := m.FillBytes(make([]byte, k))
        // EM = 0x00 || 0x01 || PS || 0x00 || T
 
        ok := subtle.ConstantTimeByteEq(em[0], 0)
@@ -323,13 +321,3 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte,
        }
        return
 }
-
-// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
-// needed.
-func copyWithLeftPad(dest, src []byte) {
-       numPaddingBytes := len(dest) - len(src)
-       for i := 0; i < numPaddingBytes; i++ {
-               dest[i] = 0
-       }
-       copy(dest[numPaddingBytes:], src)
-}
index f9844d87329a843e53784fd032017ac0e3ad5956..b2adbedb28fa859aa9ee0a23ac5313ca53370e02 100644 (file)
@@ -207,20 +207,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
 // Note that hashed must be the result of hashing the input message using the
 // given hash function. salt is a random sequence of bytes whose length will be
 // later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
        emBits := priv.N.BitLen() - 1
        em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
        if err != nil {
-               return
+               return nil, err
        }
        m := new(big.Int).SetBytes(em)
        c, err := decryptAndCheck(rand, priv, m)
        if err != nil {
-               return
+               return nil, err
        }
-       s = make([]byte, priv.Size())
-       copyWithLeftPad(s, c.Bytes())
-       return
+       s := make([]byte, priv.Size())
+       return c.FillBytes(s), nil
 }
 
 const (
@@ -296,11 +295,9 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
        m := encrypt(new(big.Int), pub, s)
        emBits := pub.N.BitLen() - 1
        emLen := (emBits + 7) / 8
-       emBytes := m.Bytes()
-       if emLen < len(emBytes) {
+       if m.BitLen() > emLen*8 {
                return ErrVerification
        }
-       em := make([]byte, emLen)
-       copyWithLeftPad(em, emBytes)
+       em := m.FillBytes(make([]byte, emLen))
        return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
 }
index b4bfa13defbdf04136d10e640dd5b336142272b7..28eb5926c1a54773fcb4dcf15805b24be1c44db7 100644 (file)
@@ -416,16 +416,9 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
        m := new(big.Int)
        m.SetBytes(em)
        c := encrypt(new(big.Int), pub, m)
-       out := c.Bytes()
 
-       if len(out) < k {
-               // If the output is too small, we need to left-pad with zeros.
-               t := make([]byte, k)
-               copy(t[k-len(out):], out)
-               out = t
-       }
-
-       return out, nil
+       out := make([]byte, k)
+       return c.FillBytes(out), nil
 }
 
 // ErrDecryption represents a failure to decrypt a message.
@@ -597,12 +590,9 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
        lHash := hash.Sum(nil)
        hash.Reset()
 
-       // Converting the plaintext number to bytes will strip any
-       // leading zeros so we may have to left pad. We do this unconditionally
-       // to avoid leaking timing information. (Although we still probably
-       // leak the number of leading zeros. It's not clear that we can do
-       // anything about this.)
-       em := leftPad(m.Bytes(), k)
+       // We probably leak the number of leading zeros.
+       // It's not clear that we can do anything about this.
+       em := m.FillBytes(make([]byte, k))
 
        firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
 
@@ -643,15 +633,3 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
 
        return rest[index+1:], nil
 }
-
-// leftPad returns a new slice of length size. The contents of input are right
-// aligned in the new slice.
-func leftPad(input []byte, size int) (out []byte) {
-       n := len(input)
-       if n > size {
-               n = size
-       }
-       out = make([]byte, size)
-       copy(out[len(out)-n:], input)
-       return
-}
index 2aab323202f7d3dd9ce07a267447f5d3e7c10342..314016979afb82a0635d13ecd73f472ffbcd3654 100644 (file)
@@ -173,11 +173,8 @@ func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
        }
 
        xShared, _ := curve.ScalarMult(x, y, p.privateKey)
-       sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
-       xBytes := xShared.Bytes()
-       copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)
-
-       return sharedKey
+       sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
+       return xShared.FillBytes(sharedKey)
 }
 
 type x25519Parameters struct {
index 0bfb90cd5464aa7dc44c3761a7fe966a8e1b3cb0..52c108ff1d6240b15c5db5bece3d2c99f392c3b3 100644 (file)
@@ -52,13 +52,10 @@ func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
 // marshalECPrivateKey marshals an EC private key into ASN.1, DER format and
 // sets the curve ID to the given OID, or omits it if OID is nil.
 func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
-       privateKeyBytes := key.D.Bytes()
-       paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
-       copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes)
-
+       privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
        return asn1.Marshal(ecPrivateKey{
                Version:       1,
-               PrivateKey:    paddedPrivateKey,
+               PrivateKey:    key.D.FillBytes(privateKey),
                NamedCurveOID: oid,
                PublicKey:     asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
        })
index 8816cf5266cc4f9eab65efb89ef1ed8f01d24d5c..65f32487b58c017c8aa224c7e805cf0d2fcd20b0 100644 (file)
@@ -447,11 +447,26 @@ func (z *Int) SetBytes(buf []byte) *Int {
 }
 
 // Bytes returns the absolute value of x as a big-endian byte slice.
+//
+// To use a fixed length slice, or a preallocated one, use FillBytes.
 func (x *Int) Bytes() []byte {
        buf := make([]byte, len(x.abs)*_S)
        return buf[x.abs.bytes(buf):]
 }
 
+// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
+// big-endian byte slice, and returns buf.
+//
+// If the absolute value of x doesn't fit in buf, FillBytes will panic.
+func (x *Int) FillBytes(buf []byte) []byte {
+       // Clear whole buffer. (This gets optimized into a memclr.)
+       for i := range buf {
+               buf[i] = 0
+       }
+       x.abs.bytes(buf)
+       return buf
+}
+
 // BitLen returns the length of the absolute value of x in bits.
 // The bit length of 0 is 0.
 func (x *Int) BitLen() int {
index e3a1587b3f0ad7acb0cb56091c9fa69902c89eb8..3c8557323a03245424a837b56d4480a2531e9193 100644 (file)
@@ -1840,3 +1840,57 @@ func BenchmarkDiv(b *testing.B) {
                })
        }
 }
+
+func TestFillBytes(t *testing.T) {
+       checkResult := func(t *testing.T, buf []byte, want *Int) {
+               t.Helper()
+               got := new(Int).SetBytes(buf)
+               if got.CmpAbs(want) != 0 {
+                       t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf)
+               }
+       }
+       panics := func(f func()) (panic bool) {
+               defer func() { panic = recover() != nil }()
+               f()
+               return
+       }
+
+       for _, n := range []string{
+               "0",
+               "1000",
+               "0xffffffff",
+               "-0xffffffff",
+               "0xffffffffffffffff",
+               "0x10000000000000000",
+               "0xabababababababababababababababababababababababababa",
+               "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+       } {
+               t.Run(n, func(t *testing.T) {
+                       t.Logf(n)
+                       x, ok := new(Int).SetString(n, 0)
+                       if !ok {
+                               panic("invalid test entry")
+                       }
+
+                       // Perfectly sized buffer.
+                       byteLen := (x.BitLen() + 7) / 8
+                       buf := make([]byte, byteLen)
+                       checkResult(t, x.FillBytes(buf), x)
+
+                       // Way larger, checking all bytes get zeroed.
+                       buf = make([]byte, 100)
+                       for i := range buf {
+                               buf[i] = 0xff
+                       }
+                       checkResult(t, x.FillBytes(buf), x)
+
+                       // Too small.
+                       if byteLen > 0 {
+                               buf = make([]byte, byteLen-1)
+                               if !panics(func() { x.FillBytes(buf) }) {
+                                       t.Errorf("expected panic for small buffer and value %x", x)
+                               }
+                       }
+               })
+       }
+}
index c31ec5156b81d8830b189245be2b399c6f8d6e2b..6a3989bf9d82bf1c9ec3b57aedead466002bfb8e 100644 (file)
@@ -1476,19 +1476,26 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
 }
 
 // bytes writes the value of z into buf using big-endian encoding.
-// len(buf) must be >= len(z)*_S. The value of z is encoded in the
-// slice buf[i:]. The number i of unused bytes at the beginning of
-// buf is returned as result.
+// The value of z is encoded in the slice buf[i:]. If the value of z
+// cannot be represented in buf, bytes panics. The number i of unused
+// bytes at the beginning of buf is returned as result.
 func (z nat) bytes(buf []byte) (i int) {
        i = len(buf)
        for _, d := range z {
                for j := 0; j < _S; j++ {
                        i--
-                       buf[i] = byte(d)
+                       if i >= 0 {
+                               buf[i] = byte(d)
+                       } else if byte(d) != 0 {
+                               panic("math/big: buffer too small to fit value")
+                       }
                        d >>= 8
                }
        }
 
+       if i < 0 {
+               i = 0
+       }
        for i < len(buf) && buf[i] == 0 {
                i++
        }