]> Cypherpunks repositories - gostls13.git/commitdiff
flags: better tests.
authorRob Pike <r@golang.org>
Tue, 6 Apr 2010 23:46:52 +0000 (16:46 -0700)
committerRob Pike <r@golang.org>
Tue, 6 Apr 2010 23:46:52 +0000 (16:46 -0700)
R=rsc
CC=golang-dev
https://golang.org/cl/864044

src/pkg/flag/flag.go
src/pkg/flag/flag_test.go

index a0cb4f5cae9d5d4896531f69503d338ba8454ebd..9457e9bfc58755d2bab4143245eb7d686aa88207 100644 (file)
@@ -221,7 +221,7 @@ type allFlags struct {
        first_arg int // 0 is the program name, 1 is first arg
 }
 
-var flags *allFlags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1}
+var flags *allFlags
 
 // VisitAll visits the flags, calling fn for each. It visits all flags, even those not set.
 func VisitAll(fn func(*Flag)) {
@@ -276,6 +276,16 @@ var Usage = func() {
        PrintDefaults()
 }
 
+var panicOnError = false
+
+func fail() {
+       Usage()
+       if panicOnError {
+               panic("flag parse error")
+       }
+       os.Exit(2)
+}
+
 func NFlag() int { return len(flags.actual) }
 
 // Arg returns the i'th command-line argument.  Arg(0) is the first remaining argument
@@ -442,8 +452,7 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) {
        name := s[num_minuses:]
        if len(name) == 0 || name[0] == '-' || name[0] == '=' {
                fmt.Fprintln(os.Stderr, "bad flag syntax:", s)
-               Usage()
-               os.Exit(2)
+               fail()
        }
 
        // it's a flag. does it have an argument?
@@ -461,15 +470,13 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) {
        flag, alreadythere := m[name] // BUG
        if !alreadythere {
                fmt.Fprintf(os.Stderr, "flag provided but not defined: -%s\n", name)
-               Usage()
-               os.Exit(2)
+               fail()
        }
        if f, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
                if has_value {
                        if !f.Set(value) {
                                fmt.Fprintf(os.Stderr, "invalid boolean value %t for flag: -%s\n", value, name)
-                               Usage()
-                               os.Exit(2)
+                               fail()
                        }
                } else {
                        f.Set("true")
@@ -484,14 +491,12 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) {
                }
                if !has_value {
                        fmt.Fprintf(os.Stderr, "flag needs an argument: -%s\n", name)
-                       Usage()
-                       os.Exit(2)
+                       fail()
                }
                ok = flag.Value.Set(value)
                if !ok {
                        fmt.Fprintf(os.Stderr, "invalid value %s for flag: -%s\n", value, name)
-                       Usage()
-                       os.Exit(2)
+                       fail()
                }
        }
        flags.actual[name] = flag
@@ -512,3 +517,32 @@ func Parse() {
                }
        }
 }
+
+// ResetForTesting clears all flag state and sets the usage function as directed.
+// After calling ResetForTesting, parse errors in flag handling will panic rather
+// than exit the program.
+// For testing only!
+func ResetForTesting(usage func()) {
+       flags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1}
+       Usage = usage
+       panicOnError = true
+}
+
+// ParseForTesting parses the flag state using the provided arguments. It
+// should be called after 1) ResetForTesting and 2) setting up the new flags.
+// The return value reports whether the parse was error-free.
+// For testing only!
+func ParseForTesting(args []string) (result bool) {
+       defer func() {
+               if recover() != nil {
+                       result = false
+               }
+       }()
+       os.Args = args
+       Parse()
+       return true
+}
+
+func init() {
+       flags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1}
+}
index 03e8a3e229eaecf6e29d22f7b144687739a41f0f..83bf7eebf85c9d3b7a939abb871a4504604de81d 100644 (file)
@@ -6,6 +6,7 @@ package flag_test
 
 import (
        . "flag"
+       "fmt"
        "testing"
 )
 
@@ -79,21 +80,106 @@ func TestEverything(t *testing.T) {
        }
 }
 
+func TestUsage(t *testing.T) {
+       called := false
+       ResetForTesting(func() { called = true })
+       if ParseForTesting([]string{"a.out", "-x"}) {
+               t.Error("parse did not fail for unknown flag")
+       }
+       if !called {
+               t.Error("did not call Usage for unknown flag")
+       }
+}
+
+func TestParse(t *testing.T) {
+       ResetForTesting(func() { t.Error("bad parse") })
+       boolFlag := Bool("bool", false, "bool value")
+       bool2Flag := Bool("bool2", false, "bool2 value")
+       intFlag := Int("int", 0, "int value")
+       int64Flag := Int64("int64", 0, "int64 value")
+       uintFlag := Uint("uint", 0, "uint value")
+       uint64Flag := Uint64("uint64", 0, "uint64 value")
+       stringFlag := String("string", "0", "string value")
+       floatFlag := Float("float", 0, "float value")
+       float64Flag := Float("float64", 0, "float64 value")
+       extra := "one-extra-argument"
+       args := []string{
+               "a.out",
+               "-bool",
+               "-bool2=true",
+               "--int", "22",
+               "--int64", "23",
+               "-uint", "24",
+               "--uint64", "25",
+               "-string", "hello",
+               "--float", "3141.5",
+               "-float64", "2718e28",
+               extra,
+       }
+       if !ParseForTesting(args) {
+               t.Fatal("parse failed")
+       }
+       if *boolFlag != true {
+               t.Error("bool flag should be true, is ", *boolFlag)
+       }
+       if *bool2Flag != true {
+               t.Error("bool2 flag should be true, is ", *bool2Flag)
+       }
+       if *intFlag != 22 {
+               t.Error("int flag should be 22, is ", *intFlag)
+       }
+       if *int64Flag != 23 {
+               t.Error("int64 flag should be 23, is ", *int64Flag)
+       }
+       if *uintFlag != 24 {
+               t.Error("uint flag should be 24, is ", *uintFlag)
+       }
+       if *uint64Flag != 25 {
+               t.Error("uint64 flag should be 25, is ", *uint64Flag)
+       }
+       if *stringFlag != "hello" {
+               t.Error("string flag should be `hello`, is ", *stringFlag)
+       }
+       if *floatFlag != 3141.5 {
+               t.Error("float flag should be 3141.5, is ", *floatFlag)
+       }
+       if *float64Flag != 2718e28 {
+               t.Error("float64 flag should be 2718e28, is ", *float64Flag)
+       }
+       if len(Args()) != 1 {
+               t.Error("expected one argument, got", len(Args()))
+       } else if Args()[0] != extra {
+               t.Errorf("expected argument %q got %q", extra, Args()[0])
+       }
+}
+
 // Declare a user-defined flag.
-// TODO: do the work to make this test better by resetting flag state
-// and manipulating os.Args.
 type flagVar []string
 
 func (f *flagVar) String() string {
-       return "foo"
+       return fmt.Sprint([]string(*f))
 }
 
 func (f *flagVar) Set(value string) bool {
+       n := make(flagVar, len(*f)+1)
+       copy(n, *f)
+       *f = n
+       (*f)[len(*f)-1] = value
        return true
 }
 
-var v flagVar
-
-func init() {
-       Var(&v, "testV", "usage")
+func TestUserDefined(t *testing.T) {
+       ResetForTesting(func() { t.Fatal("bad parse") })
+       var v flagVar
+       Var(&v, "v", "usage")
+       if !ParseForTesting([]string{"a.out", "-v", "1", "-v", "2", "-v=3"}) {
+               t.Error("parse failed")
+       }
+       if len(v) != 3 {
+               t.Fatal("expected 3 args; got ", len(v))
+       }
+       expect := "[1 2 3]"
+       if v.String() != expect {
+               t.Errorf("expected value %q got %q", expect, v.String())
+       }
 }