}
// If in == nil, the source is the contents of the file with the given filename.
-func processFile(filename string, in io.Reader, out io.Writer) os.Error {
+func processFile(filename string, in io.Reader, out io.Writer, stdin bool) os.Error {
if in == nil {
f, err := os.Open(filename)
if err != nil {
return err
}
- file, err := parser.ParseFile(fset, filename, src, parserMode)
+ file, adjust, err := parse(filename, src, stdin)
if err != nil {
return err
}
if err != nil {
return err
}
- res := buf.Bytes()
+ res := adjust(src, buf.Bytes())
if !bytes.Equal(src, res) {
// formatting has changed
func (v fileVisitor) VisitFile(path string, f *os.FileInfo) {
if isGoFile(f) {
v <- nil // synchronize error handler
- if err := processFile(path, nil, os.Stdout); err != nil {
+ if err := processFile(path, nil, os.Stdout, false); err != nil {
v <- err
}
}
initRewrite()
if flag.NArg() == 0 {
- if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
+ if err := processFile("<standard input>", os.Stdin, os.Stdout, true); err != nil {
report(err)
}
return
case err != nil:
report(err)
case dir.IsRegular():
- if err := processFile(path, nil, os.Stdout); err != nil {
+ if err := processFile(path, nil, os.Stdout, false); err != nil {
report(err)
}
case dir.IsDirectory():
return
}
+
+// parse parses src, which was read from filename,
+// as a Go source file or statement list.
+func parse(filename string, src []byte, stdin bool) (*ast.File, func(orig, src []byte) []byte, os.Error) {
+ // Try as whole source file.
+ file, err := parser.ParseFile(fset, filename, src, parserMode)
+ if err == nil {
+ adjust := func(orig, src []byte) []byte { return src }
+ return file, adjust, nil
+ }
+ // If the error is that the source file didn't begin with a
+ // package line and this is standard input, fall through to
+ // try as a source fragment. Stop and return on any other error.
+ if !stdin || !strings.Contains(err.String(), "expected 'package'") {
+ return nil, nil, err
+ }
+
+ // If this is a declaration list, make it a source file
+ // by inserting a package clause.
+ // Insert using a ;, not a newline, so that the line numbers
+ // in psrc match the ones in src.
+ psrc := append([]byte("package p;"), src...)
+ file, err = parser.ParseFile(fset, filename, psrc, parserMode)
+ if err == nil {
+ adjust := func(orig, src []byte) []byte {
+ // Remove the package clause.
+ // Gofmt has turned the ; into a \n.
+ src = src[len("package p\n"):]
+ return matchSpace(orig, src)
+ }
+ return file, adjust, nil
+ }
+ // If the error is that the source file didn't begin with a
+ // declaration, fall through to try as a statement list.
+ // Stop and return on any other error.
+ if !strings.Contains(err.String(), "expected declaration") {
+ return nil, nil, err
+ }
+
+ // If this is a statement list, make it a source file
+ // by inserting a package clause and turning the list
+ // into a function body. This handles expressions too.
+ // Insert using a ;, not a newline, so that the line numbers
+ // in fsrc match the ones in src.
+ fsrc := append(append([]byte("package p; func _() {"), src...), '}')
+ file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
+ if err == nil {
+ adjust := func(orig, src []byte) []byte {
+ // Remove the wrapping.
+ // Gofmt has turned the ; into a \n\n.
+ src = src[len("package p\n\nfunc _() {"):]
+ src = src[:len(src)-len("}\n")]
+ // Gofmt has also indented the function body one level.
+ // Remove that indent.
+ src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
+ return matchSpace(orig, src)
+ }
+ return file, adjust, nil
+ }
+
+ // Failed, and out of options.
+ return nil, nil, err
+}
+
+func cutSpace(b []byte) (before, middle, after []byte) {
+ i := 0
+ for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
+ i++
+ }
+ j := len(b)
+ for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
+ j--
+ }
+ return b[:i], b[i:j], b[j:]
+}
+
+// matchSpace reformats src to use the same space context as orig.
+// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
+// 2) matchSpace copies the indentation of the first non-blank line in orig
+// to every non-blank line in src.
+// 3) matchSpace copies the trailing space from orig and uses it in place
+// of src's trailing space.
+func matchSpace(orig []byte, src []byte) []byte {
+ before, _, after := cutSpace(orig)
+ i := bytes.LastIndex(before, []byte{'\n'})
+ before, indent := before[:i+1], before[i+1:]
+
+ _, src, _ = cutSpace(src)
+
+ var b bytes.Buffer
+ b.Write(before)
+ for len(src) > 0 {
+ line := src
+ if i := bytes.IndexByte(line, '\n'); i >= 0 {
+ line, src = line[:i+1], line[i+1:]
+ } else {
+ src = nil
+ }
+ if len(line) > 0 && line[0] != '\n' { // not blank
+ b.Write(indent)
+ }
+ b.Write(line)
+ }
+ b.Write(after)
+ return b.Bytes()
+}
"testing"
)
-func runTest(t *testing.T, dirname, in, out, flags string) {
- in = filepath.Join(dirname, in)
- out = filepath.Join(dirname, out)
-
+func runTest(t *testing.T, in, out, flags string) {
// process flags
*simplifyAST = false
*rewriteRule = ""
+ stdin := false
for _, flag := range strings.Split(flags, " ") {
elts := strings.SplitN(flag, "=", 2)
name := elts[0]
*rewriteRule = value
case "-s":
*simplifyAST = true
+ case "-stdin":
+ // fake flag - pretend input is from stdin
+ stdin = true
default:
t.Errorf("unrecognized flag name: %s", name)
}
initRewrite()
var buf bytes.Buffer
- err := processFile(in, nil, &buf)
+ err := processFile(in, nil, &buf, stdin)
if err != nil {
t.Error(err)
return
if got := buf.Bytes(); bytes.Compare(got, expected) != 0 {
t.Errorf("(gofmt %s) != %s (see %s.gofmt)", in, out, in)
+ d, err := diff(expected, got)
+ if err == nil {
+ t.Errorf("%s", d)
+ }
ioutil.WriteFile(in+".gofmt", got, 0666)
}
}
// TODO(gri) Add more test cases!
var tests = []struct {
- dirname, in, out, flags string
+ in, flags string
}{
- {".", "gofmt.go", "gofmt.go", ""},
- {".", "gofmt_test.go", "gofmt_test.go", ""},
- {"testdata", "composites.input", "composites.golden", "-s"},
- {"testdata", "rewrite1.input", "rewrite1.golden", "-r=Foo->Bar"},
- {"testdata", "rewrite2.input", "rewrite2.golden", "-r=int->bool"},
+ {"gofmt.go", ""},
+ {"gofmt_test.go", ""},
+ {"testdata/composites.input", "-s"},
+ {"testdata/rewrite1.input", "-r=Foo->Bar"},
+ {"testdata/rewrite2.input", "-r=int->bool"},
+ {"testdata/stdin*.input", "-stdin"},
}
func TestRewrite(t *testing.T) {
for _, test := range tests {
- runTest(t, test.dirname, test.in, test.out, test.flags)
+ match, err := filepath.Glob(test.in)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ for _, in := range match {
+ out := in
+ if strings.HasSuffix(in, ".input") {
+ out = in[:len(in)-len(".input")] + ".golden"
+ }
+ runTest(t, in, out, test.flags)
+ if in != out {
+ // Check idempotence.
+ runTest(t, out, out, test.flags)
+ }
+ }
}
}