From 16328513bfb12d96e8f33fc37f816e1441027135 Mon Sep 17 00:00:00 2001 From: Carl Johnson Date: Thu, 27 Aug 2020 13:08:44 +0000 Subject: [PATCH] flag: add Func Fixes #39557 Change-Id: Ida578f7484335e8c6bf927255f75377eda63b563 GitHub-Last-Rev: b97294f7669c24011e5b093179d65636512a84cd GitHub-Pull-Request: golang/go#39880 Reviewed-on: https://go-review.googlesource.com/c/go/+/240014 Reviewed-by: Russ Cox Trust: Ian Lance Taylor --- src/flag/example_func_test.go | 41 ++++++++++++++++++++++++++++ src/flag/flag.go | 22 ++++++++++++++- src/flag/flag_test.go | 50 +++++++++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 src/flag/example_func_test.go diff --git a/src/flag/example_func_test.go b/src/flag/example_func_test.go new file mode 100644 index 0000000000..7c30c5e713 --- /dev/null +++ b/src/flag/example_func_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flag_test + +import ( + "errors" + "flag" + "fmt" + "net" + "os" +) + +func ExampleFunc() { + fs := flag.NewFlagSet("ExampleFunc", flag.ContinueOnError) + fs.SetOutput(os.Stdout) + var ip net.IP + fs.Func("ip", "`IP address` to parse", func(s string) error { + ip = net.ParseIP(s) + if ip == nil { + return errors.New("could not parse IP") + } + return nil + }) + fs.Parse([]string{"-ip", "127.0.0.1"}) + fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback()) + + // 256 is not a valid IPv4 component + fs.Parse([]string{"-ip", "256.0.0.1"}) + fmt.Printf("{ip: %v, loopback: %t}\n\n", ip, ip.IsLoopback()) + + // Output: + // {ip: 127.0.0.1, loopback: true} + // + // invalid value "256.0.0.1" for flag -ip: could not parse IP + // Usage of ExampleFunc: + // -ip IP address + // IP address to parse + // {ip: , loopback: false} +} diff --git a/src/flag/flag.go b/src/flag/flag.go index 286bba6873..a8485f034f 100644 --- a/src/flag/flag.go +++ b/src/flag/flag.go @@ -278,6 +278,12 @@ func (d *durationValue) Get() interface{} { return time.Duration(*d) } func (d *durationValue) String() string { return (*time.Duration)(d).String() } +type funcValue func(string) error + +func (f funcValue) Set(s string) error { return f(s) } + +func (f funcValue) String() string { return "" } + // Value is the interface to the dynamic value stored in a flag. // (The default value is represented as a string.) // @@ -296,7 +302,7 @@ type Value interface { // Getter is an interface that allows the contents of a Value to be retrieved. // It wraps the Value interface, rather than being part of it, because it // appeared after Go 1 and its compatibility rules. All Value types provided -// by this package satisfy the Getter interface. +// by this package satisfy the Getter interface, except the type used by Func. type Getter interface { Value Get() interface{} @@ -830,6 +836,20 @@ func Duration(name string, value time.Duration, usage string) *time.Duration { return CommandLine.Duration(name, value, usage) } +// Func defines a flag with the specified name and usage string. +// Each time the flag is seen, fn is called with the value of the flag. +// If fn returns a non-nil error, it will be treated as a flag value parsing error. +func (f *FlagSet) Func(name, usage string, fn func(string) error) { + f.Var(funcValue(fn), name, usage) +} + +// Func defines a flag with the specified name and usage string. +// Each time the flag is seen, fn is called with the value of the flag. +// If fn returns a non-nil error, it will be treated as a flag value parsing error. +func Func(name, usage string, fn func(string) error) { + CommandLine.Func(name, usage, fn) +} + // Var defines a flag with the specified name and usage string. The type and // value of the flag are represented by the first argument, of type Value, which // typically holds a user-defined implementation of Value. For instance, the diff --git a/src/flag/flag_test.go b/src/flag/flag_test.go index a01a5e4cea..2793064511 100644 --- a/src/flag/flag_test.go +++ b/src/flag/flag_test.go @@ -38,6 +38,7 @@ func TestEverything(t *testing.T) { String("test_string", "0", "string value") Float64("test_float64", 0, "float64 value") Duration("test_duration", 0, "time.Duration value") + Func("test_func", "func value", func(string) error { return nil }) m := make(map[string]*Flag) desired := "0" @@ -52,6 +53,8 @@ func TestEverything(t *testing.T) { ok = true case f.Name == "test_duration" && f.Value.String() == desired+"s": ok = true + case f.Name == "test_func" && f.Value.String() == "": + ok = true } if !ok { t.Error("Visit: bad value", f.Value.String(), "for", f.Name) @@ -59,7 +62,7 @@ func TestEverything(t *testing.T) { } } VisitAll(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("VisitAll misses some flags") for k, v := range m { t.Log(k, *v) @@ -82,9 +85,10 @@ func TestEverything(t *testing.T) { Set("test_string", "1") Set("test_float64", "1") Set("test_duration", "1s") + Set("test_func", "1") desired = "1" Visit(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("Visit fails after set") for k, v := range m { t.Log(k, *v) @@ -257,6 +261,48 @@ func TestUserDefined(t *testing.T) { } } +func TestUserDefinedFunc(t *testing.T) { + var flags FlagSet + flags.Init("test", ContinueOnError) + var ss []string + flags.Func("v", "usage", func(s string) error { + ss = append(ss, s) + return nil + }) + if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil { + t.Error(err) + } + if len(ss) != 3 { + t.Fatal("expected 3 args; got ", len(ss)) + } + expect := "[1 2 3]" + if got := fmt.Sprint(ss); got != expect { + t.Errorf("expected value %q got %q", expect, got) + } + // test usage + var buf strings.Builder + flags.SetOutput(&buf) + flags.Parse([]string{"-h"}) + if usage := buf.String(); !strings.Contains(usage, "usage") { + t.Errorf("usage string not included: %q", usage) + } + // test Func error + flags = *NewFlagSet("test", ContinueOnError) + flags.Func("v", "usage", func(s string) error { + return fmt.Errorf("test error") + }) + // flag not set, so no error + if err := flags.Parse(nil); err != nil { + t.Error(err) + } + // flag set, expect error + if err := flags.Parse([]string{"-v", "1"}); err == nil { + t.Error("expected error; got none") + } else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") { + t.Errorf(`error should contain "test error"; got %q`, errMsg) + } +} + func TestUserDefinedForCommandLine(t *testing.T) { const help = "HELP" var result string -- 2.48.1