]> Cypherpunks repositories - gostls13.git/commitdiff
gofix: test import insertion, deletion
authorRuss Cox <rsc@golang.org>
Wed, 26 Oct 2011 21:04:07 +0000 (14:04 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 26 Oct 2011 21:04:07 +0000 (14:04 -0700)
Small change to go/ast, go/parser, go/printer so that
gofix can delete the blank line left from deleting an import.

R=golang-dev, bradfitz, adg
CC=golang-dev
https://golang.org/cl/5321046

src/cmd/gofix/fix.go
src/cmd/gofix/import_test.go [new file with mode: 0644]
src/pkg/go/ast/ast.go
src/pkg/go/parser/parser.go
src/pkg/go/printer/nodes.go

index 9e4fd56a6ef8826db521bcc0a0eb501194f8da41..4eaadac2b4ac525bce5cfdec81d3da525fcc9003 100644 (file)
@@ -19,7 +19,7 @@ type fix struct {
        desc string
 }
 
-// main runs sort.Sort(fixes) after init process is done.
+// main runs sort.Sort(fixes) before printing list of fixes.
 type fixlist []fix
 
 func (f fixlist) Len() int           { return len(f) }
@@ -316,6 +316,20 @@ func importPath(s *ast.ImportSpec) string {
        return ""
 }
 
+// declImports reports whether gen contains an import of path.
+func declImports(gen *ast.GenDecl, path string) bool {
+       if gen.Tok != token.IMPORT {
+               return false
+       }
+       for _, spec := range gen.Specs {
+               impspec := spec.(*ast.ImportSpec)
+               if importPath(impspec) == path {
+                       return true
+               }
+       }
+       return false
+}
+
 // isPkgDot returns true if t is the expression "pkg.name"
 // where pkg is an imported identifier.
 func isPkgDot(t ast.Expr, pkg, name string) bool {
@@ -486,12 +500,18 @@ func addImport(f *ast.File, path string) {
        var impdecl *ast.GenDecl
 
        // Find an import decl to add to.
-       for _, decl := range f.Decls {
+       var lastImport int = -1
+       for i, decl := range f.Decls {
                gen, ok := decl.(*ast.GenDecl)
 
                if ok && gen.Tok == token.IMPORT {
-                       impdecl = gen
-                       break
+                       lastImport = i
+                       // Do not add to import "C", to avoid disrupting the
+                       // association with its doc comment, breaking cgo.
+                       if !declImports(gen, "C") {
+                               impdecl = gen
+                               break
+                       }
                }
        }
 
@@ -501,8 +521,8 @@ func addImport(f *ast.File, path string) {
                        Tok: token.IMPORT,
                }
                f.Decls = append(f.Decls, nil)
-               copy(f.Decls[1:], f.Decls)
-               f.Decls[0] = impdecl
+               copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
+               f.Decls[lastImport+1] = impdecl
        }
 
        // Ensure the import decl has parentheses, if needed.
@@ -540,7 +560,6 @@ func deleteImport(f *ast.File, path string) {
                }
                for j, spec := range gen.Specs {
                        impspec := spec.(*ast.ImportSpec)
-
                        if oldImport != impspec {
                                continue
                        }
@@ -558,7 +577,13 @@ func deleteImport(f *ast.File, path string) {
                        } else if len(gen.Specs) == 1 {
                                gen.Lparen = token.NoPos // drop parens
                        }
-
+                       if j > 0 {
+                               // We deleted an entry but now there will be
+                               // a blank line-sized hole where the import was.
+                               // Close the hole by making the previous
+                               // import appear to "end" where this one did.
+                               gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
+                       }
                        break
                }
        }
diff --git a/src/cmd/gofix/import_test.go b/src/cmd/gofix/import_test.go
new file mode 100644 (file)
index 0000000..f878c0c
--- /dev/null
@@ -0,0 +1,269 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "go/ast"
+
+func init() {
+       addTestCases(importTests, nil)
+}
+
+var importTests = []testCase{
+       {
+               Name: "import.0",
+               Fn:   addImportFn("os"),
+               In: `package main
+
+import (
+       "os"
+)
+`,
+               Out: `package main
+
+import (
+       "os"
+)
+`,
+       },
+       {
+               Name: "import.1",
+               Fn:   addImportFn("os"),
+               In: `package main
+`,
+               Out: `package main
+
+import "os"
+`,
+       },
+       {
+               Name: "import.2",
+               Fn:   addImportFn("os"),
+               In: `package main
+
+// Comment
+import "C"
+`,
+               Out: `package main
+
+// Comment
+import "C"
+import "os"
+`,
+       },
+       {
+               Name: "import.3",
+               Fn:   addImportFn("os"),
+               In: `package main
+
+// Comment
+import "C"
+
+import (
+       "io"
+       "utf8"
+)
+`,
+               Out: `package main
+
+// Comment
+import "C"
+
+import (
+       "io"
+       "os"
+       "utf8"
+)
+`,
+       },
+       {
+               Name: "import.4",
+               Fn:   deleteImportFn("os"),
+               In: `package main
+
+import (
+       "os"
+)
+`,
+               Out: `package main
+`,
+       },
+       {
+               Name: "import.5",
+               Fn:   deleteImportFn("os"),
+               In: `package main
+
+// Comment
+import "C"
+import "os"
+`,
+               Out: `package main
+
+// Comment
+import "C"
+`,
+       },
+       {
+               Name: "import.6",
+               Fn:   deleteImportFn("os"),
+               In: `package main
+
+// Comment
+import "C"
+
+import (
+       "io"
+       "os"
+       "utf8"
+)
+`,
+               Out: `package main
+
+// Comment
+import "C"
+
+import (
+       "io"
+       "utf8"
+)
+`,
+       },
+       {
+               Name: "import.7",
+               Fn:   deleteImportFn("io"),
+               In: `package main
+
+import (
+       "io"   // a
+       "os"   // b
+       "utf8" // c
+)
+`,
+               Out: `package main
+
+import (
+       // a
+       "os"   // b
+       "utf8" // c
+)
+`,
+       },
+       {
+               Name: "import.8",
+               Fn:   deleteImportFn("os"),
+               In: `package main
+
+import (
+       "io"   // a
+       "os"   // b
+       "utf8" // c
+)
+`,
+               Out: `package main
+
+import (
+       "io" // a
+       // b
+       "utf8" // c
+)
+`,
+       },
+       {
+               Name: "import.9",
+               Fn:   deleteImportFn("utf8"),
+               In: `package main
+
+import (
+       "io"   // a
+       "os"   // b
+       "utf8" // c
+)
+`,
+               Out: `package main
+
+import (
+       "io" // a
+       "os" // b
+       // c
+)
+`,
+       },
+       {
+               Name: "import.10",
+               Fn:   deleteImportFn("io"),
+               In: `package main
+
+import (
+       "io"
+       "os"
+       "utf8"
+)
+`,
+               Out: `package main
+
+import (
+       "os"
+       "utf8"
+)
+`,
+       },
+       {
+               Name: "import.11",
+               Fn:   deleteImportFn("os"),
+               In: `package main
+
+import (
+       "io"
+       "os"
+       "utf8"
+)
+`,
+               Out: `package main
+
+import (
+       "io"
+       "utf8"
+)
+`,
+       },
+       {
+               Name: "import.12",
+               Fn:   deleteImportFn("utf8"),
+               In: `package main
+
+import (
+       "io"
+       "os"
+       "utf8"
+)
+`,
+               Out: `package main
+
+import (
+       "io"
+       "os"
+)
+`,
+       },
+}
+
+func addImportFn(path string) func(*ast.File) bool {
+       return func(f *ast.File) bool {
+               if !imports(f, path) {
+                       addImport(f, path)
+                       return true
+               }
+               return false
+       }
+}
+
+func deleteImportFn(path string) func(*ast.File) bool {
+       return func(f *ast.File) bool {
+               if imports(f, path) {
+                       deleteImport(f, path)
+                       return true
+               }
+               return false
+       }
+}
index 22bd5ee2266ad4836eadf682d992207d1ea35e14..f8caafc179a62ec7ece999d9f6737a1c91fd9ab0 100644 (file)
@@ -752,6 +752,7 @@ type (
                Name    *Ident        // local package name (including "."); or nil
                Path    *BasicLit     // import path
                Comment *CommentGroup // line comments; or nil
+               EndPos  token.Pos     // end of spec (overrides Path.Pos if nonzero)
        }
 
        // A ValueSpec node represents a constant or variable declaration
@@ -785,7 +786,13 @@ func (s *ImportSpec) Pos() token.Pos {
 func (s *ValueSpec) Pos() token.Pos { return s.Names[0].Pos() }
 func (s *TypeSpec) Pos() token.Pos  { return s.Name.Pos() }
 
-func (s *ImportSpec) End() token.Pos { return s.Path.End() }
+func (s *ImportSpec) End() token.Pos {
+       if s.EndPos != 0 {
+               return s.EndPos
+       }
+       return s.Path.End()
+}
+
 func (s *ValueSpec) End() token.Pos {
        if n := len(s.Values); n > 0 {
                return s.Values[n-1].End()
index be82b2f801abaf99d1860dc966cff9ee11cc525b..c78c6b56ecb2cc5c7efcbf14d32b464a7fca649d 100644 (file)
@@ -1909,7 +1909,7 @@ func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec {
        p.expectSemi() // call before accessing p.linecomment
 
        // collect imports
-       spec := &ast.ImportSpec{doc, ident, path, p.lineComment}
+       spec := &ast.ImportSpec{doc, ident, path, p.lineComment, token.NoPos}
        p.imports = append(p.imports, spec)
 
        return spec
index 364530634aaeb21763155d63cdede2bdcd902893..248e43d4e72993e93c4f1db9ac5321d395d242ce 100644 (file)
@@ -1278,6 +1278,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, multiLine *bool) {
                }
                p.expr(s.Path, multiLine)
                p.setComment(s.Comment)
+               p.print(s.EndPos)
 
        case *ast.ValueSpec:
                if n != 1 {