From be1b93065356e71362ca8469fc53c9ab102c4be5 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Thu, 19 May 2016 13:31:58 -0400 Subject: [PATCH] reflect: hide unexported methods that do not satisfy interfaces Fixes #15673 Change-Id: Ib36d8db3299a93d92665dbde012d52c2c5332ac0 Reviewed-on: https://go-review.googlesource.com/23253 Reviewed-by: Russ Cox Run-TryBot: David Crawshaw TryBot-Result: Gobot Gobot --- src/reflect/all_test.go | 19 ++++---- src/reflect/type.go | 100 +++++++++++++++++++++++++++++----------- 2 files changed, 81 insertions(+), 38 deletions(-) diff --git a/src/reflect/all_test.go b/src/reflect/all_test.go index 9799fee357..f09ffeb566 100644 --- a/src/reflect/all_test.go +++ b/src/reflect/all_test.go @@ -2388,13 +2388,13 @@ type outer struct { inner } -func (*inner) m() {} -func (*outer) m() {} +func (*inner) M() {} +func (*outer) M() {} func TestNestedMethods(t *testing.T) { typ := TypeOf((*outer)(nil)) - if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).m).Pointer() { - t.Errorf("Wrong method table for outer: (m=%p)", (*outer).m) + if typ.NumMethod() != 1 || typ.Method(0).Func.Pointer() != ValueOf((*outer).M).Pointer() { + t.Errorf("Wrong method table for outer: (M=%p)", (*outer).M) for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) t.Errorf("\t%d: %s %#x\n", i, m.Name, m.Func.Pointer()) @@ -2416,18 +2416,15 @@ var unexpi unexpI = new(unexp) func TestUnexportedMethods(t *testing.T) { typ := TypeOf(unexpi) + if got := typ.NumMethod(); got != 1 { + t.Error("NumMethod=%d, want 1 satisfied method", got) + } if typ.Method(0).Type == nil { t.Error("missing type for satisfied method 'f'") } if !typ.Method(0).Func.IsValid() { t.Error("missing func for satisfied method 'f'") } - if typ.Method(1).Type != nil { - t.Error("found type for unsatisfied method 'g'") - } - if typ.Method(1).Func.IsValid() { - t.Error("found func for unsatisfied method 'g'") - } } type InnerInt struct { @@ -5187,7 +5184,7 @@ func useStack(n int) { type Impl struct{} -func (Impl) f() {} +func (Impl) F() {} func TestValueString(t *testing.T) { rv := ValueOf(Impl{}) diff --git a/src/reflect/type.go b/src/reflect/type.go index dd7b797c04..c9389199d8 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -763,16 +763,63 @@ func (t *rtype) pointers() bool { return t.kind&kindNoPointers == 0 } func (t *rtype) common() *rtype { return t } +var methodCache struct { + sync.RWMutex + m map[*rtype][]method +} + +// satisfiedMethods returns methods of t that satisfy an interface. +// This may include unexported methods that satisfy an interface +// defined with unexported methods in the same package as t. +func (t *rtype) satisfiedMethods() []method { + methodCache.RLock() + methods, found := methodCache.m[t] + methodCache.RUnlock() + + if found { + return methods + } + + ut := t.uncommon() + if ut == nil { + return nil + } + allm := ut.methods() + allSatisfied := true + for _, m := range allm { + if m.mtyp == 0 { + allSatisfied = false + break + } + } + if allSatisfied { + methods = allm + } else { + methods = make([]method, 0, len(allm)) + for _, m := range allm { + if m.mtyp != 0 { + methods = append(methods, m) + } + } + methods = methods[:len(methods):len(methods)] + } + + methodCache.Lock() + if methodCache.m == nil { + methodCache.m = make(map[*rtype][]method) + } + methodCache.m[t] = methods + methodCache.Unlock() + + return methods +} + func (t *rtype) NumMethod() int { if t.Kind() == Interface { tt := (*interfaceType)(unsafe.Pointer(t)) return tt.NumMethod() } - ut := t.uncommon() - if ut == nil { - return 0 - } - return int(ut.mcount) + return len(t.satisfiedMethods()) } func (t *rtype) Method(i int) (m Method) { @@ -780,40 +827,39 @@ func (t *rtype) Method(i int) (m Method) { tt := (*interfaceType)(unsafe.Pointer(t)) return tt.Method(i) } - ut := t.uncommon() - - if ut == nil || i < 0 || i >= int(ut.mcount) { + methods := t.satisfiedMethods() + if i < 0 || i >= len(methods) { panic("reflect: Method index out of range") } - p := ut.methods()[i] + p := methods[i] pname := t.nameOff(p.name) m.Name = pname.name() fl := flag(Func) if !pname.isExported() { m.PkgPath = pname.pkgPath() if m.PkgPath == "" { + ut := t.uncommon() m.PkgPath = t.nameOff(ut.pkgPath).name() } fl |= flagStickyRO } - if p.mtyp != 0 { - mtyp := t.typeOff(p.mtyp) - ft := (*funcType)(unsafe.Pointer(mtyp)) - in := make([]Type, 0, 1+len(ft.in())) - in = append(in, t) - for _, arg := range ft.in() { - in = append(in, arg) - } - out := make([]Type, 0, len(ft.out())) - for _, ret := range ft.out() { - out = append(out, ret) - } - mt := FuncOf(in, out, ft.IsVariadic()) - m.Type = mt - tfn := t.textOff(p.tfn) - fn := unsafe.Pointer(&tfn) - m.Func = Value{mt.(*rtype), fn, fl} + mtyp := t.typeOff(p.mtyp) + ft := (*funcType)(unsafe.Pointer(mtyp)) + in := make([]Type, 0, 1+len(ft.in())) + in = append(in, t) + for _, arg := range ft.in() { + in = append(in, arg) } + out := make([]Type, 0, len(ft.out())) + for _, ret := range ft.out() { + out = append(out, ret) + } + mt := FuncOf(in, out, ft.IsVariadic()) + m.Type = mt + tfn := t.textOff(p.tfn) + fn := unsafe.Pointer(&tfn) + m.Func = Value{mt.(*rtype), fn, fl} + m.Index = i return m } @@ -831,7 +877,7 @@ func (t *rtype) MethodByName(name string) (m Method, ok bool) { for i := 0; i < int(ut.mcount); i++ { p := utmethods[i] pname := t.nameOff(p.name) - if pname.name() == name { + if pname.isExported() && pname.name() == name { return t.Method(i), true } } -- 2.48.1