]> Cypherpunks repositories - gostls13.git/commitdiff
plugin: properly handle recursively defined types
authorTodd Neal <todd@tneal.org>
Tue, 11 Apr 2017 02:33:27 +0000 (21:33 -0500)
committerTodd Neal <todd@tneal.org>
Wed, 12 Apr 2017 12:46:07 +0000 (12:46 +0000)
Prevent a crash if the same type in two plugins had a recursive
definition, either by referring to a pointer to itself or a map existing
with the type as a value type (which creates a recursive definition
through the overflow bucket type).

Fixes #19258

Change-Id: Iac1cbda4c5b6e8edd5e6859a4d5da3bad539a9c6
Reviewed-on: https://go-review.googlesource.com/40292
Run-TryBot: Todd Neal <todd@tneal.org>
Reviewed-by: David Crawshaw <crawshaw@golang.org>
misc/cgo/testplugin/unnamed1.go
misc/cgo/testplugin/unnamed2.go
src/runtime/type.go

index 102edaf3e29c97c3643a21a098fc3a17e19f077c..5c1df086d76df5fba39435cc7cab4d2befdeec32 100644 (file)
@@ -9,4 +9,15 @@ import "C"
 
 func FuncInt() int { return 1 }
 
+// Add a recursive type to to check that type equality across plugins doesn't
+// crash. See https://golang.org/issues/19258
+func FuncRecursive() X { return X{} }
+
+type Y struct {
+       X *X
+}
+type X struct {
+       Y Y
+}
+
 func main() {}
index 55070d5e9f79e428afe6ff19529c4232ef9ada5a..7ef66109c5ca6246e973cd92d757ff095be4d8cb 100644 (file)
@@ -9,4 +9,13 @@ import "C"
 
 func FuncInt() int { return 2 }
 
+func FuncRecursive() X { return X{} }
+
+type Y struct {
+       X *X
+}
+type X struct {
+       Y Y
+}
+
 func main() {}
index 10442eff695e5eee59b42217f3bd47e17dd46ee6..d001c5cd447fefa93d31a58d22c901f48fc2a716 100644 (file)
@@ -511,7 +511,8 @@ func typelinksinit() {
                        for _, tl := range md.typelinks {
                                t := (*_type)(unsafe.Pointer(md.types + uintptr(tl)))
                                for _, candidate := range typehash[t.hash] {
-                                       if typesEqual(t, candidate) {
+                                       seen := map[_typePair]struct{}{}
+                                       if typesEqual(t, candidate, seen) {
                                                t = candidate
                                                break
                                        }
@@ -524,6 +525,11 @@ func typelinksinit() {
        }
 }
 
+type _typePair struct {
+       t1 *_type
+       t2 *_type
+}
+
 // typesEqual reports whether two types are equal.
 //
 // Everywhere in the runtime and reflect packages, it is assumed that
@@ -536,7 +542,17 @@ func typelinksinit() {
 // back into earlier ones.
 //
 // Only typelinksinit needs this function.
-func typesEqual(t, v *_type) bool {
+func typesEqual(t, v *_type, seen map[_typePair]struct{}) bool {
+       tp := _typePair{t, v}
+       if _, ok := seen[tp]; ok {
+               return true
+       }
+
+       // mark these types as seen, and thus equivalent which prevents an infinite loop if
+       // the two types are identical, but recursively defined and loaded from
+       // different modules
+       seen[tp] = struct{}{}
+
        if t == v {
                return true
        }
@@ -568,11 +584,11 @@ func typesEqual(t, v *_type) bool {
        case kindArray:
                at := (*arraytype)(unsafe.Pointer(t))
                av := (*arraytype)(unsafe.Pointer(v))
-               return typesEqual(at.elem, av.elem) && at.len == av.len
+               return typesEqual(at.elem, av.elem, seen) && at.len == av.len
        case kindChan:
                ct := (*chantype)(unsafe.Pointer(t))
                cv := (*chantype)(unsafe.Pointer(v))
-               return ct.dir == cv.dir && typesEqual(ct.elem, cv.elem)
+               return ct.dir == cv.dir && typesEqual(ct.elem, cv.elem, seen)
        case kindFunc:
                ft := (*functype)(unsafe.Pointer(t))
                fv := (*functype)(unsafe.Pointer(v))
@@ -581,13 +597,13 @@ func typesEqual(t, v *_type) bool {
                }
                tin, vin := ft.in(), fv.in()
                for i := 0; i < len(tin); i++ {
-                       if !typesEqual(tin[i], vin[i]) {
+                       if !typesEqual(tin[i], vin[i], seen) {
                                return false
                        }
                }
                tout, vout := ft.out(), fv.out()
                for i := 0; i < len(tout); i++ {
-                       if !typesEqual(tout[i], vout[i]) {
+                       if !typesEqual(tout[i], vout[i], seen) {
                                return false
                        }
                }
@@ -616,7 +632,7 @@ func typesEqual(t, v *_type) bool {
                        }
                        tityp := resolveTypeOff(unsafe.Pointer(tm), tm.ityp)
                        vityp := resolveTypeOff(unsafe.Pointer(vm), vm.ityp)
-                       if !typesEqual(tityp, vityp) {
+                       if !typesEqual(tityp, vityp, seen) {
                                return false
                        }
                }
@@ -624,15 +640,15 @@ func typesEqual(t, v *_type) bool {
        case kindMap:
                mt := (*maptype)(unsafe.Pointer(t))
                mv := (*maptype)(unsafe.Pointer(v))
-               return typesEqual(mt.key, mv.key) && typesEqual(mt.elem, mv.elem)
+               return typesEqual(mt.key, mv.key, seen) && typesEqual(mt.elem, mv.elem, seen)
        case kindPtr:
                pt := (*ptrtype)(unsafe.Pointer(t))
                pv := (*ptrtype)(unsafe.Pointer(v))
-               return typesEqual(pt.elem, pv.elem)
+               return typesEqual(pt.elem, pv.elem, seen)
        case kindSlice:
                st := (*slicetype)(unsafe.Pointer(t))
                sv := (*slicetype)(unsafe.Pointer(v))
-               return typesEqual(st.elem, sv.elem)
+               return typesEqual(st.elem, sv.elem, seen)
        case kindStruct:
                st := (*structtype)(unsafe.Pointer(t))
                sv := (*structtype)(unsafe.Pointer(v))
@@ -648,7 +664,7 @@ func typesEqual(t, v *_type) bool {
                        if tf.name.pkgPath() != vf.name.pkgPath() {
                                return false
                        }
-                       if !typesEqual(tf.typ, vf.typ) {
+                       if !typesEqual(tf.typ, vf.typ, seen) {
                                return false
                        }
                        if tf.name.tag() != vf.name.tag() {