]> Cypherpunks repositories - gostls13.git/commitdiff
go/ast: fixed bug in NotNilFilter, added test
authorRobert Griesemer <gri@golang.org>
Wed, 13 Apr 2011 16:37:13 +0000 (09:37 -0700)
committerRobert Griesemer <gri@golang.org>
Wed, 13 Apr 2011 16:37:13 +0000 (09:37 -0700)
- fixed a couple of comments
- cleanups after reflect change

R=rsc
CC=golang-dev
https://golang.org/cl/4389041

src/pkg/go/ast/print.go
src/pkg/go/ast/print_test.go [new file with mode: 0644]

index f4b2fc8f444034d4bc408fcf12f66a3473d5864b..e6d4e838d836f88b92874302fa9a1cbf4b646837 100644 (file)
@@ -26,7 +26,7 @@ func NotNilFilter(_ string, v reflect.Value) bool {
        case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
                return !v.IsNil()
        }
-       return false
+       return true
 }
 
 
@@ -80,7 +80,7 @@ type printer struct {
        output  io.Writer
        fset    *token.FileSet
        filter  FieldFilter
-       ptrmap  map[interface{}]int // *reflect.PtrValue -> line number
+       ptrmap  map[interface{}]int // *T -> line number
        written int                 // number of bytes written to output
        indent  int                 // current indentation level
        last    byte                // the last byte processed by Write
@@ -141,6 +141,11 @@ func (p *printer) printf(format string, args ...interface{}) {
 // Implementation note: Print is written for AST nodes but could be
 // used to print arbitrary data structures; such a version should
 // probably be in a different package.
+//
+// Note: This code detects (some) cycles created via pointers but
+// not cycles that are created via slices or maps containing the
+// same slice or map. Code for general data structures probably
+// should catch those as well.
 
 func (p *printer) print(x reflect.Value) {
        if !NotNilFilter("", x) {
@@ -148,17 +153,17 @@ func (p *printer) print(x reflect.Value) {
                return
        }
 
-       switch v := x; v.Kind() {
+       switch x.Kind() {
        case reflect.Interface:
-               p.print(v.Elem())
+               p.print(x.Elem())
 
        case reflect.Map:
-               p.printf("%s (len = %d) {\n", x.Type().String(), v.Len())
+               p.printf("%s (len = %d) {\n", x.Type().String(), x.Len())
                p.indent++
-               for _, key := range v.MapKeys() {
+               for _, key := range x.MapKeys() {
                        p.print(key)
                        p.printf(": ")
-                       p.print(v.MapIndex(key))
+                       p.print(x.MapIndex(key))
                        p.printf("\n")
                }
                p.indent--
@@ -169,24 +174,24 @@ func (p *printer) print(x reflect.Value) {
                // type-checked ASTs may contain cycles - use ptrmap
                // to keep track of objects that have been printed
                // already and print the respective line number instead
-               ptr := v.Interface()
+               ptr := x.Interface()
                if line, exists := p.ptrmap[ptr]; exists {
                        p.printf("(obj @ %d)", line)
                } else {
                        p.ptrmap[ptr] = p.line
-                       p.print(v.Elem())
+                       p.print(x.Elem())
                }
 
        case reflect.Slice:
-               if s, ok := v.Interface().([]byte); ok {
+               if s, ok := x.Interface().([]byte); ok {
                        p.printf("%#q", s)
                        return
                }
-               p.printf("%s (len = %d) {\n", x.Type().String(), v.Len())
+               p.printf("%s (len = %d) {\n", x.Type().String(), x.Len())
                p.indent++
-               for i, n := 0, v.Len(); i < n; i++ {
+               for i, n := 0, x.Len(); i < n; i++ {
                        p.printf("%d: ", i)
-                       p.print(v.Index(i))
+                       p.print(x.Index(i))
                        p.printf("\n")
                }
                p.indent--
@@ -195,10 +200,10 @@ func (p *printer) print(x reflect.Value) {
        case reflect.Struct:
                p.printf("%s {\n", x.Type().String())
                p.indent++
-               t := v.Type()
+               t := x.Type()
                for i, n := 0, t.NumField(); i < n; i++ {
                        name := t.Field(i).Name
-                       value := v.Field(i)
+                       value := x.Field(i)
                        if p.filter == nil || p.filter(name, value) {
                                p.printf("%s: ", name)
                                p.print(value)
@@ -209,11 +214,20 @@ func (p *printer) print(x reflect.Value) {
                p.printf("}")
 
        default:
-               value := x.Interface()
-               // position values can be printed nicely if we have a file set
-               if pos, ok := value.(token.Pos); ok && p.fset != nil {
-                       value = p.fset.Position(pos)
+               v := x.Interface()
+               switch v := v.(type) {
+               case string:
+                       // print strings in quotes
+                       p.printf("%q", v)
+                       return
+               case token.Pos:
+                       // position values can be printed nicely if we have a file set
+                       if p.fset != nil {
+                               p.printf("%s", p.fset.Position(v))
+                               return
+                       }
                }
-               p.printf("%v", value)
+               // default
+               p.printf("%v", v)
        }
 }
diff --git a/src/pkg/go/ast/print_test.go b/src/pkg/go/ast/print_test.go
new file mode 100644 (file)
index 0000000..0820dcf
--- /dev/null
@@ -0,0 +1,80 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ast
+
+import (
+       "bytes"
+       "strings"
+       "testing"
+)
+
+
+var tests = []struct {
+       x interface{} // x is printed as s
+       s string
+}{
+       // basic types
+       {nil, "0  nil"},
+       {true, "0  true"},
+       {42, "0  42"},
+       {3.14, "0  3.14"},
+       {1 + 2.718i, "0  (1+2.718i)"},
+       {"foobar", "0  \"foobar\""},
+
+       // maps
+       {map[string]int{"a": 1, "b": 2},
+               `0  map[string] int (len = 2) {
+               1  .  "a": 1
+               2  .  "b": 2
+               3  }`},
+
+       // pointers
+       {new(int), "0  *0"},
+
+       // slices
+       {[]int{1, 2, 3},
+               `0  []int (len = 3) {
+               1  .  0: 1
+               2  .  1: 2
+               3  .  2: 3
+               4  }`},
+
+       // structs
+       {struct{ x, y int }{42, 991},
+               `0  struct { x int; y int } {
+               1  .  x: 42
+               2  .  y: 991
+               3  }`},
+}
+
+
+// Split s into lines, trim whitespace from all lines, and return
+// the concatenated non-empty lines.
+func trim(s string) string {
+       lines := strings.Split(s, "\n", -1)
+       i := 0
+       for _, line := range lines {
+               line = strings.TrimSpace(line)
+               if line != "" {
+                       lines[i] = line
+                       i++
+               }
+       }
+       return strings.Join(lines[0:i], "\n")
+}
+
+
+func TestPrint(t *testing.T) {
+       var buf bytes.Buffer
+       for _, test := range tests {
+               buf.Reset()
+               if _, err := Fprint(&buf, nil, test.x, nil); err != nil {
+                       t.Errorf("Fprint failed: %s", err)
+               }
+               if s, ts := trim(buf.String()), trim(test.s); s != ts {
+                       t.Errorf("got:\n%s\nexpected:\n%s\n", s, ts)
+               }
+       }
+}