z.neg = true // z cannot be zero if x is positive
return z
}
+
+// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z.
+// It panics if x is negative.
+func (z *Int) Sqrt(x *Int) *Int {
+ if x.neg {
+ panic("square root of negative number")
+ }
+ z.neg = false
+ z.abs = z.abs.sqrt(x.abs)
+ return z
+}
"encoding/hex"
"fmt"
"math/rand"
+ "strings"
"testing"
"testing/quick"
)
n := NewInt(10)
n.Rand(rand.New(rand.NewSource(9)), n)
}
+
+func TestSqrt(t *testing.T) {
+ root := 0
+ r := new(Int)
+ for i := 0; i < 10000; i++ {
+ if (root+1)*(root+1) <= i {
+ root++
+ }
+ n := NewInt(int64(i))
+ r.SetInt64(-2)
+ r.Sqrt(n)
+ if r.Cmp(NewInt(int64(root))) != 0 {
+ t.Errorf("Sqrt(%v) = %v, want %v", n, r, root)
+ }
+ }
+
+ for i := 0; i < 1000; i += 10 {
+ n, _ := new(Int).SetString("1"+strings.Repeat("0", i), 10)
+ r := new(Int).Sqrt(n)
+ root, _ := new(Int).SetString("1"+strings.Repeat("0", i/2), 10)
+ if r.Cmp(root) != 0 {
+ t.Errorf("Sqrt(1e%d) = %v, want 1e%d", i, r, i/2)
+ }
+ }
+
+ // Test aliasing.
+ r.SetInt64(100)
+ r.Sqrt(r)
+ if r.Int64() != 10 {
+ t.Errorf("Sqrt(100) = %v, want 10 (aliased output)", r.Int64())
+ }
+}
+
+func BenchmarkSqrt(b *testing.B) {
+ n, _ := new(Int).SetString("1"+strings.Repeat("0", 1001), 10)
+ b.ResetTimer()
+ t := new(Int)
+ for i := 0; i < b.N; i++ {
+ t.Sqrt(n)
+ }
+}
return z.norm()
}
+
+// sqrt sets z = ⌊√x⌋
+func (z nat) sqrt(x nat) nat {
+ if x.cmp(natOne) <= 0 {
+ return z.set(x)
+ }
+ if alias(z, x) {
+ z = nil
+ }
+
+ // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
+ // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
+ // https://members.loria.fr/PZimmermann/mca/pub226.html
+ // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
+ // otherwise it converges to the correct z and stays there.
+ var z1, z2 nat
+ z1 = z
+ z1 = z1.setUint64(1)
+ z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x
+ for n := 0; ; n++ {
+ z2, _ = z2.div(nil, x, z1)
+ z2 = z2.add(z2, z1)
+ z2 = z2.shr(z2, 1)
+ if z2.cmp(z1) >= 0 {
+ // z1 is answer.
+ // Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
+ if n&1 == 0 {
+ return z1
+ }
+ return z.set(z1)
+ }
+ z1, z2 = z2, z1
+ }
+}