]> Cypherpunks repositories - gostls13.git/commitdiff
math/rand: add Shuffle
authorJosh Bleecher Snyder <josharian@gmail.com>
Wed, 26 Jul 2017 00:53:30 +0000 (17:53 -0700)
committerIan Lance Taylor <iant@golang.org>
Fri, 8 Sep 2017 13:03:02 +0000 (13:03 +0000)
Shuffle uses the Fisher-Yates algorithm.

Since this is new API, it affords us the opportunity
to use a much faster Int31n implementation that mostly avoids division.
As a result, BenchmarkPerm30ViaShuffle is
about 30% faster than BenchmarkPerm30,
despite requiring a separate initialization loop
and using function calls to swap elements.

Fixes #20480
Updates #16213
Updates #21211

Change-Id: Ib8956c4bebed9d84f193eb98282ec16ee7c2b2d5
Reviewed-on: https://go-review.googlesource.com/51891
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/math/rand/example_test.go
src/math/rand/rand.go
src/math/rand/rand_test.go

index 614eeaed515e8ecacafecb86e75be0ba9dd23712..aa1f2bcc7341ee0258efaacc9406fd648501a9fe 100644 (file)
@@ -8,6 +8,7 @@ import (
        "fmt"
        "math/rand"
        "os"
+       "strings"
        "text/tabwriter"
 )
 
@@ -105,3 +106,34 @@ func ExamplePerm() {
        // 2
        // 0
 }
+
+func ExampleShuffle() {
+       words := strings.Fields("ink runs from the corners of my mouth")
+       rand.Shuffle(len(words), func(i, j int) {
+               words[i], words[j] = words[j], words[i]
+       })
+       fmt.Println(words)
+
+       // Output:
+       // [mouth my the of runs corners from ink]
+}
+
+func ExampleShuffle_slicesInUnison() {
+       numbers := []byte("12345")
+       letters := []byte("ABCDE")
+       // Shuffle numbers, swapping corresponding entries in letters at the same time.
+       rand.Shuffle(len(numbers), func(i, j int) {
+               numbers[i], numbers[j] = numbers[j], numbers[i]
+               letters[i], letters[j] = letters[j], letters[i]
+       })
+       for i := range numbers {
+               fmt.Printf("%c: %c\n", letters[i], numbers[i])
+       }
+
+       // Output:
+       // C: 3
+       // D: 4
+       // A: 1
+       // E: 5
+       // B: 2
+}
index fe99c948ac2f604e9234f360a55efafc0fcc9686..a607409a1649dc18ce1ceafc14e3821926668cd7 100644 (file)
@@ -135,6 +135,30 @@ func (r *Rand) Int31n(n int32) int32 {
        return v % n
 }
 
+// int31n returns, as an int32, a non-negative pseudo-random number in [0,n).
+// n must be > 0, but int31n does not check this; the caller must ensure it.
+// int31n exists because Int31n is inefficient, but Go 1 compatibility
+// requires that the stream of values produced by math/rand remain unchanged.
+// int31n can thus only be used internally, by newly introduced APIs.
+//
+// For implementation details, see:
+// http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction
+// http://lemire.me/blog/2016/06/30/fast-random-shuffling
+func (r *Rand) int31n(n int32) int32 {
+       v := r.Uint32()
+       prod := uint64(v) * uint64(n)
+       low := uint32(prod)
+       if low < uint32(n) {
+               thresh := uint32(-n) % uint32(n)
+               for low < thresh {
+                       v = r.Uint32()
+                       prod = uint64(v) * uint64(n)
+                       low = uint32(prod)
+               }
+       }
+       return int32(prod >> 32)
+}
+
 // Intn returns, as an int, a non-negative pseudo-random number in [0,n).
 // It panics if n <= 0.
 func (r *Rand) Intn(n int) int {
@@ -202,6 +226,31 @@ func (r *Rand) Perm(n int) []int {
        return m
 }
 
+// Shuffle pseudo-randomizes the order of elements.
+// n is the number of elements. Shuffle panics if n < 0.
+// swap swaps the elements with indexes i and j.
+func (r *Rand) Shuffle(n int, swap func(i, j int)) {
+       if n < 0 {
+               panic("invalid argument to Shuffle")
+       }
+
+       // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
+       // Shuffle really ought not be called with n that doesn't fit in 32 bits.
+       // Not only will it take a very long time, but with 2³¹! possible permutations,
+       // there's no way that any PRNG can have a big enough internal state to
+       // generate even a minuscule percentage of the possible permutations.
+       // Nevertheless, the right API signature accepts an int n, so handle it as best we can.
+       i := n - 1
+       for ; i > 1<<31-1-1; i-- {
+               j := int(r.Int63n(int64(i + 1)))
+               swap(i, j)
+       }
+       for ; i > 0; i-- {
+               j := int(r.int31n(int32(i + 1)))
+               swap(i, j)
+       }
+}
+
 // Read generates len(p) random bytes and writes them into p. It
 // always returns len(p) and a nil error.
 // Read should not be called concurrently with any other Rand method.
@@ -288,6 +337,11 @@ func Float32() float32 { return globalRand.Float32() }
 // from the default Source.
 func Perm(n int) []int { return globalRand.Perm(n) }
 
+// Shuffle pseudo-randomizes the order of elements using the default Source.
+// n is the number of elements. Shuffle panics if n <= 0.
+// swap swaps the elements with indexes i and j.
+func Shuffle(n int, swap func(i, j int)) { globalRand.Shuffle(n, swap) }
+
 // Read generates len(p) random bytes from the default Source and
 // writes them into p. It always returns len(p) and a nil error.
 // Read, unlike the Rand.Read method, is safe for concurrent use.
index da065159d9355da6f492d6270647a961fc897327..1a13accde9638ee33677761c6d15072b0059528c 100644 (file)
@@ -450,6 +450,113 @@ func TestReadSeedReset(t *testing.T) {
        }
 }
 
+func TestShuffleSmall(t *testing.T) {
+       // Check that Shuffle allows n=0 and n=1, but that swap is never called for them.
+       r := New(NewSource(1))
+       for n := 0; n <= 1; n++ {
+               r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) })
+       }
+}
+
+// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!).
+// See https://en.wikipedia.org/wiki/Lehmer_code.
+// encodePerm modifies the input slice.
+func encodePerm(s []int) int {
+       // Convert to Lehmer code.
+       for i, x := range s {
+               r := s[i+1:]
+               for j, y := range r {
+                       if y > x {
+                               r[j]--
+                       }
+               }
+       }
+       // Convert to int in [0, n!).
+       m := 0
+       fact := 1
+       for i := len(s) - 1; i >= 0; i-- {
+               m += s[i] * fact
+               fact *= len(s) - i
+       }
+       return m
+}
+
+// TestUniformFactorial tests several ways of generating a uniform value in [0, n!).
+func TestUniformFactorial(t *testing.T) {
+       r := New(NewSource(testSeeds[0]))
+       top := 6
+       if testing.Short() {
+               top = 4
+       }
+       for n := 3; n <= top; n++ {
+               t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) {
+                       // Calculate n!.
+                       nfact := 1
+                       for i := 2; i <= n; i++ {
+                               nfact *= i
+                       }
+
+                       // Test a few different ways to generate a uniform distribution.
+                       p := make([]int, n) // re-usable slice for Shuffle generator
+                       tests := [...]struct {
+                               name string
+                               fn   func() int
+                       }{
+                               {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }},
+                               {name: "int31n", fn: func() int { return int(r.int31n(int32(nfact))) }},
+                               {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }},
+                               {name: "Shuffle", fn: func() int {
+                                       // Generate permutation using Shuffle.
+                                       for i := range p {
+                                               p[i] = i
+                                       }
+                                       r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] })
+                                       return encodePerm(p)
+                               }},
+                       }
+
+                       for _, test := range tests {
+                               t.Run(test.name, func(t *testing.T) {
+                                       // Gather chi-squared values and check that they follow
+                                       // the expected normal distribution given n!-1 degrees of freedom.
+                                       // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and
+                                       // https://www.johndcook.com/Beautiful_Testing_ch10.pdf.
+                                       nsamples := 10 * nfact
+                                       if nsamples < 200 {
+                                               nsamples = 200
+                                       }
+                                       samples := make([]float64, nsamples)
+                                       for i := range samples {
+                                               // Generate some uniformly distributed values and count their occurrences.
+                                               const iters = 1000
+                                               counts := make([]int, nfact)
+                                               for i := 0; i < iters; i++ {
+                                                       counts[test.fn()]++
+                                               }
+                                               // Calculate chi-squared and add to samples.
+                                               want := iters / float64(nfact)
+                                               var χ2 float64
+                                               for _, have := range counts {
+                                                       err := float64(have) - want
+                                                       χ2 += err * err
+                                               }
+                                               χ2 /= want
+                                               samples[i] = χ2
+                                       }
+
+                                       // Check that our samples approximate the appropriate normal distribution.
+                                       dof := float64(nfact - 1)
+                                       expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)}
+                                       errorScale := max(1.0, expected.stddev)
+                                       expected.closeEnough = 0.10 * errorScale
+                                       expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211.
+                                       checkSampleDistribution(t, samples, expected)
+                               })
+                       }
+               })
+       }
+}
+
 // Benchmarks
 
 func BenchmarkInt63Threadsafe(b *testing.B) {
@@ -514,6 +621,30 @@ func BenchmarkPerm30(b *testing.B) {
        }
 }
 
+func BenchmarkPerm30ViaShuffle(b *testing.B) {
+       r := New(NewSource(1))
+       for n := b.N; n > 0; n-- {
+               p := make([]int, 30)
+               for i := range p {
+                       p[i] = i
+               }
+               r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] })
+       }
+}
+
+// BenchmarkShuffleOverhead uses a minimal swap function
+// to measure just the shuffling overhead.
+func BenchmarkShuffleOverhead(b *testing.B) {
+       r := New(NewSource(1))
+       for n := b.N; n > 0; n-- {
+               r.Shuffle(52, func(i, j int) {
+                       if i < 0 || i >= 52 || j < 0 || j >= 52 {
+                               b.Fatalf("bad swap(%d, %d)", i, j)
+                       }
+               })
+       }
+}
+
 func BenchmarkRead3(b *testing.B) {
        r := New(NewSource(1))
        buf := make([]byte, 3)