}
// addImport adds the import path to the file f, if absent.
-func addImport(f *ast.File, ipath string) {
+func addImport(f *ast.File, ipath string) (added bool) {
if imports(f, ipath) {
- return
+ return false
}
// Determine name of import.
impdecl.Specs[insertAt] = newImport
f.Imports = append(f.Imports, newImport)
+ return true
}
// deleteImport deletes the import path from the file f, if present.
-func deleteImport(f *ast.File, path string) {
+func deleteImport(f *ast.File, path string) (deleted bool) {
oldImport := importSpec(f, path)
// Find the import node that imports path, if any.
// We found an import spec that imports path.
// Delete it.
+ deleted = true
copy(gen.Specs[j:], gen.Specs[j+1:])
gen.Specs = gen.Specs[:len(gen.Specs)-1]
break
}
}
+
+ return
+}
+
+// rewriteImport rewrites any import of path oldPath to path newPath.
+func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
+ for _, imp := range f.Imports {
+ if importPath(imp) == oldPath {
+ rewrote = true
+ imp.Path.Value = strconv.Quote(newPath)
+ }
+ }
+ return
}
func usesImport(f *ast.File, path string) (used bool) {
"io"
"os"
)
+`,
+ },
+ {
+ Name: "import.13",
+ Fn: rewriteImportFn("utf8", "encoding/utf8"),
+ In: `package main
+
+import (
+ "io"
+ "os"
+ "utf8" // thanks ken
+)
+`,
+ Out: `package main
+
+import (
+ "encoding/utf8" // thanks ken
+ "io"
+ "os"
+)
`,
},
}
return false
}
}
+
+func rewriteImportFn(old, new string) func(*ast.File) bool {
+ return func(f *ast.File) bool {
+ if imports(f, old) {
+ rewriteImport(f, old, new)
+ return true
+ }
+ return false
+ }
+}
"exec"
"flag"
"fmt"
+ "go/ast"
"go/parser"
"go/printer"
"go/scanner"
tabWidth,
}
+func gofmtFile(f *ast.File) ([]byte, error) {
+ var buf bytes.Buffer
+
+ ast.SortImports(fset, f)
+ _, err := printConfig.Fprint(&buf, fset, f)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
func processFile(filename string, useStdin bool) error {
var f *os.File
var err error
var fixlog bytes.Buffer
- var buf bytes.Buffer
if useStdin {
f = os.Stdin
// AST changed.
// Print and parse, to update any missing scoping
// or position information for subsequent fixers.
- buf.Reset()
- _, err = printConfig.Fprint(&buf, fset, newFile)
+ newSrc, err := gofmtFile(newFile)
if err != nil {
return err
}
- newSrc := buf.Bytes()
newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
if err != nil {
return err
// output of the printer run on a standard AST generated by the parser,
// but the source we generated inside the loop above is the
// output of the printer run on a mangled AST generated by a fixer.
- buf.Reset()
- _, err = printConfig.Fprint(&buf, fset, newFile)
+ newSrc, err := gofmtFile(newFile)
if err != nil {
return err
}
- newSrc := buf.Bytes()
if *doDiff {
data, err := diff(src, newSrc)
package main
import (
- "bytes"
"go/ast"
"go/parser"
- "go/printer"
"strings"
"testing"
)
return
}
- var buf bytes.Buffer
- buf.Reset()
- _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+ outb, err := gofmtFile(file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
- if s := buf.String(); in != s && fn != fnop {
+ if s := string(outb); in != s && fn != fnop {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s)
tdiff(t, in, s)
fixed = fn(file)
}
- buf.Reset()
- _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+ outb, err = gofmtFile(file)
if err != nil {
t.Errorf("%s: printing: %v", desc, err)
return
}
- return buf.String(), fixed, true
+ return string(outb), fixed, true
}
func TestRewrite(t *testing.T) {
// BUG: can't use range clause to receive when using ImportNValues to limit the count.
import (
- "log"
"io"
+ "log"
"net"
"os"
"reflect"
// BUG: can't use range clause to receive when using ImportNValues to limit the count.
import (
- "log"
"io"
+ "log"
"net"
"os"
"reflect"
}
}
+ ast.SortImports(fset, file)
+
if *simplifyAST {
simplify(file)
}
{"testdata/rewrite2.input", "-r=int->bool"},
{"testdata/stdin*.input", "-stdin"},
{"testdata/comments.input", ""},
+ {"testdata/import.input", ""},
}
func TestRewrite(t *testing.T) {
--- /dev/null
+package main
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math"
+)
+
+import (
+ "fmt"
+
+ "math"
+
+ "log"
+
+ "errors"
+
+ "io"
+)
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math"
+
+ "fmt"
+
+ "math"
+
+ "log"
+
+ "errors"
+
+ "io"
+)
+
+import (
+ // a block with comments
+ "errors"
+ "fmt" // for Printf
+ "io" // for Reader
+ "log" // for Fatal
+ "math"
+)
+
+import (
+ "fmt" // for Printf
+
+ "math"
+
+ "log" // for Fatal
+
+ "errors"
+
+ "io" // for Reader
+)
+
+import (
+ // for Printf
+ "fmt"
+
+ "math"
+
+ // for Fatal
+ "log"
+
+ "errors"
+
+ // for Reader
+ "io"
+)
+
+import (
+ "errors"
+ "fmt" // for Printf
+ "io" // for Reader
+ "log" // for Fatal
+ "math"
+
+ "fmt" // for Printf
+
+ "math"
+
+ "log" // for Fatal
+
+ "errors"
+
+ "io" // for Reader
+)
+
+import (
+ "fmt" // for Printf
+
+ "errors"
+ "io" // for Reader
+ "log" // for Fatal
+ "math"
+
+ "errors"
+ "fmt" // for Printf
+ "io" // for Reader
+ "log" // for Fatal
+ "math"
+)
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "math"
+ "log"
+ "errors"
+ "io"
+)
+
+import (
+ "fmt"
+
+ "math"
+
+ "log"
+
+ "errors"
+
+ "io"
+)
+
+import (
+ "fmt"
+ "math"
+ "log"
+ "errors"
+ "io"
+
+ "fmt"
+
+ "math"
+
+ "log"
+
+ "errors"
+
+ "io"
+)
+
+import (
+ // a block with comments
+ "fmt" // for Printf
+ "math"
+ "log" // for Fatal
+ "errors"
+ "io" // for Reader
+)
+
+import (
+ "fmt" // for Printf
+
+ "math"
+
+ "log" // for Fatal
+
+ "errors"
+
+ "io" // for Reader
+)
+
+import (
+ // for Printf
+ "fmt"
+
+ "math"
+
+ // for Fatal
+ "log"
+
+ "errors"
+
+ // for Reader
+ "io"
+)
+
+import (
+ "fmt" // for Printf
+ "math"
+ "log" // for Fatal
+ "errors"
+ "io" // for Reader
+
+ "fmt" // for Printf
+
+ "math"
+
+ "log" // for Fatal
+
+ "errors"
+
+ "io" // for Reader
+)
+
+import (
+ "fmt" // for Printf
+
+ "math"
+ "log" // for Fatal
+ "errors"
+ "io" // for Reader
+
+ "fmt" // for Printf
+ "math"
+ "log" // for Fatal
+ "errors"
+ "io" // for Reader
+)
GOFILES=\
ast.go\
filter.go\
+ import.go\
print.go\
resolve.go\
scope.go\
--- /dev/null
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "go/token"
+ "sort"
+ "strconv"
+)
+
+// SortImports sorts runs of consecutive import lines in import blocks in f.
+func SortImports(fset *token.FileSet, f *File) {
+ for _, d := range f.Decls {
+ d, ok := d.(*GenDecl)
+ if !ok || d.Tok != token.IMPORT {
+ // Not an import declaration, so we're done.
+ // Imports are always first.
+ break
+ }
+
+ if d.Lparen == token.NoPos {
+ // Not a block: sorted by default.
+ continue
+ }
+
+ // Identify and sort runs of specs on successive lines.
+ i := 0
+ for j, s := range d.Specs {
+ if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line {
+ // j begins a new run. End this one.
+ sortSpecs(fset, f, d.Specs[i:j])
+ i = j
+ }
+ }
+ sortSpecs(fset, f, d.Specs[i:])
+ }
+}
+
+func importPath(s Spec) string {
+ t, err := strconv.Unquote(s.(*ImportSpec).Path.Value)
+ if err == nil {
+ return t
+ }
+ return ""
+}
+
+type posSpan struct {
+ Start token.Pos
+ End token.Pos
+}
+
+func sortSpecs(fset *token.FileSet, f *File, specs []Spec) {
+ // Avoid work if already sorted (also catches < 2 entries).
+ sorted := true
+ for i, s := range specs {
+ if i > 0 && importPath(specs[i-1]) > importPath(s) {
+ sorted = false
+ break
+ }
+ }
+ if sorted {
+ return
+ }
+
+ // Record positions for specs.
+ pos := make([]posSpan, len(specs))
+ for i, s := range specs {
+ pos[i] = posSpan{s.Pos(), s.End()}
+ }
+
+ // Identify comments in this range.
+ // Any comment from pos[0].Start to the final line counts.
+ lastLine := fset.Position(pos[len(pos)-1].End).Line
+ cstart := len(f.Comments)
+ cend := len(f.Comments)
+ for i, g := range f.Comments {
+ if g.Pos() < pos[0].Start {
+ continue
+ }
+ if i < cstart {
+ cstart = i
+ }
+ if fset.Position(g.End()).Line > lastLine {
+ cend = i
+ break
+ }
+ }
+ comments := f.Comments[cstart:cend]
+
+ // Assign each comment to the import spec preceding it.
+ importComment := map[*ImportSpec][]*CommentGroup{}
+ specIndex := 0
+ for _, g := range comments {
+ for specIndex+1 < len(specs) && pos[specIndex+1].Start <= g.Pos() {
+ specIndex++
+ }
+ s := specs[specIndex].(*ImportSpec)
+ importComment[s] = append(importComment[s], g)
+ }
+
+ // Sort the import specs by import path.
+ // Reassign the import paths to have the same position sequence.
+ // Reassign each comment to abut the end of its spec.
+ // Sort the comments by new position.
+ sort.Sort(byImportPath(specs))
+ for i, s := range specs {
+ s := s.(*ImportSpec)
+ s.Path.ValuePos = pos[i].Start
+ s.EndPos = pos[i].End
+ for _, g := range importComment[s] {
+ for _, c := range g.List {
+ c.Slash = pos[i].End
+ }
+ }
+ }
+ sort.Sort(byCommentPos(comments))
+}
+
+type byImportPath []Spec // slice of *ImportSpec
+
+func (x byImportPath) Len() int { return len(x) }
+func (x byImportPath) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+func (x byImportPath) Less(i, j int) bool { return importPath(x[i]) < importPath(x[j]) }
+
+type byCommentPos []*CommentGroup
+
+func (x byCommentPos) Len() int { return len(x) }
+func (x byCommentPos) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+func (x byCommentPos) Less(i, j int) bool { return x[i].Pos() < x[j].Pos() }