]> Cypherpunks repositories - gostls13.git/commitdiff
flag: return a consistent parse error if the flag value is invalid
authorRob Pike <r@golang.org>
Thu, 18 Oct 2018 23:48:53 +0000 (10:48 +1100)
committerRob Pike <r@golang.org>
Fri, 19 Oct 2018 03:48:38 +0000 (03:48 +0000)
Return a consistently formatted error string that reports either
a parse error or a range error.

Before:
invalid boolean value "3" for -debug: strconv.ParseBool: parsing "3": invalid syntax

After:
invalid boolean value "3" for -debug: parse error

Fixes #26822

Change-Id: I60992bf23da32a4c0cf32472a8af486a3c9674ad
Reviewed-on: https://go-review.googlesource.com/c/143257
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/flag/flag.go
src/flag/flag_test.go

index ae84e1f775e926c32143174c7afe8b378d122bcb..2eef9d6ab9ae2fc945d35719278bdf194f1183b3 100644 (file)
@@ -83,6 +83,28 @@ import (
 // but no such flag is defined.
 var ErrHelp = errors.New("flag: help requested")
 
+// errParse is returned by Set if a flag's value fails to parse, such as with an invalid integer for Int.
+// It then gets wrapped through failf to provide more information.
+var errParse = errors.New("parse error")
+
+// errRange is returned by Set if a flag's value is out of range.
+// It then gets wrapped through failf to provide more information.
+var errRange = errors.New("value out of range")
+
+func numError(err error) error {
+       ne, ok := err.(*strconv.NumError)
+       if !ok {
+               return err
+       }
+       if ne.Err == strconv.ErrSyntax {
+               return errParse
+       }
+       if ne.Err == strconv.ErrRange {
+               return errRange
+       }
+       return err
+}
+
 // -- bool Value
 type boolValue bool
 
@@ -93,6 +115,9 @@ func newBoolValue(val bool, p *bool) *boolValue {
 
 func (b *boolValue) Set(s string) error {
        v, err := strconv.ParseBool(s)
+       if err != nil {
+               err = errParse
+       }
        *b = boolValue(v)
        return err
 }
@@ -120,6 +145,9 @@ func newIntValue(val int, p *int) *intValue {
 
 func (i *intValue) Set(s string) error {
        v, err := strconv.ParseInt(s, 0, strconv.IntSize)
+       if err != nil {
+               err = numError(err)
+       }
        *i = intValue(v)
        return err
 }
@@ -138,6 +166,9 @@ func newInt64Value(val int64, p *int64) *int64Value {
 
 func (i *int64Value) Set(s string) error {
        v, err := strconv.ParseInt(s, 0, 64)
+       if err != nil {
+               err = numError(err)
+       }
        *i = int64Value(v)
        return err
 }
@@ -156,6 +187,9 @@ func newUintValue(val uint, p *uint) *uintValue {
 
 func (i *uintValue) Set(s string) error {
        v, err := strconv.ParseUint(s, 0, strconv.IntSize)
+       if err != nil {
+               err = numError(err)
+       }
        *i = uintValue(v)
        return err
 }
@@ -174,6 +208,9 @@ func newUint64Value(val uint64, p *uint64) *uint64Value {
 
 func (i *uint64Value) Set(s string) error {
        v, err := strconv.ParseUint(s, 0, 64)
+       if err != nil {
+               err = numError(err)
+       }
        *i = uint64Value(v)
        return err
 }
@@ -209,6 +246,9 @@ func newFloat64Value(val float64, p *float64) *float64Value {
 
 func (f *float64Value) Set(s string) error {
        v, err := strconv.ParseFloat(s, 64)
+       if err != nil {
+               err = numError(err)
+       }
        *f = float64Value(v)
        return err
 }
@@ -227,6 +267,9 @@ func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
 
 func (d *durationValue) Set(s string) error {
        v, err := time.ParseDuration(s)
+       if err != nil {
+               err = errParse
+       }
        *d = durationValue(v)
        return err
 }
index c7f0c07d4489b3b6ff515e1cd4fea636a6d44973..0d9491c02089de4a59d06a8ec4156a956c6588d4 100644 (file)
@@ -9,6 +9,7 @@ import (
        . "flag"
        "fmt"
        "io"
+       "io/ioutil"
        "os"
        "sort"
        "strconv"
@@ -491,3 +492,55 @@ func TestGetters(t *testing.T) {
                t.Errorf("unexpected output: got %v, expected %v", fs.Output(), expectedOutput)
        }
 }
+
+func TestParseError(t *testing.T) {
+       for _, typ := range []string{"bool", "int", "int64", "uint", "uint64", "float64", "duration"} {
+               fs := NewFlagSet("parse error test", ContinueOnError)
+               fs.SetOutput(ioutil.Discard)
+               _ = fs.Bool("bool", false, "")
+               _ = fs.Int("int", 0, "")
+               _ = fs.Int64("int64", 0, "")
+               _ = fs.Uint("uint", 0, "")
+               _ = fs.Uint64("uint64", 0, "")
+               _ = fs.Float64("float64", 0, "")
+               _ = fs.Duration("duration", 0, "")
+               // Strings cannot give errors.
+               args := []string{"-" + typ + "=x"}
+               err := fs.Parse(args) // x is not a valid setting for any flag.
+               if err == nil {
+                       t.Errorf("Parse(%q)=%v; expected parse error", args, err)
+                       continue
+               }
+               if !strings.Contains(err.Error(), "invalid") || !strings.Contains(err.Error(), "parse error") {
+                       t.Errorf("Parse(%q)=%v; expected parse error", args, err)
+               }
+       }
+}
+
+func TestRangeError(t *testing.T) {
+       bad := []string{
+               "-int=123456789012345678901",
+               "-int64=123456789012345678901",
+               "-uint=123456789012345678901",
+               "-uint64=123456789012345678901",
+               "-float64=1e1000",
+       }
+       for _, arg := range bad {
+               fs := NewFlagSet("parse error test", ContinueOnError)
+               fs.SetOutput(ioutil.Discard)
+               _ = fs.Int("int", 0, "")
+               _ = fs.Int64("int64", 0, "")
+               _ = fs.Uint("uint", 0, "")
+               _ = fs.Uint64("uint64", 0, "")
+               _ = fs.Float64("float64", 0, "")
+               // Strings cannot give errors, and bools and durations do not return strconv.NumError.
+               err := fs.Parse([]string{arg})
+               if err == nil {
+                       t.Errorf("Parse(%q)=%v; expected range error", arg, err)
+                       continue
+               }
+               if !strings.Contains(err.Error(), "invalid") || !strings.Contains(err.Error(), "value out of range") {
+                       t.Errorf("Parse(%q)=%v; expected range error", arg, err)
+               }
+       }
+}