]> Cypherpunks repositories - gostls13.git/commitdiff
gofmt: minor refactor to permit easy testing
authorRobert Griesemer <gri@golang.org>
Wed, 13 Apr 2011 20:57:53 +0000 (13:57 -0700)
committerRobert Griesemer <gri@golang.org>
Wed, 13 Apr 2011 20:57:53 +0000 (13:57 -0700)
R=rsc
CC=golang-dev
https://golang.org/cl/4397046

src/cmd/gofmt/gofmt.go

index 1ed5b9b993138f140d29b9914a1dc7ea13e3cdfb..ce274aa21b16c7a824efc0f7356ccb045abed328 100644 (file)
@@ -13,6 +13,7 @@ import (
        "go/printer"
        "go/scanner"
        "go/token"
+       "io"
        "io/ioutil"
        "os"
        "path/filepath"
@@ -86,14 +87,23 @@ func isGoFile(f *os.FileInfo) bool {
 }
 
 
-func processFile(f *os.File) os.Error {
-       src, err := ioutil.ReadAll(f)
+// If in == nil, the source is the contents of the file with the given filename.
+func processFile(filename string, in io.Reader, out io.Writer) os.Error {
+       if in == nil {
+               f, err := os.Open(filename)
+               if err != nil {
+                       return err
+               }
+               defer f.Close()
+               in = f
+       }
+
+       src, err := ioutil.ReadAll(in)
        if err != nil {
                return err
        }
 
-       file, err := parser.ParseFile(fset, f.Name(), src, parserMode)
-
+       file, err := parser.ParseFile(fset, filename, src, parserMode)
        if err != nil {
                return err
        }
@@ -116,10 +126,10 @@ func processFile(f *os.File) os.Error {
        if !bytes.Equal(src, res) {
                // formatting has changed
                if *list {
-                       fmt.Fprintln(os.Stdout, f.Name())
+                       fmt.Fprintln(out, filename)
                }
                if *write {
-                       err = ioutil.WriteFile(f.Name(), res, 0)
+                       err = ioutil.WriteFile(filename, res, 0)
                        if err != nil {
                                return err
                        }
@@ -127,23 +137,13 @@ func processFile(f *os.File) os.Error {
        }
 
        if !*list && !*write {
-               _, err = os.Stdout.Write(res)
+               _, err = out.Write(res)
        }
 
        return err
 }
 
 
-func processFileByName(filename string) os.Error {
-       file, err := os.Open(filename)
-       if err != nil {
-               return err
-       }
-       defer file.Close()
-       return processFile(file)
-}
-
-
 type fileVisitor chan os.Error
 
 func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
@@ -154,7 +154,7 @@ func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool {
 func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
        if isGoFile(f) {
                v <- nil // synchronize error handler
-               if err := processFileByName(path); err != nil {
+               if err := processFile(path, nil, os.Stdout); err != nil {
                        v <- err
                }
        }
@@ -210,9 +210,10 @@ func gofmtMain() {
        initRewrite()
 
        if flag.NArg() == 0 {
-               if err := processFile(os.Stdin); err != nil {
+               if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
                        report(err)
                }
+               return
        }
 
        for i := 0; i < flag.NArg(); i++ {
@@ -221,7 +222,7 @@ func gofmtMain() {
                case err != nil:
                        report(err)
                case dir.IsRegular():
-                       if err := processFileByName(path); err != nil {
+                       if err := processFile(path, nil, os.Stdout); err != nil {
                                report(err)
                        }
                case dir.IsDirectory():