-// Copyright 2020 The Go Authors. All rights reserved.
+// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
// +build ignore
-// Note: this program must be run with the GOROOT
-// environment variable set to the root of this tree.
-// GOROOT=...
-// cd $GOROOT/src/cmd/compile/internal/ir
-// ../../../../../bin/go run -mod=mod mknode.go
+// Note: this program must be run in this directory.
+// go run mknode.go
package main
import (
"bytes"
"fmt"
+ "go/ast"
"go/format"
- "go/types"
+ "go/parser"
+ "go/token"
+ "io/fs"
"io/ioutil"
"log"
- "reflect"
"sort"
"strings"
-
- "golang.org/x/tools/go/packages"
)
-var irPkg *types.Package
+var fset = token.NewFileSet()
+
var buf bytes.Buffer
-func main() {
- cfg := &packages.Config{
- Mode: packages.NeedSyntax | packages.NeedTypes,
+// concreteNodes contains all concrete types in the package that implement Node
+// (except for the mini* types).
+var concreteNodes []*ast.TypeSpec
+
+// interfaceNodes contains all interface types in the package that implement Node.
+var interfaceNodes []*ast.TypeSpec
+
+// mini contains the embeddable mini types (miniNode, miniExpr, and miniStmt).
+var mini = map[string]*ast.TypeSpec{}
+
+// implementsNode reports whether the type t is one which represents a Node
+// in the AST.
+func implementsNode(t ast.Expr) bool {
+ id, ok := t.(*ast.Ident)
+ if !ok {
+ return false // only named types
}
- pkgs, err := packages.Load(cfg, "cmd/compile/internal/ir")
- if err != nil {
- log.Fatal(err)
+ for _, ts := range interfaceNodes {
+ if ts.Name.Name == id.Name {
+ return true
+ }
+ }
+ for _, ts := range concreteNodes {
+ if ts.Name.Name == id.Name {
+ return true
+ }
+ }
+ return false
+}
+
+func isMini(t ast.Expr) bool {
+ id, ok := t.(*ast.Ident)
+ return ok && mini[id.Name] != nil
+}
+
+func isNamedType(t ast.Expr, name string) bool {
+ if id, ok := t.(*ast.Ident); ok {
+ if id.Name == name {
+ return true
+ }
}
- irPkg = pkgs[0].Types
+ return false
+}
+func main() {
fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
fmt.Fprintln(&buf)
fmt.Fprintln(&buf, "package ir")
fmt.Fprintln(&buf)
fmt.Fprintln(&buf, `import "fmt"`)
- scope := irPkg.Scope()
- for _, name := range scope.Names() {
- if strings.HasPrefix(name, "mini") {
- continue
- }
-
- obj, ok := scope.Lookup(name).(*types.TypeName)
- if !ok {
- continue
- }
- typ := obj.Type().(*types.Named)
- if !implementsNode(types.NewPointer(typ)) {
- continue
+ filter := func(file fs.FileInfo) bool {
+ return !strings.HasPrefix(file.Name(), "mknode")
+ }
+ pkgs, err := parser.ParseDir(fset, ".", filter, 0)
+ if err != nil {
+ panic(err)
+ }
+ pkg := pkgs["ir"]
+
+ // Find all the mini types. These let us determine which
+ // concrete types implement Node, so we need to find them first.
+ for _, f := range pkg.Files {
+ for _, d := range f.Decls {
+ g, ok := d.(*ast.GenDecl)
+ if !ok {
+ continue
+ }
+ for _, s := range g.Specs {
+ t, ok := s.(*ast.TypeSpec)
+ if !ok {
+ continue
+ }
+ if strings.HasPrefix(t.Name.Name, "mini") {
+ mini[t.Name.Name] = t
+ // Double-check that it is or embeds miniNode.
+ if t.Name.Name != "miniNode" {
+ s := t.Type.(*ast.StructType)
+ if !isNamedType(s.Fields.List[0].Type, "miniNode") {
+ panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
+ }
+ }
+ }
+ }
}
+ }
- fmt.Fprintf(&buf, "\n")
- fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
-
- switch name {
- case "Name", "Func":
- // Too specialized to automate.
- continue
+ // Find all the declarations of concrete types that implement Node.
+ for _, f := range pkg.Files {
+ for _, d := range f.Decls {
+ g, ok := d.(*ast.GenDecl)
+ if !ok {
+ continue
+ }
+ for _, s := range g.Specs {
+ t, ok := s.(*ast.TypeSpec)
+ if !ok {
+ continue
+ }
+ if strings.HasPrefix(t.Name.Name, "mini") {
+ // We don't treat the mini types as
+ // concrete implementations of Node
+ // (even though they are) because
+ // we only use them by embedding them.
+ continue
+ }
+ if isConcreteNode(t) {
+ concreteNodes = append(concreteNodes, t)
+ }
+ if isInterfaceNode(t) {
+ interfaceNodes = append(interfaceNodes, t)
+ }
+ }
}
-
- forNodeFields(typ,
- "func (n *%[1]s) copy() Node { c := *n\n",
- "",
- "c.%[1]s = copy%[2]s(c.%[1]s)",
- "return &c }\n")
-
- forNodeFields(typ,
- "func (n *%[1]s) doChildren(do func(Node) bool) bool {\n",
- "if n.%[1]s != nil && do(n.%[1]s) { return true }",
- "if do%[2]s(n.%[1]s, do) { return true }",
- "return false }\n")
-
- forNodeFields(typ,
- "func (n *%[1]s) editChildren(edit func(Node) Node) {\n",
- "if n.%[1]s != nil { n.%[1]s = edit(n.%[1]s).(%[2]s) }",
- "edit%[2]s(n.%[1]s, edit)",
- "}\n")
}
+ // Sort for deterministic output.
+ sort.Slice(concreteNodes, func(i, j int) bool {
+ return concreteNodes[i].Name.Name < concreteNodes[j].Name.Name
+ })
+ // Generate code for each concrete type.
+ for _, t := range concreteNodes {
+ processType(t)
+ }
+ // Add some helpers.
+ generateHelpers()
- makeHelpers()
-
+ // Format and write output.
out, err := format.Source(buf.Bytes())
if err != nil {
// write out mangled source so we can see the bug.
out = buf.Bytes()
}
-
err = ioutil.WriteFile("node_gen.go", out, 0666)
if err != nil {
log.Fatal(err)
}
}
-// needHelper maps needed slice helpers from their base name to their
-// respective slice-element type.
-var needHelper = map[string]string{}
-
-func makeHelpers() {
- var names []string
- for name := range needHelper {
- names = append(names, name)
+// isConcreteNode reports whether the type t is a concrete type
+// implementing Node.
+func isConcreteNode(t *ast.TypeSpec) bool {
+ s, ok := t.Type.(*ast.StructType)
+ if !ok {
+ return false
}
- sort.Strings(names)
-
- for _, name := range names {
- fmt.Fprintf(&buf, sliceHelperTmpl, name, needHelper[name])
+ for _, f := range s.Fields.List {
+ if isMini(f.Type) {
+ return true
+ }
}
+ return false
}
-const sliceHelperTmpl = `
-func copy%[1]s(list []%[2]s) []%[2]s {
- if list == nil {
- return nil
+// isInterfaceNode reports whether the type t is an interface type
+// implementing Node (including Node itself).
+func isInterfaceNode(t *ast.TypeSpec) bool {
+ s, ok := t.Type.(*ast.InterfaceType)
+ if !ok {
+ return false
}
- c := make([]%[2]s, len(list))
- copy(c, list)
- return c
-}
-func do%[1]s(list []%[2]s, do func(Node) bool) bool {
- for _, x := range list {
- if x != nil && do(x) {
+ if t.Name.Name == "Node" {
+ return true
+ }
+ if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
+ // These we exempt from consideration (fields of
+ // this type don't need to be walked or copied).
+ return false
+ }
+
+ // Look for embedded Node type.
+ // Note that this doesn't handle multi-level embedding, but
+ // we have none of that at the moment.
+ for _, f := range s.Methods.List {
+ if len(f.Names) != 0 {
+ continue
+ }
+ if isNamedType(f.Type, "Node") {
return true
}
}
return false
}
-func edit%[1]s(list []%[2]s, edit func(Node) Node) {
- for i, x := range list {
- if x != nil {
- list[i] = edit(x).(%[2]s)
- }
- }
-}
-`
-func forNodeFields(named *types.Named, prologue, singleTmpl, sliceTmpl, epilogue string) {
- fmt.Fprintf(&buf, prologue, named.Obj().Name())
+func processType(t *ast.TypeSpec) {
+ name := t.Name.Name
+ fmt.Fprintf(&buf, "\n")
+ fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
- anyField(named.Underlying().(*types.Struct), func(f *types.Var) bool {
- if f.Embedded() {
- return false
- }
- name, typ := f.Name(), f.Type()
+ switch name {
+ case "Name", "Func":
+ // Too specialized to automate.
+ return
+ }
- slice, _ := typ.Underlying().(*types.Slice)
- if slice != nil {
- typ = slice.Elem()
- }
+ s := t.Type.(*ast.StructType)
+ fields := s.Fields.List
- tmpl, what := singleTmpl, types.TypeString(typ, types.RelativeTo(irPkg))
- if what == "go/constant.Value" {
- return false
+ // Expand any embedded fields.
+ for i := 0; i < len(fields); i++ {
+ f := fields[i]
+ if len(f.Names) != 0 {
+ continue // not embedded
}
- if implementsNode(typ) {
- if slice != nil {
- helper := strings.TrimPrefix(what, "*") + "s"
- needHelper[helper] = what
- tmpl, what = sliceTmpl, helper
- }
- } else if what == "*Field" {
- // Special case for *Field.
- tmpl = sliceTmpl
- if slice != nil {
- what = "Fields"
- } else {
- what = "Field"
- }
+ if isMini(f.Type) {
+ // Insert the fields of the embedded type into the main type.
+ // (It would be easier just to append, but inserting in place
+ // matches the old mknode behavior.)
+ ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
+ var f2 []*ast.Field
+ f2 = append(f2, fields[:i]...)
+ f2 = append(f2, ss.Fields.List...)
+ f2 = append(f2, fields[i+1:]...)
+ fields = f2
+ i--
+ continue
+ } else if isNamedType(f.Type, "origNode") {
+ // Ignore this field
+ copy(fields[i:], fields[i+1:])
+ fields = fields[:len(fields)-1]
+ i--
+ continue
} else {
- return false
- }
-
- if tmpl == "" {
- return false
- }
-
- // Allow template to not use all arguments without
- // upsetting fmt.Printf.
- s := fmt.Sprintf(tmpl+"\x00 %[1]s %[2]s", name, what)
- fmt.Fprintln(&buf, s[:strings.LastIndex(s, "\x00")])
- return false
- })
-
- fmt.Fprintf(&buf, epilogue)
-}
-
-func implementsNode(typ types.Type) bool {
- if _, ok := typ.Underlying().(*types.Interface); ok {
- // TODO(mdempsky): Check the interface implements Node.
- // Worst case, node_gen.go will fail to compile if we're wrong.
- return true
- }
-
- if ptr, ok := typ.(*types.Pointer); ok {
- if str, ok := ptr.Elem().Underlying().(*types.Struct); ok {
- return anyField(str, func(f *types.Var) bool {
- return f.Embedded() && f.Name() == "miniNode"
- })
+ panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
}
}
-
- return false
-}
-
-func anyField(typ *types.Struct, pred func(f *types.Var) bool) bool {
- for i, n := 0, typ.NumFields(); i < n; i++ {
- if value, ok := reflect.StructTag(typ.Tag(i)).Lookup("mknode"); ok {
- if value != "-" {
- panic(fmt.Sprintf("unexpected tag value: %q", value))
+ // Process fields.
+ var copyBody bytes.Buffer
+ var doChildrenBody bytes.Buffer
+ var editChildrenBody bytes.Buffer
+ for _, f := range fields {
+ if f.Tag != nil {
+ tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
+ if strings.HasPrefix(tag, "mknode:") {
+ if tag[7:] == "\"-\"" {
+ continue
+ }
+ panic(fmt.Sprintf("unexpected tag value: %s", tag))
}
- continue
}
-
- f := typ.Field(i)
- if pred(f) {
- return true
+ names := f.Names
+ ft := f.Type
+ if isNamedType(ft, "Nodes") {
+ // Nodes == []Node
+ ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
+ }
+ isSlice := false
+ if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
+ isSlice = true
+ ft = a.Elt
+ }
+ isPtr := false
+ if p, ok := ft.(*ast.StarExpr); ok {
+ isPtr = true
+ ft = p.X
}
- if f.Embedded() {
- if typ, ok := f.Type().Underlying().(*types.Struct); ok {
- if anyField(typ, pred) {
- return true
+ if !implementsNode(ft) {
+ continue
+ }
+ for _, name := range names {
+ if isSlice {
+ fmt.Fprintf(©Body, "c.%s = copy%ss(c.%s)\n", name, ft, name)
+ fmt.Fprintf(&doChildrenBody,
+ "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
+ fmt.Fprintf(&editChildrenBody,
+ "edit%ss(n.%s, edit)\n", ft, name)
+ } else {
+ fmt.Fprintf(&doChildrenBody,
+ "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
+ ptr := ""
+ if isPtr {
+ ptr = "*"
}
+ fmt.Fprintf(&editChildrenBody,
+ "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
}
}
}
- return false
+ fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
+ buf.WriteString(copyBody.String())
+ fmt.Fprintf(&buf, "return &c\n}\n")
+ fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
+ buf.WriteString(doChildrenBody.String())
+ fmt.Fprintf(&buf, "return false\n}\n")
+ fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
+ buf.WriteString(editChildrenBody.String())
+ fmt.Fprintf(&buf, "}\n")
+}
+
+func generateHelpers() {
+ for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node", "Ntype"} {
+ ptr := "*"
+ if typ == "Node" || typ == "Ntype" {
+ ptr = "" // interfaces don't need *
+ }
+ fmt.Fprintf(&buf, "\n")
+ fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
+ fmt.Fprintf(&buf, "if list == nil { return nil }\n")
+ fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
+ fmt.Fprintf(&buf, "copy(c, list)\n")
+ fmt.Fprintf(&buf, "return c\n")
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
+ fmt.Fprintf(&buf, "for _, x := range list {\n")
+ fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
+ fmt.Fprintf(&buf, "return true\n")
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "return false\n")
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
+ fmt.Fprintf(&buf, "for i, x := range list {\n")
+ fmt.Fprintf(&buf, "if x != nil {\n")
+ fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "}\n")
+ fmt.Fprintf(&buf, "}\n")
+ }
}