]> Cypherpunks repositories - gostls13.git/commitdiff
go/ast: consider anonymous fields and set Incomplete bit when filtering ASTs
authorRobert Griesemer <gri@golang.org>
Thu, 12 May 2011 16:01:32 +0000 (09:01 -0700)
committerRobert Griesemer <gri@golang.org>
Thu, 12 May 2011 16:01:32 +0000 (09:01 -0700)
Also:
- fieldListExports: don't require internal pointer to StructType/InterfaceType node
- filterFieldLists: make structure match fieldListExports

R=rsc
CC=golang-dev
https://golang.org/cl/4527050

src/pkg/go/ast/filter.go

index 97320b90ec59f4bf9ae3c1d9983bec7d76a0475c..0907fd53da8db86d02d4b76bc225e0fab06c328f 100644 (file)
@@ -21,24 +21,26 @@ func identListExports(list []*Ident) []*Ident {
 }
 
 
-// isExportedType assumes that typ is a correct type.
-func isExportedType(typ Expr) bool {
-       switch t := typ.(type) {
+// fieldName assumes that x is the type of an anonymous field and
+// returns the corresponding field name. If x is not an acceptable
+// anonymous field, the result is nil.
+//
+func fieldName(x Expr) *Ident {
+       switch t := x.(type) {
        case *Ident:
-               return t.IsExported()
-       case *ParenExpr:
-               return isExportedType(t.X)
+               return t
        case *SelectorExpr:
-               // assume t.X is a typename
-               return t.Sel.IsExported()
+               if _, ok := t.X.(*Ident); ok {
+                       return t.Sel
+               }
        case *StarExpr:
-               return isExportedType(t.X)
+               return fieldName(t.X)
        }
-       return false
+       return nil
 }
 
 
-func fieldListExports(fields *FieldList, incomplete *bool) {
+func fieldListExports(fields *FieldList) (removedFields bool) {
        if fields == nil {
                return
        }
@@ -53,12 +55,13 @@ func fieldListExports(fields *FieldList, incomplete *bool) {
                        // fields, so this is not absolutely correct.
                        // However, this cannot be done w/o complete
                        // type information.)
-                       exported = isExportedType(f.Type)
+                       name := fieldName(f.Type)
+                       exported = name != nil && name.IsExported()
                } else {
                        n := len(f.Names)
                        f.Names = identListExports(f.Names)
                        if len(f.Names) < n {
-                               *incomplete = true
+                               removedFields = true
                        }
                        exported = len(f.Names) > 0
                }
@@ -69,9 +72,10 @@ func fieldListExports(fields *FieldList, incomplete *bool) {
                }
        }
        if j < len(list) {
-               *incomplete = true
+               removedFields = true
        }
        fields.List = list[0:j]
+       return
 }
 
 
@@ -90,12 +94,16 @@ func typeExports(typ Expr) {
        case *ArrayType:
                typeExports(t.Elt)
        case *StructType:
-               fieldListExports(t.Fields, &t.Incomplete)
+               if fieldListExports(t.Fields) {
+                       t.Incomplete = true
+               }
        case *FuncType:
                paramListExports(t.Params)
                paramListExports(t.Results)
        case *InterfaceType:
-               fieldListExports(t.Methods, &t.Incomplete)
+               if fieldListExports(t.Methods) {
+                       t.Incomplete = true
+               }
        case *MapType:
                typeExports(t.Key)
                typeExports(t.Value)
@@ -206,25 +214,36 @@ func filterIdentList(list []*Ident, f Filter) []*Ident {
 }
 
 
-func filterFieldList(list []*Field, f Filter) []*Field {
+func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
+       if fields == nil {
+               return false
+       }
+       list := fields.List
        j := 0
-       for _, field := range list {
-               field.Names = filterIdentList(field.Names, f)
-               if len(field.Names) > 0 {
-                       list[j] = field
+       for _, f := range list {
+               keepField := false
+               if len(f.Names) == 0 {
+                       // anonymous field
+                       name := fieldName(f.Type)
+                       keepField = name != nil && filter(name.Name)
+               } else {
+                       n := len(f.Names)
+                       f.Names = filterIdentList(f.Names, filter)
+                       if len(f.Names) < n {
+                               removedFields = true
+                       }
+                       keepField = len(f.Names) > 0
+               }
+               if keepField {
+                       list[j] = f
                        j++
                }
        }
-       return list[0:j]
-}
-
-
-func filterFields(fields *FieldList, f Filter) bool {
-       if fields == nil {
-               return false
+       if j < len(list) {
+               removedFields = true
        }
-       fields.List = filterFieldList(fields.List, f)
-       return len(fields.List) > 0
+       fields.List = list[0:j]
+       return
 }
 
 
@@ -239,9 +258,15 @@ func filterSpec(spec Spec, f Filter) bool {
                }
                switch t := s.Type.(type) {
                case *StructType:
-                       return filterFields(t.Fields, f)
+                       if filterFieldList(t.Fields, f) {
+                               t.Incomplete = true
+                       }
+                       return len(t.Fields.List) > 0
                case *InterfaceType:
-                       return filterFields(t.Methods, f)
+                       if filterFieldList(t.Methods, f) {
+                               t.Incomplete = true
+                       }
+                       return len(t.Methods.List) > 0
                }
        }
        return false