]> Cypherpunks repositories - gostls13.git/commitdiff
go/ast: fix ast.Walk
authorRobert Griesemer <gri@golang.org>
Thu, 9 Dec 2010 18:22:01 +0000 (10:22 -0800)
committerRobert Griesemer <gri@golang.org>
Thu, 9 Dec 2010 18:22:01 +0000 (10:22 -0800)
- change Walk signature to use an ast.Node instead of interface{}
- add Pos functions to a couple of ast types to make them proper nodes
- explicit nil checks where a node can be nil; incorrect ASTs cause Walk to crash

For now ast.Walk is exercised extensively as part of godoc's indexer;
so we have some confidence in its correctness. But this needs a test,
eventually.

Fixes #1326.

R=rsc, r
CC=golang-dev
https://golang.org/cl/3481043

src/cmd/godoc/index.go
src/cmd/gofmt/simplify.go
src/cmd/govet/govet.go
src/pkg/go/ast/ast.go
src/pkg/go/ast/walk.go
src/pkg/go/printer/nodes.go

index b0bb8cef3bbfb8659c592d6611dd4e363e3da79b..6f41f1819dba4f91ef6979b1cd123f2ad8cb3347 100644 (file)
@@ -509,7 +509,7 @@ func (x *Indexer) visitSpec(spec ast.Spec, isVarDecl bool) {
 }
 
 
-func (x *Indexer) Visit(node interface{}) ast.Visitor {
+func (x *Indexer) Visit(node ast.Node) ast.Visitor {
        // TODO(gri): methods in interface types are categorized as VarDecl
        switch n := node.(type) {
        case nil:
index b166f04d8eb6bef7e7dbf2f87f3f72f7f33d6f71..bcc67c4a6e93baf10a4aabc8db3cdb3e0fcaa345 100644 (file)
@@ -12,7 +12,7 @@ import (
 
 type simplifier struct{}
 
-func (s *simplifier) Visit(node interface{}) ast.Visitor {
+func (s *simplifier) Visit(node ast.Node) ast.Visitor {
        switch n := node.(type) {
        case *ast.CompositeLit:
                // array, slice, and map composite literals may be simplified
@@ -61,7 +61,7 @@ func (s *simplifier) Visit(node interface{}) ast.Visitor {
 }
 
 
-func simplify(node interface{}) {
+func simplify(node ast.Node) {
        var s simplifier
        ast.Walk(&s, node)
 }
index c748c8018f1647e0ba9f12152e7008e88d05baf1..4ab908ae296708fa0ab4b574d009a26a928fc4ed 100644 (file)
@@ -139,7 +139,7 @@ func (f *File) checkFile(name string, file *ast.File) {
 }
 
 // Visit implements the ast.Visitor interface.
-func (f *File) Visit(node interface{}) ast.Visitor {
+func (f *File) Visit(node ast.Node) ast.Visitor {
        // TODO: could return nil for nodes that cannot contain a CallExpr -
        // will shortcut traversal.  Worthwhile?
        switch n := node.(type) {
index da1f428e32f5459ae537a9cfbff3c5dfab62b6ff..dfd950a943ffb9914e7604f783c8607971a111a5 100644 (file)
@@ -70,19 +70,20 @@ type Comment struct {
 }
 
 
-func (c *Comment) Pos() token.Pos {
-       return c.Slash
-}
+func (c *Comment) Pos() token.Pos { return c.Slash }
 
 
 // A CommentGroup represents a sequence of comments
 // with no other tokens and no empty lines between.
 //
 type CommentGroup struct {
-       List []*Comment
+       List []*Comment // len(List) > 0
 }
 
 
+func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() }
+
+
 // ----------------------------------------------------------------------------
 // Expressions and types
 
@@ -115,6 +116,9 @@ type FieldList struct {
 }
 
 
+func (list *FieldList) Pos() token.Pos { return list.Opening }
+
+
 // NumFields returns the number of (named and anonymous fields) in a FieldList.
 func (f *FieldList) NumFields() int {
        n := 0
@@ -294,7 +298,7 @@ type (
        FuncType struct {
                Func    token.Pos  // position of "func" keyword
                Params  *FieldList // (incoming) parameters
-               Results *FieldList // (outgoing) results
+               Results *FieldList // (outgoing) results; or nil
        }
 
        // An InterfaceType node represents an interface type.
@@ -494,7 +498,7 @@ type (
        BranchStmt struct {
                TokPos token.Pos   // position of Tok
                Tok    token.Token // keyword token (BREAK, CONTINUE, GOTO, FALLTHROUGH)
-               Label  *Ident
+               Label  *Ident      // label name; or nil
        }
 
        // A BlockStmt node represents a braced statement list.
@@ -507,10 +511,10 @@ type (
        // An IfStmt node represents an if statement.
        IfStmt struct {
                If   token.Pos // position of "if" keyword
-               Init Stmt
-               Cond Expr
+               Init Stmt      // initalization statement; or nil
+               Cond Expr      // condition; or nil
                Body *BlockStmt
-               Else Stmt
+               Else Stmt // else branch; or nil
        }
 
        // A CaseClause represents a case of an expression switch statement.
@@ -523,9 +527,9 @@ type (
 
        // A SwitchStmt node represents an expression switch statement.
        SwitchStmt struct {
-               Switch token.Pos // position of "switch" keyword
-               Init   Stmt
-               Tag    Expr
+               Switch token.Pos  // position of "switch" keyword
+               Init   Stmt       // initalization statement; or nil
+               Tag    Expr       // tag expression; or nil
                Body   *BlockStmt // CaseClauses only
        }
 
@@ -539,8 +543,8 @@ type (
 
        // An TypeSwitchStmt node represents a type switch statement.
        TypeSwitchStmt struct {
-               Switch token.Pos // position of "switch" keyword
-               Init   Stmt
+               Switch token.Pos  // position of "switch" keyword
+               Init   Stmt       // initalization statement; or nil
                Assign Stmt       // x := y.(type)
                Body   *BlockStmt // TypeCaseClauses only
        }
@@ -563,9 +567,9 @@ type (
        // A ForStmt represents a for statement.
        ForStmt struct {
                For  token.Pos // position of "for" keyword
-               Init Stmt
-               Cond Expr
-               Post Stmt
+               Init Stmt      // initalization statement; or nil
+               Cond Expr      // condition; or nil
+               Post Stmt      // post iteration statement; or nil
                Body *BlockStmt
        }
 
@@ -780,3 +784,13 @@ type Package struct {
        Scope *Scope           // package scope; or nil
        Files map[string]*File // Go source files by filename
 }
+
+
+func (p *Package) Pos() (pos token.Pos) {
+       // get the position of the package clause of the first file, if any
+       for _, f := range p.Files {
+               pos = f.Pos()
+               break
+       }
+       return
+}
index 296da5652de700eb8726853de37fc3535188adfb..eb4780942293233773757ba3d14a6c96684ce943 100644 (file)
@@ -10,51 +10,57 @@ import "fmt"
 // If the result visitor w is not nil, Walk visits each of the children
 // of node with the visitor w, followed by a call of w.Visit(nil).
 type Visitor interface {
-       Visit(node interface{}) (w Visitor)
+       Visit(node Node) (w Visitor)
 }
 
 
-func walkIdent(v Visitor, x *Ident) {
-       if x != nil {
+// Helper functions for common node lists. They may be empty.
+
+func walkIdentList(v Visitor, list []*Ident) {
+       for _, x := range list {
                Walk(v, x)
        }
 }
 
 
-func walkCommentGroup(v Visitor, g *CommentGroup) {
-       if g != nil {
-               Walk(v, g)
+func walkExprList(v Visitor, list []Expr) {
+       for _, x := range list {
+               Walk(v, x)
        }
 }
 
 
-func walkBlockStmt(v Visitor, b *BlockStmt) {
-       if b != nil {
-               Walk(v, b)
+func walkStmtList(v Visitor, list []Stmt) {
+       for _, x := range list {
+               Walk(v, x)
        }
 }
 
 
-// Walk traverses an AST in depth-first order: If node != nil, it
-// invokes v.Visit(node). If the visitor w returned by v.Visit(node) is
-// not nil, Walk visits each of the children of node with the visitor w,
-// followed by a call of w.Visit(nil).
-//
-// Walk may be called with any of the named ast node types. It also
-// accepts arguments of type []*Field, []*Ident, []Expr, []Stmt and []Decl;
-// the respective children are the slice elements.
-//
-func Walk(v Visitor, node interface{}) {
-       if node == nil {
-               return
+func walkDeclList(v Visitor, list []Decl) {
+       for _, x := range list {
+               Walk(v, x)
        }
+}
+
+
+// TODO(gri): Investigate if providing a closure to Walk leads to
+//            simpler use (and may help eliminate Inspect in turn).
+
+// Walk traverses an AST in depth-first order: It starts by calling
+// v.Visit(node); node must not be nil. If the visitor w returned by
+// v.Visit(node) is not nil, Walk is invoked recursively with visitor
+// w for each of the non-nil children of node, followed by a call of
+// w.Visit(nil).
+//
+func Walk(v Visitor, node Node) {
        if v = v.Visit(node); v == nil {
                return
        }
 
        // walk children
        // (the order of the cases matches the order
-       // of the corresponding declaration in ast.go)
+       // of the corresponding node types in ast.go)
        switch n := node.(type) {
        // Comments and fields
        case *Comment:
@@ -66,11 +72,17 @@ func Walk(v Visitor, node interface{}) {
                }
 
        case *Field:
-               walkCommentGroup(v, n.Doc)
-               Walk(v, n.Names)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
+               walkIdentList(v, n.Names)
                Walk(v, n.Type)
-               Walk(v, n.Tag)
-               walkCommentGroup(v, n.Comment)
+               if n.Tag != nil {
+                       Walk(v, n.Tag)
+               }
+               if n.Comment != nil {
+                       Walk(v, n.Comment)
+               }
 
        case *FieldList:
                for _, f := range n.List {
@@ -82,21 +94,21 @@ func Walk(v Visitor, node interface{}) {
                // nothing to do
 
        case *FuncLit:
-               if n != nil {
-                       Walk(v, n.Type)
-               }
-               walkBlockStmt(v, n.Body)
+               Walk(v, n.Type)
+               Walk(v, n.Body)
 
        case *CompositeLit:
-               Walk(v, n.Type)
-               Walk(v, n.Elts)
+               if n.Type != nil {
+                       Walk(v, n.Type)
+               }
+               walkExprList(v, n.Elts)
 
        case *ParenExpr:
                Walk(v, n.X)
 
        case *SelectorExpr:
                Walk(v, n.X)
-               walkIdent(v, n.Sel)
+               Walk(v, n.Sel)
 
        case *IndexExpr:
                Walk(v, n.X)
@@ -104,16 +116,22 @@ func Walk(v Visitor, node interface{}) {
 
        case *SliceExpr:
                Walk(v, n.X)
-               Walk(v, n.Index)
-               Walk(v, n.End)
+               if n.Index != nil {
+                       Walk(v, n.Index)
+               }
+               if n.End != nil {
+                       Walk(v, n.End)
+               }
 
        case *TypeAssertExpr:
                Walk(v, n.X)
-               Walk(v, n.Type)
+               if n.Type != nil {
+                       Walk(v, n.Type)
+               }
 
        case *CallExpr:
                Walk(v, n.Fun)
-               Walk(v, n.Args)
+               walkExprList(v, n.Args)
 
        case *StarExpr:
                Walk(v, n.X)
@@ -131,7 +149,9 @@ func Walk(v Visitor, node interface{}) {
 
        // Types
        case *ArrayType:
-               Walk(v, n.Len)
+               if n.Len != nil {
+                       Walk(v, n.Len)
+               }
                Walk(v, n.Elt)
 
        case *StructType:
@@ -164,7 +184,7 @@ func Walk(v Visitor, node interface{}) {
                // nothing to do
 
        case *LabeledStmt:
-               walkIdent(v, n.Label)
+               Walk(v, n.Label)
                Walk(v, n.Stmt)
 
        case *ExprStmt:
@@ -174,148 +194,177 @@ func Walk(v Visitor, node interface{}) {
                Walk(v, n.X)
 
        case *AssignStmt:
-               Walk(v, n.Lhs)
-               Walk(v, n.Rhs)
+               walkExprList(v, n.Lhs)
+               walkExprList(v, n.Rhs)
 
        case *GoStmt:
-               if n.Call != nil {
-                       Walk(v, n.Call)
-               }
+               Walk(v, n.Call)
 
        case *DeferStmt:
-               if n.Call != nil {
-                       Walk(v, n.Call)
-               }
+               Walk(v, n.Call)
 
        case *ReturnStmt:
-               Walk(v, n.Results)
+               walkExprList(v, n.Results)
 
        case *BranchStmt:
-               walkIdent(v, n.Label)
+               if n.Label != nil {
+                       Walk(v, n.Label)
+               }
 
        case *BlockStmt:
-               Walk(v, n.List)
+               walkStmtList(v, n.List)
 
        case *IfStmt:
-               Walk(v, n.Init)
-               Walk(v, n.Cond)
-               walkBlockStmt(v, n.Body)
-               Walk(v, n.Else)
+               if n.Init != nil {
+                       Walk(v, n.Init)
+               }
+               if n.Cond != nil {
+                       Walk(v, n.Cond)
+               }
+               Walk(v, n.Body)
+               if n.Else != nil {
+                       Walk(v, n.Else)
+               }
 
        case *CaseClause:
-               Walk(v, n.Values)
-               Walk(v, n.Body)
+               walkExprList(v, n.Values)
+               walkStmtList(v, n.Body)
 
        case *SwitchStmt:
-               Walk(v, n.Init)
-               Walk(v, n.Tag)
-               walkBlockStmt(v, n.Body)
+               if n.Init != nil {
+                       Walk(v, n.Init)
+               }
+               if n.Tag != nil {
+                       Walk(v, n.Tag)
+               }
+               Walk(v, n.Body)
 
        case *TypeCaseClause:
-               Walk(v, n.Types)
-               Walk(v, n.Body)
+               for _, x := range n.Types {
+                       Walk(v, x)
+               }
+               walkStmtList(v, n.Body)
 
        case *TypeSwitchStmt:
-               Walk(v, n.Init)
+               if n.Init != nil {
+                       Walk(v, n.Init)
+               }
                Walk(v, n.Assign)
-               walkBlockStmt(v, n.Body)
+               Walk(v, n.Body)
 
        case *CommClause:
-               Walk(v, n.Lhs)
-               Walk(v, n.Rhs)
-               Walk(v, n.Body)
+               if n.Lhs != nil {
+                       Walk(v, n.Lhs)
+               }
+               if n.Rhs != nil {
+                       Walk(v, n.Rhs)
+               }
+               walkStmtList(v, n.Body)
 
        case *SelectStmt:
-               walkBlockStmt(v, n.Body)
+               Walk(v, n.Body)
 
        case *ForStmt:
-               Walk(v, n.Init)
-               Walk(v, n.Cond)
-               Walk(v, n.Post)
-               walkBlockStmt(v, n.Body)
+               if n.Init != nil {
+                       Walk(v, n.Init)
+               }
+               if n.Cond != nil {
+                       Walk(v, n.Cond)
+               }
+               if n.Post != nil {
+                       Walk(v, n.Post)
+               }
+               Walk(v, n.Body)
 
        case *RangeStmt:
                Walk(v, n.Key)
-               Walk(v, n.Value)
+               if n.Value != nil {
+                       Walk(v, n.Value)
+               }
                Walk(v, n.X)
-               walkBlockStmt(v, n.Body)
+               Walk(v, n.Body)
 
        // Declarations
        case *ImportSpec:
-               walkCommentGroup(v, n.Doc)
-               walkIdent(v, n.Name)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
+               if n.Name != nil {
+                       Walk(v, n.Name)
+               }
                Walk(v, n.Path)
-               walkCommentGroup(v, n.Comment)
+               if n.Comment != nil {
+                       Walk(v, n.Comment)
+               }
 
        case *ValueSpec:
-               walkCommentGroup(v, n.Doc)
-               Walk(v, n.Names)
-               Walk(v, n.Type)
-               Walk(v, n.Values)
-               walkCommentGroup(v, n.Comment)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
+               walkIdentList(v, n.Names)
+               if n.Type != nil {
+                       Walk(v, n.Type)
+               }
+               walkExprList(v, n.Values)
+               if n.Comment != nil {
+                       Walk(v, n.Comment)
+               }
 
        case *TypeSpec:
-               walkCommentGroup(v, n.Doc)
-               walkIdent(v, n.Name)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
+               Walk(v, n.Name)
                Walk(v, n.Type)
-               walkCommentGroup(v, n.Comment)
+               if n.Comment != nil {
+                       Walk(v, n.Comment)
+               }
 
        case *BadDecl:
                // nothing to do
 
        case *GenDecl:
-               walkCommentGroup(v, n.Doc)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
                for _, s := range n.Specs {
                        Walk(v, s)
                }
 
        case *FuncDecl:
-               walkCommentGroup(v, n.Doc)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
                if n.Recv != nil {
                        Walk(v, n.Recv)
                }
-               walkIdent(v, n.Name)
-               if n.Type != nil {
-                       Walk(v, n.Type)
+               Walk(v, n.Name)
+               Walk(v, n.Type)
+               if n.Body != nil {
+                       Walk(v, n.Body)
                }
-               walkBlockStmt(v, n.Body)
 
        // Files and packages
        case *File:
-               walkCommentGroup(v, n.Doc)
-               walkIdent(v, n.Name)
-               Walk(v, n.Decls)
+               if n.Doc != nil {
+                       Walk(v, n.Doc)
+               }
+               Walk(v, n.Name)
+               walkDeclList(v, n.Decls)
                for _, g := range n.Comments {
                        Walk(v, g)
                }
+               // don't walk n.Comments - they have been
+               // visited already through the individual
+               // nodes
 
        case *Package:
                for _, f := range n.Files {
                        Walk(v, f)
                }
 
-       case []*Ident:
-               for _, x := range n {
-                       Walk(v, x)
-               }
-
-       case []Expr:
-               for _, x := range n {
-                       Walk(v, x)
-               }
-
-       case []Stmt:
-               for _, x := range n {
-                       Walk(v, x)
-               }
-
-       case []Decl:
-               for _, x := range n {
-                       Walk(v, x)
-               }
-
        default:
-               fmt.Printf("ast.Walk: unexpected type %T", n)
+               fmt.Printf("ast.Walk: unexpected node type %T", n)
                panic("ast.Walk")
        }
 
@@ -323,20 +372,20 @@ func Walk(v Visitor, node interface{}) {
 }
 
 
-type inspector func(node interface{}) bool
+type inspector func(Node) bool
 
-func (f inspector) Visit(node interface{}) Visitor {
-       if node != nil && f(node) {
+func (f inspector) Visit(node Node) Visitor {
+       if f(node) {
                return f
        }
        return nil
 }
 
 
-// Inspect traverses an AST in depth-first order: If node != nil, it
-// invokes f(node). If f returns true, inspect invokes f for all the
-// non-nil children of node, recursively.
+// Inspect traverses an AST in depth-first order: It starts by calling
+// f(node); node must not be nil. If f returns true, Inspect invokes f
+// for all the non-nil children of node, recursively.
 //
-func Inspect(ast interface{}, f func(node interface{}) bool) {
-       Walk(inspector(f), ast)
+func Inspect(node Node, f func(Node) bool) {
+       Walk(inspector(f), node)
 }
index 7ae7b54b5e648b64b0a2aa274e17a88d01167732..dc311644ffecc57d0dc5fc7c38a456f7d80206e1 100644 (file)
@@ -994,7 +994,7 @@ func stripParens(x ast.Expr) ast.Expr {
                // parentheses must not be stripped if there are any
                // unparenthesized composite literals starting with
                // a type name
-               ast.Inspect(px.X, func(node interface{}) bool {
+               ast.Inspect(px.X, func(node ast.Node) bool {
                        switch x := node.(type) {
                        case *ast.ParenExpr:
                                // parentheses protect enclosed composite literals