]> Cypherpunks repositories - gostls13.git/commitdiff
strconv: check bitsize range in ParseInt and ParseUint
authorMartin Möhrmann <moehrmann@google.com>
Sat, 12 Aug 2017 19:02:43 +0000 (21:02 +0200)
committerRobert Griesemer <gri@golang.org>
Tue, 22 Aug 2017 13:37:40 +0000 (13:37 +0000)
Return an error when a bitSize below 0 or above 64 is specified.

Move bitSize 0 handling in ParseInt after the call to ParseUint
to avoid a spill.

AMD64:
name       old time/op  new time/op  delta
Atoi       28.9ns ± 6%  27.4ns ± 6%  -5.21%  (p=0.002 n=20+20)
AtoiNeg    24.6ns ± 2%  23.1ns ± 1%  -6.04%  (p=0.000 n=19+18)
Atoi64     38.8ns ± 1%  38.0ns ± 1%  -2.03%  (p=0.000 n=17+20)
Atoi64Neg  35.5ns ± 1%  34.3ns ± 1%  -3.42%  (p=0.000 n=19+20)

Updates #21275

Change-Id: I70f0e4a16fa003f7ea929ca4ef56bd1a4181660b
Reviewed-on: https://go-review.googlesource.com/55139
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>

src/strconv/atoi.go
src/strconv/atoi_test.go
src/strconv/export_test.go [new file with mode: 0644]

index 2d6c1dec356825903589b5255bc21b6a3588252a..e1ac42716c299c0a3abf99021ffc6f9d48544459 100644 (file)
@@ -35,6 +35,10 @@ func baseError(fn, str string, base int) *NumError {
        return &NumError{fn, str, errors.New("invalid base " + Itoa(base))}
 }
 
+func bitSizeError(fn, str string, bitSize int) *NumError {
+       return &NumError{fn, str, errors.New("invalid bit size " + Itoa(bitSize))}
+}
+
 const intSize = 32 << (^uint(0) >> 63)
 
 // IntSize is the size in bits of an int or uint value.
@@ -46,10 +50,6 @@ const maxUint64 = (1<<64 - 1)
 func ParseUint(s string, base int, bitSize int) (uint64, error) {
        const fnParseUint = "ParseUint"
 
-       if bitSize == 0 {
-               bitSize = int(IntSize)
-       }
-
        if len(s) == 0 {
                return 0, syntaxError(fnParseUint, s)
        }
@@ -79,6 +79,12 @@ func ParseUint(s string, base int, bitSize int) (uint64, error) {
                return 0, baseError(fnParseUint, s0, base)
        }
 
+       if bitSize == 0 {
+               bitSize = int(IntSize)
+       } else if bitSize < 0 || bitSize > 64 {
+               return 0, bitSizeError(fnParseUint, s0, bitSize)
+       }
+
        // Cutoff is the smallest number such that cutoff*base > maxUint64.
        // Use compile-time constants for common cases.
        var cutoff uint64
@@ -128,14 +134,17 @@ func ParseUint(s string, base int, bitSize int) (uint64, error) {
        return n, nil
 }
 
-// ParseInt interprets a string s in the given base (2 to 36) and
-// returns the corresponding value i. If base == 0, the base is
-// implied by the string's prefix: base 16 for "0x", base 8 for
-// "0", and base 10 otherwise.
+// ParseInt interprets a string s in the given base (0, 2 to 36) and
+// bit size (0 to 64) and returns the corresponding value i.
+//
+// If base == 0, the base is implied by the string's prefix:
+// base 16 for "0x", base 8 for "0", and base 10 otherwise.
+// For bases 1, below 0 or above 36 an error is returned.
 //
 // The bitSize argument specifies the integer type
 // that the result must fit into. Bit sizes 0, 8, 16, 32, and 64
 // correspond to int, int8, int16, int32, and int64.
+// For a bitSize below 0 or above 64 an error is returned.
 //
 // The errors that ParseInt returns have concrete type *NumError
 // and include err.Num = s. If s is empty or contains invalid
@@ -147,10 +156,6 @@ func ParseUint(s string, base int, bitSize int) (uint64, error) {
 func ParseInt(s string, base int, bitSize int) (i int64, err error) {
        const fnParseInt = "ParseInt"
 
-       if bitSize == 0 {
-               bitSize = int(IntSize)
-       }
-
        // Empty string bad.
        if len(s) == 0 {
                return 0, syntaxError(fnParseInt, s)
@@ -174,6 +179,11 @@ func ParseInt(s string, base int, bitSize int) (i int64, err error) {
                err.(*NumError).Num = s0
                return 0, err
        }
+
+       if bitSize == 0 {
+               bitSize = int(IntSize)
+       }
+
        cutoff := uint64(1 << uint(bitSize-1))
        if !neg && un >= cutoff {
                return int64(cutoff - 1), rangeError(fnParseInt, s0)
index 527cc406c181d812196851c08071715e5e5f31c6..94844c7e103ce073fb945f57a67b7c0ff2293aba 100644 (file)
@@ -354,6 +354,87 @@ func TestParseInt(t *testing.T) {
        }
 }
 
+func bitSizeErrStub(name string, bitSize int) error {
+       return BitSizeError(name, "0", bitSize)
+}
+
+func baseErrStub(name string, base int) error {
+       return BaseError(name, "0", base)
+}
+
+func noErrStub(name string, arg int) error {
+       return nil
+}
+
+type parseErrorTest struct {
+       arg     int
+       errStub func(name string, arg int) error
+}
+
+var parseBitSizeTests = []parseErrorTest{
+       {-1, bitSizeErrStub},
+       {0, noErrStub},
+       {64, noErrStub},
+       {65, bitSizeErrStub},
+}
+
+var parseBaseTests = []parseErrorTest{
+       {-1, baseErrStub},
+       {0, noErrStub},
+       {1, baseErrStub},
+       {2, noErrStub},
+       {36, noErrStub},
+       {37, baseErrStub},
+}
+
+func TestParseIntBitSize(t *testing.T) {
+       for i := range parseBitSizeTests {
+               test := &parseBitSizeTests[i]
+               testErr := test.errStub("ParseInt", test.arg)
+               _, err := ParseInt("0", 0, test.arg)
+               if !reflect.DeepEqual(testErr, err) {
+                       t.Errorf("ParseInt(\"0\", 0, %v) = 0, %v want 0, %v",
+                               test.arg, err, testErr)
+               }
+       }
+}
+
+func TestParseUintBitSize(t *testing.T) {
+       for i := range parseBitSizeTests {
+               test := &parseBitSizeTests[i]
+               testErr := test.errStub("ParseUint", test.arg)
+               _, err := ParseUint("0", 0, test.arg)
+               if !reflect.DeepEqual(testErr, err) {
+                       t.Errorf("ParseUint(\"0\", 0, %v) = 0, %v want 0, %v",
+                               test.arg, err, testErr)
+               }
+       }
+}
+
+func TestParseIntBase(t *testing.T) {
+       for i := range parseBaseTests {
+               test := &parseBaseTests[i]
+               testErr := test.errStub("ParseInt", test.arg)
+               _, err := ParseInt("0", test.arg, 0)
+               if !reflect.DeepEqual(testErr, err) {
+                       t.Errorf("ParseInt(\"0\", %v, 0) = 0, %v want 0, %v",
+                               test.arg, err, testErr)
+               }
+       }
+}
+
+func TestParseUintBase(t *testing.T) {
+       for i := range parseBaseTests {
+               test := &parseBaseTests[i]
+               testErr := test.errStub("ParseUint", test.arg)
+               _, err := ParseUint("0", test.arg, 0)
+               if !reflect.DeepEqual(testErr, err) {
+                       t.Errorf("ParseUint(\"0\", %v, 0) = 0, %v want 0, %v",
+                               test.arg, err, testErr)
+               }
+       }
+}
+
 func TestNumError(t *testing.T) {
        for _, test := range numErrorTests {
                err := &NumError{
diff --git a/src/strconv/export_test.go b/src/strconv/export_test.go
new file mode 100644 (file)
index 0000000..8c03a7f
--- /dev/null
@@ -0,0 +1,10 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package strconv
+
+var (
+       BitSizeError = bitSizeError
+       BaseError    = baseError
+)