]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix probabilities in DNS SRV shuffleByWeight
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 17 Apr 2014 20:28:40 +0000 (13:28 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 17 Apr 2014 20:28:40 +0000 (13:28 -0700)
Patch from msolo. Just moving it to a CL.
The test fails before and passes with the fix.

Fixes #7098

LGTM=msolo, rsc
R=rsc, iant, msolo
CC=golang-codereviews
https://golang.org/cl/88900044

src/pkg/net/dnsclient.go
src/pkg/net/dnsclient_test.go [new file with mode: 0644]

index 01db4372945c230f7df3a4f6d72bb7bae353e321..9bffa11f9165df0a338cdd300fbd9c304225575a 100644 (file)
@@ -191,10 +191,10 @@ func (addrs byPriorityWeight) shuffleByWeight() {
        }
        for sum > 0 && len(addrs) > 1 {
                s := 0
-               n := rand.Intn(sum + 1)
+               n := rand.Intn(sum)
                for i := range addrs {
                        s += int(addrs[i].Weight)
-                       if s >= n {
+                       if s > n {
                                if i > 0 {
                                        t := addrs[i]
                                        copy(addrs[1:i+1], addrs[0:i])
diff --git a/src/pkg/net/dnsclient_test.go b/src/pkg/net/dnsclient_test.go
new file mode 100644 (file)
index 0000000..435eb35
--- /dev/null
@@ -0,0 +1,69 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "math/rand"
+       "testing"
+)
+
+func checkDistribution(t *testing.T, data []*SRV, margin float64) {
+       sum := 0
+       for _, srv := range data {
+               sum += int(srv.Weight)
+       }
+
+       results := make(map[string]int)
+
+       count := 1000
+       for j := 0; j < count; j++ {
+               d := make([]*SRV, len(data))
+               copy(d, data)
+               byPriorityWeight(d).shuffleByWeight()
+               key := d[0].Target
+               results[key] = results[key] + 1
+       }
+
+       actual := results[data[0].Target]
+       expected := float64(count) * float64(data[0].Weight) / float64(sum)
+       diff := float64(actual) - expected
+       t.Logf("actual: %v diff: %v e: %v m: %v", actual, diff, expected, margin)
+       if diff < 0 {
+               diff = -diff
+       }
+       if diff > (expected * margin) {
+               t.Errorf("missed target weight: expected %v, %v", expected, actual)
+       }
+}
+
+func testUniformity(t *testing.T, size int, margin float64) {
+       rand.Seed(1)
+       data := make([]*SRV, size)
+       for i := 0; i < size; i++ {
+               data[i] = &SRV{Target: string('a' + i), Weight: 1}
+       }
+       checkDistribution(t, data, margin)
+}
+
+func TestUniformity(t *testing.T) {
+       testUniformity(t, 2, 0.05)
+       testUniformity(t, 3, 0.10)
+       testUniformity(t, 10, 0.20)
+       testWeighting(t, 0.05)
+}
+
+func testWeighting(t *testing.T, margin float64) {
+       rand.Seed(1)
+       data := []*SRV{
+               {Target: "a", Weight: 60},
+               {Target: "b", Weight: 30},
+               {Target: "c", Weight: 10},
+       }
+       checkDistribution(t, data, margin)
+}
+
+func TestWeighting(t *testing.T) {
+       testWeighting(t, 0.05)
+}