]> Cypherpunks repositories - gostls13.git/commitdiff
flag: add implicit boolFlag interface
authorRick Arnold <rickarnoldjr@gmail.com>
Sat, 22 Dec 2012 18:34:48 +0000 (13:34 -0500)
committerRuss Cox <rsc@golang.org>
Sat, 22 Dec 2012 18:34:48 +0000 (13:34 -0500)
Any flag.Value that has an IsBoolFlag method that returns true
will be treated as a bool flag type during parsing.

Fixes #4262.

R=bradfitz, rsc
CC=golang-dev
https://golang.org/cl/6944064

src/pkg/flag/flag.go
src/pkg/flag/flag_test.go

index bbabd88c8ca394b8b3da43e74f8478a6fb2da896..85dd8c3b37ac06e5c2eac2a1aca2bff617990557 100644 (file)
@@ -91,6 +91,15 @@ func (b *boolValue) Set(s string) error {
 
 func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
 
+func (b *boolValue) IsBoolFlag() bool { return true }
+
+// optional interface to indicate boolean flags that can be
+// supplied without "=value" text
+type boolFlag interface {
+       Value
+       IsBoolFlag() bool
+}
+
 // -- int Value
 type intValue int
 
@@ -204,6 +213,10 @@ func (d *durationValue) String() string { return (*time.Duration)(d).String() }
 
 // Value is the interface to the dynamic value stored in a flag.
 // (The default value is represented as a string.)
+//
+// If a Value has an IsBoolFlag() bool method returning true,
+// the command-line parser makes -name equivalent to -name=true
+// rather than using the next command-line argument.
 type Value interface {
        String() string
        Set(string) error
@@ -704,7 +717,7 @@ func (f *FlagSet) parseOne() (bool, error) {
                }
                return false, f.failf("flag provided but not defined: -%s", name)
        }
-       if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
+       if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
                if has_value {
                        if err := fv.Set(value); err != nil {
                                return false, f.failf("invalid boolean value %q for  -%s: %v", value, name, err)
index a9561f269f91ff32bfc31f1f820b28f626407985..7a26fffd8d6c4c43b5cf18bef553913808222568 100644 (file)
@@ -208,6 +208,47 @@ func TestUserDefined(t *testing.T) {
        }
 }
 
+// Declare a user-defined boolean flag type.
+type boolFlagVar struct {
+       count int
+}
+
+func (b *boolFlagVar) String() string {
+       return fmt.Sprintf("%d", b.count)
+}
+
+func (b *boolFlagVar) Set(value string) error {
+       if value == "true" {
+               b.count++
+       }
+       return nil
+}
+
+func (b *boolFlagVar) IsBoolFlag() bool {
+       return b.count < 4
+}
+
+func TestUserDefinedBool(t *testing.T) {
+       var flags FlagSet
+       flags.Init("test", ContinueOnError)
+       var b boolFlagVar
+       var err error
+       flags.Var(&b, "b", "usage")
+       if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
+               if b.count < 4 {
+                       t.Error(err)
+               }
+       }
+
+       if b.count != 4 {
+               t.Errorf("want: %d; got: %d", 4, b.count)
+       }
+
+       if err == nil {
+               t.Error("expected error; got none")
+       }
+}
+
 func TestSetOutput(t *testing.T) {
        var flags FlagSet
        var buf bytes.Buffer