// Use unresolved identifiers to determine the imports used by this
// example. The heuristic assumes package names match base import
// paths for imports w/o renames (should be good enough most of the time).
- namedImports := make(map[string]string) // [name]path
- var blankImports []ast.Spec // _ imports
+ var namedImports []ast.Spec
+ var blankImports []ast.Spec // _ imports
+
+ // To preserve the blank lines between groups of imports, find the
+ // start position of each group, and assign that position to all
+ // imports from that group.
+ groupStarts := findImportGroupStarts(file.Imports)
+ groupStart := func(s *ast.ImportSpec) token.Pos {
+ for i, start := range groupStarts {
+ if s.Path.ValuePos < start {
+ return groupStarts[i-1]
+ }
+ }
+ return groupStarts[len(groupStarts)-1]
+ }
+
for _, s := range file.Imports {
p, err := strconv.Unquote(s.Path.Value)
if err != nil {
}
}
if unresolved[n] {
- namedImports[n] = p
+ // Copy the spec and its path to avoid modifying the original.
+ spec := *s
+ path := *s.Path
+ spec.Path = &path
+ spec.Path.ValuePos = groupStart(&spec)
+ namedImports = append(namedImports, &spec)
delete(unresolved, n)
}
}
Lparen: 1, // Need non-zero Lparen and Rparen so that printer
Rparen: 1, // treats this as a factored import.
}
- for n, p := range namedImports {
- s := &ast.ImportSpec{Path: &ast.BasicLit{Value: strconv.Quote(p)}}
- if path.Base(p) != n {
- s.Name = ast.NewIdent(n)
- }
- importDecl.Specs = append(importDecl.Specs, s)
- }
- importDecl.Specs = append(importDecl.Specs, blankImports...)
+ importDecl.Specs = append(namedImports, blankImports...)
// Synthesize main function.
funcDecl := &ast.FuncDecl{
sort.Slice(decls, func(i, j int) bool {
return decls[i].Pos() < decls[j].Pos()
})
-
sort.Slice(comments, func(i, j int) bool {
return comments[i].Pos() < comments[j].Pos()
})
}
}
+// findImportGroupStarts finds the start positions of each sequence of import
+// specs that are not separated by a blank line.
+func findImportGroupStarts(imps []*ast.ImportSpec) []token.Pos {
+ startImps := findImportGroupStarts1(imps)
+ groupStarts := make([]token.Pos, len(startImps))
+ for i, imp := range startImps {
+ groupStarts[i] = imp.Pos()
+ }
+ return groupStarts
+}
+
+// Helper for findImportGroupStarts to ease testing.
+func findImportGroupStarts1(origImps []*ast.ImportSpec) []*ast.ImportSpec {
+ // Copy to avoid mutation.
+ imps := make([]*ast.ImportSpec, len(origImps))
+ copy(imps, origImps)
+ // Assume the imports are sorted by position.
+ sort.Slice(imps, func(i, j int) bool { return imps[i].Pos() < imps[j].Pos() })
+ // Assume gofmt has been applied, so there is a blank line between adjacent imps
+ // if and only if they are more than 2 positions apart (newline, tab).
+ var groupStarts []*ast.ImportSpec
+ prevEnd := token.Pos(-2)
+ for _, imp := range imps {
+ if imp.Pos()-prevEnd > 2 {
+ groupStarts = append(groupStarts, imp)
+ }
+ prevEnd = imp.End()
+ // Account for end-of-line comments.
+ if imp.Comment != nil {
+ prevEnd = imp.Comment.End()
+ }
+ }
+ return groupStarts
+}
+
// playExampleFile takes a whole file example and synthesizes a new *ast.File
// such that the example is function main in package main.
func playExampleFile(file *ast.File) *ast.File {
--- /dev/null
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package doc
+
+import (
+ "go/parser"
+ "go/token"
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+func TestImportGroupStarts(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ in string
+ want []string // paths of group-starting imports
+ }{
+ {
+ name: "one group",
+ in: `package p
+import (
+ "a"
+ "b"
+ "c"
+ "d"
+)
+`,
+ want: []string{"a"},
+ },
+ {
+ name: "several groups",
+ in: `package p
+import (
+ "a"
+
+ "b"
+ "c"
+
+ "d"
+)
+`,
+ want: []string{"a", "b", "d"},
+ },
+ {
+ name: "extra space",
+ in: `package p
+import (
+ "a"
+
+
+ "b"
+ "c"
+
+
+ "d"
+)
+`,
+ want: []string{"a", "b", "d"},
+ },
+ {
+ name: "line comment",
+ in: `package p
+import (
+ "a" // comment
+ "b" // comment
+
+ "c"
+)`,
+ want: []string{"a", "c"},
+ },
+ {
+ name: "named import",
+ in: `package p
+import (
+ "a"
+ n "b"
+
+ m "c"
+ "d"
+)`,
+ want: []string{"a", "c"},
+ },
+ {
+ name: "blank import",
+ in: `package p
+import (
+ "a"
+
+ _ "b"
+
+ _ "c"
+ "d"
+)`,
+ want: []string{"a", "b", "c"},
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ fset := token.NewFileSet()
+ file, err := parser.ParseFile(fset, "test.go", strings.NewReader(test.in), parser.ParseComments)
+ if err != nil {
+ t.Fatal(err)
+ }
+ imps := findImportGroupStarts1(file.Imports)
+ got := make([]string, len(imps))
+ for i, imp := range imps {
+ got[i], err = strconv.Unquote(imp.Path.Value)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("got %v, want %v", got, test.want)
+ }
+ })
+ }
+
+}