Replaced almost every use of Bytes with FillBytes.
Note that the approved proposal was for
func (*Int) FillBytes(buf []byte)
while this implements
func (*Int) FillBytes(buf []byte) []byte
because the latter was far nicer to use in all callsites.
Fixes #35833
Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a
Reviewed-on: https://go-review.googlesource.com/c/go/+/230397
Reviewed-by: Robert Griesemer <gri@golang.org>
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
- byteLen := (bitSize + 7) >> 3
+ byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
func Marshal(curve Curve, x, y *big.Int) []byte {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
- xBytes := x.Bytes()
- copy(ret[1+byteLen-len(xBytes):], xBytes)
- yBytes := y.Bytes()
- copy(ret[1+2*byteLen-len(yBytes):], yBytes)
+ x.FillBytes(ret[1 : 1+byteLen])
+ y.FillBytes(ret[1+byteLen : 1+2*byteLen])
+
return ret
}
// It is an error if the point is not in uncompressed form or is not on the curve.
// On error, x = nil.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return
}
m := new(big.Int).SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
return
}
- em = leftPad(m.Bytes(), k)
+ em = m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
return nil, err
}
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
// VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
c := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, c)
- em := leftPad(m.Bytes(), k)
+ em := m.FillBytes(make([]byte, k))
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
}
return
}
-
-// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
-// needed.
-func copyWithLeftPad(dest, src []byte) {
- numPaddingBytes := len(dest) - len(src)
- for i := 0; i < numPaddingBytes; i++ {
- dest[i] = 0
- }
- copy(dest[numPaddingBytes:], src)
-}
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
- return
+ return nil, err
}
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
- return
+ return nil, err
}
- s = make([]byte, priv.Size())
- copyWithLeftPad(s, c.Bytes())
- return
+ s := make([]byte, priv.Size())
+ return c.FillBytes(s), nil
}
const (
m := encrypt(new(big.Int), pub, s)
emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
- emBytes := m.Bytes()
- if emLen < len(emBytes) {
+ if m.BitLen() > emLen*8 {
return ErrVerification
}
- em := make([]byte, emLen)
- copyWithLeftPad(em, emBytes)
+ em := m.FillBytes(make([]byte, emLen))
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
m := new(big.Int)
m.SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- out := c.Bytes()
- if len(out) < k {
- // If the output is too small, we need to left-pad with zeros.
- t := make([]byte, k)
- copy(t[k-len(out):], out)
- out = t
- }
-
- return out, nil
+ out := make([]byte, k)
+ return c.FillBytes(out), nil
}
// ErrDecryption represents a failure to decrypt a message.
lHash := hash.Sum(nil)
hash.Reset()
- // Converting the plaintext number to bytes will strip any
- // leading zeros so we may have to left pad. We do this unconditionally
- // to avoid leaking timing information. (Although we still probably
- // leak the number of leading zeros. It's not clear that we can do
- // anything about this.)
- em := leftPad(m.Bytes(), k)
+ // We probably leak the number of leading zeros.
+ // It's not clear that we can do anything about this.
+ em := m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
return rest[index+1:], nil
}
-
-// leftPad returns a new slice of length size. The contents of input are right
-// aligned in the new slice.
-func leftPad(input []byte, size int) (out []byte) {
- n := len(input)
- if n > size {
- n = size
- }
- out = make([]byte, size)
- copy(out[len(out)-n:], input)
- return
-}
}
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
- sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
- xBytes := xShared.Bytes()
- copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)
-
- return sharedKey
+ sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
+ return xShared.FillBytes(sharedKey)
}
type x25519Parameters struct {
// marshalECPrivateKey marshals an EC private key into ASN.1, DER format and
// sets the curve ID to the given OID, or omits it if OID is nil.
func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
- privateKeyBytes := key.D.Bytes()
- paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
- copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes)
-
+ privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
return asn1.Marshal(ecPrivateKey{
Version: 1,
- PrivateKey: paddedPrivateKey,
+ PrivateKey: key.D.FillBytes(privateKey),
NamedCurveOID: oid,
PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
})
}
// Bytes returns the absolute value of x as a big-endian byte slice.
+//
+// To use a fixed length slice, or a preallocated one, use FillBytes.
func (x *Int) Bytes() []byte {
buf := make([]byte, len(x.abs)*_S)
return buf[x.abs.bytes(buf):]
}
+// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
+// big-endian byte slice, and returns buf.
+//
+// If the absolute value of x doesn't fit in buf, FillBytes will panic.
+func (x *Int) FillBytes(buf []byte) []byte {
+ // Clear whole buffer. (This gets optimized into a memclr.)
+ for i := range buf {
+ buf[i] = 0
+ }
+ x.abs.bytes(buf)
+ return buf
+}
+
// BitLen returns the length of the absolute value of x in bits.
// The bit length of 0 is 0.
func (x *Int) BitLen() int {
})
}
}
+
+func TestFillBytes(t *testing.T) {
+ checkResult := func(t *testing.T, buf []byte, want *Int) {
+ t.Helper()
+ got := new(Int).SetBytes(buf)
+ if got.CmpAbs(want) != 0 {
+ t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf)
+ }
+ }
+ panics := func(f func()) (panic bool) {
+ defer func() { panic = recover() != nil }()
+ f()
+ return
+ }
+
+ for _, n := range []string{
+ "0",
+ "1000",
+ "0xffffffff",
+ "-0xffffffff",
+ "0xffffffffffffffff",
+ "0x10000000000000000",
+ "0xabababababababababababababababababababababababababa",
+ "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ } {
+ t.Run(n, func(t *testing.T) {
+ t.Logf(n)
+ x, ok := new(Int).SetString(n, 0)
+ if !ok {
+ panic("invalid test entry")
+ }
+
+ // Perfectly sized buffer.
+ byteLen := (x.BitLen() + 7) / 8
+ buf := make([]byte, byteLen)
+ checkResult(t, x.FillBytes(buf), x)
+
+ // Way larger, checking all bytes get zeroed.
+ buf = make([]byte, 100)
+ for i := range buf {
+ buf[i] = 0xff
+ }
+ checkResult(t, x.FillBytes(buf), x)
+
+ // Too small.
+ if byteLen > 0 {
+ buf = make([]byte, byteLen-1)
+ if !panics(func() { x.FillBytes(buf) }) {
+ t.Errorf("expected panic for small buffer and value %x", x)
+ }
+ }
+ })
+ }
+}
}
// bytes writes the value of z into buf using big-endian encoding.
-// len(buf) must be >= len(z)*_S. The value of z is encoded in the
-// slice buf[i:]. The number i of unused bytes at the beginning of
-// buf is returned as result.
+// The value of z is encoded in the slice buf[i:]. If the value of z
+// cannot be represented in buf, bytes panics. The number i of unused
+// bytes at the beginning of buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
- buf[i] = byte(d)
+ if i >= 0 {
+ buf[i] = byte(d)
+ } else if byte(d) != 0 {
+ panic("math/big: buffer too small to fit value")
+ }
d >>= 8
}
}
+ if i < 0 {
+ i = 0
+ }
for i < len(buf) && buf[i] == 0 {
i++
}