if max.Sign() <= 0 {
panic("crypto/rand: argument to Int is <= 0")
}
- k := (max.BitLen() + 7) / 8
-
- // b is the number of bits in the most significant byte of max.
- b := uint(max.BitLen() % 8)
+ n = new(big.Int)
+ n.Sub(max, n.SetUint64(1))
+ // bitLen is the maximum bit length needed to encode a value < max.
+ bitLen := n.BitLen()
+ if bitLen == 0 {
+ // the only valid result is 0
+ return
+ }
+ // k is the maximum byte length needed to encode a value < max.
+ k := (bitLen + 7) / 8
+ // b is the number of bits in the most significant byte of max-1.
+ b := uint(bitLen % 8)
if b == 0 {
b = 8
}
bytes := make([]byte, k)
- n = new(big.Int)
for {
_, err = io.ReadFull(rand, bytes)
package rand_test
import (
+ "bytes"
"crypto/rand"
+ "fmt"
+ "io"
"math/big"
mathrand "math/rand"
"testing"
}
}
+type countingReader struct {
+ r io.Reader
+ n int
+}
+
+func (r *countingReader) Read(p []byte) (n int, err error) {
+ n, err = r.r.Read(p)
+ r.n += n
+ return n, err
+}
+
+// Test that Int reads only the necessary number of bytes from the reader for
+// max at each bit length
+func TestIntReads(t *testing.T) {
+ for i := 0; i < 32; i++ {
+ max := int64(1 << uint64(i))
+ t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
+ reader := &countingReader{r: rand.Reader}
+
+ _, err := rand.Int(reader, big.NewInt(max))
+ if err != nil {
+ t.Fatalf("Can't generate random value: %d, %v", max, err)
+ }
+ expected := (i + 7) / 8
+ if reader.n != expected {
+ t.Errorf("Int(reader, %d) should read %d bytes, but it read: %d", max, expected, reader.n)
+ }
+ })
+ }
+}
+
+// Test that Int does not mask out valid return values
+func TestIntMask(t *testing.T) {
+ for max := 1; max <= 256; max++ {
+ t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
+ for i := 0; i < max; i++ {
+ var b bytes.Buffer
+ b.WriteByte(byte(i))
+ n, err := rand.Int(&b, big.NewInt(int64(max)))
+ if err != nil {
+ t.Fatalf("Can't generate random value: %d, %v", max, err)
+ }
+ if n.Int64() != int64(i) {
+ t.Errorf("Int(reader, %d) should have returned value of %d, but it returned: %v", max, i, n)
+ }
+ }
+ })
+ }
+}
+
func testIntPanics(t *testing.T, b *big.Int) {
defer func() {
if err := recover(); err == nil {