if f.Ref == nil {
f.Ref = make([]*Ref, 0, 8)
}
- f.walk(ast2, "prog", (*File).saveExprs)
+ f.walk(ast2, ctxProg, (*File).saveExprs)
// Accumulate exported functions.
// The comments are only on ast1 but we need to
// The first walk fills in ExpFunc, and the
// second walk changes the entries to
// refer to ast2 instead.
- f.walk(ast1, "prog", (*File).saveExport)
- f.walk(ast2, "prog", (*File).saveExport2)
+ f.walk(ast1, ctxProg, (*File).saveExport)
+ f.walk(ast2, ctxProg, (*File).saveExport2)
f.Comments = ast1.Comments
f.AST = ast2
}
// Save various references we are going to need later.
-func (f *File) saveExprs(x interface{}, context string) {
+func (f *File) saveExprs(x interface{}, context astContext) {
switch x := x.(type) {
case *ast.Expr:
switch (*x).(type) {
}
// Save references to C.xxx for later processing.
-func (f *File) saveRef(n *ast.Expr, context string) {
+func (f *File) saveRef(n *ast.Expr, context astContext) {
sel := (*n).(*ast.SelectorExpr)
// For now, assume that the only instance of capital C is when
// used as the imported package identifier.
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
return
}
- if context == "as2" {
- context = "expr"
+ if context == ctxAssign2 {
+ context = ctxExpr
}
- if context == "embed-type" {
+ if context == ctxEmbedType {
error_(sel.Pos(), "cannot embed C type")
}
goname := sel.Sel.Name
}
// Save calls to C.xxx for later processing.
-func (f *File) saveCall(call *ast.CallExpr, context string) {
+func (f *File) saveCall(call *ast.CallExpr, context astContext) {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return
if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
return
}
- c := &Call{Call: call, Deferred: context == "defer"}
+ c := &Call{Call: call, Deferred: context == ctxDefer}
f.Calls = append(f.Calls, c)
}
// If a function should be exported add it to ExpFunc.
-func (f *File) saveExport(x interface{}, context string) {
+func (f *File) saveExport(x interface{}, context astContext) {
n, ok := x.(*ast.FuncDecl)
if !ok {
return
}
// Make f.ExpFunc[i] point at the Func from this AST instead of the other one.
-func (f *File) saveExport2(x interface{}, context string) {
+func (f *File) saveExport2(x interface{}, context astContext) {
n, ok := x.(*ast.FuncDecl)
if !ok {
return
}
}
+type astContext int
+
+const (
+ ctxProg astContext = iota
+ ctxEmbedType
+ ctxType
+ ctxStmt
+ ctxExpr
+ ctxField
+ ctxParam
+ ctxAssign2 // assignment of a single expression to two variables
+ ctxSwitch
+ ctxTypeSwitch
+ ctxFile
+ ctxDecl
+ ctxSpec
+ ctxDefer
+ ctxCall // any function call other than ctxCall2
+ ctxCall2 // function call whose result is assigned to two variables
+ ctxSelector
+)
+
// walk walks the AST x, calling visit(f, x, context) for each node.
-func (f *File) walk(x interface{}, context string, visit func(*File, interface{}, string)) {
+func (f *File) walk(x interface{}, context astContext, visit func(*File, interface{}, astContext)) {
visit(f, x, context)
switch n := x.(type) {
case *ast.Expr:
// These are ordered and grouped to match ../../go/ast/ast.go
case *ast.Field:
- if len(n.Names) == 0 && context == "field" {
- f.walk(&n.Type, "embed-type", visit)
+ if len(n.Names) == 0 && context == ctxField {
+ f.walk(&n.Type, ctxEmbedType, visit)
} else {
- f.walk(&n.Type, "type", visit)
+ f.walk(&n.Type, ctxType, visit)
}
case *ast.FieldList:
for _, field := range n.List {
case *ast.Ellipsis:
case *ast.BasicLit:
case *ast.FuncLit:
- f.walk(n.Type, "type", visit)
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Type, ctxType, visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.CompositeLit:
- f.walk(&n.Type, "type", visit)
- f.walk(n.Elts, "expr", visit)
+ f.walk(&n.Type, ctxType, visit)
+ f.walk(n.Elts, ctxExpr, visit)
case *ast.ParenExpr:
f.walk(&n.X, context, visit)
case *ast.SelectorExpr:
- f.walk(&n.X, "selector", visit)
+ f.walk(&n.X, ctxSelector, visit)
case *ast.IndexExpr:
- f.walk(&n.X, "expr", visit)
- f.walk(&n.Index, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
+ f.walk(&n.Index, ctxExpr, visit)
case *ast.SliceExpr:
- f.walk(&n.X, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
if n.Low != nil {
- f.walk(&n.Low, "expr", visit)
+ f.walk(&n.Low, ctxExpr, visit)
}
if n.High != nil {
- f.walk(&n.High, "expr", visit)
+ f.walk(&n.High, ctxExpr, visit)
}
if n.Max != nil {
- f.walk(&n.Max, "expr", visit)
+ f.walk(&n.Max, ctxExpr, visit)
}
case *ast.TypeAssertExpr:
- f.walk(&n.X, "expr", visit)
- f.walk(&n.Type, "type", visit)
+ f.walk(&n.X, ctxExpr, visit)
+ f.walk(&n.Type, ctxType, visit)
case *ast.CallExpr:
- if context == "as2" {
- f.walk(&n.Fun, "call2", visit)
+ if context == ctxAssign2 {
+ f.walk(&n.Fun, ctxCall2, visit)
} else {
- f.walk(&n.Fun, "call", visit)
+ f.walk(&n.Fun, ctxCall, visit)
}
- f.walk(n.Args, "expr", visit)
+ f.walk(n.Args, ctxExpr, visit)
case *ast.StarExpr:
f.walk(&n.X, context, visit)
case *ast.UnaryExpr:
- f.walk(&n.X, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
case *ast.BinaryExpr:
- f.walk(&n.X, "expr", visit)
- f.walk(&n.Y, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
+ f.walk(&n.Y, ctxExpr, visit)
case *ast.KeyValueExpr:
- f.walk(&n.Key, "expr", visit)
- f.walk(&n.Value, "expr", visit)
+ f.walk(&n.Key, ctxExpr, visit)
+ f.walk(&n.Value, ctxExpr, visit)
case *ast.ArrayType:
- f.walk(&n.Len, "expr", visit)
- f.walk(&n.Elt, "type", visit)
+ f.walk(&n.Len, ctxExpr, visit)
+ f.walk(&n.Elt, ctxType, visit)
case *ast.StructType:
- f.walk(n.Fields, "field", visit)
+ f.walk(n.Fields, ctxField, visit)
case *ast.FuncType:
- f.walk(n.Params, "param", visit)
+ f.walk(n.Params, ctxParam, visit)
if n.Results != nil {
- f.walk(n.Results, "param", visit)
+ f.walk(n.Results, ctxParam, visit)
}
case *ast.InterfaceType:
- f.walk(n.Methods, "field", visit)
+ f.walk(n.Methods, ctxField, visit)
case *ast.MapType:
- f.walk(&n.Key, "type", visit)
- f.walk(&n.Value, "type", visit)
+ f.walk(&n.Key, ctxType, visit)
+ f.walk(&n.Value, ctxType, visit)
case *ast.ChanType:
- f.walk(&n.Value, "type", visit)
+ f.walk(&n.Value, ctxType, visit)
case *ast.BadStmt:
case *ast.DeclStmt:
- f.walk(n.Decl, "decl", visit)
+ f.walk(n.Decl, ctxDecl, visit)
case *ast.EmptyStmt:
case *ast.LabeledStmt:
- f.walk(n.Stmt, "stmt", visit)
+ f.walk(n.Stmt, ctxStmt, visit)
case *ast.ExprStmt:
- f.walk(&n.X, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
case *ast.SendStmt:
- f.walk(&n.Chan, "expr", visit)
- f.walk(&n.Value, "expr", visit)
+ f.walk(&n.Chan, ctxExpr, visit)
+ f.walk(&n.Value, ctxExpr, visit)
case *ast.IncDecStmt:
- f.walk(&n.X, "expr", visit)
+ f.walk(&n.X, ctxExpr, visit)
case *ast.AssignStmt:
- f.walk(n.Lhs, "expr", visit)
+ f.walk(n.Lhs, ctxExpr, visit)
if len(n.Lhs) == 2 && len(n.Rhs) == 1 {
- f.walk(n.Rhs, "as2", visit)
+ f.walk(n.Rhs, ctxAssign2, visit)
} else {
- f.walk(n.Rhs, "expr", visit)
+ f.walk(n.Rhs, ctxExpr, visit)
}
case *ast.GoStmt:
- f.walk(n.Call, "expr", visit)
+ f.walk(n.Call, ctxExpr, visit)
case *ast.DeferStmt:
- f.walk(n.Call, "defer", visit)
+ f.walk(n.Call, ctxDefer, visit)
case *ast.ReturnStmt:
- f.walk(n.Results, "expr", visit)
+ f.walk(n.Results, ctxExpr, visit)
case *ast.BranchStmt:
case *ast.BlockStmt:
f.walk(n.List, context, visit)
case *ast.IfStmt:
- f.walk(n.Init, "stmt", visit)
- f.walk(&n.Cond, "expr", visit)
- f.walk(n.Body, "stmt", visit)
- f.walk(n.Else, "stmt", visit)
+ f.walk(n.Init, ctxStmt, visit)
+ f.walk(&n.Cond, ctxExpr, visit)
+ f.walk(n.Body, ctxStmt, visit)
+ f.walk(n.Else, ctxStmt, visit)
case *ast.CaseClause:
- if context == "typeswitch" {
- context = "type"
+ if context == ctxTypeSwitch {
+ context = ctxType
} else {
- context = "expr"
+ context = ctxExpr
}
f.walk(n.List, context, visit)
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.SwitchStmt:
- f.walk(n.Init, "stmt", visit)
- f.walk(&n.Tag, "expr", visit)
- f.walk(n.Body, "switch", visit)
+ f.walk(n.Init, ctxStmt, visit)
+ f.walk(&n.Tag, ctxExpr, visit)
+ f.walk(n.Body, ctxSwitch, visit)
case *ast.TypeSwitchStmt:
- f.walk(n.Init, "stmt", visit)
- f.walk(n.Assign, "stmt", visit)
- f.walk(n.Body, "typeswitch", visit)
+ f.walk(n.Init, ctxStmt, visit)
+ f.walk(n.Assign, ctxStmt, visit)
+ f.walk(n.Body, ctxTypeSwitch, visit)
case *ast.CommClause:
- f.walk(n.Comm, "stmt", visit)
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Comm, ctxStmt, visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.SelectStmt:
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.ForStmt:
- f.walk(n.Init, "stmt", visit)
- f.walk(&n.Cond, "expr", visit)
- f.walk(n.Post, "stmt", visit)
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Init, ctxStmt, visit)
+ f.walk(&n.Cond, ctxExpr, visit)
+ f.walk(n.Post, ctxStmt, visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.RangeStmt:
- f.walk(&n.Key, "expr", visit)
- f.walk(&n.Value, "expr", visit)
- f.walk(&n.X, "expr", visit)
- f.walk(n.Body, "stmt", visit)
+ f.walk(&n.Key, ctxExpr, visit)
+ f.walk(&n.Value, ctxExpr, visit)
+ f.walk(&n.X, ctxExpr, visit)
+ f.walk(n.Body, ctxStmt, visit)
case *ast.ImportSpec:
case *ast.ValueSpec:
- f.walk(&n.Type, "type", visit)
+ f.walk(&n.Type, ctxType, visit)
if len(n.Names) == 2 && len(n.Values) == 1 {
- f.walk(&n.Values[0], "as2", visit)
+ f.walk(&n.Values[0], ctxAssign2, visit)
} else {
- f.walk(n.Values, "expr", visit)
+ f.walk(n.Values, ctxExpr, visit)
}
case *ast.TypeSpec:
- f.walk(&n.Type, "type", visit)
+ f.walk(&n.Type, ctxType, visit)
case *ast.BadDecl:
case *ast.GenDecl:
- f.walk(n.Specs, "spec", visit)
+ f.walk(n.Specs, ctxSpec, visit)
case *ast.FuncDecl:
if n.Recv != nil {
- f.walk(n.Recv, "param", visit)
+ f.walk(n.Recv, ctxParam, visit)
}
- f.walk(n.Type, "type", visit)
+ f.walk(n.Type, ctxType, visit)
if n.Body != nil {
- f.walk(n.Body, "stmt", visit)
+ f.walk(n.Body, ctxStmt, visit)
}
case *ast.File:
- f.walk(n.Decls, "decl", visit)
+ f.walk(n.Decls, ctxDecl, visit)
case *ast.Package:
for _, file := range n.Files {
- f.walk(file, "file", visit)
+ f.walk(file, ctxFile, visit)
}
case []ast.Decl: