]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: sketched out complete set of Float/string conversion functions
authorRobert Griesemer <gri@golang.org>
Wed, 28 Jan 2015 00:51:59 +0000 (16:51 -0800)
committerRobert Griesemer <gri@golang.org>
Wed, 28 Jan 2015 23:22:16 +0000 (23:22 +0000)
Also:
- use io.ByteScanner rather than io.RuneScanner internally
- minor simplifications in Float.Add/Sub

Change-Id: Iae0e99384128dba9eccf68592c4fd389e2bd3b4f
Reviewed-on: https://go-review.googlesource.com/3380
Reviewed-by: Rob Pike <r@golang.org>
src/math/big/float.go
src/math/big/float_test.go
src/math/big/int.go
src/math/big/nat.go
src/math/big/rat.go

index 24fdacbe8802b0c60efe05d33e9b6fbd025adaf9..0819eca729082529f2dee00115b0d1c979a27832 100644 (file)
@@ -179,7 +179,7 @@ func (x *Float) validate() {
        const msb = 1 << (_W - 1)
        m := len(x.mant)
        if x.mant[m-1]&msb == 0 {
-               panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.PString()))
+               panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.pstring()))
        }
        if x.prec <= 0 {
                panic(fmt.Sprintf("invalid precision %d", x.prec))
@@ -566,10 +566,6 @@ func (z *Float) Neg(x *Float) *Float {
 // z = x + y, ignoring signs of x and y.
 // x and y must not be 0.
 func (z *Float) uadd(x, y *Float) {
-       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
-               panic("uadd called with 0 argument")
-       }
-
        // Note: This implementation requires 2 shifts most of the
        // time. It is also inefficient if exponents or precisions
        // differ by wide margins. The following article describes
@@ -580,60 +576,61 @@ func (z *Float) uadd(x, y *Float) {
        // Point Addition With Exact Rounding (as in the MPFR Library)"
        // http://www.vinc17.net/research/papers/rnc6.pdf
 
+       if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
+               panic("uadd called with 0 argument")
+       }
+
+       // compute exponents ex, ey for mantissa with "binary point"
+       // on the right (mantissa.0) - use int64 to avoid overflow
        ex := int64(x.exp) - int64(len(x.mant))*_W
        ey := int64(y.exp) - int64(len(y.mant))*_W
 
-       var e int64
+       // TODO(gri) having a combined add-and-shift primitive
+       //           could make this code significantly faster
        switch {
        case ex < ey:
                t := z.mant.shl(y.mant, uint(ey-ex))
                z.mant = t.add(x.mant, t)
-               e = ex
        default:
-               // ex == ey
+               // ex == ey, no shift needed
                z.mant = z.mant.add(x.mant, y.mant)
-               e = ex
        case ex > ey:
                t := z.mant.shl(x.mant, uint(ex-ey))
                z.mant = t.add(t, y.mant)
-               e = ey
+               ex = ey
        }
        // len(z.mant) > 0
 
-       z.setExp(e + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
+       z.setExp(ex + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
        z.round(0)
 }
 
 // z = x - y for x >= y, ignoring signs of x and y.
 // x and y must not be zero.
 func (z *Float) usub(x, y *Float) {
+       // This code is symmetric to uadd.
+       // We have not factored the common code out because
+       // eventually uadd (and usub) should be optimized
+       // by special-casing, and the code will diverge.
+
        if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
                panic("usub called with 0 argument")
        }
 
-       if x.exp < y.exp {
-               panic("underflow")
-       }
-
-       // This code is symmetric to uadd.
-
        ex := int64(x.exp) - int64(len(x.mant))*_W
        ey := int64(y.exp) - int64(len(y.mant))*_W
 
-       var e int64
        switch {
        case ex < ey:
                t := z.mant.shl(y.mant, uint(ey-ex))
                z.mant = t.sub(x.mant, t)
-               e = ex
        default:
-               // ex == ey
+               // ex == ey, no shift needed
                z.mant = z.mant.sub(x.mant, y.mant)
-               e = ex
        case ex > ey:
                t := z.mant.shl(x.mant, uint(ex-ey))
                z.mant = t.sub(t, y.mant)
-               e = ey
+               ex = ey
        }
 
        // operands may have cancelled each other out
@@ -644,7 +641,7 @@ func (z *Float) usub(x, y *Float) {
        }
        // len(z.mant) > 0
 
-       z.setExp(e + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
+       z.setExp(ex + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
        z.round(0)
 }
 
@@ -962,13 +959,9 @@ func (x *Float) Sign() int {
        return 1
 }
 
-func (x *Float) String() string {
-       return x.PString() // TODO(gri) fix this
-}
-
-// PString returns x as a string in the format ["-"] "0." mantissa "p" exponent
+// pstring returns x as a string in the format ["-"] "0." mantissa "p" exponent
 // with a hexadecimal mantissa and a decimal exponent, or ["-"] "0" if x is zero.
-func (x *Float) PString() string {
+func (x *Float) pstring() string {
        // TODO(gri) handle Inf
        var buf bytes.Buffer
        if x.neg {
@@ -1029,7 +1022,7 @@ func (z *Float) SetString(s string) (*Float, bool) {
 // is 1.2 * 2**3. If the operation failed, the value of z is undefined but
 // the returned value is nil.
 //
-func (z *Float) scan(r io.RuneScanner) (f *Float, err error) {
+func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
        // sign
        z.neg, err = scanSign(r)
        if err != nil {
@@ -1091,3 +1084,106 @@ func (z *Float) scan(r io.RuneScanner) (f *Float, err error) {
 
        return z, nil
 }
+
+// Scan scans the number corresponding to the longest possible prefix
+// of r representing a floating-point number with a mantissa in the
+// given conversion base (the exponent is always a decimal number).
+// It returns the corresponding Float f, the actual base b, and an
+// error err, if any. The number must be of the form:
+//
+//     number   = [ prefix ] [ sign ] mantissa [ exponent ] .
+//     mantissa = digits | digits "." [ digits ] | "." digits .
+//      prefix   = prefix = "0" ( "x" | "X" | "b" | "B" ) .
+//     sign     = "+" | "-" .
+//     exponent = ( "E" | "e" | "p" ) [ sign ] digits .
+//     digits   = digit { digit } .
+//     digit    = digit  = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
+//
+// The base argument must be 0 or a value between 2 and MaxBase, inclusive.
+//
+// For base 0, the number prefix determines the actual base: A prefix of
+// ``0x'' or ``0X'' selects base 16, and a ``0b'' or ``0B'' prefix selects
+// base 2; otherwise, the actual base is 10 and no prefix is permitted.
+// Note that the octal prefix ``0'' is not supported.
+//
+// A "p" exponent indicates power of 2 for the exponent; for instance "1.2p3"
+// with base 0 or 10 corresponds to the value 1.2 * 2**3.
+//
+// BUG(gri) Currently, Scan only accepts base 10.
+func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
+       if base != 10 {
+               err = fmt.Errorf("base %d not supported yet", base)
+               return
+       }
+       b = 10
+       f, err = z.scan(r)
+       return
+}
+
+// Parse is like z.Scan(r, base), but instead of reading from an
+// io.ByteScanner, it parses the string s. An error is returned if the
+// string contains invalid or trailing characters not belonging to the
+// number.
+//
+// TODO(gri) define possible errors more precisely
+func (z *Float) Parse(s string, base int) (f *Float, b int, err error) {
+       r := strings.NewReader(s)
+
+       if f, b, err = z.Scan(r, base); err != nil {
+               return
+       }
+
+       // entire string must have been consumed
+       var ch byte
+       if ch, err = r.ReadByte(); err != io.EOF {
+               if err == nil {
+                       err = fmt.Errorf("expected end of string, found %q", ch)
+               }
+       }
+
+       return
+}
+
+// ScanFloat is like f.Scan(r, base) with f set to the given precision
+// and rounding mode.
+func ScanFloat(r io.ByteScanner, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) {
+       return NewFloat(0, prec, mode).Scan(r, base)
+}
+
+// ParseFloat is like f.Parse(s, base) with f set to the given precision
+// and rounding mode.
+func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) {
+       return NewFloat(0, prec, mode).Parse(s, base)
+}
+
+// Format converts the floating-point number x to a string according
+// to the given format and precision prec.
+//
+// The format is one of
+// 'e' (-d.dddde±dd, decimal exponent),
+// 'E' (-d.ddddE±dd, decimal exponent),
+// 'f' (-ddddd.dddd, no exponent),
+// 'g' ('e' for large exponents, 'f' otherwise),
+// 'G' ('E' for large exponents, 'f' otherwise),
+// 'b' (-ddddddp±dd, binary exponent), or
+// 'p' (-0.ddddp±dd, hexadecimal mantissa, binary exponent).
+//
+// The precision prec controls the number of digits (excluding the exponent)
+// printed by the 'e', 'E', 'f', 'g', and 'G' formats. For 'e', 'E', and 'f'
+// it is the number of digits after the decimal point. For 'g' and 'G' it is
+// the total number of digits. A negative precision selects the smallest
+// number of digits necessary such that ParseFloat will return f exactly.
+// The prec value is ignored for the 'b' or 'p' format.
+//
+// BUG(gri) Currently, Format only accepts the 'p' format.
+func (x *Float) Format(format byte, prec int) string {
+       if format != 'p' {
+               return fmt.Sprintf(`%c`, format)
+       }
+       return x.pstring()
+}
+
+// BUG(gri): Currently, String uses the 'p' (rather than 'g') format.
+func (x *Float) String() string {
+       return x.Format('p', 0)
+}
index 5c46e72c6f0866a4fe50a349c776602adf78d815..e8a14bf87dc4327bdb1182881e227dc9b3331ff7 100644 (file)
@@ -208,7 +208,7 @@ func TestFloatSetUint64(t *testing.T) {
        for _, want := range tests {
                f := new(Float).SetUint64(want)
                if got := f.Uint64(); got != want {
-                       t.Errorf("got %d (%s); want %d", got, f.PString(), want)
+                       t.Errorf("got %d (%s); want %d", got, f.pstring(), want)
                }
        }
 }
@@ -231,7 +231,7 @@ func TestFloatSetInt64(t *testing.T) {
                        }
                        f := new(Float).SetInt64(want)
                        if got := f.Int64(); got != want {
-                               t.Errorf("got %d (%s); want %d", got, f.PString(), want)
+                               t.Errorf("got %d (%s); want %d", got, f.pstring(), want)
                        }
                }
        }
@@ -256,7 +256,7 @@ func TestFloatSetFloat64(t *testing.T) {
                        }
                        f := new(Float).SetFloat64(want)
                        if got, _ := f.Float64(); got != want {
-                               t.Errorf("got %g (%s); want %g", got, f.PString(), want)
+                               t.Errorf("got %g (%s); want %g", got, f.pstring(), want)
                        }
                }
        }
@@ -687,7 +687,7 @@ func TestFromBits(t *testing.T) {
 
        for _, test := range tests {
                f := fromBits(test.bits...)
-               if got := f.PString(); got != test.want {
+               if got := f.pstring(); got != test.want {
                        t.Errorf("setBits(%v) = %s; want %s", test.bits, got, test.want)
                }
        }
@@ -757,7 +757,7 @@ func TestFloatSetFloat64String(t *testing.T) {
        }
 }
 
-func TestFloatPString(t *testing.T) {
+func TestFloatpstring(t *testing.T) {
        var tests = []struct {
                x    Float
                want string
@@ -768,7 +768,7 @@ func TestFloatPString(t *testing.T) {
                {Float{mant: nat{0x87654321}, exp: -10}, "0.87654321p-10"},
        }
        for _, test := range tests {
-               if got := test.x.PString(); got != test.want {
+               if got := test.x.pstring(); got != test.want {
                        t.Errorf("%v: got %s; want %s", test.x, got, test.want)
                }
        }
index e574cd08f6e288c8503ae5c7141facfafcc7a37c..716d5381d1246ce8ddbbe2f5e75f24c825012960 100644 (file)
@@ -461,7 +461,7 @@ func (x *Int) Format(s fmt.State, ch rune) {
 // ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
 // ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
 //
-func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, error) {
+func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) {
        // determine sign
        neg, err := scanSign(r)
        if err != nil {
@@ -478,9 +478,9 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, error) {
        return z, base, nil
 }
 
-func scanSign(r io.RuneScanner) (neg bool, err error) {
-       var ch rune
-       if ch, _, err = r.ReadRune(); err != nil {
+func scanSign(r io.ByteScanner) (neg bool, err error) {
+       var ch byte
+       if ch, err = r.ReadByte(); err != nil {
                return false, err
        }
        switch ch {
@@ -489,11 +489,29 @@ func scanSign(r io.RuneScanner) (neg bool, err error) {
        case '+':
                // nothing to do
        default:
-               r.UnreadRune()
+               r.UnreadByte()
        }
        return
 }
 
+// byteReader is a local wrapper around fmt.ScanState;
+// it implements the ByteReader interface.
+type byteReader struct {
+       fmt.ScanState
+}
+
+func (r byteReader) ReadByte() (byte, error) {
+       ch, size, err := r.ReadRune()
+       if size != 1 && err == nil {
+               err = fmt.Errorf("invalid rune %#U", ch)
+       }
+       return byte(ch), err
+}
+
+func (r byteReader) UnreadByte() error {
+       return r.UnreadRune()
+}
+
 // Scan is a support routine for fmt.Scanner; it sets z to the value of
 // the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
 // 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
@@ -514,7 +532,7 @@ func (z *Int) Scan(s fmt.ScanState, ch rune) error {
        default:
                return errors.New("Int.Scan: invalid verb")
        }
-       _, _, err := z.scan(s, base)
+       _, _, err := z.scan(byteReader{s}, base)
        return err
 }
 
@@ -569,7 +587,7 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
        if err != nil {
                return nil, false
        }
-       _, _, err = r.ReadRune()
+       _, err = r.ReadByte()
        if err != io.EOF {
                return nil, false
        }
index 6ef376c668dc99e4d211b3f6972586ed5415c234..e87c71101c1966318b790cce74593e87d57bcb7e 100644 (file)
@@ -668,7 +668,7 @@ func pow(x Word, n int) (p Word) {
 // base == 1, only), and the number of fractional digits is -count. In this
 // case, the value of the scanned number is res * 10**count.
 //
-func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error) {
+func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error) {
        // reject illegal bases
        if base < 0 || base > MaxBase {
                err = errors.New("illegal number base")
@@ -676,7 +676,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
        }
 
        // one char look-ahead
-       ch, _, err := r.ReadRune()
+       ch, err := r.ReadByte()
        if err != nil {
                return
        }
@@ -687,7 +687,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
                // actual base is 10 unless there's a base prefix
                b = 10
                if ch == '0' {
-                       switch ch, _, err = r.ReadRune(); err {
+                       switch ch, err = r.ReadByte(); err {
                        case nil:
                                // possibly one of 0x, 0X, 0b, 0B
                                b = 8
@@ -698,7 +698,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
                                        b = 2
                                }
                                if b == 2 || b == 16 {
-                                       if ch, _, err = r.ReadRune(); err != nil {
+                                       if ch, err = r.ReadByte(); err != nil {
                                                // io.EOF is also an error in this case
                                                return
                                        }
@@ -736,7 +736,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
                        base = 10 // no 2nd decimal point permitted
                        dp = count
                        // advance
-                       if ch, _, err = r.ReadRune(); err != nil {
+                       if ch, err = r.ReadByte(); err != nil {
                                if err == io.EOF {
                                        err = nil
                                        break
@@ -758,7 +758,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
                        d1 = MaxBase + 1
                }
                if d1 >= b1 {
-                       r.UnreadRune() // ch does not belong to number anymore
+                       r.UnreadByte() // ch does not belong to number anymore
                        break
                }
                count++
@@ -775,7 +775,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
                }
 
                // advance
-               if ch, _, err = r.ReadRune(); err != nil {
+               if ch, err = r.ReadByte(); err != nil {
                        if err == io.EOF {
                                err = nil
                                break
index dec310064bb37e5bb068bf7213b7899e0870a840..bd7ec73817111573e4bee51f8ee39f013c10ef2e 100644 (file)
@@ -585,7 +585,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
        }
 
        // there should be no unread characters left
-       if _, _, err = r.ReadRune(); err != io.EOF {
+       if _, err = r.ReadByte(); err != io.EOF {
                return nil, false
        }
 
@@ -615,11 +615,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
        return z, true
 }
 
-func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
+func scanExponent(r io.ByteScanner) (exp int64, base int, err error) {
        base = 10
 
-       var ch rune
-       if ch, _, err = r.ReadRune(); err != nil {
+       var ch byte
+       if ch, err = r.ReadByte(); err != nil {
                if err == io.EOF {
                        err = nil // no exponent; same as e0
                }
@@ -632,7 +632,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
        case 'p':
                base = 2
        default:
-               r.UnreadRune()
+               r.UnreadByte()
                return // no exponent; same as e0
        }
 
@@ -650,7 +650,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
        // since we only care about int64 values - the
        // from-scratch scan is easy enough and faster
        for i := 0; ; i++ {
-               if ch, _, err = r.ReadRune(); err != nil {
+               if ch, err = r.ReadByte(); err != nil {
                        if err != io.EOF || i == 0 {
                                return
                        }
@@ -659,7 +659,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
                }
                if ch < '0' || '9' < ch {
                        if i == 0 {
-                               r.UnreadRune()
+                               r.UnreadByte()
                                err = fmt.Errorf("invalid exponent (missing digits)")
                                return
                        }