]> Cypherpunks repositories - gostls13.git/commitdiff
flag: add Duration flag type.
authorDavid Symonds <dsymonds@golang.org>
Fri, 23 Dec 2011 05:29:38 +0000 (16:29 +1100)
committerDavid Symonds <dsymonds@golang.org>
Fri, 23 Dec 2011 05:29:38 +0000 (16:29 +1100)
This works in the expected way: flag.Duration returns a *time.Duration,
and uses time.ParseDuration for parsing the input.

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5489113

src/pkg/flag/flag.go
src/pkg/flag/flag_test.go

index 406ea77799d4358f47ac5183a85073b95a78d50f..86e34eec125d35dcf43df757d692cd5cdc3e141d 100644 (file)
@@ -65,12 +65,13 @@ import (
        "os"
        "sort"
        "strconv"
+       "time"
 )
 
 // ErrHelp is the error returned if the flag -help is invoked but no such flag is defined.
 var ErrHelp = errors.New("flag: help requested")
 
-// -- Bool Value
+// -- bool Value
 type boolValue bool
 
 func newBoolValue(val bool, p *bool) *boolValue {
@@ -86,7 +87,7 @@ func (b *boolValue) Set(s string) bool {
 
 func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) }
 
-// -- Int Value
+// -- int Value
 type intValue int
 
 func newIntValue(val int, p *int) *intValue {
@@ -102,7 +103,7 @@ func (i *intValue) Set(s string) bool {
 
 func (i *intValue) String() string { return fmt.Sprintf("%v", *i) }
 
-// -- Int64 Value
+// -- int64 Value
 type int64Value int64
 
 func newInt64Value(val int64, p *int64) *int64Value {
@@ -118,7 +119,7 @@ func (i *int64Value) Set(s string) bool {
 
 func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) }
 
-// -- Uint Value
+// -- uint Value
 type uintValue uint
 
 func newUintValue(val uint, p *uint) *uintValue {
@@ -165,7 +166,7 @@ func (s *stringValue) Set(val string) bool {
 
 func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) }
 
-// -- Float64 Value
+// -- float64 Value
 type float64Value float64
 
 func newFloat64Value(val float64, p *float64) *float64Value {
@@ -181,6 +182,22 @@ func (f *float64Value) Set(s string) bool {
 
 func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) }
 
+// -- time.Duration Value
+type durationValue time.Duration
+
+func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
+       *p = val
+       return (*durationValue)(p)
+}
+
+func (d *durationValue) Set(s string) bool {
+       v, err := time.ParseDuration(s)
+       *d = durationValue(v)
+       return err == nil
+}
+
+func (d *durationValue) String() string { return (*time.Duration)(d).String() }
+
 // Value is the interface to the dynamic value stored in a flag.
 // (The default value is represented as a string.)
 type Value interface {
@@ -543,12 +560,38 @@ func (f *FlagSet) Float64(name string, value float64, usage string) *float64 {
        return p
 }
 
-// Float64 defines an int flag with specified name, default value, and usage string.
+// Float64 defines a float64 flag with specified name, default value, and usage string.
 // The return value is the address of a float64 variable that stores the value of the flag.
 func Float64(name string, value float64, usage string) *float64 {
        return commandLine.Float64(name, value, usage)
 }
 
+// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
+// The argument p points to a time.Duration variable in which to store the value of the flag.
+func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
+       f.Var(newDurationValue(value, p), name, usage)
+}
+
+// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
+// The argument p points to a time.Duration variable in which to store the value of the flag.
+func DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
+       commandLine.Var(newDurationValue(value, p), name, usage)
+}
+
+// Duration defines a time.Duration flag with specified name, default value, and usage string.
+// The return value is the address of a time.Duration variable that stores the value of the flag.
+func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
+       p := new(time.Duration)
+       f.DurationVar(p, name, value, usage)
+       return p
+}
+
+// Duration defines a time.Duration flag with specified name, default value, and usage string.
+// The return value is the address of a time.Duration variable that stores the value of the flag.
+func Duration(name string, value time.Duration, usage string) *time.Duration {
+       return commandLine.Duration(name, value, usage)
+}
+
 // Var defines a flag with the specified name and usage string. The type and
 // value of the flag are represented by the first argument, of type Value, which
 // typically holds a user-defined implementation of Value. For instance, the
index f13531669c1914dcd9ae9027c042ac01e4b76df2..06a7b5c195a887a6d678fd930493d07977a97301 100644 (file)
@@ -10,16 +10,18 @@ import (
        "os"
        "sort"
        "testing"
+       "time"
 )
 
 var (
-       test_bool    = Bool("test_bool", false, "bool value")
-       test_int     = Int("test_int", 0, "int value")
-       test_int64   = Int64("test_int64", 0, "int64 value")
-       test_uint    = Uint("test_uint", 0, "uint value")
-       test_uint64  = Uint64("test_uint64", 0, "uint64 value")
-       test_string  = String("test_string", "0", "string value")
-       test_float64 = Float64("test_float64", 0, "float64 value")
+       test_bool     = Bool("test_bool", false, "bool value")
+       test_int      = Int("test_int", 0, "int value")
+       test_int64    = Int64("test_int64", 0, "int64 value")
+       test_uint     = Uint("test_uint", 0, "uint value")
+       test_uint64   = Uint64("test_uint64", 0, "uint64 value")
+       test_string   = String("test_string", "0", "string value")
+       test_float64  = Float64("test_float64", 0, "float64 value")
+       test_duration = Duration("test_duration", 0, "time.Duration value")
 )
 
 func boolString(s string) string {
@@ -41,6 +43,8 @@ func TestEverything(t *testing.T) {
                                ok = true
                        case f.Name == "test_bool" && f.Value.String() == boolString(desired):
                                ok = true
+                       case f.Name == "test_duration" && f.Value.String() == desired+"s":
+                               ok = true
                        }
                        if !ok {
                                t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
@@ -48,7 +52,7 @@ func TestEverything(t *testing.T) {
                }
        }
        VisitAll(visitor)
-       if len(m) != 7 {
+       if len(m) != 8 {
                t.Error("VisitAll misses some flags")
                for k, v := range m {
                        t.Log(k, *v)
@@ -70,9 +74,10 @@ func TestEverything(t *testing.T) {
        Set("test_uint64", "1")
        Set("test_string", "1")
        Set("test_float64", "1")
+       Set("test_duration", "1s")
        desired = "1"
        Visit(visitor)
-       if len(m) != 7 {
+       if len(m) != 8 {
                t.Error("Visit fails after set")
                for k, v := range m {
                        t.Log(k, *v)
@@ -109,6 +114,7 @@ func testParse(f *FlagSet, t *testing.T) {
        uint64Flag := f.Uint64("uint64", 0, "uint64 value")
        stringFlag := f.String("string", "0", "string value")
        float64Flag := f.Float64("float64", 0, "float64 value")
+       durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
        extra := "one-extra-argument"
        args := []string{
                "-bool",
@@ -119,6 +125,7 @@ func testParse(f *FlagSet, t *testing.T) {
                "--uint64", "25",
                "-string", "hello",
                "-float64", "2718e28",
+               "-duration", "2m",
                extra,
        }
        if err := f.Parse(args); err != nil {
@@ -151,6 +158,9 @@ func testParse(f *FlagSet, t *testing.T) {
        if *float64Flag != 2718e28 {
                t.Error("float64 flag should be 2718e28, is ", *float64Flag)
        }
+       if *durationFlag != 2*time.Minute {
+               t.Error("duration flag should be 2m, is ", *durationFlag)
+       }
        if len(f.Args()) != 1 {
                t.Error("expected one argument, got", len(f.Args()))
        } else if f.Args()[0] != extra {