]> Cypherpunks repositories - gostls13.git/commitdiff
go/ast, gofmt: facility for printing AST nodes
authorRobert Griesemer <gri@golang.org>
Thu, 19 Aug 2010 16:39:35 +0000 (09:39 -0700)
committerRobert Griesemer <gri@golang.org>
Thu, 19 Aug 2010 16:39:35 +0000 (09:39 -0700)
go/ast: implement Fprint and print functions to
print AST nodes

gofmt: print AST nodes by setting -ast flag

R=rsc, r
CC=golang-dev
https://golang.org/cl/1981044

src/cmd/gofmt/doc.go
src/cmd/gofmt/gofmt.go
src/pkg/go/ast/Makefile
src/pkg/go/ast/print.go [new file with mode: 0644]
src/pkg/go/token/token.go

index 2e4c40c216d751c75864b569f3ae9a9058990185..6fee2278368a5a01984127f4f536288756de2c38 100644 (file)
@@ -33,6 +33,8 @@ Debugging flags:
 
        -trace
                print parse trace.
+       -ast
+               print AST (before rewrites).
        -comments=true
                print comments; if false, all comments are elided from the output.
 
index a0163b75fbe14db00021c160d3e83d1f8c3fd6ff..88c9f197ce524bb2bb05d98e76dedbc8f0583c78 100644 (file)
@@ -28,6 +28,7 @@ var (
        // debugging support
        comments = flag.Bool("comments", true, "print comments")
        trace    = flag.Bool("trace", false, "print parse trace")
+       printAST = flag.Bool("ast", false, "print AST (before rewrites)")
 
        // layout control
        tabWidth  = flag.Int("tabwidth", 8, "tab width")
@@ -97,6 +98,10 @@ func processFile(f *os.File) os.Error {
                return err
        }
 
+       if *printAST {
+               ast.Print(file)
+       }
+
        if rewrite != nil {
                file = rewrite(file)
        }
index d95210b27106d284f5004684b6021ce3d5351b68..e9b885c7052bba18dd946ce262686be1148f2a9c 100644 (file)
@@ -8,6 +8,7 @@ TARG=go/ast
 GOFILES=\
        ast.go\
        filter.go\
+       print.go\
        scope.go\
        walk.go\
 
diff --git a/src/pkg/go/ast/print.go b/src/pkg/go/ast/print.go
new file mode 100644 (file)
index 0000000..b4b3ed6
--- /dev/null
@@ -0,0 +1,197 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains printing suppport for ASTs.
+
+package ast
+
+import (
+       "fmt"
+       "go/token"
+       "io"
+       "os"
+       "reflect"
+)
+
+
+// A FieldFilter may be provided to Fprint to control the output.
+type FieldFilter func(name string, value reflect.Value) bool
+
+
+// NotNilFilter returns true for field values that are not nil;
+// it returns false otherwise.
+func NotNilFilter(_ string, value reflect.Value) bool {
+       v, ok := value.(interface {
+               IsNil() bool
+       })
+       return !ok || !v.IsNil()
+}
+
+
+// Fprint prints the (sub-)tree starting at AST node x to w.
+//
+// A non-nil FieldFilter f may be provided to control the output:
+// struct fields for which f(fieldname, fieldvalue) is true are
+// are printed; all others are filtered from the output.
+//
+func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
+       // setup printer
+       p := printer{output: w, filter: f}
+
+       // install error handler
+       defer func() {
+               n = p.written
+               if e := recover(); e != nil {
+                       err = e.(localError).err // re-panics if it's not a localError
+               }
+       }()
+
+       // print x
+       if x == nil {
+               p.printf("nil\n")
+               return
+       }
+       p.print(reflect.NewValue(x))
+       p.printf("\n")
+
+       return
+}
+
+
+// Print prints x to standard output, skipping nil fields.
+// Print(x) is the same as Fprint(os.Stdout, x, NotNilFilter).
+func Print(x interface{}) (int, os.Error) {
+       return Fprint(os.Stdout, x, NotNilFilter)
+}
+
+
+type printer struct {
+       output  io.Writer
+       filter  FieldFilter
+       written int  // number of bytes written to output
+       indent  int  // current indentation level
+       last    byte // the last byte processed by Write
+}
+
+
+var indent = []byte(".  ")
+
+func (p *printer) Write(data []byte) (n int, err os.Error) {
+       var m int
+       for i, b := range data {
+               // invariant: data[0:n] has been written
+               if b == '\n' {
+                       m, err = p.output.Write(data[n : i+1])
+                       n += m
+                       if err != nil {
+                               return
+                       }
+               } else if p.last == '\n' {
+                       for j := p.indent; j > 0; j-- {
+                               _, err = p.output.Write(indent)
+                               if err != nil {
+                                       return
+                               }
+                       }
+               }
+               p.last = b
+       }
+       m, err = p.output.Write(data[n:])
+       n += m
+       return
+}
+
+
+// localError wraps locally caught os.Errors so we can distinguish
+// them from genuine panics which we don't want to return as errors.
+type localError struct {
+       err os.Error
+}
+
+
+// printf is a convenience wrapper that takes care of print errors.
+func (p *printer) printf(format string, args ...interface{}) {
+       n, err := fmt.Fprintf(p, format, args)
+       p.written += n
+       if err != nil {
+               panic(localError{err})
+       }
+}
+
+
+// Implementation note: Print is written for AST nodes but could be
+// used to print any acyclic data structure. It would also be easy
+// to generalize it to arbitrary data structures; such a version
+// should probably be in a different package.
+
+func (p *printer) print(x reflect.Value) {
+       // Note: This test is only needed because AST nodes
+       //       embed a token.Position, and thus all of them
+       //       understand the String() method (but it only
+       //       applies to the Position field).
+       // TODO: Should reconsider this AST design decision.
+       if pos, ok := x.Interface().(token.Position); ok {
+               p.printf("%s", pos)
+               return
+       }
+
+       if !NotNilFilter("", x) {
+               p.printf("nil")
+               return
+       }
+
+       switch v := x.(type) {
+       case *reflect.InterfaceValue:
+               p.print(v.Elem())
+
+       case *reflect.MapValue:
+               p.printf("%s (len = %d) {\n", x.Type().String(), v.Len())
+               p.indent++
+               for _, key := range v.Keys() {
+                       p.print(key)
+                       p.printf(": ")
+                       p.print(v.Elem(key))
+               }
+               p.indent--
+               p.printf("}")
+
+       case *reflect.PtrValue:
+               p.printf("*")
+               p.print(v.Elem())
+
+       case *reflect.SliceValue:
+               if s, ok := v.Interface().([]byte); ok {
+                       p.printf("%#q", s)
+                       return
+               }
+               p.printf("%s (len = %d) {\n", x.Type().String(), v.Len())
+               p.indent++
+               for i, n := 0, v.Len(); i < n; i++ {
+                       p.printf("%d: ", i)
+                       p.print(v.Elem(i))
+                       p.printf("\n")
+               }
+               p.indent--
+               p.printf("}")
+
+       case *reflect.StructValue:
+               p.printf("%s {\n", x.Type().String())
+               p.indent++
+               t := v.Type().(*reflect.StructType)
+               for i, n := 0, t.NumField(); i < n; i++ {
+                       name := t.Field(i).Name
+                       value := v.Field(i)
+                       if p.filter == nil || p.filter(name, value) {
+                               p.printf("%s: ", name)
+                               p.print(value)
+                               p.printf("\n")
+                       }
+               }
+               p.indent--
+               p.printf("}")
+
+       default:
+               p.printf("%v", x.Interface())
+       }
+}
index 70c2501e9cba0a8185ed488840df9e95478856e9..bc6c6a865b2a62d0d1f2f9653a3069a3279a051e 100644 (file)
@@ -353,7 +353,7 @@ func (pos Position) String() string {
                s += fmt.Sprintf("%d:%d", pos.Line, pos.Column)
        }
        if s == "" {
-               s = "???"
+               s = "-"
        }
        return s
 }