]> Cypherpunks repositories - gostls13.git/commitdiff
flag: allow a FlagSet to not write to os.Stderr
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 27 Jan 2012 17:23:06 +0000 (09:23 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 27 Jan 2012 17:23:06 +0000 (09:23 -0800)
Fixes #2747

R=golang-dev, gri, r, rogpeppe, r
CC=golang-dev
https://golang.org/cl/5564065

src/pkg/flag/flag.go
src/pkg/flag/flag_test.go

index 964f5541b86c2f53d76c36aee21d3e355562b7ee..1719af89a1237dfb78338aeb0b433debb74152e6 100644 (file)
@@ -62,6 +62,7 @@ package flag
 import (
        "errors"
        "fmt"
+       "io"
        "os"
        "sort"
        "strconv"
@@ -228,6 +229,7 @@ type FlagSet struct {
        args          []string // arguments after flags
        exitOnError   bool     // does the program exit if there's an error?
        errorHandling ErrorHandling
+       output        io.Writer // nil means stderr; use out() accessor
 }
 
 // A Flag represents the state of a flag.
@@ -254,6 +256,19 @@ func sortFlags(flags map[string]*Flag) []*Flag {
        return result
 }
 
+func (f *FlagSet) out() io.Writer {
+       if f.output == nil {
+               return os.Stderr
+       }
+       return f.output
+}
+
+// SetOutput sets the destination for usage and error messages.
+// If output is nil, os.Stderr is used.
+func (f *FlagSet) SetOutput(output io.Writer) {
+       f.output = output
+}
+
 // VisitAll visits the flags in lexicographical order, calling fn for each.
 // It visits all flags, even those not set.
 func (f *FlagSet) VisitAll(fn func(*Flag)) {
@@ -315,15 +330,16 @@ func Set(name, value string) error {
        return commandLine.Set(name, value)
 }
 
-// PrintDefaults prints to standard error the default values of all defined flags in the set.
+// PrintDefaults prints, to standard error unless configured
+// otherwise, the default values of all defined flags in the set.
 func (f *FlagSet) PrintDefaults() {
-       f.VisitAll(func(f *Flag) {
+       f.VisitAll(func(flag *Flag) {
                format := "  -%s=%s: %s\n"
-               if _, ok := f.Value.(*stringValue); ok {
+               if _, ok := flag.Value.(*stringValue); ok {
                        // put quotes on the value
                        format = "  -%s=%q: %s\n"
                }
-               fmt.Fprintf(os.Stderr, format, f.Name, f.DefValue, f.Usage)
+               fmt.Fprintf(f.out(), format, flag.Name, flag.DefValue, flag.Usage)
        })
 }
 
@@ -334,7 +350,7 @@ func PrintDefaults() {
 
 // defaultUsage is the default function to print a usage message.
 func defaultUsage(f *FlagSet) {
-       fmt.Fprintf(os.Stderr, "Usage of %s:\n", f.name)
+       fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
        f.PrintDefaults()
 }
 
@@ -601,7 +617,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) {
        flag := &Flag{name, usage, value, value.String()}
        _, alreadythere := f.formal[name]
        if alreadythere {
-               fmt.Fprintf(os.Stderr, "%s flag redefined: %s\n", f.name, name)
+               fmt.Fprintf(f.out(), "%s flag redefined: %s\n", f.name, name)
                panic("flag redefinition") // Happens only if flags are declared with identical names
        }
        if f.formal == nil {
@@ -624,7 +640,7 @@ func Var(value Value, name string, usage string) {
 // returns the error.
 func (f *FlagSet) failf(format string, a ...interface{}) error {
        err := fmt.Errorf(format, a...)
-       fmt.Fprintln(os.Stderr, err)
+       fmt.Fprintln(f.out(), err)
        f.usage()
        return err
 }
index 698c15f2c58f681c34befbf4481411af0211a8a7..a9561f269f91ff32bfc31f1f820b28f626407985 100644 (file)
@@ -5,10 +5,12 @@
 package flag_test
 
 import (
+       "bytes"
        . "flag"
        "fmt"
        "os"
        "sort"
+       "strings"
        "testing"
        "time"
 )
@@ -206,6 +208,17 @@ func TestUserDefined(t *testing.T) {
        }
 }
 
+func TestSetOutput(t *testing.T) {
+       var flags FlagSet
+       var buf bytes.Buffer
+       flags.SetOutput(&buf)
+       flags.Init("test", ContinueOnError)
+       flags.Parse([]string{"-unknown"})
+       if out := buf.String(); !strings.Contains(out, "-unknown") {
+               t.Logf("expected output mentioning unknown; got %q", out)
+       }
+}
+
 // This tests that one can reset the flags. This still works but not well, and is
 // superseded by FlagSet.
 func TestChangingArgs(t *testing.T) {