]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: implemented decimal rounding for Float-to-string conversion
authorRobert Griesemer <gri@golang.org>
Tue, 3 Feb 2015 22:29:03 +0000 (14:29 -0800)
committerRobert Griesemer <gri@golang.org>
Wed, 4 Feb 2015 21:06:37 +0000 (21:06 +0000)
Change-Id: Id508ca2f6c087861e8c6bc536bc39e54dce09825
Reviewed-on: https://go-review.googlesource.com/3840
Reviewed-by: Alan Donovan <adonovan@google.com>
src/math/big/decimal.go
src/math/big/decimal_test.go

index f4c535acdbc61a0993943e33afd9acd159a71e18..670668baaf8b294f9923be17cdfb0d2c1509f0b2 100644 (file)
@@ -71,19 +71,18 @@ func (x *decimal) init(m nat, shift int) {
        x.exp = n
        // Trim trailing zeros; instead the exponent is tracking
        // the decimal point independent of the number of digits.
-       for n > 0 && s[n-1] == 0 {
+       for n > 0 && s[n-1] == '0' {
                n--
        }
-       x.mant = make([]byte, n)
-       copy(x.mant, s)
+       x.mant = append(x.mant[:0], s[:n]...)
 
        // Do any (remaining) shift right in decimal representation.
        if shift < 0 {
                for shift < -maxShift {
-                       x.shr(maxShift)
+                       shr(x, maxShift)
                        shift += maxShift
                }
-               x.shr(uint(-shift))
+               shr(x, uint(-shift))
        }
 }
 
@@ -94,7 +93,7 @@ func (x *decimal) init(m nat, shift int) {
 // single +'0' pass at the end).
 
 // shr implements x >> s, for s <= maxShift.
-func (x *decimal) shr(s uint) {
+func shr(x *decimal, s uint) {
        // Division by 1<<s using shift-and-subtract algorithm.
 
        // pick up enough leading digits to cover first shift
@@ -146,12 +145,7 @@ func (x *decimal) shr(s uint) {
                n = n * 10
        }
 
-       // remove trailing zeros
-       w = len(x.mant)
-       for w > 0 && x.mant[w-1] == '0' {
-               w--
-       }
-       x.mant = x.mant[:w]
+       trim(x)
 }
 
 func (x *decimal) String() string {
@@ -189,3 +183,73 @@ func appendZeros(buf []byte, n int) []byte {
        }
        return buf
 }
+
+// shouldRoundUp reports if x should be rounded up
+// if shortened to n digits. n must be a valid index
+// for x.mant.
+func shouldRoundUp(x *decimal, n int) bool {
+       if x.mant[n] == '5' && n+1 == len(x.mant) {
+               // exactly halfway - round to even
+               return n > 0 && (x.mant[n-1]-'0')&1 != 0
+       }
+       // not halfway - digit tells all (x.mant has no trailing zeros)
+       return x.mant[n] >= '5'
+}
+
+// round sets x to (at most) n mantissa digits by rounding it
+// to the nearest even value with n (or fever) mantissa digits.
+// If n < 0, x remains unchanged.
+func (x *decimal) round(n int) {
+       if n < 0 || n >= len(x.mant) {
+               return // nothing to do
+       }
+
+       if shouldRoundUp(x, n) {
+               x.roundUp(n)
+       } else {
+               x.roundDown(n)
+       }
+}
+
+func (x *decimal) roundUp(n int) {
+       if n < 0 || n >= len(x.mant) {
+               return // nothing to do
+       }
+       // 0 <= n < len(x.mant)
+
+       // find first digit < '9'
+       for n > 0 && x.mant[n-1] >= '9' {
+               n--
+       }
+
+       if n == 0 {
+               // all digits are '9's => round up to '1' and update exponent
+               x.mant[0] = '1' // ok since len(x.mant) > n
+               x.mant = x.mant[:1]
+               x.exp++
+               return
+       }
+
+       // n > 0 && x.mant[n-1] < '9'
+       x.mant[n-1]++
+       x.mant = x.mant[:n]
+       // x already trimmed
+}
+
+func (x *decimal) roundDown(n int) {
+       if n < 0 || n >= len(x.mant) {
+               return // nothing to do
+       }
+       x.mant = x.mant[:n]
+       trim(x)
+}
+
+// trim cuts off any trailing zeros from x's mantissa;
+// they are meaningless for the value of x.
+func trim(x *decimal) {
+       i := len(x.mant)
+       for i > 0 && x.mant[i-1] == '0' {
+               i--
+       }
+       x.mant = x.mant[:i]
+}
index ce20800ef06ec041a3c3810ed7b7a2ba43453957..81e022a47dd60848e06cf8b802ce7e8979520f9e 100644 (file)
@@ -49,3 +49,58 @@ func TestDecimalInit(t *testing.T) {
                }
        }
 }
+
+func TestDecimalRounding(t *testing.T) {
+       for _, test := range []struct {
+               x              uint64
+               n              int
+               down, even, up string
+       }{
+               {0, 0, "0", "0", "0"},
+               {0, 1, "0", "0", "0"},
+
+               {1, 0, "0", "0", "10"},
+               {5, 0, "0", "0", "10"},
+               {9, 0, "0", "10", "10"},
+
+               {15, 1, "10", "20", "20"},
+               {45, 1, "40", "40", "50"},
+               {95, 1, "90", "100", "100"},
+
+               {12344999, 4, "12340000", "12340000", "12350000"},
+               {12345000, 4, "12340000", "12340000", "12350000"},
+               {12345001, 4, "12340000", "12350000", "12350000"},
+               {23454999, 4, "23450000", "23450000", "23460000"},
+               {23455000, 4, "23450000", "23460000", "23460000"},
+               {23455001, 4, "23450000", "23460000", "23460000"},
+
+               {99994999, 4, "99990000", "99990000", "100000000"},
+               {99995000, 4, "99990000", "100000000", "100000000"},
+               {99999999, 4, "99990000", "100000000", "100000000"},
+
+               {12994999, 4, "12990000", "12990000", "13000000"},
+               {12995000, 4, "12990000", "13000000", "13000000"},
+               {12999999, 4, "12990000", "13000000", "13000000"},
+       } {
+               x := nat(nil).setUint64(test.x)
+
+               var d decimal
+               d.init(x, 0)
+               d.roundDown(test.n)
+               if got := d.String(); got != test.down {
+                       t.Errorf("roundDown(%d, %d) = %s; want %s", test.x, test.n, got, test.down)
+               }
+
+               d.init(x, 0)
+               d.round(test.n)
+               if got := d.String(); got != test.even {
+                       t.Errorf("round(%d, %d) = %s; want %s", test.x, test.n, got, test.even)
+               }
+
+               d.init(x, 0)
+               d.roundUp(test.n)
+               if got := d.String(); got != test.up {
+                       t.Errorf("roundUp(%d, %d) = %s; want %s", test.x, test.n, got, test.up)
+               }
+       }
+}