From 6431b984db2cf902e344285173cad793512923da Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Tue, 6 Apr 2010 16:46:52 -0700 Subject: [PATCH] flags: better tests. R=rsc CC=golang-dev https://golang.org/cl/864044 --- src/pkg/flag/flag.go | 56 ++++++++++++++++----- src/pkg/flag/flag_test.go | 100 +++++++++++++++++++++++++++++++++++--- 2 files changed, 138 insertions(+), 18 deletions(-) diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go index a0cb4f5cae..9457e9bfc5 100644 --- a/src/pkg/flag/flag.go +++ b/src/pkg/flag/flag.go @@ -221,7 +221,7 @@ type allFlags struct { first_arg int // 0 is the program name, 1 is first arg } -var flags *allFlags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1} +var flags *allFlags // VisitAll visits the flags, calling fn for each. It visits all flags, even those not set. func VisitAll(fn func(*Flag)) { @@ -276,6 +276,16 @@ var Usage = func() { PrintDefaults() } +var panicOnError = false + +func fail() { + Usage() + if panicOnError { + panic("flag parse error") + } + os.Exit(2) +} + func NFlag() int { return len(flags.actual) } // Arg returns the i'th command-line argument. Arg(0) is the first remaining argument @@ -442,8 +452,7 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) { name := s[num_minuses:] if len(name) == 0 || name[0] == '-' || name[0] == '=' { fmt.Fprintln(os.Stderr, "bad flag syntax:", s) - Usage() - os.Exit(2) + fail() } // it's a flag. does it have an argument? @@ -461,15 +470,13 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) { flag, alreadythere := m[name] // BUG if !alreadythere { fmt.Fprintf(os.Stderr, "flag provided but not defined: -%s\n", name) - Usage() - os.Exit(2) + fail() } if f, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg if has_value { if !f.Set(value) { fmt.Fprintf(os.Stderr, "invalid boolean value %t for flag: -%s\n", value, name) - Usage() - os.Exit(2) + fail() } } else { f.Set("true") @@ -484,14 +491,12 @@ func (f *allFlags) parseOne(index int) (ok bool, next int) { } if !has_value { fmt.Fprintf(os.Stderr, "flag needs an argument: -%s\n", name) - Usage() - os.Exit(2) + fail() } ok = flag.Value.Set(value) if !ok { fmt.Fprintf(os.Stderr, "invalid value %s for flag: -%s\n", value, name) - Usage() - os.Exit(2) + fail() } } flags.actual[name] = flag @@ -512,3 +517,32 @@ func Parse() { } } } + +// ResetForTesting clears all flag state and sets the usage function as directed. +// After calling ResetForTesting, parse errors in flag handling will panic rather +// than exit the program. +// For testing only! +func ResetForTesting(usage func()) { + flags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1} + Usage = usage + panicOnError = true +} + +// ParseForTesting parses the flag state using the provided arguments. It +// should be called after 1) ResetForTesting and 2) setting up the new flags. +// The return value reports whether the parse was error-free. +// For testing only! +func ParseForTesting(args []string) (result bool) { + defer func() { + if recover() != nil { + result = false + } + }() + os.Args = args + Parse() + return true +} + +func init() { + flags = &allFlags{make(map[string]*Flag), make(map[string]*Flag), 1} +} diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go index 03e8a3e229..83bf7eebf8 100644 --- a/src/pkg/flag/flag_test.go +++ b/src/pkg/flag/flag_test.go @@ -6,6 +6,7 @@ package flag_test import ( . "flag" + "fmt" "testing" ) @@ -79,21 +80,106 @@ func TestEverything(t *testing.T) { } } +func TestUsage(t *testing.T) { + called := false + ResetForTesting(func() { called = true }) + if ParseForTesting([]string{"a.out", "-x"}) { + t.Error("parse did not fail for unknown flag") + } + if !called { + t.Error("did not call Usage for unknown flag") + } +} + +func TestParse(t *testing.T) { + ResetForTesting(func() { t.Error("bad parse") }) + boolFlag := Bool("bool", false, "bool value") + bool2Flag := Bool("bool2", false, "bool2 value") + intFlag := Int("int", 0, "int value") + int64Flag := Int64("int64", 0, "int64 value") + uintFlag := Uint("uint", 0, "uint value") + uint64Flag := Uint64("uint64", 0, "uint64 value") + stringFlag := String("string", "0", "string value") + floatFlag := Float("float", 0, "float value") + float64Flag := Float("float64", 0, "float64 value") + extra := "one-extra-argument" + args := []string{ + "a.out", + "-bool", + "-bool2=true", + "--int", "22", + "--int64", "23", + "-uint", "24", + "--uint64", "25", + "-string", "hello", + "--float", "3141.5", + "-float64", "2718e28", + extra, + } + if !ParseForTesting(args) { + t.Fatal("parse failed") + } + if *boolFlag != true { + t.Error("bool flag should be true, is ", *boolFlag) + } + if *bool2Flag != true { + t.Error("bool2 flag should be true, is ", *bool2Flag) + } + if *intFlag != 22 { + t.Error("int flag should be 22, is ", *intFlag) + } + if *int64Flag != 23 { + t.Error("int64 flag should be 23, is ", *int64Flag) + } + if *uintFlag != 24 { + t.Error("uint flag should be 24, is ", *uintFlag) + } + if *uint64Flag != 25 { + t.Error("uint64 flag should be 25, is ", *uint64Flag) + } + if *stringFlag != "hello" { + t.Error("string flag should be `hello`, is ", *stringFlag) + } + if *floatFlag != 3141.5 { + t.Error("float flag should be 3141.5, is ", *floatFlag) + } + if *float64Flag != 2718e28 { + t.Error("float64 flag should be 2718e28, is ", *float64Flag) + } + if len(Args()) != 1 { + t.Error("expected one argument, got", len(Args())) + } else if Args()[0] != extra { + t.Errorf("expected argument %q got %q", extra, Args()[0]) + } +} + // Declare a user-defined flag. -// TODO: do the work to make this test better by resetting flag state -// and manipulating os.Args. type flagVar []string func (f *flagVar) String() string { - return "foo" + return fmt.Sprint([]string(*f)) } func (f *flagVar) Set(value string) bool { + n := make(flagVar, len(*f)+1) + copy(n, *f) + *f = n + (*f)[len(*f)-1] = value return true } -var v flagVar - -func init() { - Var(&v, "testV", "usage") +func TestUserDefined(t *testing.T) { + ResetForTesting(func() { t.Fatal("bad parse") }) + var v flagVar + Var(&v, "v", "usage") + if !ParseForTesting([]string{"a.out", "-v", "1", "-v", "2", "-v=3"}) { + t.Error("parse failed") + } + if len(v) != 3 { + t.Fatal("expected 3 args; got ", len(v)) + } + expect := "[1 2 3]" + if v.String() != expect { + t.Errorf("expected value %q got %q", expect, v.String()) + } } -- 2.50.0