]> Cypherpunks repositories - gostls13.git/commitdiff
make interface to the flags themselves more public.
authorRob Pike <r@golang.org>
Tue, 17 Feb 2009 03:43:15 +0000 (19:43 -0800)
committerRob Pike <r@golang.org>
Tue, 17 Feb 2009 03:43:15 +0000 (19:43 -0800)
add visitor functions to scan the flags.
add a way to set a flag.
add a flag test.

R=rsc
DELTA=169  (99 added, 19 deleted, 51 changed)
OCL=25076
CL=25078

src/lib/flag.go
src/lib/flag_test.go [new file with mode: 0644]

index 35e18f9e6612d84ffba035f2e83167096e3c0a77..541966d87d39f855130ba17092e9ad47a33c7500 100644 (file)
@@ -115,11 +115,13 @@ func newBoolValue(val bool, p *bool) *boolValue {
        return &boolValue(p)
 }
 
-func (b *boolValue) set(val bool) {
-       *b.p = val;
+func (b *boolValue) set(s string) bool {
+       v, ok  := atob(s);
+       *b.p = v;
+       return ok
 }
 
-func (b *boolValue) str() string {
+func (b *boolValue) String() string {
        return fmt.Sprintf("%v", *b.p)
 }
 
@@ -133,11 +135,13 @@ func newIntValue(val int, p *int) *intValue {
        return &intValue(p)
 }
 
-func (i *intValue) set(val int) {
-       *i.p = val;
+func (i *intValue) set(s string) bool {
+       v, ok  := atoi(s);
+       *i.p = int(v);
+       return ok
 }
 
-func (i *intValue) str() string {
+func (i *intValue) String() string {
        return fmt.Sprintf("%v", *i.p)
 }
 
@@ -151,11 +155,13 @@ func newInt64Value(val int64, p *int64) *int64Value {
        return &int64Value(p)
 }
 
-func (i *int64Value) set(val int64) {
-       *i.p = val;
+func (i *int64Value) set(s string) bool {
+       v, ok := atoi(s);
+       *i.p = v;
+       return ok;
 }
 
-func (i *int64Value) str() string {
+func (i *int64Value) String() string {
        return fmt.Sprintf("%v", *i.p)
 }
 
@@ -169,11 +175,13 @@ func newUintValue(val uint, p *uint) *uintValue {
        return &uintValue(p)
 }
 
-func (i *uintValue) set(val uint) {
-       *i.p = val
+func (i *uintValue) set(s string) bool {
+       v, ok := atoi(s);       // TODO(r): want unsigned
+       *i.p = uint(v);
+       return ok;
 }
 
-func (i *uintValue) str() string {
+func (i *uintValue) String() string {
        return fmt.Sprintf("%v", *i.p)
 }
 
@@ -187,11 +195,13 @@ func newUint64Value(val uint64, p *uint64) *uint64Value {
        return &uint64Value(p)
 }
 
-func (i *uint64Value) set(val uint64) {
-       *i.p = val;
+func (i *uint64Value) set(s string) bool {
+       v, ok := atoi(s);       // TODO(r): want unsigned
+       *i.p = uint64(v);
+       return ok;
 }
 
-func (i *uint64Value) str() string {
+func (i *uint64Value) String() string {
        return fmt.Sprintf("%v", *i.p)
 }
 
@@ -205,25 +215,27 @@ func newStringValue(val string, p *string) *stringValue {
        return &stringValue(p)
 }
 
-func (s *stringValue) set(val string) {
+func (s *stringValue) set(val string) bool {
        *s.p = val;
+       return true;
 }
 
-func (s *stringValue) str() string {
+func (s *stringValue) String() string {
        return fmt.Sprintf("%#q", *s.p)
 }
 
-// -- Value interface
-type _Value interface {
-       str() string;
+// -- FlagValue interface
+type FlagValue interface {
+       String() string;
+       set(string) bool;
 }
 
-// -- Flag structure (internal)
+// -- Flag structure
 type Flag struct {
-       name    string; // name as it appears on command line
-       usage   string; // help message
-       value   _Value; // value as set
-       defvalue        string; // default value (as text); for usage message
+       Name    string; // name as it appears on command line
+       Usage   string; // help message
+       Value   FlagValue;      // value as set
+       DefValue        string; // default value (as text); for usage message
 }
 
 type allFlags struct {
@@ -234,10 +246,45 @@ type allFlags struct {
 
 var flags *allFlags = &allFlags(make(map[string] *Flag), make(map[string] *Flag), 1)
 
-func PrintDefaults() {
+// Visit all flags, including those defined but not set.
+func VisitAll(fn func(*Flag)) {
        for k, f := range flags.formal {
-               print("  -", f.name, "=", f.defvalue, ": ", f.usage, "\n");
+               fn(f)
+       }
+}
+
+// Visit only those flags that have been set
+func Visit(fn func(*Flag)) {
+       for k, f := range flags.actual {
+               fn(f)
+       }
+}
+
+func Lookup(name string) *Flag {
+       f, ok := flags.formal[name];
+       if !ok {
+               return nil
+       }
+       return f
+}
+
+func Set(name, value string) bool {
+       f, ok := flags.formal[name];
+       if !ok {
+               return false
+       }
+       ok = f.Value.set(value);
+       if !ok {
+               return false
        }
+       flags.actual[name] = f;
+       return true;
+}
+
+func PrintDefaults() {
+       VisitAll(func(f *Flag) {
+               print("  -", f.Name, "=", f.DefValue, ": ", f.Usage, "\n");
+       })
 }
 
 func Usage() {
@@ -266,12 +313,9 @@ func NArg() int {
        return len(sys.Args) - flags.first_arg
 }
 
-func add(name string, value _Value, usage string) {
-       f := new(Flag);
-       f.name = name;
-       f.usage = usage;
-       f.value = value;
-       f.defvalue = value.str();       // Remember the default value as a string; it won't change.
+func add(name string, value FlagValue, usage string) {
+       // Remember the default value as a string; it won't change.
+       f := &Flag(name, usage, value, value.String());
        dummy, alreadythere := flags.formal[name];
        if alreadythere {
                print("flag redefined: ", name, "\n");
@@ -388,16 +432,14 @@ func (f *allFlags) parseOne(index int) (ok bool, next int)
                print("flag provided but not defined: -", name, "\n");
                Usage();
        }
-       if f, ok := flag.value.(*boolValue); ok {
+       if f, ok := flag.Value.(*boolValue); ok {       // special case: doesn't need an arg
                if has_value {
-                       k, ok := atob(value);
-                       if !ok {
+                       if !f.set(value) {
                                print("invalid boolean value ", value, " for flag: -", name, "\n");
                                Usage();
                        }
-                       f.set(k)
                } else {
-                       f.set(true)
+                       f.set("true")
                }
        } else {
                // It must have a value, which might be the next argument.
@@ -411,24 +453,10 @@ func (f *allFlags) parseOne(index int) (ok bool, next int)
                        print("flag needs an argument: -", name, "\n");
                        Usage();
                }
-               if f, ok := flag.value.(*stringValue); ok {
-                       f.set(value)
-               } else {
-                       // It's an integer flag.  TODO(r): check for overflow?
-                       k, ok := atoi(value);
-                       if !ok {
-                               print("invalid integer value ", value, " for flag: -", name, "\n");
+               ok = flag.Value.set(value);
+               if !ok {
+                       print("invalid value ", value, " for flag: -", name, "\n");
                                Usage();
-                       }
-                       if f, ok := flag.value.(*intValue); ok {
-                               f.set(int(k));
-                       } else if f, ok := flag.value.(*int64Value); ok {
-                               f.set(k);
-                       } else if f, ok := flag.value.(*uintValue); ok {
-                               f.set(uint(k));
-                       } else if f, ok := flag.value.(*uint64Value); ok {
-                               f.set(uint64(k));
-                       }
                }
        }
        flags.actual[name] = flag;
diff --git a/src/lib/flag_test.go b/src/lib/flag_test.go
new file mode 100644 (file)
index 0000000..1212cf8
--- /dev/null
@@ -0,0 +1,56 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package flag
+
+import (
+       "flag";
+       "fmt";
+       "testing";
+)
+
+var (
+       test_bool = flag.Bool("test_bool", true, "bool value");
+       test_int = flag.Int("test_int", 1, "int value");
+       test_int64 = flag.Int64("test_int64", 1, "int64 value");
+       test_uint = flag.Uint("test_uint", 1, "uint value");
+       test_uint64 = flag.Uint64("test_uint64", 1, "uint64 value");
+       test_string = flag.String("test_string", "1", "string value");
+)
+
+// Because this calls flag.Parse, it needs to be the only Test* function
+func TestEverything(t *testing.T) {
+       flag.Parse();
+       m := make(map[string] *flag.Flag);
+       visitor := func(f *flag.Flag) {
+               if len(f.Name) > 5 && f.Name[0:5] == "test_" {
+                       m[f.Name] = f
+               }
+       };
+       flag.VisitAll(visitor);
+       if len(m) != 6 {
+               t.Error("flag.VisitAll misses some flags");
+               for k, v := range m {
+                       t.Log(k, *v)
+               }
+       }
+       m = make(map[string] *flag.Flag);
+       flag.Visit(visitor);
+       if len(m) != 0 {
+               t.Errorf("flag.Visit sees unset flags");
+               for k, v := range m {
+                       t.Log(k, *v)
+               }
+       }
+       // Now set some flags
+       flag.Set("test_bool", "false");
+       flag.Set("test_uint", "1234");
+       flag.Visit(visitor);
+       if len(m) != 2 {
+               t.Error("flag.Visit fails after set");
+               for k, v := range m {
+                       t.Log(k, *v)
+               }
+       }
+}