]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: add (*Float).Sqrt
authorAlberto Donizetti <alb.donizetti@gmail.com>
Tue, 3 Oct 2017 08:18:00 +0000 (10:18 +0200)
committerRobert Griesemer <gri@golang.org>
Thu, 26 Oct 2017 17:29:27 +0000 (17:29 +0000)
This change adds a Square root method to the big.Float type, with
signature

  (z *Float) Sqrt(x *Float) *Float

Fixes #20460

Change-Id: I050aaed0615fe0894e11c800744600648343c223
Reviewed-on: https://go-review.googlesource.com/67830
Run-TryBot: Alberto Donizetti <alb.donizetti@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/math/big/sqrt.go [new file with mode: 0644]
src/math/big/sqrt_test.go [new file with mode: 0644]

diff --git a/src/math/big/sqrt.go b/src/math/big/sqrt.go
new file mode 100644 (file)
index 0000000..4f24fdb
--- /dev/null
@@ -0,0 +1,150 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package big
+
+import "math"
+
+var (
+       nhalf = NewFloat(-0.5)
+       half  = NewFloat(0.5)
+       one   = NewFloat(1.0)
+       two   = NewFloat(2.0)
+)
+
+// Sqrt sets z to the rounded square root of x, and returns it.
+//
+// If z's precision is 0, it is changed to x's precision before the
+// operation. Rounding is performed according to z's precision and
+// rounding mode.
+//
+// The function panics if z < 0. The value of z is undefined in that
+// case.
+func (z *Float) Sqrt(x *Float) *Float {
+       if debugFloat {
+               x.validate()
+       }
+
+       if z.prec == 0 {
+               z.prec = x.prec
+       }
+
+       if x.Sign() == -1 {
+               // following IEEE754-2008 (section 7.2)
+               panic(ErrNaN{"square root of negative operand"})
+       }
+
+       // handle ±0 and +∞
+       if x.form != finite {
+               z.acc = Exact
+               z.form = x.form
+               z.neg = x.neg // IEEE754-2008 requires √±0 = ±0
+               return z
+       }
+
+       // MantExp sets the argument's precision to the receiver's, and
+       // when z.prec > x.prec this will lower z.prec. Restore it after
+       // the MantExp call.
+       prec := z.prec
+       b := x.MantExp(z)
+       z.prec = prec
+
+       // Compute √(z·2**b) as
+       //   √( z)·2**(½b)     if b is even
+       //   √(2z)·2**(⌊½b⌋)   if b > 0 is odd
+       //   √(½z)·2**(⌈½b⌉)   if b < 0 is odd
+       switch b % 2 {
+       case 0:
+               // nothing to do
+       case 1:
+               z.Mul(two, z)
+       case -1:
+               z.Mul(half, z)
+       }
+       // 0.25 <= z < 2.0
+
+       // Solving x² - z = 0 directly requires a Quo call, but it's
+       // faster for small precisions.
+       //
+       // Solving 1/x² - z = 0 avoids the Quo call and is much faster for
+       // high precisions.
+       //
+       // 128bit precision is an empirically chosen threshold.
+       if z.prec <= 128 {
+               z.sqrtDirect(z)
+       } else {
+               z.sqrtInverse(z)
+       }
+
+       // re-attach halved exponent
+       return z.SetMantExp(z, b/2)
+}
+
+// Compute √x (up to prec 128) by solving
+//   t² - x = 0
+// for t, starting with a 53 bits precision guess from math.Sqrt and
+// then using at most two iterations of Newton's method.
+func (z *Float) sqrtDirect(x *Float) {
+       // let
+       //   f(t) = t² - x
+       // then
+       //   g(t) = f(t)/f'(t) = ½(t² - x)/t
+       u := new(Float)
+       g := func(t *Float) *Float {
+               u.prec = t.prec
+               u.Mul(t, t)    // u = t²
+               u.Sub(u, x)    //   = t² - x
+               u.Mul(half, u) //   = ½(t² - x)
+               u.Quo(u, t)    //   = ½(t² - x)/t
+               return u
+       }
+
+       xf, _ := x.Float64()
+       sq := NewFloat(math.Sqrt(xf))
+
+       switch {
+       case z.prec > 128:
+               panic("sqrtDirect: only for z.prec <= 128")
+       case z.prec > 64:
+               sq.prec *= 2
+               sq.Sub(sq, g(sq))
+               fallthrough
+       default:
+               sq.prec *= 2
+               sq.Sub(sq, g(sq))
+       }
+
+       z.Set(sq)
+}
+
+// Compute √x (to z.prec precision) by solving
+//   1/t² - x = 0
+// for t (using Newton's method), and then inverting.
+func (z *Float) sqrtInverse(x *Float) {
+       // let
+       //   f(t) = 1/t² - x
+       // then
+       //   g(t) = f(t)/f'(t) = -½t(1 - xt²)
+       u := new(Float)
+       g := func(t *Float) *Float {
+               u.prec = t.prec
+               u.Mul(t, t)     // u = t²
+               u.Mul(x, u)     //   = xt²
+               u.Sub(one, u)   //   = 1 - xt²
+               u.Mul(nhalf, u) //   = -½(1 - xt²)
+               u.Mul(t, u)     //   = -½t(1 - xt²)
+               return u
+       }
+
+       xf, _ := x.Float64()
+       sqi := NewFloat(1 / math.Sqrt(xf))
+       for prec := 2 * z.prec; sqi.prec < prec; {
+               sqi.prec *= 2
+               sqi.Sub(sqi, g(sqi))
+       }
+       // sqi = 1/√x
+
+       // x/√x = √x
+       z.Mul(x, sqi)
+}
diff --git a/src/math/big/sqrt_test.go b/src/math/big/sqrt_test.go
new file mode 100644 (file)
index 0000000..6a412d6
--- /dev/null
@@ -0,0 +1,123 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package big
+
+import (
+       "fmt"
+       "math"
+       "math/rand"
+       "testing"
+)
+
+// TestFloatSqrt64 tests that Float.Sqrt of numbers with 53bit mantissa
+// behaves like float math.Sqrt.
+func TestFloatSqrt64(t *testing.T) {
+       for i := 0; i < 1e5; i++ {
+               r := rand.Float64()
+
+               got := new(Float).SetPrec(53)
+               got.Sqrt(NewFloat(r))
+               want := NewFloat(math.Sqrt(r))
+               if got.Cmp(want) != 0 {
+                       t.Fatalf("Sqrt(%g) =\n got %g;\nwant %g", r, got, want)
+               }
+       }
+}
+
+func TestFloatSqrt(t *testing.T) {
+       for _, test := range []struct {
+               x    string
+               want string
+       }{
+               // Test values were generated on Wolfram Alpha using query
+               //   'sqrt(N) to 350 digits'
+               // 350 decimal digits give up to 1000 binary digits.
+               {"0.03125", "0.17677669529663688110021109052621225982120898442211850914708496724884155980776337985629844179095519659187673077886403712811560450698134215158051518713749197892665283324093819909447499381264409775757143376369499645074628431682460775184106467733011114982619404115381053858929018135497032545349940642599871090667456829147610370507757690729404938184321879"},
+               {"0.125", "0.35355339059327376220042218105242451964241796884423701829416993449768311961552675971259688358191039318375346155772807425623120901396268430316103037427498395785330566648187639818894998762528819551514286752738999290149256863364921550368212935466022229965238808230762107717858036270994065090699881285199742181334913658295220741015515381458809876368643757"},
+               {"0.5", "0.70710678118654752440084436210484903928483593768847403658833986899536623923105351942519376716382078636750692311545614851246241802792536860632206074854996791570661133296375279637789997525057639103028573505477998580298513726729843100736425870932044459930477616461524215435716072541988130181399762570399484362669827316590441482031030762917619752737287514"},
+               {"2.0", "1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727350138462309122970249248360558507372126441214970999358314132226659275055927557999505011527820605714701095599716059702745345968620147285174186408891986095523292304843087143214508397626036279952514079896872533965463318088296406206152583523950547457503"},
+               {"3.0", "1.7320508075688772935274463415058723669428052538103806280558069794519330169088000370811461867572485756756261414154067030299699450949989524788116555120943736485280932319023055820679748201010846749232650153123432669033228866506722546689218379712270471316603678615880190499865373798593894676503475065760507566183481296061009476021871903250831458295239598"},
+               {"4.0", "2.0"},
+
+               {"1p512", "1p256"},
+               {"4p1024", "2p512"},
+               {"9p2048", "3p1024"},
+
+               {"1p-1024", "1p-512"},
+               {"4p-2048", "2p-1024"},
+               {"9p-4096", "3p-2048"},
+       } {
+               for _, prec := range []uint{24, 53, 64, 65, 100, 128, 129, 200, 256, 400, 600, 800, 1000} {
+                       x := new(Float).SetPrec(prec)
+                       x.Parse(test.x, 10)
+
+                       got := new(Float).SetPrec(prec).Sqrt(x)
+                       want := new(Float).SetPrec(prec)
+                       want.Parse(test.want, 10)
+                       if got.Cmp(want) != 0 {
+                               t.Errorf("prec = %d, Sqrt(%v) =\ngot  %g;\nwant %g",
+                                       prec, test.x, got, want)
+                       }
+
+                       // Square test.
+                       // If got holds the square root of x to precision p, then
+                       //   got = √x + k
+                       // for some k such that |k| < 2**(-p). Thus,
+                       //   got² = (√x + k)² = x + 2k√n + k²
+                       // and the error must satisfy
+                       //   err = |got² - x| ≈ | 2k√n | < 2**(-p+1)*√n
+                       // Ignoring the k² term for simplicity.
+
+                       // err = |got² - x|
+                       // (but do intermediate steps with 32 guard digits to
+                       // avoid introducing spurious rounding-related errors)
+                       sq := new(Float).SetPrec(prec+32).Mul(got, got)
+                       diff := new(Float).Sub(sq, x)
+                       err := diff.Abs(diff).SetPrec(prec)
+
+                       // maxErr = 2**(-p+1)*√x
+                       one := new(Float).SetPrec(prec).SetInt64(1)
+                       maxErr := new(Float).Mul(new(Float).SetMantExp(one, -int(prec)+1), got)
+
+                       if err.Cmp(maxErr) >= 0 {
+                               t.Errorf("prec = %d, Sqrt(%v) =\ngot err  %g;\nwant maxErr %g",
+                                       prec, test.x, err, maxErr)
+                       }
+               }
+       }
+}
+
+func TestFloatSqrtSpecial(t *testing.T) {
+       for _, test := range []struct {
+               x    *Float
+               want *Float
+       }{
+               {NewFloat(+0), NewFloat(+0)},
+               {NewFloat(-0), NewFloat(-0)},
+               {NewFloat(math.Inf(+1)), NewFloat(math.Inf(+1))},
+       } {
+               got := new(Float).Sqrt(test.x)
+               if got.neg != test.want.neg || got.form != test.want.form {
+                       t.Errorf("Sqrt(%v) = %v (neg: %v); want %v (neg: %v)",
+                               test.x, got, got.neg, test.want, test.want.neg)
+               }
+       }
+
+}
+
+// Benchmarks
+
+func BenchmarkFloatSqrt(b *testing.B) {
+       for _, prec := range []uint{64, 128, 256, 1e3, 1e4, 1e5, 1e6} {
+               x := NewFloat(2)
+               z := new(Float).SetPrec(prec)
+               b.Run(fmt.Sprintf("%v", prec), func(b *testing.B) {
+                       b.ReportAllocs()
+                       for n := 0; n < b.N; n++ {
+                               z.Sqrt(x)
+                       }
+               })
+       }
+}