]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rand: only read necessary bytes for Int
authorWade Simmons <wade@wades.im>
Tue, 23 May 2017 00:31:17 +0000 (20:31 -0400)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 23 May 2017 21:02:14 +0000 (21:02 +0000)
We only need to read the number of bytes required to store the value
"max - 1" to generate a random number in the range [0, max).

Before, there was an off-by-one error where an extra byte was read from
the io.Reader for inputs like "256" (right at the boundary for a byte).
There was a similar off-by-one error in the logic for clearing bits and
thus for any input that was a power of 2, there was a 50% chance the
read would continue to be retried as the mask failed to remove a bit.

Fixes #18165.

Change-Id: I548c1368990e23e365591e77980e9086fafb6518
Reviewed-on: https://go-review.googlesource.com/43891
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/crypto/rand/util.go
src/crypto/rand/util_test.go

index 592c57e76386e97859580e70bb2854759e2cbef9..4dd171120341f13b5aa7517a04de453011441104 100644 (file)
@@ -107,16 +107,23 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
        if max.Sign() <= 0 {
                panic("crypto/rand: argument to Int is <= 0")
        }
-       k := (max.BitLen() + 7) / 8
-
-       // b is the number of bits in the most significant byte of max.
-       b := uint(max.BitLen() % 8)
+       n = new(big.Int)
+       n.Sub(max, n.SetUint64(1))
+       // bitLen is the maximum bit length needed to encode a value < max.
+       bitLen := n.BitLen()
+       if bitLen == 0 {
+               // the only valid result is 0
+               return
+       }
+       // k is the maximum byte length needed to encode a value < max.
+       k := (bitLen + 7) / 8
+       // b is the number of bits in the most significant byte of max-1.
+       b := uint(bitLen % 8)
        if b == 0 {
                b = 8
        }
 
        bytes := make([]byte, k)
-       n = new(big.Int)
 
        for {
                _, err = io.ReadFull(rand, bytes)
index 48a2c3fc0cb73bc2013cc575bde087e1cf01eea0..685624e1b3dfbea04abfe46291dd8f294a63b6ba 100644 (file)
@@ -5,7 +5,10 @@
 package rand_test
 
 import (
+       "bytes"
        "crypto/rand"
+       "fmt"
+       "io"
        "math/big"
        mathrand "math/rand"
        "testing"
@@ -45,6 +48,56 @@ func TestInt(t *testing.T) {
        }
 }
 
+type countingReader struct {
+       r io.Reader
+       n int
+}
+
+func (r *countingReader) Read(p []byte) (n int, err error) {
+       n, err = r.r.Read(p)
+       r.n += n
+       return n, err
+}
+
+// Test that Int reads only the necessary number of bytes from the reader for
+// max at each bit length
+func TestIntReads(t *testing.T) {
+       for i := 0; i < 32; i++ {
+               max := int64(1 << uint64(i))
+               t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
+                       reader := &countingReader{r: rand.Reader}
+
+                       _, err := rand.Int(reader, big.NewInt(max))
+                       if err != nil {
+                               t.Fatalf("Can't generate random value: %d, %v", max, err)
+                       }
+                       expected := (i + 7) / 8
+                       if reader.n != expected {
+                               t.Errorf("Int(reader, %d) should read %d bytes, but it read: %d", max, expected, reader.n)
+                       }
+               })
+       }
+}
+
+// Test that Int does not mask out valid return values
+func TestIntMask(t *testing.T) {
+       for max := 1; max <= 256; max++ {
+               t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
+                       for i := 0; i < max; i++ {
+                               var b bytes.Buffer
+                               b.WriteByte(byte(i))
+                               n, err := rand.Int(&b, big.NewInt(int64(max)))
+                               if err != nil {
+                                       t.Fatalf("Can't generate random value: %d, %v", max, err)
+                               }
+                               if n.Int64() != int64(i) {
+                                       t.Errorf("Int(reader, %d) should have returned value of %d, but it returned: %v", max, i, n)
+                               }
+                       }
+               })
+       }
+}
+
 func testIntPanics(t *testing.T, b *big.Int) {
        defer func() {
                if err := recover(); err == nil {