]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: handle NaNs in Float.Cmp
authorRobert Griesemer <gri@golang.org>
Thu, 5 Mar 2015 21:10:33 +0000 (13:10 -0800)
committerRobert Griesemer <gri@golang.org>
Thu, 12 Mar 2015 18:56:25 +0000 (18:56 +0000)
Also:
- Implemented NewFloat convenience factory function (analogous to
  NewInt and NewRat).
- Implemented convenience accessors for Accuracy values returned
  from Float.Cmp.
- Added test and example.

Change-Id: I985bb4f86e6def222d4b2505417250d29a39c60e
Reviewed-on: https://go-review.googlesource.com/6970
Reviewed-by: Alan Donovan <adonovan@google.com>
src/math/big/float.go
src/math/big/float_test.go
src/math/big/floatexample_test.go

index 73439f493e9269aa9537c69650d009bcbdd973ae..778cc20da5c1139ce5faf7d7791f58400b2482cd 100644 (file)
@@ -65,6 +65,12 @@ type Float struct {
        exp  int32
 }
 
+// NewFloat allocates and returns a new Float set to x,
+// with precision 53 and rounding mode ToNearestEven.
+func NewFloat(x float64) *Float {
+       return new(Float).SetFloat64(x)
+}
+
 // Exponent and precision limits.
 const (
        MaxExp  = math.MaxInt32  // largest supported exponent
@@ -135,13 +141,6 @@ const (
 
 //go:generate stringer -type=Accuracy
 
-func (x *Float) cmpZero() Accuracy {
-       if x.neg {
-               return Above
-       }
-       return Below
-}
-
 // SetPrec sets z's precision to prec and returns the (possibly) rounded
 // value of z. Rounding occurs according to z's rounding mode if the mantissa
 // cannot be represented in prec bits without loss of precision.
@@ -173,6 +172,13 @@ func (z *Float) SetPrec(prec uint) *Float {
        return z
 }
 
+func (x *Float) cmpZero() Accuracy {
+       if x.neg {
+               return Above
+       }
+       return Below
+}
+
 // SetMode sets z's rounding mode to mode and returns an exact z.
 // z remains unchanged otherwise.
 // z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact.
@@ -1030,7 +1036,7 @@ func (z *Float) Neg(x *Float) *Float {
 }
 
 // z = x + y, ignoring signs of x and y.
-// x and y must not be 0, Inf, or NaN.
+// x.form and y.form must be finite.
 func (z *Float) uadd(x, y *Float) {
        // Note: This implementation requires 2 shifts most of the
        // time. It is also inefficient if exponents or precisions
@@ -1074,7 +1080,7 @@ func (z *Float) uadd(x, y *Float) {
 }
 
 // z = x - y for x >= y, ignoring signs of x and y.
-// x and y must not be 0, Inf, or NaN.
+// x.form and y.form must be finite.
 func (z *Float) usub(x, y *Float) {
        // This code is symmetric to uadd.
        // We have not factored the common code out because
@@ -1116,7 +1122,7 @@ func (z *Float) usub(x, y *Float) {
 }
 
 // z = x * y, ignoring signs of x and y.
-// x and y must not be 0, Inf, or NaN.
+// x.form and y.form must be finite.
 func (z *Float) umul(x, y *Float) {
        if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
                panic("umul called with empty mantissa")
@@ -1137,7 +1143,7 @@ func (z *Float) umul(x, y *Float) {
 }
 
 // z = x / y, ignoring signs of x and y.
-// x and y must not be 0, Inf, or NaN.
+// x.form and y.form must be finite.
 func (z *Float) uquo(x, y *Float) {
        if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
                panic("uquo called with empty mantissa")
@@ -1184,18 +1190,19 @@ func (z *Float) uquo(x, y *Float) {
        z.round(sbit)
 }
 
-// ucmp returns -1, 0, or 1, depending on whether x < y, x == y, or x > y,
-// while ignoring the signs of x and y. x and y must not be 0, Inf, or NaN.
-func (x *Float) ucmp(y *Float) int {
+// ucmp returns Below, Exact, or Above, depending
+// on whether x < y, x == y, or x > y.
+// x.form and y.form must be finite.
+func (x *Float) ucmp(y *Float) Accuracy {
        if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
                panic("ucmp called with empty mantissa")
        }
 
        switch {
        case x.exp < y.exp:
-               return -1
+               return Below
        case x.exp > y.exp:
-               return 1
+               return Above
        }
        // x.exp == y.exp
 
@@ -1214,13 +1221,13 @@ func (x *Float) ucmp(y *Float) int {
                }
                switch {
                case xm < ym:
-                       return -1
+                       return Below
                case xm > ym:
-                       return 1
+                       return Above
                }
        }
 
-       return 0
+       return Exact
 }
 
 // Handling of sign bit as defined by IEEE 754-2008, section 6.3:
@@ -1286,7 +1293,7 @@ func (z *Float) Add(x, y *Float) *Float {
        } else {
                // x + (-y) == x - y == -(y - x)
                // (-x) + y == y - x == -(x - y)
-               if x.ucmp(y) >= 0 {
+               if x.ucmp(y) == Above {
                        z.usub(x, y)
                } else {
                        z.neg = !z.neg
@@ -1342,7 +1349,7 @@ func (z *Float) Sub(x, y *Float) *Float {
        } else {
                // x - y == x - y == -(y - x)
                // (-x) - (-y) == y - x == -(x - y)
-               if x.ucmp(y) >= 0 {
+               if x.ucmp(y) == Above {
                        z.usub(x, y)
                } else {
                        z.neg = !z.neg
@@ -1441,46 +1448,49 @@ func (z *Float) Quo(x, y *Float) *Float {
 
 // Cmp compares x and y and returns:
 //
-//   -1 if x <  y
-//    0 if x == y (incl. -0 == 0)
-//   +1 if x >  y
+//   Below if x <  y
+//   Exact if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf)
+//   Above if x >  y
+//   Undef if any of x, y is NaN
 //
-// Infinities with matching sign are equal.
-// NaN values are never equal.
-// BUG(gri) Float.Cmp does not implement comparing of NaNs.
-func (x *Float) Cmp(y *Float) int {
+func (x *Float) Cmp(y *Float) Accuracy {
        if debugFloat {
                x.validate()
                y.validate()
        }
 
+       if x.form == nan || y.form == nan {
+               return Undef
+       }
+
        mx := x.ord()
        my := y.ord()
        switch {
        case mx < my:
-               return -1
+               return Below
        case mx > my:
-               return +1
+               return Above
        }
        // mx == my
 
        // only if |mx| == 1 we have to compare the mantissae
        switch mx {
        case -1:
-               return -x.ucmp(y)
+               return y.ucmp(x)
        case +1:
-               return +x.ucmp(y)
+               return x.ucmp(y)
        }
 
-       return 0
+       return Exact
 }
 
-func umax32(x, y uint32) uint32 {
-       if x > y {
-               return x
-       }
-       return y
-}
+// The following accessors simplify testing of Cmp results.
+func (acc Accuracy) Eql() bool { return acc == Exact }
+func (acc Accuracy) Neq() bool { return acc != Exact }
+func (acc Accuracy) Lss() bool { return acc == Below }
+func (acc Accuracy) Leq() bool { return acc&Above == 0 }
+func (acc Accuracy) Gtr() bool { return acc == Above }
+func (acc Accuracy) Geq() bool { return acc&Below == 0 }
 
 // ord classifies x and returns:
 //
@@ -1507,3 +1517,10 @@ func (x *Float) ord() int {
        }
        return m
 }
+
+func umax32(x, y uint32) uint32 {
+       if x > y {
+               return x
+       }
+       return y
+}
index edd0056ff306c64a1cefa05347790b50ac5d60e8..dca78a84c53e85b6487cbeb34293dd69079997b8 100644 (file)
@@ -1063,8 +1063,8 @@ func TestFloatAdd32(t *testing.T) {
                                x0, y0 = y0, x0
                        }
 
-                       x := new(Float).SetFloat64(x0)
-                       y := new(Float).SetFloat64(y0)
+                       x := NewFloat(x0)
+                       y := NewFloat(y0)
                        z := new(Float).SetPrec(24)
 
                        z.Add(x, y)
@@ -1096,8 +1096,8 @@ func TestFloatAdd64(t *testing.T) {
                                x0, y0 = y0, x0
                        }
 
-                       x := new(Float).SetFloat64(x0)
-                       y := new(Float).SetFloat64(y0)
+                       x := NewFloat(x0)
+                       y := NewFloat(y0)
                        z := new(Float).SetPrec(53)
 
                        z.Add(x, y)
@@ -1182,8 +1182,8 @@ func TestFloatMul64(t *testing.T) {
                                x0, y0 = y0, x0
                        }
 
-                       x := new(Float).SetFloat64(x0)
-                       y := new(Float).SetFloat64(y0)
+                       x := NewFloat(x0)
+                       y := NewFloat(y0)
                        z := new(Float).SetPrec(53)
 
                        z.Mul(x, y)
@@ -1260,7 +1260,7 @@ func TestFloatQuo(t *testing.T) {
                z := bits.Float()
 
                // compute accurate x as z*y
-               y := new(Float).SetFloat64(3.14159265358979323e123)
+               y := NewFloat(3.14159265358979323e123)
 
                x := new(Float).SetPrec(z.Prec() + y.Prec()).SetMode(ToZero)
                x.Mul(z, y)
@@ -1329,10 +1329,9 @@ func TestFloatQuoSmoke(t *testing.T) {
        }
 }
 
-// TestFloatArithmeticSpecialValues tests that Float operations produce
-// the correct result for all combinations of regular and special value
-// arguments (±0, ±Inf, NaN) and ±1 and ±2.71828 as representatives for
-// nonzero finite values.
+// TestFloatArithmeticSpecialValues tests that Float operations produce the
+// correct results for combinations of zero (±0), finite (±1 and ±2.71828),
+// and non-finite (±Inf, NaN) operands.
 func TestFloatArithmeticSpecialValues(t *testing.T) {
        zero := 0.0
        args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1), math.NaN()}
@@ -1442,6 +1441,39 @@ func TestFloatArithmeticRounding(t *testing.T) {
        }
 }
 
-func TestFloatCmp(t *testing.T) {
-       // TODO(gri) implement this
+// TestFloatCmpSpecialValues tests that Cmp produces the correct results for
+// combinations of zero (±0), finite (±1 and ±2.71828), and non-finite (±Inf,
+// NaN) operands.
+func TestFloatCmpSpecialValues(t *testing.T) {
+       zero := 0.0
+       args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1), math.NaN()}
+       xx := new(Float)
+       yy := new(Float)
+       for i := 0; i < 4; i++ {
+               for _, x := range args {
+                       xx.SetFloat64(x)
+                       // check conversion is correct
+                       // (no need to do this for y, since we see exactly the
+                       // same values there)
+                       if got, acc := xx.Float64(); !math.IsNaN(x) && (got != x || acc != Exact) {
+                               t.Errorf("Float(%g) == %g (%s)", x, got, acc)
+                       }
+                       for _, y := range args {
+                               yy.SetFloat64(y)
+                               got := xx.Cmp(yy)
+                               want := Undef
+                               switch {
+                               case x < y:
+                                       want = Below
+                               case x == y:
+                                       want = Exact
+                               case x > y:
+                                       want = Above
+                               }
+                               if got != want {
+                                       t.Errorf("(%g).Cmp(%g) = %s; want %s", x, y, got, want)
+                               }
+                       }
+               }
+       }
 }
index 23655d09660f7189478cc715d7d410a3d33a80fd..181c0bc136acb438be0d06c31771a0588a8707d0 100644 (file)
@@ -6,11 +6,10 @@ package big_test
 
 import (
        "fmt"
+       "math"
        "math/big"
 )
 
-// TODO(gri) add more examples
-
 func ExampleFloat_Add() {
        // Operating on numbers of different precision.
        var x, y, z big.Float
@@ -29,11 +28,10 @@ func ExampleFloat_Add() {
 
 func Example_Shift() {
        // Implementing Float "shift" by modifying the (binary) exponents directly.
-       var x big.Float
        for s := -5; s <= 5; s++ {
-               x.SetFloat64(0.5)
-               x.SetMantExp(&x, x.MantExp(nil)+s) // shift x by s
-               fmt.Println(&x)
+               x := big.NewFloat(0.5)
+               x.SetMantExp(x, x.MantExp(nil)+s) // shift x by s
+               fmt.Println(x)
        }
        // Output:
        // 0.015625
@@ -48,3 +46,92 @@ func Example_Shift() {
        // 8
        // 16
 }
+
+func ExampleFloat_Cmp() {
+       inf := math.Inf(1)
+       zero := 0.0
+       nan := math.NaN()
+
+       operands := []float64{-inf, -1.2, -zero, 0, +1.2, +inf, nan}
+
+       fmt.Println("   x     y   cmp   eql  neq  lss  leq  gtr  geq")
+       fmt.Println("-----------------------------------------------")
+       for _, x64 := range operands {
+               x := big.NewFloat(x64)
+               for _, y64 := range operands {
+                       y := big.NewFloat(y64)
+                       t := x.Cmp(y)
+                       fmt.Printf(
+                               "%4s  %4s  %5s   %c    %c    %c    %c    %c    %c\n",
+                               x, y, t,
+                               mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
+               }
+               fmt.Println()
+       }
+
+       // Output:
+       //    x     y   cmp   eql  neq  lss  leq  gtr  geq
+       // -----------------------------------------------
+       // -Inf  -Inf  Exact   ●    ○    ○    ●    ○    ●
+       // -Inf  -1.2  Below   ○    ●    ●    ●    ○    ○
+       // -Inf    -0  Below   ○    ●    ●    ●    ○    ○
+       // -Inf     0  Below   ○    ●    ●    ●    ○    ○
+       // -Inf   1.2  Below   ○    ●    ●    ●    ○    ○
+       // -Inf  +Inf  Below   ○    ●    ●    ●    ○    ○
+       // -Inf   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       // -1.2  -Inf  Above   ○    ●    ○    ○    ●    ●
+       // -1.2  -1.2  Exact   ●    ○    ○    ●    ○    ●
+       // -1.2    -0  Below   ○    ●    ●    ●    ○    ○
+       // -1.2     0  Below   ○    ●    ●    ●    ○    ○
+       // -1.2   1.2  Below   ○    ●    ●    ●    ○    ○
+       // -1.2  +Inf  Below   ○    ●    ●    ●    ○    ○
+       // -1.2   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       //   -0  -Inf  Above   ○    ●    ○    ○    ●    ●
+       //   -0  -1.2  Above   ○    ●    ○    ○    ●    ●
+       //   -0    -0  Exact   ●    ○    ○    ●    ○    ●
+       //   -0     0  Exact   ●    ○    ○    ●    ○    ●
+       //   -0   1.2  Below   ○    ●    ●    ●    ○    ○
+       //   -0  +Inf  Below   ○    ●    ●    ●    ○    ○
+       //   -0   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       //    0  -Inf  Above   ○    ●    ○    ○    ●    ●
+       //    0  -1.2  Above   ○    ●    ○    ○    ●    ●
+       //    0    -0  Exact   ●    ○    ○    ●    ○    ●
+       //    0     0  Exact   ●    ○    ○    ●    ○    ●
+       //    0   1.2  Below   ○    ●    ●    ●    ○    ○
+       //    0  +Inf  Below   ○    ●    ●    ●    ○    ○
+       //    0   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       //  1.2  -Inf  Above   ○    ●    ○    ○    ●    ●
+       //  1.2  -1.2  Above   ○    ●    ○    ○    ●    ●
+       //  1.2    -0  Above   ○    ●    ○    ○    ●    ●
+       //  1.2     0  Above   ○    ●    ○    ○    ●    ●
+       //  1.2   1.2  Exact   ●    ○    ○    ●    ○    ●
+       //  1.2  +Inf  Below   ○    ●    ●    ●    ○    ○
+       //  1.2   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       // +Inf  -Inf  Above   ○    ●    ○    ○    ●    ●
+       // +Inf  -1.2  Above   ○    ●    ○    ○    ●    ●
+       // +Inf    -0  Above   ○    ●    ○    ○    ●    ●
+       // +Inf     0  Above   ○    ●    ○    ○    ●    ●
+       // +Inf   1.2  Above   ○    ●    ○    ○    ●    ●
+       // +Inf  +Inf  Exact   ●    ○    ○    ●    ○    ●
+       // +Inf   NaN  Undef   ○    ●    ○    ○    ○    ○
+       //
+       //  NaN  -Inf  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN  -1.2  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN    -0  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN     0  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN   1.2  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN  +Inf  Undef   ○    ●    ○    ○    ○    ○
+       //  NaN   NaN  Undef   ○    ●    ○    ○    ○    ○
+}
+
+func mark(p bool) rune {
+       if p {
+               return '●'
+       }
+       return '○'
+}