]> Cypherpunks repositories - gostls13.git/commitdiff
flag: add (*FlagSet).Name, (*FlagSet).ErrorHandling, export (*FlagSet).Output
authorTim Cooper <tim.cooper@layeh.com>
Thu, 12 Oct 2017 19:51:01 +0000 (16:51 -0300)
committerIan Lance Taylor <iant@golang.org>
Tue, 31 Oct 2017 03:54:16 +0000 (03:54 +0000)
Allows code that operates on a FlagSet to know the name and error
handling behavior of the FlagSet without having to call FlagSet.Init.

Fixes #17628
Fixes #21888

Change-Id: Ib0fe4c8885f9ccdacf5a7fb761d5ecb23f3bb055
Reviewed-on: https://go-review.googlesource.com/70391
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/flag/flag.go
src/flag/flag_test.go

index fde7411f82ce1e418423964e42d7b75b60850e02..d638e49b42618f84dc706e99a286b60fa195c039 100644 (file)
@@ -308,13 +308,25 @@ func sortFlags(flags map[string]*Flag) []*Flag {
        return result
 }
 
-func (f *FlagSet) out() io.Writer {
+// Output returns the destination for usage and error messages. os.Stderr is returned if
+// output was not set or was set to nil.
+func (f *FlagSet) Output() io.Writer {
        if f.output == nil {
                return os.Stderr
        }
        return f.output
 }
 
+// Name returns the name of the flag set.
+func (f *FlagSet) Name() string {
+       return f.name
+}
+
+// ErrorHandling returns the error handling behavior of the flag set.
+func (f *FlagSet) ErrorHandling() ErrorHandling {
+       return f.errorHandling
+}
+
 // SetOutput sets the destination for usage and error messages.
 // If output is nil, os.Stderr is used.
 func (f *FlagSet) SetOutput(output io.Writer) {
@@ -474,7 +486,7 @@ func (f *FlagSet) PrintDefaults() {
                                s += fmt.Sprintf(" (default %v)", flag.DefValue)
                        }
                }
-               fmt.Fprint(f.out(), s, "\n")
+               fmt.Fprint(f.Output(), s, "\n")
        })
 }
 
@@ -504,9 +516,9 @@ func PrintDefaults() {
 // defaultUsage is the default function to print a usage message.
 func (f *FlagSet) defaultUsage() {
        if f.name == "" {
-               fmt.Fprintf(f.out(), "Usage:\n")
+               fmt.Fprintf(f.Output(), "Usage:\n")
        } else {
-               fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
+               fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name)
        }
        f.PrintDefaults()
 }
@@ -525,7 +537,7 @@ func (f *FlagSet) defaultUsage() {
 // happens anyway as the command line's error handling strategy is set to
 // ExitOnError.
 var Usage = func() {
-       fmt.Fprintf(CommandLine.out(), "Usage of %s:\n", os.Args[0])
+       fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0])
        PrintDefaults()
 }
 
@@ -793,7 +805,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) {
                } else {
                        msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
                }
-               fmt.Fprintln(f.out(), msg)
+               fmt.Fprintln(f.Output(), msg)
                panic(msg) // Happens only if flags are declared with identical names
        }
        if f.formal == nil {
@@ -816,7 +828,7 @@ func Var(value Value, name string, usage string) {
 // returns the error.
 func (f *FlagSet) failf(format string, a ...interface{}) error {
        err := fmt.Errorf(format, a...)
-       fmt.Fprintln(f.out(), err)
+       fmt.Fprintln(f.Output(), err)
        f.usage()
        return err
 }
index 4c6db96ba0673331e2325fd2bfbd7f13b4478ce7..67c409f29ba2f64f3ed147087884dfef34773cc8 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bytes"
        . "flag"
        "fmt"
+       "io"
        "os"
        "sort"
        "strconv"
@@ -454,3 +455,36 @@ func TestUsageOutput(t *testing.T) {
                t.Errorf("output = %q; want %q", got, want)
        }
 }
+
+func TestGetters(t *testing.T) {
+       expectedName := "flag set"
+       expectedErrorHandling := ContinueOnError
+       expectedOutput := io.Writer(os.Stderr)
+       fs := NewFlagSet(expectedName, expectedErrorHandling)
+
+       if fs.Name() != expectedName {
+               t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
+       }
+       if fs.ErrorHandling() != expectedErrorHandling {
+               t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
+       }
+       if fs.Output() != expectedOutput {
+               t.Errorf("unexpected output: got %#v, expected %#v", fs.Output(), expectedOutput)
+       }
+
+       expectedName = "gopher"
+       expectedErrorHandling = ExitOnError
+       expectedOutput = os.Stdout
+       fs.Init(expectedName, expectedErrorHandling)
+       fs.SetOutput(expectedOutput)
+
+       if fs.Name() != expectedName {
+               t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
+       }
+       if fs.ErrorHandling() != expectedErrorHandling {
+               t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
+       }
+       if fs.Output() != expectedOutput {
+               t.Errorf("unexpected output: got %v, expected %v", fs.Output(), expectedOutput)
+       }
+}