"fmt"
"go/ast"
"go/parser"
- "go/printer"
"go/token"
"io"
"io/ioutil"
"log"
"os"
- "path/filepath"
"sort"
"strconv"
- "strings"
+ "cmd/internal/edit"
"cmd/internal/objabi"
)
var profile string // The profile to read; the value of -html or -func
-var counterStmt func(*File, ast.Expr) ast.Stmt
+var counterStmt func(*File, string) string
const (
atomicPackagePath = "sync/atomic"
// File is a wrapper for the state of a file used in the parser.
// The basic parse tree walker is a method of this type.
type File struct {
- fset *token.FileSet
- name string // Name of file.
- astFile *ast.File
- blocks []Block
- atomicPkg string // Package name for "sync/atomic" in this file.
+ fset *token.FileSet
+ name string // Name of file.
+ astFile *ast.File
+ blocks []Block
+ content []byte
+ edit *edit.Buffer
+}
+
+// findText finds text in the original source, starting at pos.
+// It correctly skips over comments and assumes it need not
+// handle quoted strings.
+// It returns a byte offset within f.src.
+func (f *File) findText(pos token.Pos, text string) int {
+ b := []byte(text)
+ start := f.offset(pos)
+ i := start
+ s := f.content
+ for i < len(s) {
+ if bytes.HasPrefix(s[i:], b) {
+ return i
+ }
+ if i+2 <= len(s) && s[i] == '/' && s[i+1] == '/' {
+ for i < len(s) && s[i] != '\n' {
+ i++
+ }
+ continue
+ }
+ if i+2 <= len(s) && s[i] == '/' && s[i+1] == '*' {
+ for i += 2; ; i++ {
+ if i+2 > len(s) {
+ return 0
+ }
+ if s[i] == '*' && s[i+1] == '/' {
+ i += 2
+ break
+ }
+ }
+ continue
+ }
+ i++
+ }
+ return -1
}
// Visit implements the ast.Visitor interface.
case *ast.CaseClause: // switch
for _, n := range n.List {
clause := n.(*ast.CaseClause)
- clause.Body = f.addCounters(clause.Colon+1, clause.End(), clause.Body, false)
+ f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false)
}
return f
case *ast.CommClause: // select
for _, n := range n.List {
clause := n.(*ast.CommClause)
- clause.Body = f.addCounters(clause.Colon+1, clause.End(), clause.Body, false)
+ f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false)
}
return f
}
}
- n.List = f.addCounters(n.Lbrace, n.Rbrace+1, n.List, true) // +1 to step past closing brace.
+ f.addCounters(n.Lbrace, n.Lbrace+1, n.Rbrace+1, n.List, true) // +1 to step past closing brace.
case *ast.IfStmt:
if n.Init != nil {
ast.Walk(f, n.Init)
// if y {
// }
// }
+ f.edit.Insert(f.offset(n.Body.End()), "else{")
+ elseOffset := f.findText(n.Body.End(), "else")
+ if elseOffset < 0 {
+ panic("lost else")
+ }
+ f.edit.Delete(elseOffset, elseOffset+4)
+ f.edit.Insert(f.offset(n.Else.End()), "}")
switch stmt := n.Else.(type) {
case *ast.IfStmt:
block := &ast.BlockStmt{
return f
}
-// fixDirectives identifies //go: comments (known as directives or pragmas) and
-// attaches them to the documentation group of the following top-level
-// definitions. All other comments are dropped since non-documentation comments
-// tend to get printed in the wrong place after the AST is modified.
-//
-// fixDirectives returns a list of unhandled directives. These comments could
-// not be attached to a top-level declaration and should be be printed at the
-// end of the file.
-func (f *File) fixDirectives() []*ast.Comment {
- // Scan comments in the file and collect directives. Detach all comments.
- var directives []*ast.Comment
- prev := token.NoPos // smaller than any valid token.Pos
- for _, cg := range f.astFile.Comments {
- for _, c := range cg.List {
- // Skip directives that will be included by initialComments, i.e., those
- // before the package declaration but not in the file doc comment group.
- if f.isDirective(c) && (c.Pos() >= f.astFile.Package || cg == f.astFile.Doc) {
- // Assume (but verify) that comments are sorted by position.
- pos := c.Pos()
- if !pos.IsValid() {
- log.Fatalf("compiler directive has no position: %q", c.Text)
- } else if pos < prev {
- log.Fatalf("compiler directives are out of order. %s was before %s.",
- f.fset.Position(prev), f.fset.Position(pos))
- }
- prev = pos
-
- directives = append(directives, c)
- }
- }
- cg.List = nil
- }
- f.astFile.Comments = nil // force printer to use node comments
- if len(directives) == 0 {
- // Common case: no directives to attach.
- return nil
- }
-
- // Iterate over top-level declarations and attach preceding directives.
- di := 0
- prev = token.NoPos
- for _, decl := range f.astFile.Decls {
- // Assume (but verify) that declarations are sorted by position.
- pos := decl.Pos()
- if !pos.IsValid() {
- // Synthetic decl. Don't add directives.
- continue
- }
- if pos < prev {
- log.Fatalf("declarations are out of order. %s was before %s.",
- f.fset.Position(prev), f.fset.Position(pos))
- }
- prev = pos
-
- var doc **ast.CommentGroup
- switch d := decl.(type) {
- case *ast.FuncDecl:
- doc = &d.Doc
- case *ast.GenDecl:
- // Limitation: for grouped declarations, we attach directives to the decl,
- // not individual specs. Directives must start in the first column, so
- // they are lost when the group is indented.
- doc = &d.Doc
- default:
- // *ast.BadDecls is the only other type we might see, but
- // we don't need to handle it here.
- continue
- }
-
- for di < len(directives) && directives[di].Pos() < pos {
- c := directives[di]
- if *doc == nil {
- *doc = new(ast.CommentGroup)
- }
- c.Slash = pos - 1 // must be strictly less than pos
- (*doc).List = append((*doc).List, c)
- di++
- }
- }
-
- // Return trailing directives. These cannot apply to a specific declaration
- // and may be printed at the end of the file.
- return directives[di:]
-}
-
// unquote returns the unquoted string.
func unquote(s string) string {
t, err := strconv.Unquote(s)
return t
}
-// addImport adds an import for the specified path, if one does not already exist, and returns
-// the local package name.
-func (f *File) addImport(path string) string {
- // Does the package already import it?
- for _, s := range f.astFile.Imports {
- if unquote(s.Path.Value) == path {
- if s.Name != nil {
- return s.Name.Name
- }
- return filepath.Base(path)
- }
- }
- newImport := &ast.ImportSpec{
- Name: ast.NewIdent(atomicPackageName),
- Path: &ast.BasicLit{
- Kind: token.STRING,
- Value: fmt.Sprintf("%q", path),
- },
- }
- impDecl := &ast.GenDecl{
- Tok: token.IMPORT,
- Specs: []ast.Spec{
- newImport,
- },
- }
- // Make the new import the first Decl in the file.
- astFile := f.astFile
- astFile.Decls = append(astFile.Decls, nil)
- copy(astFile.Decls[1:], astFile.Decls[0:])
- astFile.Decls[0] = impDecl
- astFile.Imports = append(astFile.Imports, newImport)
-
- // Now refer to the package, just in case it ends up unused.
- // That is, append to the end of the file the declaration
- // var _ = _cover_atomic_.AddUint32
- reference := &ast.GenDecl{
- Tok: token.VAR,
- Specs: []ast.Spec{
- &ast.ValueSpec{
- Names: []*ast.Ident{
- ast.NewIdent("_"),
- },
- Values: []ast.Expr{
- &ast.SelectorExpr{
- X: ast.NewIdent(atomicPackageName),
- Sel: ast.NewIdent("AddUint32"),
- },
- },
- },
- },
- }
- astFile.Decls = append(astFile.Decls, reference)
- return atomicPackageName
-}
-
var slashslash = []byte("//")
-// initialComments returns the prefix of content containing only
-// whitespace and line comments. Any +build directives must appear
-// within this region. This approach is more reliable than using
-// go/printer to print a modified AST containing comments.
-//
-func initialComments(content []byte) []byte {
- // Derived from go/build.Context.shouldBuild.
- end := 0
- p := content
- for len(p) > 0 {
- line := p
- if i := bytes.IndexByte(line, '\n'); i >= 0 {
- line, p = line[:i], p[i+1:]
- } else {
- p = p[len(p):]
- }
- line = bytes.TrimSpace(line)
- if len(line) == 0 { // Blank line.
- end = len(content) - len(p)
- continue
- }
- if !bytes.HasPrefix(line, slashslash) { // Not comment line.
- break
- }
- }
- return content[:end]
-}
-
func annotate(name string) {
fset := token.NewFileSet()
content, err := ioutil.ReadFile(name)
file := &File{
fset: fset,
name: name,
+ content: content,
+ edit: edit.NewBuffer(content),
astFile: parsedFile,
}
if *mode == "atomic" {
- file.atomicPkg = file.addImport(atomicPackagePath)
+ // Add import of sync/atomic immediately after package clause.
+ // We do this even if there is an existing import, because the
+ // existing import may be shadowed at any given place we want
+ // to refer to it, and our name (_cover_atomic_) is less likely to
+ // be shadowed.
+ file.edit.Insert(file.offset(file.astFile.Name.End()),
+ fmt.Sprintf("; import %s %q", atomicPackageName, atomicPackagePath))
}
- unhandledDirectives := file.fixDirectives()
ast.Walk(file, file.astFile)
+ newContent := file.edit.Bytes()
+
fd := os.Stdout
if *output != "" {
var err error
log.Fatalf("cover: %s", err)
}
}
- fd.Write(initialComments(content)) // Retain '// +build' directives.
- file.print(fd)
+ fd.Write(newContent)
+
// After printing the source tree, add some declarations for the counters etc.
// We could do this by adding to the tree, but it's easier just to print the text.
file.addVariables(fd)
-
- // Print directives that are not processed in fixDirectives. Some
- // directives are not attached to anything in the tree hence will not
- // be printed by printer. So, we have to explicitly print them here.
- for _, c := range unhandledDirectives {
- fmt.Fprintln(fd, c.Text)
- }
-}
-
-func (f *File) print(w io.Writer) {
- printer.Fprint(w, f.fset, f.astFile)
-}
-
-// isDirective reports whether a comment is a compiler directive.
-func (f *File) isDirective(c *ast.Comment) bool {
- return strings.HasPrefix(c.Text, "//go:") && f.fset.Position(c.Slash).Column == 1
-}
-
-// intLiteral returns an ast.BasicLit representing the integer value.
-func (f *File) intLiteral(i int) *ast.BasicLit {
- node := &ast.BasicLit{
- Kind: token.INT,
- Value: fmt.Sprint(i),
- }
- return node
-}
-
-// index returns an ast.BasicLit representing the number of counters present.
-func (f *File) index() *ast.BasicLit {
- return f.intLiteral(len(f.blocks))
}
// setCounterStmt returns the expression: __count[23] = 1.
-func setCounterStmt(f *File, counter ast.Expr) ast.Stmt {
- return &ast.AssignStmt{
- Lhs: []ast.Expr{counter},
- Tok: token.ASSIGN,
- Rhs: []ast.Expr{f.intLiteral(1)},
- }
+func setCounterStmt(f *File, counter string) string {
+ return fmt.Sprintf("%s = 1", counter)
}
// incCounterStmt returns the expression: __count[23]++.
-func incCounterStmt(f *File, counter ast.Expr) ast.Stmt {
- return &ast.IncDecStmt{
- X: counter,
- Tok: token.INC,
- }
+func incCounterStmt(f *File, counter string) string {
+ return fmt.Sprintf("%s++", counter)
}
// atomicCounterStmt returns the expression: atomic.AddUint32(&__count[23], 1)
-func atomicCounterStmt(f *File, counter ast.Expr) ast.Stmt {
- return &ast.ExprStmt{
- X: &ast.CallExpr{
- Fun: &ast.SelectorExpr{
- X: ast.NewIdent(f.atomicPkg),
- Sel: ast.NewIdent("AddUint32"),
- },
- Args: []ast.Expr{&ast.UnaryExpr{
- Op: token.AND,
- X: counter,
- },
- f.intLiteral(1),
- },
- },
- }
+func atomicCounterStmt(f *File, counter string) string {
+ return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter)
}
// newCounter creates a new counter expression of the appropriate form.
-func (f *File) newCounter(start, end token.Pos, numStmt int) ast.Stmt {
- counter := &ast.IndexExpr{
- X: &ast.SelectorExpr{
- X: ast.NewIdent(*varVar),
- Sel: ast.NewIdent("Count"),
- },
- Index: f.index(),
- }
- stmt := counterStmt(f, counter)
+func (f *File) newCounter(start, end token.Pos, numStmt int) string {
+ stmt := counterStmt(f, fmt.Sprintf("%s.Count[%d]", *varVar, len(f.blocks)))
f.blocks = append(f.blocks, Block{start, end, numStmt})
return stmt
}
// counters will be added before S1 and before S3. The block containing S2
// will be visited in a separate call.
// TODO: Nested simple blocks get unnecessary (but correct) counters
-func (f *File) addCounters(pos, blockEnd token.Pos, list []ast.Stmt, extendToClosingBrace bool) []ast.Stmt {
+func (f *File) addCounters(pos, insertPos, blockEnd token.Pos, list []ast.Stmt, extendToClosingBrace bool) {
// Special case: make sure we add a counter to an empty block. Can't do this below
// or we will add a counter to an empty statement list after, say, a return statement.
if len(list) == 0 {
- return []ast.Stmt{f.newCounter(pos, blockEnd, 0)}
+ f.edit.Insert(f.offset(insertPos), f.newCounter(insertPos, blockEnd, 0)+";")
+ return
}
// We have a block (statement list), but it may have several basic blocks due to the
// appearance of statements that affect the flow of control.
- var newList []ast.Stmt
for {
// Find first statement that affects flow of control (break, continue, if, etc.).
// It will be the last statement of this basic block.
end = blockEnd
}
if pos != end { // Can have no source to cover if e.g. blocks abut.
- newList = append(newList, f.newCounter(pos, end, last))
+ f.edit.Insert(f.offset(insertPos), f.newCounter(pos, end, last)+";")
}
- newList = append(newList, list[0:last]...)
list = list[last:]
if len(list) == 0 {
break
}
pos = list[0].Pos()
+ insertPos = pos
}
- return newList
}
// hasFuncLiteral reports the existence and position of the first func literal
// Close the struct initialization.
fmt.Fprintf(w, "}\n")
+
+ // Emit a reference to the atomic package to avoid
+ // import and not used error when there's no code in a file.
+ if *mode == "atomic" {
+ fmt.Fprintf(w, "var _ = %s.LoadUint32\n", atomicPackageName)
+ }
}
--- /dev/null
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package edit implements buffered position-based editing of byte slices.
+package edit
+
+import (
+ "fmt"
+ "sort"
+)
+
+// A Buffer is a queue of edits to apply to a given byte slice.
+type Buffer struct {
+ old []byte
+ q edits
+}
+
+// An edit records a single text modification: change the bytes in [start,end) to new.
+type edit struct {
+ start int
+ end int
+ new string
+}
+
+// An edits is a list of edits that is sortable by start offset, breaking ties by end offset.
+type edits []edit
+
+func (x edits) Len() int { return len(x) }
+func (x edits) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+func (x edits) Less(i, j int) bool {
+ if x[i].start != x[j].start {
+ return x[i].start < x[j].start
+ }
+ return x[i].end < x[j].end
+}
+
+// NewBuffer returns a new buffer to accumulate changes to an initial data slice.
+// The returned buffer maintains a reference to the data, so the caller must ensure
+// the data is not modified until after the Buffer is done being used.
+func NewBuffer(data []byte) *Buffer {
+ return &Buffer{old: data}
+}
+
+func (b *Buffer) Insert(pos int, new string) {
+ if pos < 0 || pos > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{pos, pos, new})
+}
+
+func (b *Buffer) Delete(start, end int) {
+ if end < start || start < 0 || end > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{start, end, ""})
+}
+
+func (b *Buffer) Replace(start, end int, new string) {
+ if end < start || start < 0 || end > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{start, end, new})
+}
+
+// Bytes returns a new byte slice containing the original data
+// with the queued edits applied.
+func (b *Buffer) Bytes() []byte {
+ // Sort edits by starting position and then by ending position.
+ // Breaking ties by ending position allows insertions at point x
+ // to be applied before a replacement of the text at [x, y).
+ sort.Stable(b.q)
+
+ var new []byte
+ offset := 0
+ for i, e := range b.q {
+ if e.start < offset {
+ e0 := b.q[i-1]
+ panic(fmt.Sprintf("overlapping edits: [%d,%d)->%q, [%d,%d)->%q", e0.start, e0.end, e0.new, e.start, e.end, e.new))
+ }
+ new = append(new, b.old[offset:e.start]...)
+ offset = e.end
+ new = append(new, e.new...)
+ }
+ new = append(new, b.old[offset:]...)
+ return new
+}
+
+// String returns a string containing the original data
+// with the queued edits applied.
+func (b *Buffer) String() string {
+ return string(b.Bytes())
+}