From 8a2553e380196dda556608e2fe79881004770eb9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 22 May 2017 20:31:17 -0400 Subject: [PATCH] crypto/rand: only read necessary bytes for Int We only need to read the number of bytes required to store the value "max - 1" to generate a random number in the range [0, max). Before, there was an off-by-one error where an extra byte was read from the io.Reader for inputs like "256" (right at the boundary for a byte). There was a similar off-by-one error in the logic for clearing bits and thus for any input that was a power of 2, there was a 50% chance the read would continue to be retried as the mask failed to remove a bit. Fixes #18165. Change-Id: I548c1368990e23e365591e77980e9086fafb6518 Reviewed-on: https://go-review.googlesource.com/43891 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/crypto/rand/util.go | 17 ++++++++---- src/crypto/rand/util_test.go | 53 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/src/crypto/rand/util.go b/src/crypto/rand/util.go index 592c57e763..4dd1711203 100644 --- a/src/crypto/rand/util.go +++ b/src/crypto/rand/util.go @@ -107,16 +107,23 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) { if max.Sign() <= 0 { panic("crypto/rand: argument to Int is <= 0") } - k := (max.BitLen() + 7) / 8 - - // b is the number of bits in the most significant byte of max. - b := uint(max.BitLen() % 8) + n = new(big.Int) + n.Sub(max, n.SetUint64(1)) + // bitLen is the maximum bit length needed to encode a value < max. + bitLen := n.BitLen() + if bitLen == 0 { + // the only valid result is 0 + return + } + // k is the maximum byte length needed to encode a value < max. + k := (bitLen + 7) / 8 + // b is the number of bits in the most significant byte of max-1. + b := uint(bitLen % 8) if b == 0 { b = 8 } bytes := make([]byte, k) - n = new(big.Int) for { _, err = io.ReadFull(rand, bytes) diff --git a/src/crypto/rand/util_test.go b/src/crypto/rand/util_test.go index 48a2c3fc0c..685624e1b3 100644 --- a/src/crypto/rand/util_test.go +++ b/src/crypto/rand/util_test.go @@ -5,7 +5,10 @@ package rand_test import ( + "bytes" "crypto/rand" + "fmt" + "io" "math/big" mathrand "math/rand" "testing" @@ -45,6 +48,56 @@ func TestInt(t *testing.T) { } } +type countingReader struct { + r io.Reader + n int +} + +func (r *countingReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + r.n += n + return n, err +} + +// Test that Int reads only the necessary number of bytes from the reader for +// max at each bit length +func TestIntReads(t *testing.T) { + for i := 0; i < 32; i++ { + max := int64(1 << uint64(i)) + t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) { + reader := &countingReader{r: rand.Reader} + + _, err := rand.Int(reader, big.NewInt(max)) + if err != nil { + t.Fatalf("Can't generate random value: %d, %v", max, err) + } + expected := (i + 7) / 8 + if reader.n != expected { + t.Errorf("Int(reader, %d) should read %d bytes, but it read: %d", max, expected, reader.n) + } + }) + } +} + +// Test that Int does not mask out valid return values +func TestIntMask(t *testing.T) { + for max := 1; max <= 256; max++ { + t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) { + for i := 0; i < max; i++ { + var b bytes.Buffer + b.WriteByte(byte(i)) + n, err := rand.Int(&b, big.NewInt(int64(max))) + if err != nil { + t.Fatalf("Can't generate random value: %d, %v", max, err) + } + if n.Int64() != int64(i) { + t.Errorf("Int(reader, %d) should have returned value of %d, but it returned: %v", max, i, n) + } + } + }) + } +} + func testIntPanics(t *testing.T, b *big.Int) { defer func() { if err := recover(); err == nil { -- 2.50.0