From: Lakshay Garg Date: Wed, 28 Jun 2017 16:37:55 +0000 (+0530) Subject: math: implement the erfinv function X-Git-Tag: go1.10beta1~1477 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=77412b9300582a5f1882b66e1b39ac7d96178dbe;p=gostls13.git math: implement the erfinv function This commit defines the inverse of error function (erfinv) in the math package. The function is based on the rational approximation of percentage points of normal distribution available at https://www.jstor.org/stable/pdf/2347330.pdf. Fixes #6359 Change-Id: Icfe4508f623e0574c7fffdbf7aa929540fd4c944 Reviewed-on: https://go-review.googlesource.com/46990 Run-TryBot: Robert Griesemer TryBot-Result: Gobot Gobot Reviewed-by: Robert Griesemer --- diff --git a/src/math/all_test.go b/src/math/all_test.go index 11bb8b2564..7c1794f3d7 100644 --- a/src/math/all_test.go +++ b/src/math/all_test.go @@ -211,6 +211,18 @@ var erfc = []float64{ 7.9630075582117758758440411e-01, 1.7806938696800922672994468e+00, } +var erfinv = []float64{ + 4.746037673358033586786350696e-01, + 8.559054432692110956388764172e-01, + -2.45427830571707336251331946e-02, + -4.78116683518973366268905506e-01, + 1.479804430319470983648120853e+00, + 2.654485787128896161882650211e-01, + 5.027444534221520197823192493e-01, + 2.466703532707627818954585670e-01, + 1.632011465103005426240343116e-01, + -1.06672334642196900710000389e+00, +} var exp = []float64{ 1.4533071302642137507696589e+02, 2.2958822575694449002537581e+03, @@ -942,6 +954,23 @@ var erfcSC = []float64{ NaN(), } +var vferfinvSC = []float64{ + 1, + -1, + 0, + Inf(-1), + Inf(1), + NaN(), +} +var erfinvSC = []float64{ + Inf(+1), + Inf(-1), + 0, + NaN(), + NaN(), + NaN(), +} + var vfexpSC = []float64{ Inf(-1), -2000, @@ -2122,6 +2151,30 @@ func TestErfc(t *testing.T) { } } +func TestErfinv(t *testing.T) { + for i := 0; i < len(vf); i++ { + a := vf[i] / 10 + if f := Erfinv(a); !veryclose(erfinv[i], f) { + t.Errorf("Erfinv(%g) = %g, want %g", a, f, erfinv[i]) + } + } + for i := 0; i < len(vferfinvSC); i++ { + if f := Erfinv(vferfinvSC[i]); !alike(erfinvSC[i], f) { + t.Errorf("Erfinv(%g) = %g, want %g", vferfinvSC[i], f, erfinvSC[i]) + } + } + for x := -0.9; x <= 0.90; x += 1e-2 { + if f := Erf(Erfinv(x)); !close(x, f) { + t.Errorf("Erf(Erfinv(%g)) = %g, want %g", x, f, x) + } + } + for x := -0.9; x <= 0.90; x += 1e-2 { + if f := Erfinv(Erf(x)); !close(x, f) { + t.Errorf("Erfinv(Erf(%g)) = %g, want %g", x, f, x) + } + } +} + func TestExp(t *testing.T) { testExp(t, Exp, "Exp") testExp(t, ExpGo, "ExpGo") @@ -2975,6 +3028,14 @@ func BenchmarkErfc(b *testing.B) { GlobalF = x } +func BenchmarkErfinv(b *testing.B) { + x := 0.0 + for i := 0; i < b.N; i++ { + x = Erfinv(.5) + } + GlobalF = x +} + func BenchmarkExp(b *testing.B) { x := 0.0 for i := 0; i < b.N; i++ { diff --git a/src/math/erfinv.go b/src/math/erfinv.go new file mode 100644 index 0000000000..3ea38e0355 --- /dev/null +++ b/src/math/erfinv.go @@ -0,0 +1,116 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package math + +/* + Inverse of the floating-point error function. +*/ + +// This implementation is based on the rational approximation +// of percentage points of normal distribution available from +// http://www.jstor.org/stable/2347330. + +const ( + // Coefficients for approximation to erf in |x| <= 0.85 + a0 = 1.1975323115670912564578e0 + a1 = 4.7072688112383978012285e1 + a2 = 6.9706266534389598238465e2 + a3 = 4.8548868893843886794648e3 + a4 = 1.6235862515167575384252e4 + a5 = 2.3782041382114385731252e4 + a6 = 1.1819493347062294404278e4 + a7 = 8.8709406962545514830200e2 + b0 = 1.0000000000000000000e0 + b1 = 4.2313330701600911252e1 + b2 = 6.8718700749205790830e2 + b3 = 5.3941960214247511077e3 + b4 = 2.1213794301586595867e4 + b5 = 3.9307895800092710610e4 + b6 = 2.8729085735721942674e4 + b7 = 5.2264952788528545610e3 + // Coefficients for approximation to erf in 0.85 < |x| <= 1-2*exp(-25) + c0 = 1.42343711074968357734e0 + c1 = 4.63033784615654529590e0 + c2 = 5.76949722146069140550e0 + c3 = 3.64784832476320460504e0 + c4 = 1.27045825245236838258e0 + c5 = 2.41780725177450611770e-1 + c6 = 2.27238449892691845833e-2 + c7 = 7.74545014278341407640e-4 + d0 = 1.4142135623730950488016887e0 + d1 = 2.9036514445419946173133295e0 + d2 = 2.3707661626024532365971225e0 + d3 = 9.7547832001787427186894837e-1 + d4 = 2.0945065210512749128288442e-1 + d5 = 2.1494160384252876777097297e-2 + d6 = 7.7441459065157709165577218e-4 + d7 = 1.4859850019840355905497876e-9 + // Coefficients for approximation to erf in 1-2*exp(-25) < |x| < 1 + e0 = 6.65790464350110377720e0 + e1 = 5.46378491116411436990e0 + e2 = 1.78482653991729133580e0 + e3 = 2.96560571828504891230e-1 + e4 = 2.65321895265761230930e-2 + e5 = 1.24266094738807843860e-3 + e6 = 2.71155556874348757815e-5 + e7 = 2.01033439929228813265e-7 + f0 = 1.414213562373095048801689e0 + f1 = 8.482908416595164588112026e-1 + f2 = 1.936480946950659106176712e-1 + f3 = 2.103693768272068968719679e-2 + f4 = 1.112800997078859844711555e-3 + f5 = 2.611088405080593625138020e-5 + f6 = 2.010321207683943062279931e-7 + f7 = 2.891024605872965461538222e-15 +) + +// Erfinv returns the inverse error function of x. +// +// Special cases are: +// Erfinv(1) = +Inf +// Erfinv(-1) = -Inf +// Erfinv(x) = NaN if x < -1 or x > 1 +// Erfinv(NaN) = NaN +func Erfinv(x float64) float64 { + // special cases + if IsNaN(x) || x <= -1 || x >= 1 { + if x == -1 || x == 1 { + return Inf(int(x)) + } + return NaN() + } + + sign := false + if x < 0 { + x = -x + sign = true + } + + var ans float64 + if x <= 0.85 { // |x| <= 0.85 + r := 0.180625 - 0.25*x*x + z1 := ((((((a7*r+a6)*r+a5)*r+a4)*r+a3)*r+a2)*r+a1)*r + a0 + z2 := ((((((b7*r+b6)*r+b5)*r+b4)*r+b3)*r+b2)*r+b1)*r + b0 + ans = (x * z1) / z2 + } else { + var z1, z2 float64 + r := Sqrt(Ln2 - Log(1.0-x)) + if r <= 5.0 { + r -= 1.6 + z1 = ((((((c7*r+c6)*r+c5)*r+c4)*r+c3)*r+c2)*r+c1)*r + c0 + z2 = ((((((d7*r+d6)*r+d5)*r+d4)*r+d3)*r+d2)*r+d1)*r + d0 + } else { + r -= 5.0 + z1 = ((((((e7*r+e6)*r+e5)*r+e4)*r+e3)*r+e2)*r+e1)*r + e0 + z2 = ((((((f7*r+f6)*r+f5)*r+f4)*r+f3)*r+f2)*r+f1)*r + f0 + } + ans = z1 / z2 + } + + if sign { + return -ans + } + return ans +}