From 00c73f5c6e4b80a24eb19218c006c8a3f08e1ed8 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Thu, 5 Mar 2015 17:32:57 -0800 Subject: [PATCH] math/big: cleaner handling of exponent under/overflow Fixed several corner-case bugs and added corresponding tests. Change-Id: I23096b9caeeff0956f65ab59fa91e168d0e47bb8 Reviewed-on: https://go-review.googlesource.com/7001 Reviewed-by: Alan Donovan --- src/math/big/bits_test.go | 7 +- src/math/big/float.go | 202 ++++++++++++++++++------------------- src/math/big/float_test.go | 62 ++++++++++++ src/math/big/floatconv.go | 12 ++- 4 files changed, 176 insertions(+), 107 deletions(-) diff --git a/src/math/big/bits_test.go b/src/math/big/bits_test.go index 761f75628f..3ce24222d7 100644 --- a/src/math/big/bits_test.go +++ b/src/math/big/bits_test.go @@ -187,7 +187,12 @@ func (bits Bits) Float() *Float { // create corresponding float z := new(Float).SetInt(x) // normalized - z.setExp(int64(z.exp) + int64(min)) + if e := int64(z.exp) + int64(min); MinExp <= e && e <= MaxExp { + z.exp = int32(e) + } else { + // this should never happen for our test cases + panic("exponent out of range") + } return z } diff --git a/src/math/big/float.go b/src/math/big/float.go index d716c8ca59..a86471e2a5 100644 --- a/src/math/big/float.go +++ b/src/math/big/float.go @@ -154,7 +154,7 @@ func (z *Float) SetPrec(prec uint) *Float { z.prec = 0 if z.form == finite { // truncate z to 0 - z.acc = z.cmpZero() + z.acc = makeAcc(z.neg) z.form = zero } return z @@ -172,8 +172,8 @@ func (z *Float) SetPrec(prec uint) *Float { return z } -func (x *Float) cmpZero() Accuracy { - if x.neg { +func makeAcc(above bool) Accuracy { + if above { return Above } return Below @@ -265,22 +265,24 @@ func (x *Float) MantExp(mant *Float) (exp int) { return } -// setExp sets the exponent for z. -// If e < MinExp, z becomes ±0; if e > MaxExp, z becomes ±Inf. -func (z *Float) setExp(e int64) { - if debugFloat && z.form != finite { - panic("setExp called for non-finite Float") +func (z *Float) setExpAndRound(exp int64, sbit uint) { + if exp < MinExp { + // underflow + z.acc = makeAcc(z.neg) + z.form = zero + return } - switch { - case e < MinExp: - // TODO(gri) check that accuracy is adjusted if necessary - z.form = zero // underflow - default: - z.exp = int32(e) - case e > MaxExp: - // TODO(gri) check that accuracy is adjusted if necessary - z.form = inf // overflow + + if exp > MaxExp { + // overflow + z.acc = makeAcc(!z.neg) + z.form = inf + return } + + z.form = finite + z.exp = int32(exp) + z.round(sbit) } // SetMantExp sets z to mant × 2**exp and and returns z. @@ -308,7 +310,7 @@ func (z *Float) SetMantExp(mant *Float, exp int) *Float { if z.form != finite { return z } - z.setExp(int64(z.exp) + int64(exp)) + z.setExpAndRound(int64(z.exp)+int64(exp), 0) return z } @@ -368,14 +370,14 @@ func (x *Float) validate() { } m := len(x.mant) if m == 0 { - panic("nonzero finite x with empty mantissa") + panic("nonzero finite number with empty mantissa") } const msb = 1 << (_W - 1) if x.mant[m-1]&msb == 0 { panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Format('p', 0))) } - if x.prec <= 0 { - panic(fmt.Sprintf("invalid precision %d", x.prec)) + if x.prec == 0 { + panic("zero precision finite number") } } @@ -507,7 +509,14 @@ func (z *Float) round(sbit uint) { shrVU(z.mant, z.mant, 1) z.mant[n-1] |= 1 << (_W - 1) // adjust exponent - z.exp++ + if z.exp < MaxExp { + z.exp++ + } else { + // exponent overflow + z.acc = makeAcc(!z.neg) + z.form = inf + return + } } z.acc = Above } @@ -515,8 +524,6 @@ func (z *Float) round(sbit uint) { // zero out trailing bits in least-significant word z.mant[0] &^= lsb - 1 - // TODO(gri) can z.mant be all 0s at this point? - // update accuracy if z.acc != Exact && z.neg { z.acc ^= Below | Above @@ -655,13 +662,9 @@ func (z *Float) SetInt(x *Int) *Float { return z } // x != 0 - z.form = finite z.mant = z.mant.set(x.abs) fnorm(z.mant) - z.setExp(int64(bits)) - if z.prec < bits { - z.round(0) - } + z.setExpAndRound(int64(bits), 0) return z } @@ -692,7 +695,7 @@ func (z *Float) SetInf(sign int) *Float { } // SetNaN sets z to a NaN value, and returns z. -// The precision of z is unchanged and the result is always Undef. +// The precision of z is unchanged and the result accuracy is always Undef. func (z *Float) SetNaN() *Float { z.acc = Undef z.form = nan @@ -711,14 +714,15 @@ func (z *Float) Set(x *Float) *Float { } z.acc = Exact if z != x { - if z.prec == 0 { - z.prec = x.prec - } z.form = x.form z.neg = x.neg - z.exp = x.exp - z.mant = z.mant.set(x.mant) - if z.prec < x.prec { + if x.form == finite { + z.exp = x.exp + z.mant = z.mant.set(x.mant) + } + if z.prec == 0 { + z.prec = x.prec + } else if z.prec < x.prec { z.round(0) } } @@ -738,8 +742,10 @@ func (z *Float) Copy(x *Float) *Float { z.acc = x.acc z.form = x.form z.neg = x.neg - z.mant = z.mant.set(x.mant) - z.exp = x.exp + if z.form == finite { + z.mant = z.mant.set(x.mant) + z.exp = x.exp + } } return z } @@ -821,7 +827,7 @@ func (x *Float) Int64() (int64, Accuracy) { switch x.form { case finite: // 0 < |x| < +Inf - acc := x.cmpZero() + acc := makeAcc(x.neg) if x.exp <= 0 { // 0 < |x| < 1 return 0, acc @@ -927,7 +933,7 @@ func (x *Float) Int(z *Int) (*Int, Accuracy) { switch x.form { case finite: // 0 < |x| < +Inf - acc := x.cmpZero() + acc := makeAcc(x.neg) if x.exp <= 0 { // 0 < |x| < 1 return z.SetInt64(0), acc @@ -960,7 +966,7 @@ func (x *Float) Int(z *Int) (*Int, Accuracy) { return z.SetInt64(0), Exact case inf: - return nil, x.cmpZero() + return nil, makeAcc(x.neg) case nan: return nil, Undef @@ -1010,7 +1016,7 @@ func (x *Float) Rat(z *Rat) (*Rat, Accuracy) { return z.SetInt64(0), Exact case inf: - return nil, x.cmpZero() + return nil, makeAcc(x.neg) case nan: return nil, Undef @@ -1035,8 +1041,22 @@ func (z *Float) Neg(x *Float) *Float { return z } -// z = x + y, ignoring signs of x and y. -// x.form and y.form must be finite. +func validateBinaryOperands(x, y *Float) { + if !debugFloat { + // avoid performance bugs + panic("validateBinaryOperands called but debugFloat is not set") + } + if len(x.mant) == 0 { + panic("empty mantissa for x") + } + if len(y.mant) == 0 { + panic("empty mantissa for y") + } +} + +// z = x + y, ignoring signs of x and y for the addition +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. func (z *Float) uadd(x, y *Float) { // Note: This implementation requires 2 shifts most of the // time. It is also inefficient if exponents or precisions @@ -1048,8 +1068,8 @@ func (z *Float) uadd(x, y *Float) { // Point Addition With Exact Rounding (as in the MPFR Library)" // http://www.vinc17.net/research/papers/rnc6.pdf - if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { - panic("uadd called with empty mantissa") + if debugFloat { + validateBinaryOperands(x, y) } // compute exponents ex, ey for mantissa with "binary point" @@ -1075,20 +1095,20 @@ func (z *Float) uadd(x, y *Float) { } // len(z.mant) > 0 - z.setExp(ex + int64(len(z.mant))*_W - fnorm(z.mant)) - z.round(0) + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) } -// z = x - y for x >= y, ignoring signs of x and y. -// x.form and y.form must be finite. +// z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. func (z *Float) usub(x, y *Float) { // This code is symmetric to uadd. // We have not factored the common code out because // eventually uadd (and usub) should be optimized // by special-casing, and the code will diverge. - if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { - panic("usub called with empty mantissa") + if debugFloat { + validateBinaryOperands(x, y) } ex := int64(x.exp) - int64(len(x.mant))*_W @@ -1113,19 +1133,20 @@ func (z *Float) usub(x, y *Float) { if len(z.mant) == 0 { z.acc = Exact z.form = zero + z.neg = false return } // len(z.mant) > 0 - z.setExp(ex + int64(len(z.mant))*_W - fnorm(z.mant)) - z.round(0) + z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) } -// z = x * y, ignoring signs of x and y. -// x.form and y.form must be finite. +// z = x * y, ignoring signs of x and y for the multiplication +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. func (z *Float) umul(x, y *Float) { - if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { - panic("umul called with empty mantissa") + if debugFloat { + validateBinaryOperands(x, y) } // Note: This is doing too much work if the precision @@ -1137,16 +1158,15 @@ func (z *Float) umul(x, y *Float) { e := int64(x.exp) + int64(y.exp) z.mant = z.mant.mul(x.mant, y.mant) - // normalize mantissa - z.setExp(e - fnorm(z.mant)) - z.round(0) + z.setExpAndRound(e-fnorm(z.mant), 0) } -// z = x / y, ignoring signs of x and y. -// x.form and y.form must be finite. +// z = x / y, ignoring signs of x and y for the division +// but using the sign of z for rounding the result. +// x and y must have a non-empty mantissa and valid exponent. func (z *Float) uquo(x, y *Float) { - if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { - panic("uquo called with empty mantissa") + if debugFloat { + validateBinaryOperands(x, y) } // mantissa length in words for desired result precision + 1 @@ -1172,13 +1192,8 @@ func (z *Float) uquo(x, y *Float) { // divide var r nat z.mant, r = z.mant.div(nil, xadj, y.mant) - - // determine exponent e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W - // normalize mantissa - z.setExp(e - fnorm(z.mant)) - // The result is long enough to include (at least) the rounding bit. // If there's a non-zero remainder, the corresponding fractional part // (if it were computed), would have a non-zero sticky bit (if it were @@ -1187,15 +1202,16 @@ func (z *Float) uquo(x, y *Float) { if len(r) > 0 { sbit = 1 } - z.round(sbit) + + z.setExpAndRound(e-fnorm(z.mant), sbit) } // ucmp returns Below, Exact, or Above, depending -// on whether x < y, x == y, or x > y. -// x.form and y.form must be finite. +// on whether |x| < |y|, |x| == |y|, or |x| > |y|. +// x and y must have a non-empty mantissa and valid exponent. func (x *Float) ucmp(y *Float) Accuracy { - if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { - panic("ucmp called with empty mantissa") + if debugFloat { + validateBinaryOperands(x, y) } switch { @@ -1284,7 +1300,6 @@ func (z *Float) Add(x, y *Float) *Float { } // x, y != 0 - z.form = finite z.neg = x.neg if x.neg == y.neg { // x + y == x + y @@ -1301,11 +1316,6 @@ func (z *Float) Add(x, y *Float) *Float { } } - // -0 is only possible for -0 + -0 - if z.form == zero { - z.neg = false - } - return z } @@ -1340,7 +1350,6 @@ func (z *Float) Sub(x, y *Float) *Float { } // x, y != 0 - z.form = finite z.neg = x.neg if x.neg != y.neg { // x - (-y) == x + y @@ -1357,11 +1366,6 @@ func (z *Float) Sub(x, y *Float) *Float { } } - // -0 is only possible for -0 - 0 - if z.form == zero { - z.neg = false - } - return z } @@ -1392,15 +1396,9 @@ func (z *Float) Mul(x, y *Float) *Float { return z } - if x.form == zero || y.form == zero { - z.acc = Exact - z.form = zero - return z - } - // x, y != 0 - z.form = finite z.umul(x, y) + return z } @@ -1426,6 +1424,7 @@ func (z *Float) Quo(x, y *Float) *Float { // TODO(gri) handle Inf separately return z.SetNaN() } + // x == ±0 || y == ±0 if x.form == zero { if y.form == zero { return z.SetNaN() @@ -1433,16 +1432,14 @@ func (z *Float) Quo(x, y *Float) *Float { z.form = zero return z } - // x != 0 - if y.form == zero { - z.form = inf - return z - } + // y == ±0 + z.form = inf + return z } // x, y != 0 - z.form = finite z.uquo(x, y) + return z } @@ -1505,6 +1502,7 @@ func (res cmpResult) Geq() bool { return res.acc&Below == 0 } // +1 if 0 < x < +Inf // +2 if x == +Inf // +// x must not be NaN. func (x *Float) ord() int { var m int switch x.form { @@ -1514,8 +1512,8 @@ func (x *Float) ord() int { return 0 case inf: m = 2 - case nan: - panic("unimplemented") + default: + panic("unreachable") } if x.neg { m = -m diff --git a/src/math/big/float_test.go b/src/math/big/float_test.go index 683809bf56..379352c886 100644 --- a/src/math/big/float_test.go +++ b/src/math/big/float_test.go @@ -1389,6 +1389,68 @@ func TestFloatArithmeticSpecialValues(t *testing.T) { } } +func TestFloatArithmeticOverflow(t *testing.T) { + for _, test := range []struct { + prec uint + mode RoundingMode + op byte + x, y, want string + acc Accuracy + }{ + {4, ToNearestEven, '+', "0", "0", "0", Exact}, // smoke test + {4, ToNearestEven, '+', "0x.8p0", "0x.8p0", "0x.8p1", Exact}, // smoke test + + {4, ToNearestEven, '+', "0", "0x.8p2147483647", "0x.8p2147483647", Exact}, + {4, ToNearestEven, '+', "0x.8p2147483500", "0x.8p2147483647", "0x.8p2147483647", Below}, // rounded to zero + {4, ToNearestEven, '+', "0x.8p2147483647", "0x.8p2147483647", "+Inf", Above}, // exponent overflow in + + {4, ToNearestEven, '+', "-0x.8p2147483647", "-0x.8p2147483647", "-Inf", Below}, // exponent overflow in + + {4, ToNearestEven, '-', "-0x.8p2147483647", "0x.8p2147483647", "-Inf", Below}, // exponent overflow in - + + {4, ToZero, '+', "0x.fp2147483647", "0x.8p2147483643", "0x.fp2147483647", Below}, // rounded to zero + {4, ToNearestEven, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above}, // exponent overflow in rounding + {4, AwayFromZero, '+', "0x.fp2147483647", "0x.8p2147483643", "+Inf", Above}, // exponent overflow in rounding + + {4, AwayFromZero, '-', "-0x.fp2147483647", "0x.8p2147483644", "-Inf", Below}, // exponent overflow in rounding + {4, ToNearestEven, '-', "-0x.fp2147483647", "0x.8p2147483643", "-Inf", Below}, // exponent overflow in rounding + {4, ToZero, '-', "-0x.fp2147483647", "0x.8p2147483643", "-0x.fp2147483647", Above}, // rounded to zero + + {4, ToNearestEven, '+', "0", "0x.8p-2147483648", "0x.8p-2147483648", Exact}, + {4, ToNearestEven, '+', "0x.8p-2147483648", "0x.8p-2147483648", "0x.8p-2147483647", Exact}, + + {4, ToNearestEven, '*', "1", "0x.8p2147483647", "0x.8p2147483647", Exact}, + {4, ToNearestEven, '*', "2", "0x.8p2147483647", "+Inf", Above}, // exponent overflow in * + {4, ToNearestEven, '*', "-2", "0x.8p2147483647", "-Inf", Below}, // exponent overflow in * + + {4, ToNearestEven, '/', "0.5", "0x.8p2147483647", "0x.8p-2147483646", Exact}, + {4, ToNearestEven, '/', "0x.8p0", "0x.8p2147483647", "0x.8p-2147483646", Exact}, + {4, ToNearestEven, '/', "0x.8p-1", "0x.8p2147483647", "0x.8p-2147483647", Exact}, + {4, ToNearestEven, '/', "0x.8p-2", "0x.8p2147483647", "0x.8p-2147483648", Exact}, + {4, ToNearestEven, '/', "0x.8p-3", "0x.8p2147483647", "0", Below}, // exponent underflow in / + } { + x := makeFloat(test.x) + y := makeFloat(test.y) + z := new(Float).SetPrec(test.prec).SetMode(test.mode) + switch test.op { + case '+': + z.Add(x, y) + case '-': + z.Sub(x, y) + case '*': + z.Mul(x, y) + case '/': + z.Quo(x, y) + default: + panic("unreachable") + } + if got := z.Format('p', 0); got != test.want || z.Acc() != test.acc { + t.Errorf( + "prec = %d (%s): %s %c %s = %s (%s); want %s (%s)", + test.prec, test.mode, x.Format('p', 0), test.op, y.Format('p', 0), got, z.Acc(), test.want, test.acc, + ) + } + } +} + // TODO(gri) Add tests that check correctness in the presence of aliasing. // For rounding modes ToNegativeInf and ToPositiveInf, rounding is affected diff --git a/src/math/big/floatconv.go b/src/math/big/floatconv.go index a3d0ff97a8..f6a78b794c 100644 --- a/src/math/big/floatconv.go +++ b/src/math/big/floatconv.go @@ -101,8 +101,6 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) { } // len(z.mant) > 0 - z.form = finite - // The mantissa may have a decimal point (fcount <= 0) and there // may be a nonzero exponent exp. The decimal point amounts to a // division by b**(-fcount). An exponent means multiplication by @@ -142,7 +140,14 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) { // we don't need exp anymore // apply 2**exp2 - z.setExp(exp2) + if MinExp <= exp2 && exp2 <= MaxExp { + z.form = finite + z.exp = int32(exp2) + } else { + f = nil + err = fmt.Errorf("exponent overflow") + return + } if exp10 == 0 { // no decimal exponent to consider @@ -160,7 +165,6 @@ func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) { fpowTen := new(Float).SetInt(new(Int).SetBits(powTen)) // apply 10**exp10 - // (uquo and umul do the rounding) if exp10 < 0 { z.uquo(z, fpowTen) } else { -- 2.48.1