]> Cypherpunks repositories - gostls13.git/commitdiff
flag: recover panic when calling String on zero value in PrintDefaults
authorAndrew Gerrand <adg@golang.org>
Tue, 29 Mar 2022 23:37:07 +0000 (10:37 +1100)
committerGopher Robot <gobot@golang.org>
Fri, 1 Apr 2022 03:43:34 +0000 (03:43 +0000)
When printing the usage message, recover panics when calling String
methods on reflect-constructed flag.Value zero values. Collect the panic
messages and include them at the end of the PrintDefaults output so that
the programmer knows to fix the panic.

Fixes #28667

Change-Id: Ic4378a5813a2e26f063d5580d678add65ece8f97
Reviewed-on: https://go-review.googlesource.com/c/go/+/396574
Run-TryBot: Andrew Gerrand <adg@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Auto-Submit: Ian Lance Taylor <iant@golang.org>
Trust: Andrew Gerrand <adg@golang.org>

src/flag/flag.go
src/flag/flag_test.go

index cdea949a2f06fac2b87f9101eb474cbfa13da1de..15bcb6cea97e89474c79de68cdf9f890cd5ec3a1 100644 (file)
@@ -488,7 +488,7 @@ func Set(name, value string) error {
 
 // isZeroValue determines whether the string represents the zero
 // value for a flag.
-func isZeroValue(flag *Flag, value string) bool {
+func isZeroValue(flag *Flag, value string) (ok bool, err error) {
        // Build a zero value of the flag's Value type, and see if the
        // result of calling its String method equals the value passed in.
        // This works unless the Value type is itself an interface type.
@@ -499,7 +499,18 @@ func isZeroValue(flag *Flag, value string) bool {
        } else {
                z = reflect.Zero(typ)
        }
-       return value == z.Interface().(Value).String()
+       // Catch panics calling the String method, which shouldn't prevent the
+       // usage message from being printed, but that we should report to the
+       // user so that they know to fix their code.
+       defer func() {
+               if e := recover(); e != nil {
+                       if typ.Kind() == reflect.Pointer {
+                               typ = typ.Elem()
+                       }
+                       err = fmt.Errorf("panic calling String method on zero %v for flag %s: %v", typ, flag.Name, e)
+               }
+       }()
+       return value == z.Interface().(Value).String(), nil
 }
 
 // UnquoteUsage extracts a back-quoted name from the usage
@@ -545,6 +556,7 @@ func UnquoteUsage(flag *Flag) (name string, usage string) {
 // default values of all defined command-line flags in the set. See the
 // documentation for the global function PrintDefaults for more information.
 func (f *FlagSet) PrintDefaults() {
+       var isZeroValueErrs []error
        f.VisitAll(func(flag *Flag) {
                var b strings.Builder
                fmt.Fprintf(&b, "  -%s", flag.Name) // Two spaces before -; see next two comments.
@@ -564,7 +576,11 @@ func (f *FlagSet) PrintDefaults() {
                }
                b.WriteString(strings.ReplaceAll(usage, "\n", "\n    \t"))
 
-               if !isZeroValue(flag, flag.DefValue) {
+               // Print the default value only if it differs to the zero value
+               // for this flag type.
+               if isZero, err := isZeroValue(flag, flag.DefValue); err != nil {
+                       isZeroValueErrs = append(isZeroValueErrs, err)
+               } else if !isZero {
                        if _, ok := flag.Value.(*stringValue); ok {
                                // put quotes on the value
                                fmt.Fprintf(&b, " (default %q)", flag.DefValue)
@@ -574,6 +590,15 @@ func (f *FlagSet) PrintDefaults() {
                }
                fmt.Fprint(f.Output(), b.String(), "\n")
        })
+       // If calling String on any zero flag.Values triggered a panic, print
+       // the messages after the full set of defaults so that the programmer
+       // knows to fix the panic.
+       if errs := isZeroValueErrs; len(errs) > 0 {
+               fmt.Fprintln(f.Output())
+               for _, err := range errs {
+                       fmt.Fprintln(f.Output(), err)
+               }
+       }
 }
 
 // PrintDefaults prints, to standard error unless configured otherwise,
index d5c443d3c660d3f8cd89e7314e171790240e30af..ca6ba5d149c69ae8f10d37c9d19f739842c8421a 100644 (file)
@@ -432,6 +432,25 @@ func TestHelp(t *testing.T) {
        }
 }
 
+// zeroPanicker is a flag.Value whose String method panics if its dontPanic
+// field is false.
+type zeroPanicker struct {
+       dontPanic bool
+       v         string
+}
+
+func (f *zeroPanicker) Set(s string) error {
+       f.v = s
+       return nil
+}
+
+func (f *zeroPanicker) String() string {
+       if !f.dontPanic {
+               panic("panic!")
+       }
+       return f.v
+}
+
 const defaultOutput = `  -A    for bootstrapping, allow 'any' type
   -Alongflagname
        disable bounds checking
@@ -452,10 +471,19 @@ const defaultOutput = `  -A       for bootstrapping, allow 'any' type
        a non-zero int (default 27)
   -O   a flag
        multiline help string (default true)
+  -V list
+       a list of strings (default [a b])
   -Z int
        an int that defaults to zero
+  -ZP0 value
+       a flag whose String method panics when it is zero
+  -ZP1 value
+       a flag whose String method panics when it is zero
   -maxT timeout
        set timeout for dial
+
+panic calling String method on zero flag_test.zeroPanicker for flag ZP0: panic!
+panic calling String method on zero flag_test.zeroPanicker for flag ZP1: panic!
 `
 
 func TestPrintDefaults(t *testing.T) {
@@ -472,12 +500,15 @@ func TestPrintDefaults(t *testing.T) {
        fs.String("M", "", "a multiline\nhelp\nstring")
        fs.Int("N", 27, "a non-zero int")
        fs.Bool("O", true, "a flag\nmultiline help string")
+       fs.Var(&flagVar{"a", "b"}, "V", "a `list` of strings")
        fs.Int("Z", 0, "an int that defaults to zero")
+       fs.Var(&zeroPanicker{true, ""}, "ZP0", "a flag whose String method panics when it is zero")
+       fs.Var(&zeroPanicker{true, "something"}, "ZP1", "a flag whose String method panics when it is zero")
        fs.Duration("maxT", 0, "set `timeout` for dial")
        fs.PrintDefaults()
        got := buf.String()
        if got != defaultOutput {
-               t.Errorf("got %q want %q\n", got, defaultOutput)
+               t.Errorf("got:\n%q\nwant:\n%q", got, defaultOutput)
        }
 }