]> Cypherpunks repositories - gostls13.git/commitdiff
go/ast: generalize ast.FilterFile
authorRobert Griesemer <gri@golang.org>
Tue, 23 Aug 2011 01:51:51 +0000 (18:51 -0700)
committerRobert Griesemer <gri@golang.org>
Tue, 23 Aug 2011 01:51:51 +0000 (18:51 -0700)
ast.FilterFile(src, ast.IsExported) has the same
effect as ast.FileExports(src) with this change.

1st step towards removing FileExports - it is
just a special case of FilterFile with this CL.

Added corresponding test.

R=r
CC=golang-dev
https://golang.org/cl/4938041

src/pkg/go/ast/filter.go
src/pkg/go/parser/filter_test.go [new file with mode: 0644]

index 4c96e71c0371cd0376b92f1776514cf43e278e41..d4b580e003ae0e9dc5780c0fa1ef3355ca7b0b0b 100644 (file)
@@ -20,24 +20,6 @@ func identListExports(list []*Ident) []*Ident {
        return list[0:j]
 }
 
-// fieldName assumes that x is the type of an anonymous field and
-// returns the corresponding field name. If x is not an acceptable
-// anonymous field, the result is nil.
-//
-func fieldName(x Expr) *Ident {
-       switch t := x.(type) {
-       case *Ident:
-               return t
-       case *SelectorExpr:
-               if _, ok := t.X.(*Ident); ok {
-                       return t.Sel
-               }
-       case *StarExpr:
-               return fieldName(t.X)
-       }
-       return nil
-}
-
 func fieldListExports(fields *FieldList) (removedFields bool) {
        if fields == nil {
                return
@@ -203,6 +185,24 @@ func filterIdentList(list []*Ident, f Filter) []*Ident {
        return list[0:j]
 }
 
+// fieldName assumes that x is the type of an anonymous field and
+// returns the corresponding field name. If x is not an acceptable
+// anonymous field, the result is nil.
+//
+func fieldName(x Expr) *Ident {
+       switch t := x.(type) {
+       case *Ident:
+               return t
+       case *SelectorExpr:
+               if _, ok := t.X.(*Ident); ok {
+                       return t.Sel
+               }
+       case *StarExpr:
+               return fieldName(t.X)
+       }
+       return nil
+}
+
 func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
        if fields == nil {
                return false
@@ -224,6 +224,7 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
                        keepField = len(f.Names) > 0
                }
                if keepField {
+                       filterType(f.Type, filter)
                        list[j] = f
                        j++
                }
@@ -235,26 +236,71 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
        return
 }
 
+func filterParamList(fields *FieldList, filter Filter) bool {
+       if fields == nil {
+               return false
+       }
+       var b bool
+       for _, f := range fields.List {
+               if filterType(f.Type, filter) {
+                       b = true
+               }
+       }
+       return b
+}
+
+func filterType(typ Expr, f Filter) bool {
+       switch t := typ.(type) {
+       case *Ident:
+               return f(t.Name)
+       case *ParenExpr:
+               return filterType(t.X, f)
+       case *ArrayType:
+               return filterType(t.Elt, f)
+       case *StructType:
+               if filterFieldList(t.Fields, f) {
+                       t.Incomplete = true
+               }
+               return len(t.Fields.List) > 0
+       case *FuncType:
+               b1 := filterParamList(t.Params, f)
+               b2 := filterParamList(t.Results, f)
+               return b1 || b2
+       case *InterfaceType:
+               if filterFieldList(t.Methods, f) {
+                       t.Incomplete = true
+               }
+               return len(t.Methods.List) > 0
+       case *MapType:
+               b1 := filterType(t.Key, f)
+               b2 := filterType(t.Value, f)
+               return b1 || b2
+       case *ChanType:
+               return filterType(t.Value, f)
+       }
+       return false
+}
+
 func filterSpec(spec Spec, f Filter) bool {
        switch s := spec.(type) {
        case *ValueSpec:
                s.Names = filterIdentList(s.Names, f)
-               return len(s.Names) > 0
+               if len(s.Names) > 0 {
+                       filterType(s.Type, f)
+                       return true
+               }
        case *TypeSpec:
                if f(s.Name.Name) {
+                       filterType(s.Type, f)
                        return true
                }
-               switch t := s.Type.(type) {
-               case *StructType:
-                       if filterFieldList(t.Fields, f) {
-                               t.Incomplete = true
-                       }
-                       return len(t.Fields.List) > 0
-               case *InterfaceType:
-                       if filterFieldList(t.Methods, f) {
-                               t.Incomplete = true
-                       }
-                       return len(t.Methods.List) > 0
+               if f != IsExported {
+                       // For general filtering (not just exports),
+                       // filter type even if name is not filtered
+                       // out.
+                       // If the type contains filtered elements,
+                       // keep the declaration.
+                       return filterType(s.Type, f)
                }
        }
        return false
@@ -284,6 +330,7 @@ func FilterDecl(decl Decl, f Filter) bool {
                d.Specs = filterSpecList(d.Specs, f)
                return len(d.Specs) > 0
        case *FuncDecl:
+               d.Body = nil // strip body
                return f(d.Name.Name)
        }
        return false
@@ -292,13 +339,16 @@ func FilterDecl(decl Decl, f Filter) bool {
 // FilterFile trims the AST for a Go file in place by removing all
 // names from top-level declarations (including struct field and
 // interface method names, but not from parameter lists) that don't
-// pass through the filter f. If the declaration is empty afterwards,
-// the declaration is removed from the AST.
-// The File.comments list is not changed.
+// pass through the filter f. Function bodies are set to nil.
+// If the declaration is empty afterwards, the declaration is
+// removed from the AST. The File.comments list is not changed.
 //
 // FilterFile returns true if there are any top-level declarations
 // left after filtering; it returns false otherwise.
 //
+// To trim an AST such that only exported nodes remain, call
+// FilterFile with IsExported as filter function.
+//
 func FilterFile(src *File, f Filter) bool {
        j := 0
        for _, d := range src.Decls {
@@ -314,14 +364,17 @@ func FilterFile(src *File, f Filter) bool {
 // FilterPackage trims the AST for a Go package in place by removing all
 // names from top-level declarations (including struct field and
 // interface method names, but not from parameter lists) that don't
-// pass through the filter f. If the declaration is empty afterwards,
-// the declaration is removed from the AST.
-// The pkg.Files list is not changed, so that file names and top-level
-// package comments don't get lost.
+// pass through the filter f. Function bodies are set to nil.
+// If the declaration is empty afterwards, the declaration is
+// removed from the AST. The pkg.Files list is not changed, so
+// that file names and top-level package comments don't get lost.
 //
 // FilterPackage returns true if there are any top-level declarations
 // left after filtering; it returns false otherwise.
 //
+// To trim an AST such that only exported nodes remain, call
+// FilterPackage with IsExported as filter function.
+//
 func FilterPackage(pkg *Package, f Filter) bool {
        hasDecls := false
        for _, src := range pkg.Files {
@@ -467,17 +520,16 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
                seen := make(map[string]bool)
                for _, f := range pkg.Files {
                        for _, imp := range f.Imports {
-                               path := imp.Path.Value
-                               if !seen[path] {
-                                       //TODO: consider handling cases where:
+                               if path := imp.Path.Value; !seen[path] {
+                                       // TODO: consider handling cases where:
                                        // - 2 imports exist with the same import path but
                                        //   have different local names (one should probably 
                                        //   keep both of them)
                                        // - 2 imports exist but only one has a comment
                                        // - 2 imports exist and they both have (possibly
                                        //   different) comments
-                                       seen[path] = true
                                        imports = append(imports, imp)
+                                       seen[path] = true
                                }
                        }
                }
diff --git a/src/pkg/go/parser/filter_test.go b/src/pkg/go/parser/filter_test.go
new file mode 100644 (file)
index 0000000..f1672a9
--- /dev/null
@@ -0,0 +1,91 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// A test that ensures that ast.FileExports(file) and
+// ast.FilterFile(file, ast.IsExported) produce the
+// same trimmed AST given the same input file for all
+// files under runtime.GOROOT().
+//
+// The test is here because it requires parsing, and the
+// parser imports AST already (avoid import cycle).
+
+package parser
+
+import (
+       "bytes"
+       "go/ast"
+       "go/printer"
+       "go/token"
+       "os"
+       "path/filepath"
+       "runtime"
+       "strings"
+       "testing"
+)
+
+// For a short test run, limit the number of files to a few.
+// Set to a large value to test all files under GOROOT.
+const maxFiles = 10
+
+type visitor struct {
+       t *testing.T
+       n int
+}
+
+func (v *visitor) VisitDir(path string, f *os.FileInfo) bool {
+       return true
+}
+
+func str(f *ast.File, fset *token.FileSet) string {
+       var buf bytes.Buffer
+       printer.Fprint(&buf, fset, f)
+       return buf.String()
+}
+
+func (v *visitor) VisitFile(path string, f *os.FileInfo) {
+       // exclude files that clearly don't make it
+       if !f.IsRegular() || len(f.Name) > 0 && f.Name[0] == '.' || !strings.HasSuffix(f.Name, ".go") {
+               return
+       }
+
+       // should we stop early for quick test runs?
+       if v.n <= 0 {
+               return
+       }
+       v.n--
+
+       fset := token.NewFileSet()
+
+       // get two ASTs f1, f2 of identical structure;
+       // parsing twice is easiest
+       f1, err := ParseFile(fset, path, nil, ParseComments)
+       if err != nil {
+               v.t.Logf("parse error (1): %s", err)
+               return
+       }
+
+       f2, err := ParseFile(fset, path, nil, ParseComments)
+       if err != nil {
+               v.t.Logf("parse error (2): %s", err)
+               return
+       }
+
+       b1 := ast.FileExports(f1)
+       b2 := ast.FilterFile(f2, ast.IsExported)
+       if b1 != b2 {
+               v.t.Errorf("filtering failed (a): %s", path)
+               return
+       }
+
+       s1 := str(f1, fset)
+       s2 := str(f2, fset)
+       if s1 != s2 {
+               v.t.Errorf("filtering failed (b): %s", path)
+               return
+       }
+}
+
+func TestFilter(t *testing.T) {
+       filepath.Walk(runtime.GOROOT(), &visitor{t, maxFiles}, nil)
+}