]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: add (*Int).Sqrt
authorRuss Cox <rsc@golang.org>
Mon, 10 Oct 2016 20:18:43 +0000 (16:18 -0400)
committerRuss Cox <rsc@golang.org>
Mon, 17 Oct 2016 20:30:19 +0000 (20:30 +0000)
This is needed for some of the more complex primality tests
(to filter out exact squares), and while the code is simple the
boundary conditions are not obvious, so it seems worth having
in the library.

Change-Id: Ica994a6b6c1e412a6f6d9c3cf823f9b653c6bcbd
Reviewed-on: https://go-review.googlesource.com/30706
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/math/big/int.go
src/math/big/int_test.go
src/math/big/nat.go

index 51dc6f78ff9a434f26778612fd49a7d0bea1f7bd..a2c1b580f524e453a61d8ed45cfdbb4ead2f1ab5 100644 (file)
@@ -924,3 +924,14 @@ func (z *Int) Not(x *Int) *Int {
        z.neg = true // z cannot be zero if x is positive
        return z
 }
+
+// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z.
+// It panics if x is negative.
+func (z *Int) Sqrt(x *Int) *Int {
+       if x.neg {
+               panic("square root of negative number")
+       }
+       z.neg = false
+       z.abs = z.abs.sqrt(x.abs)
+       return z
+}
index 18f5be749dd7966053b8134c6b01fa61a4e3460b..b8e0778ca3201783a524a5c473754d74f94db946 100644 (file)
@@ -9,6 +9,7 @@ import (
        "encoding/hex"
        "fmt"
        "math/rand"
+       "strings"
        "testing"
        "testing/quick"
 )
@@ -1453,3 +1454,44 @@ func TestIssue2607(t *testing.T) {
        n := NewInt(10)
        n.Rand(rand.New(rand.NewSource(9)), n)
 }
+
+func TestSqrt(t *testing.T) {
+       root := 0
+       r := new(Int)
+       for i := 0; i < 10000; i++ {
+               if (root+1)*(root+1) <= i {
+                       root++
+               }
+               n := NewInt(int64(i))
+               r.SetInt64(-2)
+               r.Sqrt(n)
+               if r.Cmp(NewInt(int64(root))) != 0 {
+                       t.Errorf("Sqrt(%v) = %v, want %v", n, r, root)
+               }
+       }
+
+       for i := 0; i < 1000; i += 10 {
+               n, _ := new(Int).SetString("1"+strings.Repeat("0", i), 10)
+               r := new(Int).Sqrt(n)
+               root, _ := new(Int).SetString("1"+strings.Repeat("0", i/2), 10)
+               if r.Cmp(root) != 0 {
+                       t.Errorf("Sqrt(1e%d) = %v, want 1e%d", i, r, i/2)
+               }
+       }
+
+       // Test aliasing.
+       r.SetInt64(100)
+       r.Sqrt(r)
+       if r.Int64() != 10 {
+               t.Errorf("Sqrt(100) = %v, want 10 (aliased output)", r.Int64())
+       }
+}
+
+func BenchmarkSqrt(b *testing.B) {
+       n, _ := new(Int).SetString("1"+strings.Repeat("0", 1001), 10)
+       b.ResetTimer()
+       t := new(Int)
+       for i := 0; i < b.N; i++ {
+               t.Sqrt(n)
+       }
+}
index 4a3b7ae33f2be14e8c26f8f3cb7a932a9da03d1d..9b1a626c4cf18acd5d2df494bdc2ca740f13272e 100644 (file)
@@ -1223,3 +1223,37 @@ func (z nat) setBytes(buf []byte) nat {
 
        return z.norm()
 }
+
+// sqrt sets z = ⌊√x⌋
+func (z nat) sqrt(x nat) nat {
+       if x.cmp(natOne) <= 0 {
+               return z.set(x)
+       }
+       if alias(z, x) {
+               z = nil
+       }
+
+       // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
+       // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
+       // https://members.loria.fr/PZimmermann/mca/pub226.html
+       // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
+       // otherwise it converges to the correct z and stays there.
+       var z1, z2 nat
+       z1 = z
+       z1 = z1.setUint64(1)
+       z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x
+       for n := 0; ; n++ {
+               z2, _ = z2.div(nil, x, z1)
+               z2 = z2.add(z2, z1)
+               z2 = z2.shr(z2, 1)
+               if z2.cmp(z1) >= 0 {
+                       // z1 is answer.
+                       // Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
+                       if n&1 == 0 {
+                               return z1
+                       }
+                       return z.set(z1)
+               }
+               z1, z2 = z2, z1
+       }
+}