]> Cypherpunks repositories - gostls13.git/commitdiff
reflect: hide unexported methods that do not satisfy interfaces
authorDavid Crawshaw <crawshaw@golang.org>
Thu, 19 May 2016 17:31:58 +0000 (13:31 -0400)
committerDavid Crawshaw <crawshaw@golang.org>
Fri, 20 May 2016 14:36:14 +0000 (14:36 +0000)
Fixes #15673

Change-Id: Ib36d8db3299a93d92665dbde012d52c2c5332ac0
Reviewed-on: https://go-review.googlesource.com/23253
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: David Crawshaw <crawshaw@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/reflect/all_test.go
src/reflect/type.go

index 9799fee35735576cc04532a312546db54e92cbec..f09ffeb56615c531b7407417880ebdc5dacd4eed 100644 (file)
@@ -2388,13 +2388,13 @@ type outer struct {
        inner
 }
 
-func (*inner) m() {}
-func (*outer) m() {}
+func (*inner) M() {}
+func (*outer) M() {}
 
 func TestNestedMethods(t *testing.T) {
        typ := TypeOf((*outer)(nil))
-       if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).m).Pointer() {
-               t.Errorf("Wrong method table for outer: (m=%p)", (*outer).m)
+       if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).M).Pointer() {
+               t.Errorf("Wrong method table for outer: (M=%p)", (*outer).M)
                for i := 0; i < typ.NumMethod(); i++ {
                        m := typ.Method(i)
                        t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer())
@@ -2416,18 +2416,15 @@ var unexpi unexpI = new(unexp)
 func TestUnexportedMethods(t *testing.T) {
        typ := TypeOf(unexpi)
 
+       if got := typ.NumMethod(); got != 1 {
+               t.Error("NumMethod=%d, want 1 satisfied method", got)
+       }
        if typ.Method(0).Type == nil {
                t.Error("missing type for satisfied method 'f'")
        }
        if !typ.Method(0).Func.IsValid() {
                t.Error("missing func for satisfied method 'f'")
        }
-       if typ.Method(1).Type != nil {
-               t.Error("found type for unsatisfied method 'g'")
-       }
-       if typ.Method(1).Func.IsValid() {
-               t.Error("found func for unsatisfied method 'g'")
-       }
 }
 
 type InnerInt struct {
@@ -5187,7 +5184,7 @@ func useStack(n int) {
 
 type Impl struct{}
 
-func (Impl) f() {}
+func (Impl) F() {}
 
 func TestValueString(t *testing.T) {
        rv := ValueOf(Impl{})
index dd7b797c048230842026236be67d1128d4e4620f..c9389199d8b96341edfe711a608a54928ef1f785 100644 (file)
@@ -763,16 +763,63 @@ func (t *rtype) pointers() bool { return t.kind&kindNoPointers == 0 }
 
 func (t *rtype) common() *rtype { return t }
 
+var methodCache struct {
+       sync.RWMutex
+       m map[*rtype][]method
+}
+
+// satisfiedMethods returns methods of t that satisfy an interface.
+// This may include unexported methods that satisfy an interface
+// defined with unexported methods in the same package as t.
+func (t *rtype) satisfiedMethods() []method {
+       methodCache.RLock()
+       methods, found := methodCache.m[t]
+       methodCache.RUnlock()
+
+       if found {
+               return methods
+       }
+
+       ut := t.uncommon()
+       if ut == nil {
+               return nil
+       }
+       allm := ut.methods()
+       allSatisfied := true
+       for _, m := range allm {
+               if m.mtyp == 0 {
+                       allSatisfied = false
+                       break
+               }
+       }
+       if allSatisfied {
+               methods = allm
+       } else {
+               methods = make([]method, 0, len(allm))
+               for _, m := range allm {
+                       if m.mtyp != 0 {
+                               methods = append(methods, m)
+                       }
+               }
+               methods = methods[:len(methods):len(methods)]
+       }
+
+       methodCache.Lock()
+       if methodCache.m == nil {
+               methodCache.m = make(map[*rtype][]method)
+       }
+       methodCache.m[t] = methods
+       methodCache.Unlock()
+
+       return methods
+}
+
 func (t *rtype) NumMethod() int {
        if t.Kind() == Interface {
                tt := (*interfaceType)(unsafe.Pointer(t))
                return tt.NumMethod()
        }
-       ut := t.uncommon()
-       if ut == nil {
-               return 0
-       }
-       return int(ut.mcount)
+       return len(t.satisfiedMethods())
 }
 
 func (t *rtype) Method(i int) (m Method) {
@@ -780,40 +827,39 @@ func (t *rtype) Method(i int) (m Method) {
                tt := (*interfaceType)(unsafe.Pointer(t))
                return tt.Method(i)
        }
-       ut := t.uncommon()
-
-       if ut == nil || i < 0 || i >= int(ut.mcount) {
+       methods := t.satisfiedMethods()
+       if i < 0 || i >= len(methods) {
                panic("reflect: Method index out of range")
        }
-       p := ut.methods()[i]
+       p := methods[i]
        pname := t.nameOff(p.name)
        m.Name = pname.name()
        fl := flag(Func)
        if !pname.isExported() {
                m.PkgPath = pname.pkgPath()
                if m.PkgPath == "" {
+                       ut := t.uncommon()
                        m.PkgPath = t.nameOff(ut.pkgPath).name()
                }
                fl |= flagStickyRO
        }
-       if p.mtyp != 0 {
-               mtyp := t.typeOff(p.mtyp)
-               ft := (*funcType)(unsafe.Pointer(mtyp))
-               in := make([]Type, 0, 1+len(ft.in()))
-               in = append(in, t)
-               for _, arg := range ft.in() {
-                       in = append(in, arg)
-               }
-               out := make([]Type, 0, len(ft.out()))
-               for _, ret := range ft.out() {
-                       out = append(out, ret)
-               }
-               mt := FuncOf(in, out, ft.IsVariadic())
-               m.Type = mt
-               tfn := t.textOff(p.tfn)
-               fn := unsafe.Pointer(&tfn)
-               m.Func = Value{mt.(*rtype), fn, fl}
+       mtyp := t.typeOff(p.mtyp)
+       ft := (*funcType)(unsafe.Pointer(mtyp))
+       in := make([]Type, 0, 1+len(ft.in()))
+       in = append(in, t)
+       for _, arg := range ft.in() {
+               in = append(in, arg)
        }
+       out := make([]Type, 0, len(ft.out()))
+       for _, ret := range ft.out() {
+               out = append(out, ret)
+       }
+       mt := FuncOf(in, out, ft.IsVariadic())
+       m.Type = mt
+       tfn := t.textOff(p.tfn)
+       fn := unsafe.Pointer(&tfn)
+       m.Func = Value{mt.(*rtype), fn, fl}
+
        m.Index = i
        return m
 }
@@ -831,7 +877,7 @@ func (t *rtype) MethodByName(name string) (m Method, ok bool) {
        for i := 0; i < int(ut.mcount); i++ {
                p := utmethods[i]
                pname := t.nameOff(p.name)
-               if pname.name() == name {
+               if pname.isExported() && pname.name() == name {
                        return t.Method(i), true
                }
        }