]> Cypherpunks repositories - gostls13.git/commitdiff
math/big: parsing of fractions and floats in mantissa bases other than 10
authorRobert Griesemer <gri@golang.org>
Thu, 29 Jan 2015 01:11:15 +0000 (17:11 -0800)
committerRobert Griesemer <gri@golang.org>
Fri, 30 Jan 2015 21:51:18 +0000 (21:51 +0000)
Change-Id: I1eaebf956a69e0958201cc5e0a9beefa062c71e1
Reviewed-on: https://go-review.googlesource.com/3454
Reviewed-by: Alan Donovan <adonovan@google.com>
src/math/big/float.go
src/math/big/float_test.go
src/math/big/int.go
src/math/big/nat.go
src/math/big/nat_test.go
src/math/big/rat.go
src/math/big/rat_test.go

index 0819eca729082529f2dee00115b0d1c979a27832..80b560f9b26a1370e85954e1ba82488f56639713 100644 (file)
@@ -405,13 +405,12 @@ func (z *Float) SetInt64(x int64) *Float {
 // TODO(gri) test denormals, +/-Inf, disallow NaN.
 func (z *Float) SetFloat64(x float64) *Float {
        z.prec = 53
+       z.neg = math.Signbit(x) // handle -0 correctly (-0 == 0)
        if x == 0 {
-               z.neg = false
                z.mant = z.mant[:0]
                z.exp = 0
                return z
        }
-       z.neg = x < 0
        fmant, exp := math.Frexp(x) // get normalized mantissa
        z.mant = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11)
        z.exp = int32(exp)
@@ -473,15 +472,16 @@ func (z *Float) Set(x *Float) *Float {
 }
 
 func high64(x nat) uint64 {
-       i := len(x) - 1
-       if i < 0 {
+       i := len(x)
+       if i == 0 {
                return 0
        }
-       v := uint64(x[i])
+       // i > 0
+       v := uint64(x[i-1])
        if _W == 32 {
                v <<= 32
-               if i > 0 {
-                       v |= uint64(x[i-1])
+               if i > 1 {
+                       v |= uint64(x[i-2])
                }
        }
        return v
@@ -959,42 +959,13 @@ func (x *Float) Sign() int {
        return 1
 }
 
-// pstring returns x as a string in the format ["-"] "0." mantissa "p" exponent
-// with a hexadecimal mantissa and a decimal exponent, or ["-"] "0" if x is zero.
-func (x *Float) pstring() string {
-       // TODO(gri) handle Inf
-       var buf bytes.Buffer
-       if x.neg {
-               buf.WriteByte('-')
-       }
-       buf.WriteByte('0')
-       if len(x.mant) > 0 {
-               // non-zero value
-               buf.WriteByte('.')
-               buf.WriteString(strings.TrimRight(x.mant.string(lowercaseDigits[:16]), "0"))
-               fmt.Fprintf(&buf, "p%d", x.exp)
-       }
-       return buf.String()
-}
-
 // SetString sets z to the value of s and returns z and a boolean indicating
-// success. s must be a floating-point number of the form:
-//
-//     number   = [ sign ] mantissa [ exponent ] .
-//     mantissa = digits | digits "." [ digits ] | "." digits .
-//     exponent = ( "E" | "e" | "p" ) [ sign ] digits .
-//     sign     = "+" | "-" .
-//     digits   = digit { digit } .
-//     digit    = "0" ... "9" .
-//
-// A "p" exponent indicates power of 2 for the exponent; for instance 1.2p3
-// is 1.2 * 2**3. If the operation failed, the value of z is undefined but
-// the returned value is nil.
-//
+// success. s must be a floating-point number of the same format as accepted
+// by Scan, with number prefixes permitted.
 func (z *Float) SetString(s string) (*Float, bool) {
        r := strings.NewReader(s)
 
-       f, err := z.scan(r)
+       f, _, err := z.Scan(r, 0)
        if err != nil {
                return nil, false
        }
@@ -1007,22 +978,32 @@ func (z *Float) SetString(s string) (*Float, bool) {
        return f, true
 }
 
-// scan sets z to the value of the longest prefix of r representing
-// a floating-point number and returns z or an error, if any.
-// The number must be of the form:
+// Scan scans the number corresponding to the longest possible prefix
+// of r representing a floating-point number with a mantissa in the
+// given conversion base (the exponent is always a decimal number).
+// It returns the corresponding Float f, the actual base b, and an
+// error err, if any. The number must be of the form:
 //
-//     number   = [ sign ] mantissa [ exponent ] .
+//     number   = [ sign ] [ prefix ] mantissa [ exponent ] .
+//     sign     = "+" | "-" .
+//      prefix   = "0" ( "x" | "X" | "b" | "B" ) .
 //     mantissa = digits | digits "." [ digits ] | "." digits .
 //     exponent = ( "E" | "e" | "p" ) [ sign ] digits .
-//     sign     = "+" | "-" .
 //     digits   = digit { digit } .
-//     digit    = "0" ... "9" .
+//     digit    = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
+//
+// The base argument must be 0 or a value between 2 through MaxBase.
+//
+// For base 0, the number prefix determines the actual base: A prefix of
+// ``0x'' or ``0X'' selects base 16, and a ``0b'' or ``0B'' prefix selects
+// base 2; otherwise, the actual base is 10 and no prefix is permitted.
+// The octal prefix ``0'' is not supported.
 //
-// A "p" exponent indicates power of 2 for the exponent; for instance 1.2p3
-// is 1.2 * 2**3. If the operation failed, the value of z is undefined but
-// the returned value is nil.
+// A "p" exponent indicates power of 2 for the exponent; for instance "1.2p3"
+// with base 0 or 10 corresponds to the value 1.2 * 2**3.
 //
-func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
+// BUG(gri) This signature conflicts with Scan(s fmt.ScanState, ch rune) error.
+func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
        // sign
        z.neg, err = scanSign(r)
        if err != nil {
@@ -1031,7 +1012,7 @@ func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
 
        // mantissa
        var ecorr int // decimal exponent correction; valid if <= 0
-       z.mant, _, ecorr, err = z.mant.scan(r, 1)
+       z.mant, b, ecorr, err = z.mant.scan(r, base, true)
        if err != nil {
                return
        }
@@ -1046,7 +1027,8 @@ func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
        // special-case 0
        if len(z.mant) == 0 {
                z.exp = 0
-               return z, nil
+               f = z
+               return
        }
        // len(z.mant) > 0
 
@@ -1064,7 +1046,8 @@ func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
        if exp == 0 {
                // no decimal exponent
                z.round(0)
-               return z, nil
+               f = z
+               return
        }
        // exp != 0
 
@@ -1082,41 +1065,7 @@ func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
                z.umul(z, powTen)
        }
 
-       return z, nil
-}
-
-// Scan scans the number corresponding to the longest possible prefix
-// of r representing a floating-point number with a mantissa in the
-// given conversion base (the exponent is always a decimal number).
-// It returns the corresponding Float f, the actual base b, and an
-// error err, if any. The number must be of the form:
-//
-//     number   = [ prefix ] [ sign ] mantissa [ exponent ] .
-//     mantissa = digits | digits "." [ digits ] | "." digits .
-//      prefix   = prefix = "0" ( "x" | "X" | "b" | "B" ) .
-//     sign     = "+" | "-" .
-//     exponent = ( "E" | "e" | "p" ) [ sign ] digits .
-//     digits   = digit { digit } .
-//     digit    = digit  = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
-//
-// The base argument must be 0 or a value between 2 and MaxBase, inclusive.
-//
-// For base 0, the number prefix determines the actual base: A prefix of
-// ``0x'' or ``0X'' selects base 16, and a ``0b'' or ``0B'' prefix selects
-// base 2; otherwise, the actual base is 10 and no prefix is permitted.
-// Note that the octal prefix ``0'' is not supported.
-//
-// A "p" exponent indicates power of 2 for the exponent; for instance "1.2p3"
-// with base 0 or 10 corresponds to the value 1.2 * 2**3.
-//
-// BUG(gri) Currently, Scan only accepts base 10.
-func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
-       if base != 10 {
-               err = fmt.Errorf("base %d not supported yet", base)
-               return
-       }
-       b = 10
-       f, err = z.scan(r)
+       f = z
        return
 }
 
@@ -1157,16 +1106,20 @@ func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b i
 }
 
 // Format converts the floating-point number x to a string according
-// to the given format and precision prec.
+// to the given format and precision prec. The format is one of:
 //
-// The format is one of
-// 'e' (-d.dddde±dd, decimal exponent),
-// 'E' (-d.ddddE±dd, decimal exponent),
-// 'f' (-ddddd.dddd, no exponent),
-// 'g' ('e' for large exponents, 'f' otherwise),
-// 'G' ('E' for large exponents, 'f' otherwise),
-// 'b' (-ddddddp±dd, binary exponent), or
-// 'p' (-0.ddddp±dd, hexadecimal mantissa, binary exponent).
+//     'e'     -d.dddde±dd, decimal exponent
+//     'E'     -d.ddddE±dd, decimal exponent
+//     'f'     -ddddd.dddd, no exponent
+//     'g'     like 'e' for large exponents, like 'f' otherwise
+//     'G'     like 'E' for large exponents, like 'f' otherwise
+//     'b'     -ddddddp±dd, binary exponent
+//     'p'     -0x.dddp±dd, binary exponent, hexadecimal mantissa
+//
+// For the binary exponent formats, the mantissa is printed in normalized form:
+//
+//     'b'     decimal integer mantissa using x.Precision() bits, or -0
+//     'p'     hexadecimal fraction with 0.5 <= 0.mantissa < 1.0, or -0
 //
 // The precision prec controls the number of digits (excluding the exponent)
 // printed by the 'e', 'E', 'f', 'g', and 'G' formats. For 'e', 'E', and 'f'
@@ -1175,15 +1128,75 @@ func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b i
 // number of digits necessary such that ParseFloat will return f exactly.
 // The prec value is ignored for the 'b' or 'p' format.
 //
-// BUG(gri) Currently, Format only accepts the 'p' format.
+// BUG(gri) Currently, Format only accepts the 'b' and 'p' format.
 func (x *Float) Format(format byte, prec int) string {
-       if format != 'p' {
-               return fmt.Sprintf(`%c`, format)
+       switch format {
+       case 'b':
+               return x.bstring()
+       case 'p':
+               return x.pstring()
        }
-       return x.pstring()
+       return fmt.Sprintf(`%%!c(%s)`, format, x.pstring())
 }
 
 // BUG(gri): Currently, String uses the 'p' (rather than 'g') format.
 func (x *Float) String() string {
        return x.Format('p', 0)
 }
+
+// TODO(gri) The 'b' and 'p' formats have different meanings here than
+// in strconv: in strconv, the printed exponent is the biased (hardware)
+// exponent; here it is the unbiased exponent. Decide what to do.
+// (a strconv 'p' formatted float value can only be interpreted correctly
+// if the bias is known; i.e., we must know if it's a 32bit or 64bit number).
+
+// bstring returns x as a string in the format ["-"] mantissa "p" exponent
+// with a decimal mantissa and a binary exponent, or ["-"] "0" if x is zero.
+// The mantissa is normalized such that is uses x.Precision() bits in binary
+// representation.
+func (x *Float) bstring() string {
+       // TODO(gri) handle Inf
+       if len(x.mant) == 0 {
+               if x.neg {
+                       return "-0"
+               }
+               return "0"
+       }
+       // x != 0
+       // normalize mantissa
+       m := x.mant
+       t := uint(len(x.mant)*_W) - x.prec // 0 <= t < _W
+       if t > 0 {
+               m = nat(nil).shr(m, t)
+       }
+       var buf bytes.Buffer
+       if x.neg {
+               buf.WriteByte('-')
+       }
+       buf.WriteString(m.decimalString())
+       fmt.Fprintf(&buf, "p%d", x.exp)
+       return buf.String()
+}
+
+// pstring returns x as a string in the format ["-"] "0x." mantissa "p" exponent
+// with a hexadecimal mantissa and a binary exponent, or ["-"] "0" if x is zero.
+// The mantissa is normalized such that 0.5 <= 0.mantissa < 1.0.
+func (x *Float) pstring() string {
+       // TODO(gri) handle Inf
+       if len(x.mant) == 0 {
+               if x.neg {
+                       return "-0"
+               }
+               return "0"
+       }
+       // x != 0
+       // mantissa is stored in normalized form
+       var buf bytes.Buffer
+       if x.neg {
+               buf.WriteByte('-')
+       }
+       buf.WriteString("0x.")
+       buf.WriteString(strings.TrimRight(x.mant.hexString(), "0"))
+       fmt.Fprintf(&buf, "p%d", x.exp)
+       return buf.String()
+}
index e8a14bf87dc4327bdb1182881e227dc9b3331ff7..940cb6d353eb7d785c5ea6e055af6afd45cb4496 100644 (file)
@@ -79,7 +79,7 @@ func testFloatRound(t *testing.T, x, r int64, prec uint, mode RoundingMode) {
 
 // TestFloatRound tests basic rounding.
 func TestFloatRound(t *testing.T) {
-       var tests = []struct {
+       for _, test := range []struct {
                prec                        uint
                x, zero, neven, naway, away string // input, results rounded to prec bits
        }{
@@ -154,9 +154,7 @@ func TestFloatRound(t *testing.T) {
                {1, "1101001", "1000000", "10000000", "10000000", "10000000"},
                {1, "1110001", "1000000", "10000000", "10000000", "10000000"},
                {1, "1111001", "1000000", "10000000", "10000000", "10000000"},
-       }
-
-       for _, test := range tests {
+       } {
                x := fromBinary(test.x)
                z := fromBinary(test.zero)
                e := fromBinary(test.neven)
@@ -195,7 +193,7 @@ func TestFloatRound24(t *testing.T) {
 }
 
 func TestFloatSetUint64(t *testing.T) {
-       var tests = []uint64{
+       for _, want := range []uint64{
                0,
                1,
                2,
@@ -204,8 +202,7 @@ func TestFloatSetUint64(t *testing.T) {
                1<<32 - 1,
                1 << 32,
                1<<64 - 1,
-       }
-       for _, want := range tests {
+       } {
                f := new(Float).SetUint64(want)
                if got := f.Uint64(); got != want {
                        t.Errorf("got %d (%s); want %d", got, f.pstring(), want)
@@ -214,7 +211,7 @@ func TestFloatSetUint64(t *testing.T) {
 }
 
 func TestFloatSetInt64(t *testing.T) {
-       var tests = []int64{
+       for _, want := range []int64{
                0,
                1,
                2,
@@ -223,8 +220,7 @@ func TestFloatSetInt64(t *testing.T) {
                1<<32 - 1,
                1 << 32,
                1<<63 - 1,
-       }
-       for _, want := range tests {
+       } {
                for i := range [2]int{} {
                        if i&1 != 0 {
                                want = -want
@@ -238,7 +234,7 @@ func TestFloatSetInt64(t *testing.T) {
 }
 
 func TestFloatSetFloat64(t *testing.T) {
-       var tests = []float64{
+       for _, want := range []float64{
                0,
                1,
                2,
@@ -248,8 +244,7 @@ func TestFloatSetFloat64(t *testing.T) {
                3.14159265e10,
                2.718281828e-123,
                1.0 / 3,
-       }
-       for _, want := range tests {
+       } {
                for i := range [2]int{} {
                        if i&1 != 0 {
                                want = -want
@@ -396,7 +391,7 @@ func TestFloatMul(t *testing.T) {
 // TestFloatMul64 tests that Float.Mul/Quo of numbers with
 // 53bit mantissa behaves like float64 multiplication/division.
 func TestFloatMul64(t *testing.T) {
-       var tests = []struct {
+       for _, test := range []struct {
                x, y float64
        }{
                {0, 0},
@@ -407,8 +402,7 @@ func TestFloatMul64(t *testing.T) {
                {2.718281828, 3.14159265358979},
                {2.718281828e10, 3.14159265358979e-32},
                {1.0 / 3, 1e200},
-       }
-       for _, test := range tests {
+       } {
                for i := range [8]int{} {
                        x0, y0 := test.x, test.y
                        if i&1 != 0 {
@@ -552,7 +546,7 @@ func normBits(x []int) []int {
 }
 
 func TestNormBits(t *testing.T) {
-       var tests = []struct {
+       for _, test := range []struct {
                x, want []int
        }{
                {nil, nil},
@@ -561,9 +555,7 @@ func TestNormBits(t *testing.T) {
                {[]int{0, 0}, []int{1}},
                {[]int{3, 1, 1}, []int{2, 3}},
                {[]int{10, 9, 8, 7, 6, 6}, []int{11}},
-       }
-
-       for _, test := range tests {
+       } {
                got := fmt.Sprintf("%v", normBits(test.x))
                want := fmt.Sprintf("%v", test.want)
                if got != want {
@@ -665,27 +657,25 @@ func fromBits(bits ...int) *Float {
 }
 
 func TestFromBits(t *testing.T) {
-       var tests = []struct {
+       for _, test := range []struct {
                bits []int
                want string
        }{
                // all different bit numbers
                {nil, "0"},
-               {[]int{0}, "0.8p1"},
-               {[]int{1}, "0.8p2"},
-               {[]int{-1}, "0.8p0"},
-               {[]int{63}, "0.8p64"},
-               {[]int{33, -30}, "0.8000000000000001p34"},
-               {[]int{255, 0}, "0.8000000000000000000000000000000000000000000000000000000000000001p256"},
+               {[]int{0}, "0x.8p1"},
+               {[]int{1}, "0x.8p2"},
+               {[]int{-1}, "0x.8p0"},
+               {[]int{63}, "0x.8p64"},
+               {[]int{33, -30}, "0x.8000000000000001p34"},
+               {[]int{255, 0}, "0x.8000000000000000000000000000000000000000000000000000000000000001p256"},
 
                // multiple equal bit numbers
-               {[]int{0, 0}, "0.8p2"},
-               {[]int{0, 0, 0, 0}, "0.8p3"},
-               {[]int{0, 1, 0}, "0.8p3"},
-               {append([]int{2, 1, 0} /* 7 */, []int{3, 1} /* 10 */ ...), "0.88p5" /* 17 */},
-       }
-
-       for _, test := range tests {
+               {[]int{0, 0}, "0x.8p2"},
+               {[]int{0, 0, 0, 0}, "0x.8p3"},
+               {[]int{0, 1, 0}, "0x.8p3"},
+               {append([]int{2, 1, 0} /* 7 */, []int{3, 1} /* 10 */ ...), "0x.88p5" /* 17 */},
+       } {
                f := fromBits(test.bits...)
                if got := f.pstring(); got != test.want {
                        t.Errorf("setBits(%v) = %s; want %s", test.bits, got, test.want)
@@ -757,19 +747,39 @@ func TestFloatSetFloat64String(t *testing.T) {
        }
 }
 
-func TestFloatpstring(t *testing.T) {
-       var tests = []struct {
-               x    Float
-               want string
+func TestFloatFormat(t *testing.T) {
+       for _, test := range []struct {
+               x      string
+               format byte
+               prec   int
+               want   string
        }{
-               {Float{}, "0"},
-               {Float{neg: true}, "-0"},
-               {Float{mant: nat{0x87654321}}, "0.87654321p0"},
-               {Float{mant: nat{0x87654321}, exp: -10}, "0.87654321p-10"},
-       }
-       for _, test := range tests {
-               if got := test.x.pstring(); got != test.want {
-                       t.Errorf("%v: got %s; want %s", test.x, got, test.want)
+               {"0", 'b', 0, "0"},
+               {"-0", 'b', 0, "-0"},
+               {"1.0", 'b', 0, "4503599627370496p1"},
+               {"-1.0", 'b', 0, "-4503599627370496p1"},
+
+               {"0", 'p', 0, "0"},
+               {"-0", 'p', 0, "-0"},
+               {"1024.0", 'p', 0, "0x.8p11"},
+               {"-1024.0", 'p', 0, "-0x.8p11"},
+       } {
+               f64, err := strconv.ParseFloat(test.x, 64)
+               if err != nil {
+                       t.Error(err)
+                       continue
+               }
+               f := new(Float).SetFloat64(f64)
+               got := f.Format(test.format, test.prec)
+               if got != test.want {
+                       t.Errorf("%v: got %s", test, got)
+               }
+               if test.format == 'b' || test.format == 'p' {
+                       continue // 'b', 'p' format not supported or different in strconv.Format
+               }
+               want := strconv.FormatFloat(f64, test.format, test.prec, 64)
+               if got != want {
+                       t.Errorf("%v: got %s; want %s", test, got, want)
                }
        }
 }
index 716d5381d1246ce8ddbbe2f5e75f24c825012960..3a4d2273354f7427d68f070b2ec7a71a6cacebc5 100644 (file)
@@ -469,7 +469,7 @@ func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) {
        }
 
        // determine mantissa
-       z.abs, base, _, err = z.abs.scan(r, base)
+       z.abs, base, _, err = z.abs.scan(r, base, false)
        if err != nil {
                return nil, base, err
        }
index e87c71101c1966318b790cce74593e87d57bcb7e..2258d7564e6d13af99e84f73525df27e80e21e5f 100644 (file)
@@ -649,28 +649,32 @@ func pow(x Word, n int) (p Word) {
 // It returns the corresponding natural number res, the actual base b,
 // a digit count, and an error err, if any.
 //
-//     number = [ prefix ] digits | digits "." [ digits ] | "." digits .
-//     prefix = "0" [ "x" | "X" | "b" | "B" ] .
-//     digits = digit { digit } .
-//     digit  = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
+//     number   = [ prefix ] mantissa .
+//     prefix   = "0" [ "x" | "X" | "b" | "B" ] .
+//      mantissa = digits | digits "." [ digits ] | "." digits .
+//     digits   = digit { digit } .
+//     digit    = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
+//
+// The base argument must be 0 or a value between 0 through MaxBase.
 //
-// The base argument must be a value between 0 and MaxBase (inclusive).
 // For base 0, the number prefix determines the actual base: A prefix of
-// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and
-// a ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base
-// is 10 and no prefix is permitted.
+// ``0x'' or ``0X'' selects base 16; if fracOk is not set, the ``0'' prefix
+// selects base 8, and a ``0b'' or ``0B'' prefix selects base 2. Otherwise
+// the selected base is 10 and no prefix is permitted.
 //
-// Base argument 1 selects actual base 10 but also enables scanning a number
-// with a decimal point.
+// If fracOk is set, an octal prefix is ignored (a leading ``0'' simply
+// stands for a zero digit), and a period followed by a fractional part
+// is permitted. The result value is computed as if there were no period
+// present; and the count value is used to determine the fractional part.
 //
 // A result digit count > 0 corresponds to the number of (non-prefix) digits
-// parsed. A digit count <= 0 indicates the presence of a decimal point (for
-// base == 1, only), and the number of fractional digits is -count. In this
-// case, the value of the scanned number is res * 10**count.
+// parsed. A digit count <= 0 indicates the presence of a period (if fracOk
+// is set, only), and -count is the number of fractional digits found.
+// In this case, the value of the scanned number is res * 10**count.
 //
-func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error) {
+func (z nat) scan(r io.ByteScanner, base int, fracOk bool) (res nat, b, count int, err error) {
        // reject illegal bases
-       if base < 0 || base > MaxBase {
+       if base != 0 && base < 2 || base > MaxBase {
                err = errors.New("illegal number base")
                return
        }
@@ -682,31 +686,37 @@ func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error)
        }
 
        // determine actual base
-       switch base {
-       case 0:
+       b = base
+       if base == 0 {
                // actual base is 10 unless there's a base prefix
                b = 10
                if ch == '0' {
+                       count = 1
                        switch ch, err = r.ReadByte(); err {
                        case nil:
                                // possibly one of 0x, 0X, 0b, 0B
-                               b = 8
+                               if !fracOk {
+                                       b = 8
+                               }
                                switch ch {
                                case 'x', 'X':
                                        b = 16
                                case 'b', 'B':
                                        b = 2
                                }
-                               if b == 2 || b == 16 {
+                               switch b {
+                               case 16, 2:
+                                       count = 0 // prefix is not counted
                                        if ch, err = r.ReadByte(); err != nil {
                                                // io.EOF is also an error in this case
                                                return
                                        }
+                               case 8:
+                                       count = 0 // prefix is not counted
                                }
                        case io.EOF:
                                // input is "0"
                                res = z[:0]
-                               count = 1
                                err = nil
                                return
                        default:
@@ -714,11 +724,6 @@ func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error)
                                return
                        }
                }
-       case 1:
-               // actual base is 10 and decimal point is permitted
-               b = 10
-       default:
-               b = base
        }
 
        // convert string
@@ -732,8 +737,8 @@ func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error)
        i := 0              // 0 <= i < n
        dp := -1            // position of decimal point
        for {
-               if base == 1 && ch == '.' {
-                       base = 10 // no 2nd decimal point permitted
+               if fracOk && ch == '.' {
+                       fracOk = false
                        dp = count
                        // advance
                        if ch, err = r.ReadByte(); err != nil {
index 42d86193e16bcf8ba6a1a1a65b94efbf7a346ac4..1e98118b63547c01922edfc45edbd3e269bdc13e 100644 (file)
@@ -88,7 +88,7 @@ var prodNN = []argNN{
 }
 
 func natFromString(s string) nat {
-       x, _, _, err := nat(nil).scan(strings.NewReader(s), 0)
+       x, _, _, err := nat(nil).scan(strings.NewReader(s), 0, false)
        if err != nil {
                panic(err)
        }
@@ -271,7 +271,7 @@ func TestString(t *testing.T) {
                        t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
                }
 
-               x, b, _, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
+               x, b, _, err := nat(nil).scan(strings.NewReader(a.s), len(a.c), false)
                if x.cmp(a.x) != 0 {
                        t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
                }
@@ -287,6 +287,7 @@ func TestString(t *testing.T) {
 var natScanTests = []struct {
        s     string // string to be scanned
        base  int    // input base
+       frac  bool   // fraction ok
        x     nat    // expected nat
        b     int    // expected base
        count int    // expected digit count
@@ -313,39 +314,39 @@ var natScanTests = []struct {
        {s: "0x.0"},
 
        // no errors
-       {"0", 0, nil, 10, 1, true, 0},
-       {"0", 10, nil, 10, 1, true, 0},
-       {"0", 36, nil, 36, 1, true, 0},
-       {"1", 0, nat{1}, 10, 1, true, 0},
-       {"1", 10, nat{1}, 10, 1, true, 0},
-       {"0 ", 0, nil, 10, 1, true, ' '},
-       {"08", 0, nil, 10, 1, true, '8'},
-       {"08", 10, nat{8}, 10, 2, true, 0},
-       {"018", 0, nat{1}, 8, 1, true, '8'},
-       {"0b1", 0, nat{1}, 2, 1, true, 0},
-       {"0b11000101", 0, nat{0xc5}, 2, 8, true, 0},
-       {"03271", 0, nat{03271}, 8, 4, true, 0},
-       {"10ab", 0, nat{10}, 10, 2, true, 'a'},
-       {"1234567890", 0, nat{1234567890}, 10, 10, true, 0},
-       {"xyz", 36, nat{(33*36+34)*36 + 35}, 36, 3, true, 0},
-       {"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, 3, true, '?'},
-       {"0x", 16, nil, 16, 1, true, 'x'},
-       {"0xdeadbeef", 0, nat{0xdeadbeef}, 16, 8, true, 0},
-       {"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, 8, true, 0},
+       {"0", 0, false, nil, 10, 1, true, 0},
+       {"0", 10, false, nil, 10, 1, true, 0},
+       {"0", 36, false, nil, 36, 1, true, 0},
+       {"1", 0, false, nat{1}, 10, 1, true, 0},
+       {"1", 10, false, nat{1}, 10, 1, true, 0},
+       {"0 ", 0, false, nil, 10, 1, true, ' '},
+       {"08", 0, false, nil, 10, 1, true, '8'},
+       {"08", 10, false, nat{8}, 10, 2, true, 0},
+       {"018", 0, false, nat{1}, 8, 1, true, '8'},
+       {"0b1", 0, false, nat{1}, 2, 1, true, 0},
+       {"0b11000101", 0, false, nat{0xc5}, 2, 8, true, 0},
+       {"03271", 0, false, nat{03271}, 8, 4, true, 0},
+       {"10ab", 0, false, nat{10}, 10, 2, true, 'a'},
+       {"1234567890", 0, false, nat{1234567890}, 10, 10, true, 0},
+       {"xyz", 36, false, nat{(33*36+34)*36 + 35}, 36, 3, true, 0},
+       {"xyz?", 36, false, nat{(33*36+34)*36 + 35}, 36, 3, true, '?'},
+       {"0x", 16, false, nil, 16, 1, true, 'x'},
+       {"0xdeadbeef", 0, false, nat{0xdeadbeef}, 16, 8, true, 0},
+       {"0XDEADBEEF", 0, false, nat{0xdeadbeef}, 16, 8, true, 0},
 
        // no errors, decimal point
-       {"0.", 0, nil, 10, 1, true, '.'},
-       {"0.", 1, nil, 10, 0, true, 0},
-       {"0.1.2", 1, nat{1}, 10, -1, true, '.'},
-       {".000", 1, nil, 10, -3, true, 0},
-       {"12.3", 1, nat{123}, 10, -1, true, 0},
-       {"012.345", 1, nat{12345}, 10, -3, true, 0},
+       {"0.", 0, false, nil, 10, 1, true, '.'},
+       {"0.", 10, true, nil, 10, 0, true, 0},
+       {"0.1.2", 10, true, nat{1}, 10, -1, true, '.'},
+       {".000", 10, true, nil, 10, -3, true, 0},
+       {"12.3", 10, true, nat{123}, 10, -1, true, 0},
+       {"012.345", 10, true, nat{12345}, 10, -3, true, 0},
 }
 
 func TestScanBase(t *testing.T) {
        for _, a := range natScanTests {
                r := strings.NewReader(a.s)
-               x, b, count, err := nat(nil).scan(r, a.base)
+               x, b, count, err := nat(nil).scan(r, a.base, a.frac)
                if err == nil && !a.ok {
                        t.Errorf("scan%+v\n\texpected error", a)
                }
@@ -431,7 +432,7 @@ var pi = "3" +
 // Test case for BenchmarkScanPi.
 func TestScanPi(t *testing.T) {
        var x nat
-       z, _, _, err := x.scan(strings.NewReader(pi), 10)
+       z, _, _, err := x.scan(strings.NewReader(pi), 10, false)
        if err != nil {
                t.Errorf("scanning pi: %s", err)
        }
@@ -457,13 +458,13 @@ func TestScanPiParallel(t *testing.T) {
 func BenchmarkScanPi(b *testing.B) {
        for i := 0; i < b.N; i++ {
                var x nat
-               x.scan(strings.NewReader(pi), 10)
+               x.scan(strings.NewReader(pi), 10, false)
        }
 }
 
 func BenchmarkStringPiParallel(b *testing.B) {
        var x nat
-       x, _, _, _ = x.scan(strings.NewReader(pi), 0)
+       x, _, _, _ = x.scan(strings.NewReader(pi), 0, false)
        if x.decimalString() != pi {
                panic("benchmark incorrect: conversion failed")
        }
@@ -511,7 +512,7 @@ func ScanHelper(b *testing.B, base int, x, y Word) {
        b.StartTimer()
 
        for i := 0; i < b.N; i++ {
-               z.scan(strings.NewReader(s), base)
+               z.scan(strings.NewReader(s), base, false)
        }
 }
 
index bd7ec73817111573e4bee51f8ee39f013c10ef2e..bc4029a721f0a701968363e161d8f37c3f6d8b8c 100644 (file)
@@ -546,12 +546,12 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
 
        // parse fraction a/b, if any
        if sep := strings.Index(s, "/"); sep >= 0 {
-               if _, ok := z.a.SetString(s[:sep], 10); !ok {
+               if _, ok := z.a.SetString(s[:sep], 0); !ok {
                        return nil, false
                }
                s = s[sep+1:]
                var err error
-               if z.b.abs, _, _, err = z.b.abs.scan(strings.NewReader(s), 10); err != nil {
+               if z.b.abs, _, _, err = z.b.abs.scan(strings.NewReader(s), 0, false); err != nil {
                        return nil, false
                }
                if len(z.b.abs) == 0 {
@@ -571,7 +571,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
 
        // mantissa
        var ecorr int
-       z.a.abs, _, ecorr, err = z.a.abs.scan(r, 1)
+       z.a.abs, _, ecorr, err = z.a.abs.scan(r, 10, true)
        if err != nil {
                return nil, false
        }
index a4fc610062abdb7cac86b3d27d460c0727329980..37f672ee3d2b32dc383f7e4ca49eba9e5c886f66 100644 (file)
@@ -56,10 +56,12 @@ func TestZeroRat(t *testing.T) {
        z.Quo(&x, &y)
 }
 
-var setStringTests = []struct {
+type StringTest struct {
        in, out string
        ok      bool
-}{
+}
+
+var setStringTests = []StringTest{
        {"0", "0", true},
        {"-0", "0", true},
        {"1", "1", true},
@@ -92,8 +94,22 @@ var setStringTests = []struct {
        {in: "1/0"},
 }
 
+// These are not supported by fmt.Fscanf.
+var setStringTests2 = []StringTest{
+       {"0x10", "16", true},
+       {"-010/1", "-8", true}, // TODO(gri) should we even permit octal here?
+       {"-010.", "-10", true},
+       {"0x10/0x20", "1/2", true},
+       {"0b1000/3", "8/3", true},
+       // TODO(gri) add more tests
+}
+
 func TestRatSetString(t *testing.T) {
-       for i, test := range setStringTests {
+       var tests []StringTest
+       tests = append(tests, setStringTests...)
+       tests = append(tests, setStringTests2...)
+
+       for i, test := range tests {
                x, ok := new(Rat).SetString(test.in)
 
                if ok {