]> Cypherpunks repositories - gostls13.git/commitdiff
flag: panic if flag name begins with - or contains =
authorKimMachineGun <geon0250@gmail.com>
Wed, 10 Mar 2021 12:55:56 +0000 (12:55 +0000)
committerEmmanuel Odeke <emmanuel@orijtech.com>
Wed, 10 Mar 2021 18:33:52 +0000 (18:33 +0000)
Fixes #41792

Change-Id: I9b4aae8a899e3c3ac9532d27932d275cfb1fab48
GitHub-Last-Rev: f06b1e17674bf77bdc2d3e798df4c4379748c8d2
GitHub-Pull-Request: golang/go#42737
Reviewed-on: https://go-review.googlesource.com/c/go/+/271788
Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com>
Reviewed-by: Rob Pike <r@golang.org>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Rob Pike <r@golang.org>

src/flag/flag.go
src/flag/flag_test.go

index a8485f034f30c13cee232c0d5ffcfa15bb36c8e2..f7598a6758f2237d007047e799214708136583cd 100644 (file)
@@ -857,17 +857,23 @@ func Func(name, usage string, fn func(string) error) {
 // of strings by giving the slice the methods of Value; in particular, Set would
 // decompose the comma-separated string into the slice.
 func (f *FlagSet) Var(value Value, name string, usage string) {
+       // Flag must not begin "-" or contain "=".
+       if strings.HasPrefix(name, "-") {
+               panic(f.sprintf("flag %q begins with -", name))
+       } else if strings.Contains(name, "=") {
+               panic(f.sprintf("flag %q contains =", name))
+       }
+
        // Remember the default value as a string; it won't change.
        flag := &Flag{name, usage, value, value.String()}
        _, alreadythere := f.formal[name]
        if alreadythere {
                var msg string
                if f.name == "" {
-                       msg = fmt.Sprintf("flag redefined: %s", name)
+                       msg = f.sprintf("flag redefined: %s", name)
                } else {
-                       msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
+                       msg = f.sprintf("%s flag redefined: %s", f.name, name)
                }
-               fmt.Fprintln(f.Output(), msg)
                panic(msg) // Happens only if flags are declared with identical names
        }
        if f.formal == nil {
@@ -886,13 +892,19 @@ func Var(value Value, name string, usage string) {
        CommandLine.Var(value, name, usage)
 }
 
+// sprintf formats the message, prints it to output, and returns it.
+func (f *FlagSet) sprintf(format string, a ...interface{}) string {
+       msg := fmt.Sprintf(format, a...)
+       fmt.Fprintln(f.Output(), msg)
+       return msg
+}
+
 // failf prints to standard error a formatted error and usage message and
 // returns the error.
 func (f *FlagSet) failf(format string, a ...interface{}) error {
-       err := fmt.Errorf(format, a...)
-       fmt.Fprintln(f.Output(), err)
+       msg := f.sprintf(format, a...)
        f.usage()
-       return err
+       return errors.New(msg)
 }
 
 // usage calls the Usage method for the flag set if one is specified,
index 06cab79405a7bab928b557c51335951ff83bdfa3..5835fcf22cd7a9ff2c1eddfc3980d92cd7d5c13b 100644 (file)
@@ -655,3 +655,86 @@ func TestExitCode(t *testing.T) {
                }
        }
 }
+
+func mustPanic(t *testing.T, testName string, expected string, f func()) {
+       t.Helper()
+       defer func() {
+               switch msg := recover().(type) {
+               case nil:
+                       t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected)
+               case string:
+                       if msg != expected {
+                               t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg)
+                       }
+               default:
+                       t.Errorf("%s\n: expected panic(%q), but got panic(%T%v)", testName, expected, msg, msg)
+               }
+       }()
+       f()
+}
+
+func TestInvalidFlags(t *testing.T) {
+       tests := []struct {
+               flag     string
+               errorMsg string
+       }{
+               {
+                       flag:     "-foo",
+                       errorMsg: "flag \"-foo\" begins with -",
+               },
+               {
+                       flag:     "foo=bar",
+                       errorMsg: "flag \"foo=bar\" contains =",
+               },
+       }
+
+       for _, test := range tests {
+               testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag)
+
+               fs := NewFlagSet("", ContinueOnError)
+               buf := bytes.NewBuffer(nil)
+               fs.SetOutput(buf)
+
+               mustPanic(t, testName, test.errorMsg, func() {
+                       var v flagVar
+                       fs.Var(&v, test.flag, "")
+               })
+               if msg := test.errorMsg + "\n"; msg != buf.String() {
+                       t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
+               }
+       }
+}
+
+func TestRedefinedFlags(t *testing.T) {
+       tests := []struct {
+               flagSetName string
+               errorMsg    string
+       }{
+               {
+                       flagSetName: "",
+                       errorMsg:    "flag redefined: foo",
+               },
+               {
+                       flagSetName: "fs",
+                       errorMsg:    "fs flag redefined: foo",
+               },
+       }
+
+       for _, test := range tests {
+               testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName)
+
+               fs := NewFlagSet(test.flagSetName, ContinueOnError)
+               buf := bytes.NewBuffer(nil)
+               fs.SetOutput(buf)
+
+               var v flagVar
+               fs.Var(&v, "foo", "")
+
+               mustPanic(t, testName, test.errorMsg, func() {
+                       fs.Var(&v, "foo", "")
+               })
+               if msg := test.errorMsg + "\n"; msg != buf.String() {
+                       t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
+               }
+       }
+}