]> Cypherpunks repositories - gostls13.git/commitdiff
gofmt, gofix: sort imports
authorRuss Cox <rsc@golang.org>
Wed, 2 Nov 2011 19:53:57 +0000 (15:53 -0400)
committerRuss Cox <rsc@golang.org>
Wed, 2 Nov 2011 19:53:57 +0000 (15:53 -0400)
Add ast.SortImports(fset, file) to go/ast, for use by both programs.

Fixes #346.

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5330069

12 files changed:
src/cmd/gofix/fix.go
src/cmd/gofix/import_test.go
src/cmd/gofix/main.go
src/cmd/gofix/main_test.go
src/cmd/gofix/testdata/reflect.export.go.in
src/cmd/gofix/testdata/reflect.export.go.out
src/cmd/gofmt/gofmt.go
src/cmd/gofmt/gofmt_test.go
src/cmd/gofmt/testdata/import.golden [new file with mode: 0644]
src/cmd/gofmt/testdata/import.input [new file with mode: 0644]
src/pkg/go/ast/Makefile
src/pkg/go/ast/import.go [new file with mode: 0644]

index 9a51085dd19e1211f260b9f5ff8d0f4c9892995e..f7b55b073d4d7e17e608479faa940a646d25fa8e 100644 (file)
@@ -569,9 +569,9 @@ func renameTop(f *ast.File, old, new string) bool {
 }
 
 // addImport adds the import path to the file f, if absent.
-func addImport(f *ast.File, ipath string) {
+func addImport(f *ast.File, ipath string) (added bool) {
        if imports(f, ipath) {
-               return
+               return false
        }
 
        // Determine name of import.
@@ -637,10 +637,11 @@ func addImport(f *ast.File, ipath string) {
        impdecl.Specs[insertAt] = newImport
 
        f.Imports = append(f.Imports, newImport)
+       return true
 }
 
 // deleteImport deletes the import path from the file f, if present.
-func deleteImport(f *ast.File, path string) {
+func deleteImport(f *ast.File, path string) (deleted bool) {
        oldImport := importSpec(f, path)
 
        // Find the import node that imports path, if any.
@@ -657,6 +658,7 @@ func deleteImport(f *ast.File, path string) {
 
                        // We found an import spec that imports path.
                        // Delete it.
+                       deleted = true
                        copy(gen.Specs[j:], gen.Specs[j+1:])
                        gen.Specs = gen.Specs[:len(gen.Specs)-1]
 
@@ -687,6 +689,19 @@ func deleteImport(f *ast.File, path string) {
                        break
                }
        }
+
+       return
+}
+
+// rewriteImport rewrites any import of path oldPath to path newPath.
+func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
+       for _, imp := range f.Imports {
+               if importPath(imp) == oldPath {
+                       rewrote = true
+                       imp.Path.Value = strconv.Quote(newPath)
+               }
+       }
+       return
 }
 
 func usesImport(f *ast.File, path string) (used bool) {
index f878c0ccfb7e7f847ace769f319e43c7b416e05b..4a9259f40929b2acadd1dbc93f986f0c3a8ae7cb 100644 (file)
@@ -244,6 +244,26 @@ import (
        "io"
        "os"
 )
+`,
+       },
+       {
+               Name: "import.13",
+               Fn:   rewriteImportFn("utf8", "encoding/utf8"),
+               In: `package main
+
+import (
+       "io"
+       "os"
+       "utf8" // thanks ken
+)
+`,
+               Out: `package main
+
+import (
+       "encoding/utf8" // thanks ken
+       "io"
+       "os"
+)
 `,
        },
 }
@@ -267,3 +287,13 @@ func deleteImportFn(path string) func(*ast.File) bool {
                return false
        }
 }
+
+func rewriteImportFn(old, new string) func(*ast.File) bool {
+       return func(f *ast.File) bool {
+               if imports(f, old) {
+                       rewriteImport(f, old, new)
+                       return true
+               }
+               return false
+       }
+}
index f462c3dfb3da6757a6a3d6ff493435a6a8a38c02..1d0f4b0f0733f38ce11fe4cebf712818e61a604f 100644 (file)
@@ -9,6 +9,7 @@ import (
        "exec"
        "flag"
        "fmt"
+       "go/ast"
        "go/parser"
        "go/printer"
        "go/scanner"
@@ -102,11 +103,21 @@ var printConfig = &printer.Config{
        tabWidth,
 }
 
+func gofmtFile(f *ast.File) ([]byte, error) {
+       var buf bytes.Buffer
+
+       ast.SortImports(fset, f)
+       _, err := printConfig.Fprint(&buf, fset, f)
+       if err != nil {
+               return nil, err
+       }
+       return buf.Bytes(), nil
+}
+
 func processFile(filename string, useStdin bool) error {
        var f *os.File
        var err error
        var fixlog bytes.Buffer
-       var buf bytes.Buffer
 
        if useStdin {
                f = os.Stdin
@@ -142,12 +153,10 @@ func processFile(filename string, useStdin bool) error {
                        // AST changed.
                        // Print and parse, to update any missing scoping
                        // or position information for subsequent fixers.
-                       buf.Reset()
-                       _, err = printConfig.Fprint(&buf, fset, newFile)
+                       newSrc, err := gofmtFile(newFile)
                        if err != nil {
                                return err
                        }
-                       newSrc := buf.Bytes()
                        newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
                        if err != nil {
                                return err
@@ -165,12 +174,10 @@ func processFile(filename string, useStdin bool) error {
        // output of the printer run on a standard AST generated by the parser,
        // but the source we generated inside the loop above is the
        // output of the printer run on a mangled AST generated by a fixer.
-       buf.Reset()
-       _, err = printConfig.Fprint(&buf, fset, newFile)
+       newSrc, err := gofmtFile(newFile)
        if err != nil {
                return err
        }
-       newSrc := buf.Bytes()
 
        if *doDiff {
                data, err := diff(src, newSrc)
index 077a15e52a447e1583ecd8a7cc9bd6c5113c280a..94e63f05d39098df7878050f13513d9fb59364dd 100644 (file)
@@ -5,10 +5,8 @@
 package main
 
 import (
-       "bytes"
        "go/ast"
        "go/parser"
-       "go/printer"
        "strings"
        "testing"
 )
@@ -43,14 +41,12 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
                return
        }
 
-       var buf bytes.Buffer
-       buf.Reset()
-       _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+       outb, err := gofmtFile(file)
        if err != nil {
                t.Errorf("%s: printing: %v", desc, err)
                return
        }
-       if s := buf.String(); in != s && fn != fnop {
+       if s := string(outb); in != s && fn != fnop {
                t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
                        desc, desc, in, desc, s)
                tdiff(t, in, s)
@@ -67,14 +63,13 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
                fixed = fn(file)
        }
 
-       buf.Reset()
-       _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file)
+       outb, err = gofmtFile(file)
        if err != nil {
                t.Errorf("%s: printing: %v", desc, err)
                return
        }
 
-       return buf.String(), fixed, true
+       return string(outb), fixed, true
 }
 
 func TestRewrite(t *testing.T) {
index 495fc46b6a3aaa80f9714124875babc31706757f..ce7940b2984fe021f04df6a4f954c712c11c8def 100644 (file)
@@ -22,8 +22,8 @@ package netchan
 // BUG: can't use range clause to receive when using ImportNValues to limit the count.
 
 import (
-       "log"
        "io"
+       "log"
        "net"
        "os"
        "reflect"
index 460edb40bfe45f2e5cc8748706cb59fc5d13a36b..7bd73c5e7fcad86ac8a5e4e78b5c26c482a30d69 100644 (file)
@@ -22,8 +22,8 @@ package netchan
 // BUG: can't use range clause to receive when using ImportNValues to limit the count.
 
 import (
-       "log"
        "io"
+       "log"
        "net"
        "os"
        "reflect"
index f5afa6f91b4ca6a8791c36851045e3e8a434c976..1ca47eccb8cf1e313e41d1020c44e28bfc20baf3 100644 (file)
@@ -114,6 +114,8 @@ func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error
                }
        }
 
+       ast.SortImports(fset, file)
+
        if *simplifyAST {
                simplify(file)
        }
index 6587f06a02dad8ac6b09c81c888f86f8e537713b..4432a178bcf5e075deef54ddce29c848e336055a 100644 (file)
@@ -78,6 +78,7 @@ var tests = []struct {
        {"testdata/rewrite2.input", "-r=int->bool"},
        {"testdata/stdin*.input", "-stdin"},
        {"testdata/comments.input", ""},
+       {"testdata/import.input", ""},
 }
 
 func TestRewrite(t *testing.T) {
diff --git a/src/cmd/gofmt/testdata/import.golden b/src/cmd/gofmt/testdata/import.golden
new file mode 100644 (file)
index 0000000..e8ee449
--- /dev/null
@@ -0,0 +1,108 @@
+package main
+
+import (
+       "errors"
+       "fmt"
+       "io"
+       "log"
+       "math"
+)
+
+import (
+       "fmt"
+
+       "math"
+
+       "log"
+
+       "errors"
+
+       "io"
+)
+
+import (
+       "errors"
+       "fmt"
+       "io"
+       "log"
+       "math"
+
+       "fmt"
+
+       "math"
+
+       "log"
+
+       "errors"
+
+       "io"
+)
+
+import (
+       // a block with comments
+       "errors"
+       "fmt" // for Printf
+       "io"  // for Reader
+       "log" // for Fatal
+       "math"
+)
+
+import (
+       "fmt" // for Printf
+
+       "math"
+
+       "log" // for Fatal
+
+       "errors"
+
+       "io" // for Reader
+)
+
+import (
+       // for Printf
+       "fmt"
+
+       "math"
+
+       // for Fatal
+       "log"
+
+       "errors"
+
+       // for Reader
+       "io"
+)
+
+import (
+       "errors"
+       "fmt" // for Printf
+       "io"  // for Reader
+       "log" // for Fatal
+       "math"
+
+       "fmt" // for Printf
+
+       "math"
+
+       "log" // for Fatal
+
+       "errors"
+
+       "io" // for Reader
+)
+
+import (
+       "fmt" // for Printf
+
+       "errors"
+       "io"  // for Reader
+       "log" // for Fatal
+       "math"
+
+       "errors"
+       "fmt" // for Printf
+       "io"  // for Reader
+       "log" // for Fatal
+       "math"
+)
diff --git a/src/cmd/gofmt/testdata/import.input b/src/cmd/gofmt/testdata/import.input
new file mode 100644 (file)
index 0000000..cc36c3e
--- /dev/null
@@ -0,0 +1,108 @@
+package main
+
+import (
+       "fmt"
+       "math"
+       "log"
+       "errors"
+       "io"
+)
+
+import (
+       "fmt"
+
+       "math"
+
+       "log"
+
+       "errors"
+
+       "io"
+)
+
+import (
+       "fmt"
+       "math"
+       "log"
+       "errors"
+       "io"
+
+       "fmt"
+
+       "math"
+
+       "log"
+
+       "errors"
+
+       "io"
+)
+
+import (
+       // a block with comments
+       "fmt" // for Printf
+       "math"
+       "log" // for Fatal
+       "errors"
+       "io" // for Reader
+)
+
+import (
+       "fmt" // for Printf
+
+       "math"
+
+       "log" // for Fatal
+
+       "errors"
+
+       "io" // for Reader
+)
+
+import (
+       // for Printf
+       "fmt"
+
+       "math"
+
+       // for Fatal
+       "log"
+
+       "errors"
+
+       // for Reader
+       "io"
+)
+
+import (
+       "fmt" // for Printf
+       "math"
+       "log" // for Fatal
+       "errors"
+       "io" // for Reader
+
+       "fmt" // for Printf
+
+       "math"
+
+       "log" // for Fatal
+
+       "errors"
+
+       "io" // for Reader
+)
+
+import (
+       "fmt" // for Printf
+
+       "math"
+       "log" // for Fatal
+       "errors"
+       "io" // for Reader
+
+       "fmt" // for Printf
+       "math"
+       "log" // for Fatal
+       "errors"
+       "io" // for Reader
+)
index 40be10208be117bd0dfa7026291b1195eef2643c..30c386cd84a06b040f59e8daff903b87a11252c1 100644 (file)
@@ -8,6 +8,7 @@ TARG=go/ast
 GOFILES=\
        ast.go\
        filter.go\
+       import.go\
        print.go\
        resolve.go\
        scope.go\
diff --git a/src/pkg/go/ast/import.go b/src/pkg/go/ast/import.go
new file mode 100644 (file)
index 0000000..c64e9bb
--- /dev/null
@@ -0,0 +1,131 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ast
+
+import (
+       "go/token"
+       "sort"
+       "strconv"
+)
+
+// SortImports sorts runs of consecutive import lines in import blocks in f.
+func SortImports(fset *token.FileSet, f *File) {
+       for _, d := range f.Decls {
+               d, ok := d.(*GenDecl)
+               if !ok || d.Tok != token.IMPORT {
+                       // Not an import declaration, so we're done.
+                       // Imports are always first.
+                       break
+               }
+
+               if d.Lparen == token.NoPos {
+                       // Not a block: sorted by default.
+                       continue
+               }
+
+               // Identify and sort runs of specs on successive lines.
+               i := 0
+               for j, s := range d.Specs {
+                       if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line {
+                               // j begins a new run.  End this one.
+                               sortSpecs(fset, f, d.Specs[i:j])
+                               i = j
+                       }
+               }
+               sortSpecs(fset, f, d.Specs[i:])
+       }
+}
+
+func importPath(s Spec) string {
+       t, err := strconv.Unquote(s.(*ImportSpec).Path.Value)
+       if err == nil {
+               return t
+       }
+       return ""
+}
+
+type posSpan struct {
+       Start token.Pos
+       End   token.Pos
+}
+
+func sortSpecs(fset *token.FileSet, f *File, specs []Spec) {
+       // Avoid work if already sorted (also catches < 2 entries).
+       sorted := true
+       for i, s := range specs {
+               if i > 0 && importPath(specs[i-1]) > importPath(s) {
+                       sorted = false
+                       break
+               }
+       }
+       if sorted {
+               return
+       }
+
+       // Record positions for specs.
+       pos := make([]posSpan, len(specs))
+       for i, s := range specs {
+               pos[i] = posSpan{s.Pos(), s.End()}
+       }
+
+       // Identify comments in this range.
+       // Any comment from pos[0].Start to the final line counts.
+       lastLine := fset.Position(pos[len(pos)-1].End).Line
+       cstart := len(f.Comments)
+       cend := len(f.Comments)
+       for i, g := range f.Comments {
+               if g.Pos() < pos[0].Start {
+                       continue
+               }
+               if i < cstart {
+                       cstart = i
+               }
+               if fset.Position(g.End()).Line > lastLine {
+                       cend = i
+                       break
+               }
+       }
+       comments := f.Comments[cstart:cend]
+
+       // Assign each comment to the import spec preceding it.
+       importComment := map[*ImportSpec][]*CommentGroup{}
+       specIndex := 0
+       for _, g := range comments {
+               for specIndex+1 < len(specs) && pos[specIndex+1].Start <= g.Pos() {
+                       specIndex++
+               }
+               s := specs[specIndex].(*ImportSpec)
+               importComment[s] = append(importComment[s], g)
+       }
+
+       // Sort the import specs by import path.
+       // Reassign the import paths to have the same position sequence.
+       // Reassign each comment to abut the end of its spec.
+       // Sort the comments by new position.
+       sort.Sort(byImportPath(specs))
+       for i, s := range specs {
+               s := s.(*ImportSpec)
+               s.Path.ValuePos = pos[i].Start
+               s.EndPos = pos[i].End
+               for _, g := range importComment[s] {
+                       for _, c := range g.List {
+                               c.Slash = pos[i].End
+                       }
+               }
+       }
+       sort.Sort(byCommentPos(comments))
+}
+
+type byImportPath []Spec // slice of *ImportSpec
+
+func (x byImportPath) Len() int           { return len(x) }
+func (x byImportPath) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
+func (x byImportPath) Less(i, j int) bool { return importPath(x[i]) < importPath(x[j]) }
+
+type byCommentPos []*CommentGroup
+
+func (x byCommentPos) Len() int           { return len(x) }
+func (x byCommentPos) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
+func (x byCommentPos) Less(i, j int) bool { return x[i].Pos() < x[j].Pos() }