]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: wrap Float.Cmp result in struct to prevent wrong use
authorRobert Griesemer <gri@golang.org>
Sat, 14 Mar 2015 00:24:30 +0000 (17:24 -0700)
committerRobert Griesemer <gri@golang.org>
Sat, 14 Mar 2015 00:48:53 +0000 (00:48 +0000)
Float.Cmp used to return a value < 0, 0, or > 0 depending on how
arguments x, y compared against each other. With the possibility
of NaNs, the result was changed into an Accuracy (to include Undef).
Consequently, Float.Cmp results could still be compared for (in-)
equality with 0, but comparing if < 0 or > 0 would provide the
wrong answer w/o any obvious notice by the compiler.

This change wraps Float.Cmp results into a struct and accessors
are used to access the desired result. This prevents incorrect
use.

Change-Id: I34e6a6c1859251ec99b5cf953e82542025ace56f
Reviewed-on: https://go-review.googlesource.com/7526
Reviewed-by: Rob Pike <r@golang.org>
src/math/big/float.go
src/math/big/float_test.go
src/math/big/floatconv_test.go
src/math/big/floatexample_test.go

index 44691c4783955d948564296a4c38fe2e571561b3..d716c8ca59ffc1f5be45198edd62c1a0b046aa63 100644 (file)
@@ -1446,6 +1446,10 @@ func (z *Float) Quo(x, y *Float) *Float {
        return z
 }
 
+type cmpResult struct {
+       acc Accuracy
+}
+
 // Cmp compares x and y and returns:
 //
 //   Below if x <  y
@@ -1453,44 +1457,45 @@ func (z *Float) Quo(x, y *Float) *Float {
 //   Above if x >  y
 //   Undef if any of x, y is NaN
 //
-func (x *Float) Cmp(y *Float) Accuracy {
+func (x *Float) Cmp(y *Float) cmpResult {
        if debugFloat {
                x.validate()
                y.validate()
        }
 
        if x.form == nan || y.form == nan {
-               return Undef
+               return cmpResult{Undef}
        }
 
        mx := x.ord()
        my := y.ord()
        switch {
        case mx < my:
-               return Below
+               return cmpResult{Below}
        case mx > my:
-               return Above
+               return cmpResult{Above}
        }
        // mx == my
 
        // only if |mx| == 1 we have to compare the mantissae
        switch mx {
        case -1:
-               return y.ucmp(x)
+               return cmpResult{y.ucmp(x)}
        case +1:
-               return x.ucmp(y)
+               return cmpResult{x.ucmp(y)}
        }
 
-       return Exact
+       return cmpResult{Exact}
 }
 
 // The following accessors simplify testing of Cmp results.
-func (acc Accuracy) Eql() bool { return acc == Exact }
-func (acc Accuracy) Neq() bool { return acc != Exact }
-func (acc Accuracy) Lss() bool { return acc == Below }
-func (acc Accuracy) Leq() bool { return acc&Above == 0 }
-func (acc Accuracy) Gtr() bool { return acc == Above }
-func (acc Accuracy) Geq() bool { return acc&Below == 0 }
+func (res cmpResult) Acc() Accuracy { return res.acc }
+func (res cmpResult) Eql() bool     { return res.acc == Exact }
+func (res cmpResult) Neq() bool     { return res.acc != Exact }
+func (res cmpResult) Lss() bool     { return res.acc == Below }
+func (res cmpResult) Leq() bool     { return res.acc&Above == 0 }
+func (res cmpResult) Gtr() bool     { return res.acc == Above }
+func (res cmpResult) Geq() bool     { return res.acc&Below == 0 }
 
 // ord classifies x and returns:
 //
index 86b1c6f7a1ac2842ca524d031059839ad96a34e0..683809bf56764ecd058d71b3adf0ad7bb8ca861f 100644 (file)
@@ -207,7 +207,7 @@ func feq(x, y *Float) bool {
        if x.IsNaN() || y.IsNaN() {
                return x.IsNaN() && y.IsNaN()
        }
-       return x.Cmp(y) == 0 && x.IsNeg() == y.IsNeg()
+       return x.Cmp(y).Eql() && x.IsNeg() == y.IsNeg()
 }
 
 func TestFloatMantExp(t *testing.T) {
@@ -918,7 +918,7 @@ func TestFloatRat(t *testing.T) {
                // inverse conversion
                if res != nil {
                        got := new(Float).SetPrec(64).SetRat(res)
-                       if got.Cmp(x) != 0 {
+                       if got.Cmp(x).Neq() {
                                t.Errorf("%s: got %s; want %s", test.x, got, x)
                        }
                }
@@ -995,7 +995,7 @@ func TestFloatInc(t *testing.T) {
                for i := 0; i < n; i++ {
                        x.Add(&x, &one)
                }
-               if x.Cmp(new(Float).SetInt64(n)) != 0 {
+               if x.Cmp(new(Float).SetInt64(n)).Neq() {
                        t.Errorf("prec = %d: got %s; want %d", prec, &x, n)
                }
        }
@@ -1036,14 +1036,14 @@ func TestFloatAdd(t *testing.T) {
                                        got := new(Float).SetPrec(prec).SetMode(mode)
                                        got.Add(x, y)
                                        want := zbits.round(prec, mode)
-                                       if got.Cmp(want) != 0 {
+                                       if got.Cmp(want).Neq() {
                                                t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t+    %s %v\n\t=    %s\n\twant %s",
                                                        i, prec, mode, x, xbits, y, ybits, got, want)
                                        }
 
                                        got.Sub(z, x)
                                        want = ybits.round(prec, mode)
-                                       if got.Cmp(want) != 0 {
+                                       if got.Cmp(want).Neq() {
                                                t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t-    %s %v\n\t=    %s\n\twant %s",
                                                        i, prec, mode, z, zbits, x, xbits, got, want)
                                        }
@@ -1137,7 +1137,7 @@ func TestFloatMul(t *testing.T) {
                                        got := new(Float).SetPrec(prec).SetMode(mode)
                                        got.Mul(x, y)
                                        want := zbits.round(prec, mode)
-                                       if got.Cmp(want) != 0 {
+                                       if got.Cmp(want).Neq() {
                                                t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t*    %s %v\n\t=    %s\n\twant %s",
                                                        i, prec, mode, x, xbits, y, ybits, got, want)
                                        }
@@ -1147,7 +1147,7 @@ func TestFloatMul(t *testing.T) {
                                        }
                                        got.Quo(z, x)
                                        want = ybits.round(prec, mode)
-                                       if got.Cmp(want) != 0 {
+                                       if got.Cmp(want).Neq() {
                                                t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t/    %s %v\n\t=    %s\n\twant %s",
                                                        i, prec, mode, z, zbits, x, xbits, got, want)
                                        }
@@ -1230,7 +1230,7 @@ func TestIssue6866(t *testing.T) {
                p.Mul(p, psix)
                z2.Sub(two, p)
 
-               if z1.Cmp(z2) != 0 {
+               if z1.Cmp(z2).Neq() {
                        t.Fatalf("prec %d: got z1 = %s != z2 = %s; want z1 == z2\n", prec, z1, z2)
                }
                if z1.Sign() != 0 {
@@ -1281,7 +1281,7 @@ func TestFloatQuo(t *testing.T) {
                                prec := uint(preci + d)
                                got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y)
                                want := bits.round(prec, mode)
-                               if got.Cmp(want) != 0 {
+                               if got.Cmp(want).Neq() {
                                        t.Errorf("i = %d, prec = %d, %s:\n\t     %s\n\t/    %s\n\t=    %s\n\twant %s",
                                                i, prec, mode, x, y, got, want)
                                }
@@ -1462,7 +1462,7 @@ func TestFloatCmpSpecialValues(t *testing.T) {
                        }
                        for _, y := range args {
                                yy.SetFloat64(y)
-                               got := xx.Cmp(yy)
+                               got := xx.Cmp(yy).Acc()
                                want := Undef
                                switch {
                                case x < y:
index e7920d0c079882a196c5eb80349e80f1ffb4ce3f..17c8b147862506522fa28d13f31639fe3b11ef56 100644 (file)
@@ -102,7 +102,7 @@ func TestFloatSetFloat64String(t *testing.T) {
                }
                f, _ := x.Float64()
                want := new(Float).SetFloat64(test.x)
-               if x.Cmp(want) != 0 {
+               if x.Cmp(want).Neq() {
                        t.Errorf("%s: got %s (%v); want %v", test.s, &x, f, test.x)
                }
        }
index 181c0bc136acb438be0d06c31771a0588a8707d0..5e4a9cbe89ceb455bc2436a6b5362609455c6488 100644 (file)
@@ -63,7 +63,7 @@ func ExampleFloat_Cmp() {
                        t := x.Cmp(y)
                        fmt.Printf(
                                "%4s  %4s  %5s   %c    %c    %c    %c    %c    %c\n",
-                               x, y, t,
+                               x, y, t.Acc(),
                                mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
                }
                fmt.Println()