]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: cleaner handling of exponent under/overflow
authorRobert Griesemer <gri@golang.org>
Fri, 6 Mar 2015 01:32:57 +0000 (17:32 -0800)
committerRobert Griesemer <gri@golang.org>
Tue, 17 Mar 2015 16:09:34 +0000 (16:09 +0000)
Fixed several corner-case bugs and added corresponding tests.

Change-Id: I23096b9caeeff0956f65ab59fa91e168d0e47bb8
Reviewed-on: https://go-review.googlesource.com/7001
Reviewed-by: Alan Donovan <adonovan@google.com>
src/math/big/bits_test.go
src/math/big/float.go
src/math/big/float_test.go
src/math/big/floatconv.go

index 761f75628f8750bf8d138ce9a511f224d2a9866b..3ce24222d741081128119cad87f4b1ca1a8f69f9 100644 (file)
@@ -187,7 +187,12 @@ func (bits Bits) Float() *Float {
 
        // create corresponding float
        z := new(Float).SetInt(x) // normalized
-       z.setExp(int64(z.exp) + int64(min))
+       if e := int64(z.exp) + int64(min); MinExp <= e && e <= MaxExp {
+               z.exp = int32(e)
+       } else {
+               // this should never happen for our test cases
+               panic("exponent out of range")
+       }
        return z
 }
 
index d716c8ca59ffc1f5be45198edd62c1a0b046aa63..a86471e2a51e1847b47093cfb24081bf332e18df 100644 (file)
@@ -154,7 +154,7 @@ func (z *Float) SetPrec(prec uint) *Float {
                z.prec = 0
                if z.form == finite {
                        // truncate z to 0
-                       z.acc = z.cmpZero()
+                       z.acc = makeAcc(z.neg)
                        z.form = zero
                }
                return z
@@ -172,8 +172,8 @@ func (z *Float) SetPrec(prec uint) *Float {
        return z
 }
 
-func (x *Float) cmpZero() Accuracy {
-       if x.neg {
+func makeAcc(above bool) Accuracy {
+       if above {
                return Above
        }
        return Below
@@ -265,22 +265,24 @@ func (x *Float) MantExp(mant *Float) (exp int) {
        return
 }
 
-// setExp sets the exponent for z.
-// If e < MinExp, z becomes ±0; if e > MaxExp, z becomes ±Inf.
-func (z *Float) setExp(e int64) {
-       if debugFloat && z.form != finite {
-               panic("setExp called for non-finite Float")
+func (z *Float) setExpAndRound(exp int64, sbit uint) {
+       if exp < MinExp {
+               // underflow
+               z.acc = makeAcc(z.neg)
+               z.form = zero
+               return
        }
-       switch {
-       case e < MinExp:
-               // TODO(gri) check that accuracy is adjusted if necessary
-               z.form = zero // underflow
-       default:
-               z.exp = int32(e)
-       case e > MaxExp:
-               // TODO(gri) check that accuracy is adjusted if necessary
-               z.form = inf // overflow
+
+       if exp > MaxExp {
+               // overflow
+               z.acc = makeAcc(!z.neg)
+               z.form = inf
+               return
        }
+
+       z.form = finite
+       z.exp = int32(exp)
+       z.round(sbit)
 }
 
 // SetMantExp sets z to mant × 2**exp and and returns z.
@@ -308,7 +310,7 @@ func (z *Float) SetMantExp(mant *Float, exp int) *Float {
        if z.form != finite {
                return z
        }
-       z.setExp(int64(z.exp) + int64(exp))
+       z.setExpAndRound(int64(z.exp)+int64(exp), 0)
        return z
 }
 
@@ -368,14 +370,14 @@ func (x *Float) validate() {
        }
        m := len(x.mant)
        if m == 0 {
-               panic("nonzero finite x with empty mantissa")
+               panic("nonzero finite number with empty mantissa")
        }
        const msb = 1 << (_W - 1)
        if x.mant[m-1]&msb == 0 {
                panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Format('p', 0)))
        }
-       if x.prec <= 0 {
-               panic(fmt.Sprintf("invalid precision %d", x.prec))
+       if x.prec == 0 {
+               panic("zero precision finite number")
        }
 }
 
@@ -507,7 +509,14 @@ func (z *Float) round(sbit uint) {
                        shrVU(z.mant, z.mant, 1)
                        z.mant[n-1] |= 1 << (_W - 1)
                        // adjust exponent
-                       z.exp++
+                       if z.exp < MaxExp {
+                               z.exp++
+                       } else {
+                               // exponent overflow
+                               z.acc = makeAcc(!z.neg)
+                               z.form = inf
+                               return
+                       }
                }
                z.acc = Above
        }
@@ -515,8 +524,6 @@ func (z *Float) round(sbit uint) {
        // zero out trailing bits in least-significant word
        z.mant[0] &^= lsb - 1
 
-       // TODO(gri) can z.mant be all 0s at this point?
-
        // update accuracy
        if z.acc != Exact && z.neg {
                z.acc ^= Below | Above
@@ -655,13 +662,9 @@ func (z *Float) SetInt(x *Int) *Float {
                return z
        }
        // x != 0
-       z.form = finite
        z.mant = z.mant.set(x.abs)
        fnorm(z.mant)
-       z.setExp(int64(bits))
-       if z.prec < bits {
-               z.round(0)
-       }
+       z.setExpAndRound(int64(bits), 0)
        return z
 }
 
@@ -692,7 +695,7 @@ func (z *Float) SetInf(sign int) *Float {
 }
 
 // SetNaN sets z to a NaN value, and returns z.
-// The precision of z is unchanged and the result is always Undef.
+// The precision of z is unchanged and the result accuracy is always Undef.
 func (z *Float) SetNaN() *Float {
        z.acc = Undef
        z.form = nan
@@ -711,14 +714,15 @@ func (z *Float) Set(x *Float) *Float {
        }
        z.acc = Exact
        if z != x {
-               if z.prec == 0 {
-                       z.prec = x.prec
-               }
                z.form = x.form
                z.neg = x.neg
-               z.exp = x.exp
-               z.mant = z.mant.set(x.mant)
-               if z.prec < x.prec {
+               if x.form == finite {
+                       z.exp = x.exp
+                       z.mant = z.mant.set(x.mant)
+               }
+               if z.prec == 0 {
+                       z.prec = x.prec
+               } else if z.prec < x.prec {
                        z.round(0)
                }
        }
@@ -738,8 +742,10 @@ func (z *Float) Copy(x *Float) *Float {
                z.acc = x.acc
                z.form = x.form
                z.neg = x.neg
-               z.mant = z.mant.set(x.mant)
-               z.exp = x.exp
+               if z.form == finite {
+                       z.mant = z.mant.set(x.mant)
+                       z.exp = x.exp
+               }
        }
        return z
 }
@@ -821,7 +827,7 @@ func (x *Float) Int64() (int64, Accuracy) {
        switch x.form {
        case finite:
                // 0 < |x| < +Inf
-               acc := x.cmpZero()
+               acc := makeAcc(x.neg)
                if x.exp <= 0 {
                        // 0 < |x| < 1
                        return 0, acc
@@ -927,7 +933,7 @@ func (x *Float) Int(z *Int) (*Int, Accuracy) {
        switch x.form {
        case finite:
                // 0 < |x| < +Inf
-               acc := x.cmpZero()
+               acc := makeAcc(x.neg)
                if x.exp <= 0 {
                        // 0 < |x| < 1
                        return z.SetInt64(0), acc
@@ -960,7 +966,7 @@ func (x *Float) Int(z *Int) (*Int, Accuracy) {
                return z.SetInt64(0), Exact
 
        case inf:
-               return nil, x.cmpZero()
+               return nil, makeAcc(x.neg)
 
        case nan:
                return nil, Undef
@@ -1010,7 +1016,7 @@ func (x *Float) Rat(z *Rat) (*Rat, Accuracy) {
                return z.SetInt64(0), Exact
 
        case inf:
-               return nil, x.cmpZero()
+               return nil, makeAcc(x.neg)
 
        case nan:
                return nil, Undef
@@ -1035,8 +1041,22 @@ func (z *Float) Neg(x *Float) *Float {
        return z
 }
 
-// z = x + y, ignoring signs of x and y.
-// x.form and y.form must be finite.
+func validateBinaryOperands(x, y *Float) {
+       if !debugFloat {
+               // avoid performance bugs
+               panic("validateBinaryOperands called but debugFloat is not set")
+       }
+       if len(x.mant) == 0 {
+               panic("empty mantissa for x")
+       }
+       if len(y.mant) == 0 {
+               panic("empty mantissa for y")
+       }
+}
+
+// z = x + y, ignoring signs of x and y for the addition
+// but using the sign of z for rounding the result.
+// x and y must have a non-empty mantissa and valid exponent.
 func (z *Float) uadd(x, y *Float) {
        // Note: This implementation requires 2 shifts most of the
        // time. It is also inefficient if exponents or precisions
@@ -1048,8 +1068,8 @@ func (z *Float) uadd(x, y *Float) {
        // Point Addition With Exact Rounding (as in the MPFR Library)"
        // http://www.vinc17.net/research/papers/rnc6.pdf
 
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("uadd called with empty mantissa")
+       if debugFloat {
+               validateBinaryOperands(x, y)
        }
 
        // compute exponents ex, ey for mantissa with "binary point"
@@ -1075,20 +1095,20 @@ func (z *Float) uadd(x, y *Float) {
        }
        // len(z.mant) > 0
 
-       z.setExp(ex + int64(len(z.mant))*_W - fnorm(z.mant))
-       z.round(0)
+       z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0)
 }
 
-// z = x - y for x >= y, ignoring signs of x and y.
-// x.form and y.form must be finite.
+// z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction
+// but using the sign of z for rounding the result.
+// x and y must have a non-empty mantissa and valid exponent.
 func (z *Float) usub(x, y *Float) {
        // This code is symmetric to uadd.
        // We have not factored the common code out because
        // eventually uadd (and usub) should be optimized
        // by special-casing, and the code will diverge.
 
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("usub called with empty mantissa")
+       if debugFloat {
+               validateBinaryOperands(x, y)
        }
 
        ex := int64(x.exp) - int64(len(x.mant))*_W
@@ -1113,19 +1133,20 @@ func (z *Float) usub(x, y *Float) {
        if len(z.mant) == 0 {
                z.acc = Exact
                z.form = zero
+               z.neg = false
                return
        }
        // len(z.mant) > 0
 
-       z.setExp(ex + int64(len(z.mant))*_W - fnorm(z.mant))
-       z.round(0)
+       z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0)
 }
 
-// z = x * y, ignoring signs of x and y.
-// x.form and y.form must be finite.
+// z = x * y, ignoring signs of x and y for the multiplication
+// but using the sign of z for rounding the result.
+// x and y must have a non-empty mantissa and valid exponent.
 func (z *Float) umul(x, y *Float) {
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("umul called with empty mantissa")
+       if debugFloat {
+               validateBinaryOperands(x, y)
        }
 
        // Note: This is doing too much work if the precision
@@ -1137,16 +1158,15 @@ func (z *Float) umul(x, y *Float) {
        e := int64(x.exp) + int64(y.exp)
        z.mant = z.mant.mul(x.mant, y.mant)
 
-       // normalize mantissa
-       z.setExp(e - fnorm(z.mant))
-       z.round(0)
+       z.setExpAndRound(e-fnorm(z.mant), 0)
 }
 
-// z = x / y, ignoring signs of x and y.
-// x.form and y.form must be finite.
+// z = x / y, ignoring signs of x and y for the division
+// but using the sign of z for rounding the result.
+// x and y must have a non-empty mantissa and valid exponent.
 func (z *Float) uquo(x, y *Float) {
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("uquo called with empty mantissa")
+       if debugFloat {
+               validateBinaryOperands(x, y)
        }
 
        // mantissa length in words for desired result precision + 1
@@ -1172,13 +1192,8 @@ func (z *Float) uquo(x, y *Float) {
        // divide
        var r nat
        z.mant, r = z.mant.div(nil, xadj, y.mant)
-
-       // determine exponent
        e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W
 
-       // normalize mantissa
-       z.setExp(e - fnorm(z.mant))
-
        // The result is long enough to include (at least) the rounding bit.
        // If there's a non-zero remainder, the corresponding fractional part
        // (if it were computed), would have a non-zero sticky bit (if it were
@@ -1187,15 +1202,16 @@ func (z *Float) uquo(x, y *Float) {
        if len(r) > 0 {
                sbit = 1
        }
-       z.round(sbit)
+
+       z.setExpAndRound(e-fnorm(z.mant), sbit)
 }
 
 // ucmp returns Below, Exact, or Above, depending
-// on whether x < y, x == y, or x > y.
-// x.form and y.form must be finite.
+// on whether |x| < |y|, |x| == |y|, or |x| > |y|.
+// x and y must have a non-empty mantissa and valid exponent.
 func (x *Float) ucmp(y *Float) Accuracy {
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("ucmp called with empty mantissa")
+       if debugFloat {
+               validateBinaryOperands(x, y)
        }
 
        switch {
@@ -1284,7 +1300,6 @@ func (z *Float) Add(x, y *Float) *Float {
        }
 
        // x, y != 0
-       z.form = finite
        z.neg = x.neg
        if x.neg == y.neg {
                // x + y == x + y
@@ -1301,11 +1316,6 @@ func (z *Float) Add(x, y *Float) *Float {
                }
        }
 
-       // -0 is only possible for -0 + -0
-       if z.form == zero {
-               z.neg = false
-       }
-
        return z
 }
 
@@ -1340,7 +1350,6 @@ func (z *Float) Sub(x, y *Float) *Float {
        }
 
        // x, y != 0
-       z.form = finite
        z.neg = x.neg
        if x.neg != y.neg {
                // x - (-y) == x + y
@@ -1357,11 +1366,6 @@ func (z *Float) Sub(x, y *Float) *Float {
                }
        }
 
-       // -0 is only possible for -0 - 0
-       if z.form == zero {
-               z.neg = false
-       }
-
        return z
 }
 
@@ -1392,15 +1396,9 @@ func (z *Float) Mul(x, y *Float) *Float {
                return z
        }
 
-       if x.form == zero || y.form == zero {
-               z.acc = Exact
-               z.form = zero
-               return z
-       }
-
        // x, y != 0
-       z.form = finite
        z.umul(x, y)
+
        return z
 }
 
@@ -1426,6 +1424,7 @@ func (z *Float) Quo(x, y *Float) *Float {
                        // TODO(gri) handle Inf separately
                        return z.SetNaN()
                }
+               // x == ±0 || y == ±0
                if x.form == zero {
                        if y.form == zero {
                                return z.SetNaN()
@@ -1433,16 +1432,14 @@ func (z *Float) Quo(x, y *Float) *Float {
                        z.form = zero
                        return z
                }
-               // x != 0
-               if y.form == zero {
-                       z.form = inf
-                       return z
-               }
+               // y == ±0
+               z.form = inf
+               return z
        }
 
        // x, y != 0
-       z.form = finite
        z.uquo(x, y)
+
        return z
 }
 
@@ -1505,6 +1502,7 @@ func (res cmpResult) Geq() bool     { return res.acc&Below == 0 }
 //     +1 if 0 < x < +Inf
 //     +2 if x == +Inf
 //
+// x must not be NaN.
 func (x *Float) ord() int {
        var m int
        switch x.form {
@@ -1514,8 +1512,8 @@ func (x *Float) ord() int {
                return 0
        case inf:
                m = 2
-       case nan:
-               panic("unimplemented")
+       default:
+               panic("unreachable")
        }
        if x.neg {
                m = -m
index 683809bf56764ecd058d71b3adf0ad7bb8ca861f..379352c886d62732173527cd7d7aa006e184d4a5 100644 (file)
@@ -1389,6 +1389,68 @@ func TestFloatArithmeticSpecialValues(t *testing.T) {
        }
 }
 
+func TestFloatArithmeticOverflow(t *testing.T) {
+       for _, test := range []struct {
+               prec       uint
+               mode       RoundingMode
+               op         byte
+               x, y, want string
+               acc        Accuracy
+       }{
+               {4, ToNearestEven, '+', "0", "0", "0", Exact},                // smoke test
+               {4, ToNearestEven, '+', "0x.8p0", "0x.8p0", "0x.8p1", Exact}, // smoke test
+
+               {4, ToNearestEven, '+', "0", "0x.8p2147483647", "0x.8p2147483647", Exact},
+               {4, ToNearestEven, '+', "0x.8p2147483500", "0x.8p2147483647", "0x.8p2147483647", Below}, // rounded to zero
+               {4, ToNearestEven, '+', "0x.8p2147483647", "0x.8p2147483647", "+Inf", Above},            // exponent overflow in +
+               {4, ToNearestEven, '+', "-0x.8p2147483647", "-0x.8p2147483647", "-Inf", Below},          // exponent overflow in +
+               {4, ToNearestEven, '-', "-0x.8p2147483647", "0x.8p2147483647", "-Inf", Below},           // exponent overflow in -
+
+               {4, ToZero, '+', "0x.fp2147483647", "0x.8p2147483643", "0x.fp2147483647", Below}, // rounded to zero
+               {4, ToNearestEven, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above},     // exponent overflow in rounding
+               {4, AwayFromZero, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above},      // exponent overflow in rounding
+
+               {4, AwayFromZero, '-', "-0x.fp2147483647", "0x.8p2147483644", "-Inf", Below},       // exponent overflow in rounding
+               {4, ToNearestEven, '-', "-0x.fp2147483647", "0x.8p2147483643", "-Inf", Below},      // exponent overflow in rounding
+               {4, ToZero, '-', "-0x.fp2147483647", "0x.8p2147483643", "-0x.fp2147483647", Above}, // rounded to zero
+
+               {4, ToNearestEven, '+', "0", "0x.8p-2147483648", "0x.8p-2147483648", Exact},
+               {4, ToNearestEven, '+', "0x.8p-2147483648", "0x.8p-2147483648", "0x.8p-2147483647", Exact},
+
+               {4, ToNearestEven, '*', "1", "0x.8p2147483647", "0x.8p2147483647", Exact},
+               {4, ToNearestEven, '*', "2", "0x.8p2147483647", "+Inf", Above},  // exponent overflow in *
+               {4, ToNearestEven, '*', "-2", "0x.8p2147483647", "-Inf", Below}, // exponent overflow in *
+
+               {4, ToNearestEven, '/', "0.5", "0x.8p2147483647", "0x.8p-2147483646", Exact},
+               {4, ToNearestEven, '/', "0x.8p0", "0x.8p2147483647", "0x.8p-2147483646", Exact},
+               {4, ToNearestEven, '/', "0x.8p-1", "0x.8p2147483647", "0x.8p-2147483647", Exact},
+               {4, ToNearestEven, '/', "0x.8p-2", "0x.8p2147483647", "0x.8p-2147483648", Exact},
+               {4, ToNearestEven, '/', "0x.8p-3", "0x.8p2147483647", "0", Below}, // exponent underflow in /
+       } {
+               x := makeFloat(test.x)
+               y := makeFloat(test.y)
+               z := new(Float).SetPrec(test.prec).SetMode(test.mode)
+               switch test.op {
+               case '+':
+                       z.Add(x, y)
+               case '-':
+                       z.Sub(x, y)
+               case '*':
+                       z.Mul(x, y)
+               case '/':
+                       z.Quo(x, y)
+               default:
+                       panic("unreachable")
+               }
+               if got := z.Format('p', 0); got != test.want || z.Acc() != test.acc {
+                       t.Errorf(
+                               "prec = %d (%s): %s %c %s = %s (%s); want %s (%s)",
+                               test.prec, test.mode, x.Format('p', 0), test.op, y.Format('p', 0), got, z.Acc(), test.want, test.acc,
+                       )
+               }
+       }
+}
+
 // TODO(gri) Add tests that check correctness in the presence of aliasing.
 
 // For rounding modes ToNegativeInf and ToPositiveInf, rounding is affected
index a3d0ff97a8d10db2e555bf048cd40ef299e3a385..f6a78b794c611f146c242ce3bd46a1719ccef89a 100644 (file)
@@ -101,8 +101,6 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
        }
        // len(z.mant) > 0
 
-       z.form = finite
-
        // The mantissa may have a decimal point (fcount <= 0) and there
        // may be a nonzero exponent exp. The decimal point amounts to a
        // division by b**(-fcount). An exponent means multiplication by
@@ -142,7 +140,14 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
        // we don't need exp anymore
 
        // apply 2**exp2
-       z.setExp(exp2)
+       if MinExp <= exp2 && exp2 <= MaxExp {
+               z.form = finite
+               z.exp = int32(exp2)
+       } else {
+               f = nil
+               err = fmt.Errorf("exponent overflow")
+               return
+       }
 
        if exp10 == 0 {
                // no decimal exponent to consider
@@ -160,7 +165,6 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
        fpowTen := new(Float).SetInt(new(Int).SetBits(powTen))
 
        // apply 10**exp10
-       // (uquo and umul do the rounding)
        if exp10 < 0 {
                z.uquo(z, fpowTen)
        } else {