]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: panic on uncomparable map key, even if map is empty
authorKeith Randall <khr@google.com>
Fri, 28 Dec 2018 22:34:48 +0000 (14:34 -0800)
committerKeith Randall <khr@golang.org>
Sat, 29 Dec 2018 01:00:54 +0000 (01:00 +0000)
Reorg map flags a bit so we don't need any extra space for the extra flag.

Fixes #23734

Change-Id: I436812156240ae90de53d0943fe1aabf3ea37417
Reviewed-on: https://go-review.googlesource.com/c/155918
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/cmd/compile/internal/gc/reflect.go
src/reflect/type.go
src/runtime/map.go
src/runtime/type.go
test/fixedbugs/issue23734.go [new file with mode: 0644]

index 2863d4b5d0ca7862eb7ff0d499e997b0c6ed0ec1..7a93ece8b90365f74dd9318168ddf54091d120f5 100644 (file)
@@ -1095,6 +1095,28 @@ func needkeyupdate(t *types.Type) bool {
        }
 }
 
+// hashMightPanic reports whether the hash of a map key of type t might panic.
+func hashMightPanic(t *types.Type) bool {
+       switch t.Etype {
+       case TINTER:
+               return true
+
+       case TARRAY:
+               return hashMightPanic(t.Elem())
+
+       case TSTRUCT:
+               for _, t1 := range t.Fields().Slice() {
+                       if hashMightPanic(t1.Type) {
+                               return true
+                       }
+               }
+               return false
+
+       default:
+               return false
+       }
+}
+
 // formalType replaces byte and rune aliases with real types.
 // They've been separate internally to make error messages
 // better, but we have to merge them in the reflect tables.
@@ -1257,25 +1279,33 @@ func dtypesym(t *types.Type) *obj.LSym {
                ot = dsymptr(lsym, ot, s1, 0)
                ot = dsymptr(lsym, ot, s2, 0)
                ot = dsymptr(lsym, ot, s3, 0)
+               var flags uint32
+               // Note: flags must match maptype accessors in ../../../../runtime/type.go
+               // and maptype builder in ../../../../reflect/type.go:MapOf.
                if t.Key().Width > MAXKEYSIZE {
                        ot = duint8(lsym, ot, uint8(Widthptr))
-                       ot = duint8(lsym, ot, 1) // indirect
+                       flags |= 1 // indirect key
                } else {
                        ot = duint8(lsym, ot, uint8(t.Key().Width))
-                       ot = duint8(lsym, ot, 0) // not indirect
                }
 
                if t.Elem().Width > MAXVALSIZE {
                        ot = duint8(lsym, ot, uint8(Widthptr))
-                       ot = duint8(lsym, ot, 1) // indirect
+                       flags |= 2 // indirect value
                } else {
                        ot = duint8(lsym, ot, uint8(t.Elem().Width))
-                       ot = duint8(lsym, ot, 0) // not indirect
                }
-
                ot = duint16(lsym, ot, uint16(bmap(t).Width))
-               ot = duint8(lsym, ot, uint8(obj.Bool2int(isreflexive(t.Key()))))
-               ot = duint8(lsym, ot, uint8(obj.Bool2int(needkeyupdate(t.Key()))))
+               if isreflexive(t.Key()) {
+                       flags |= 4 // reflexive key
+               }
+               if needkeyupdate(t.Key()) {
+                       flags |= 8 // need key update
+               }
+               if hashMightPanic(t.Key()) {
+                       flags |= 16 // hash might panic
+               }
+               ot = duint32(lsym, ot, flags)
                ot = dextratype(lsym, ot, t, 0)
 
        case TPTR:
index f48f9cf09de4647567a412bdd51a2de2ab37ebeb..5ce80c61dcfdf6de39082c2edd0f63019631293f 100644 (file)
@@ -394,16 +394,13 @@ type interfaceType struct {
 // mapType represents a map type.
 type mapType struct {
        rtype
-       key           *rtype // map key type
-       elem          *rtype // map element (value) type
-       bucket        *rtype // internal bucket structure
-       keysize       uint8  // size of key slot
-       indirectkey   uint8  // store ptr to key instead of key itself
-       valuesize     uint8  // size of value slot
-       indirectvalue uint8  // store ptr to value instead of value itself
-       bucketsize    uint16 // size of bucket
-       reflexivekey  bool   // true if k==k for all keys
-       needkeyupdate bool   // true if we need to update key on an overwrite
+       key        *rtype // map key type
+       elem       *rtype // map element (value) type
+       bucket     *rtype // internal bucket structure
+       keysize    uint8  // size of key slot
+       valuesize  uint8  // size of value slot
+       bucketsize uint16 // size of bucket
+       flags      uint32
 }
 
 // ptrType represents a pointer type.
@@ -1859,6 +1856,8 @@ func MapOf(key, elem Type) Type {
        }
 
        // Make a map type.
+       // Note: flag values must match those used in the TMAP case
+       // in ../cmd/compile/internal/gc/reflect.go:dtypesym.
        var imap interface{} = (map[unsafe.Pointer]unsafe.Pointer)(nil)
        mt := **(**mapType)(unsafe.Pointer(&imap))
        mt.str = resolveReflectName(newName(s, "", false))
@@ -1867,23 +1866,29 @@ func MapOf(key, elem Type) Type {
        mt.key = ktyp
        mt.elem = etyp
        mt.bucket = bucketOf(ktyp, etyp)
+       mt.flags = 0
        if ktyp.size > maxKeySize {
                mt.keysize = uint8(ptrSize)
-               mt.indirectkey = 1
+               mt.flags |= 1 // indirect key
        } else {
                mt.keysize = uint8(ktyp.size)
-               mt.indirectkey = 0
        }
        if etyp.size > maxValSize {
                mt.valuesize = uint8(ptrSize)
-               mt.indirectvalue = 1
+               mt.flags |= 2 // indirect value
        } else {
                mt.valuesize = uint8(etyp.size)
-               mt.indirectvalue = 0
        }
        mt.bucketsize = uint16(mt.bucket.size)
-       mt.reflexivekey = isReflexive(ktyp)
-       mt.needkeyupdate = needKeyUpdate(ktyp)
+       if isReflexive(ktyp) {
+               mt.flags |= 4
+       }
+       if needKeyUpdate(ktyp) {
+               mt.flags |= 8
+       }
+       if hashMightPanic(ktyp) {
+               mt.flags |= 16
+       }
        mt.ptrToThis = 0
 
        ti, _ := lookupCache.LoadOrStore(ckey, &mt.rtype)
@@ -2122,6 +2127,27 @@ func needKeyUpdate(t *rtype) bool {
        }
 }
 
+// hashMightPanic reports whether the hash of a map key of type t might panic.
+func hashMightPanic(t *rtype) bool {
+       switch t.Kind() {
+       case Interface:
+               return true
+       case Array:
+               tt := (*arrayType)(unsafe.Pointer(t))
+               return hashMightPanic(tt.elem)
+       case Struct:
+               tt := (*structType)(unsafe.Pointer(t))
+               for _, f := range tt.fields {
+                       if hashMightPanic(f.typ) {
+                               return true
+                       }
+               }
+               return false
+       default:
+               return false
+       }
+}
+
 // Make sure these routines stay in sync with ../../runtime/map.go!
 // These types exist only for GC, so we only fill out GC relevant info.
 // Currently, that's just size and the GC program. We also fill in string
index d835cc831aafa975d41f3f5633e311b4306ece57..9c25b63348f5b6289b0caa2f74f153c4ea0e193e 100644 (file)
@@ -404,6 +404,9 @@ func mapaccess1(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
                msanread(key, t.key.size)
        }
        if h == nil || h.count == 0 {
+               if t.hashMightPanic() {
+                       t.key.alg.hash(key, 0) // see issue 23734
+               }
                return unsafe.Pointer(&zeroVal[0])
        }
        if h.flags&hashWriting != 0 {
@@ -434,12 +437,12 @@ bucketloop:
                                continue
                        }
                        k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                k = *((*unsafe.Pointer)(k))
                        }
                        if alg.equal(key, k) {
                                v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
-                               if t.indirectvalue {
+                               if t.indirectvalue() {
                                        v = *((*unsafe.Pointer)(v))
                                }
                                return v
@@ -460,6 +463,9 @@ func mapaccess2(t *maptype, h *hmap, key unsafe.Pointer) (unsafe.Pointer, bool)
                msanread(key, t.key.size)
        }
        if h == nil || h.count == 0 {
+               if t.hashMightPanic() {
+                       t.key.alg.hash(key, 0) // see issue 23734
+               }
                return unsafe.Pointer(&zeroVal[0]), false
        }
        if h.flags&hashWriting != 0 {
@@ -490,12 +496,12 @@ bucketloop:
                                continue
                        }
                        k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                k = *((*unsafe.Pointer)(k))
                        }
                        if alg.equal(key, k) {
                                v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
-                               if t.indirectvalue {
+                               if t.indirectvalue() {
                                        v = *((*unsafe.Pointer)(v))
                                }
                                return v, true
@@ -535,12 +541,12 @@ bucketloop:
                                continue
                        }
                        k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                k = *((*unsafe.Pointer)(k))
                        }
                        if alg.equal(key, k) {
                                v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
-                               if t.indirectvalue {
+                               if t.indirectvalue() {
                                        v = *((*unsafe.Pointer)(v))
                                }
                                return k, v
@@ -620,14 +626,14 @@ bucketloop:
                                continue
                        }
                        k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                k = *((*unsafe.Pointer)(k))
                        }
                        if !alg.equal(key, k) {
                                continue
                        }
                        // already have a mapping for key. Update it.
-                       if t.needkeyupdate {
+                       if t.needkeyupdate() {
                                typedmemmove(t.key, k, key)
                        }
                        val = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
@@ -658,12 +664,12 @@ bucketloop:
        }
 
        // store new key/value at insert position
-       if t.indirectkey {
+       if t.indirectkey() {
                kmem := newobject(t.key)
                *(*unsafe.Pointer)(insertk) = kmem
                insertk = kmem
        }
-       if t.indirectvalue {
+       if t.indirectvalue() {
                vmem := newobject(t.elem)
                *(*unsafe.Pointer)(val) = vmem
        }
@@ -676,7 +682,7 @@ done:
                throw("concurrent map writes")
        }
        h.flags &^= hashWriting
-       if t.indirectvalue {
+       if t.indirectvalue() {
                val = *((*unsafe.Pointer)(val))
        }
        return val
@@ -693,6 +699,9 @@ func mapdelete(t *maptype, h *hmap, key unsafe.Pointer) {
                msanread(key, t.key.size)
        }
        if h == nil || h.count == 0 {
+               if t.hashMightPanic() {
+                       t.key.alg.hash(key, 0) // see issue 23734
+               }
                return
        }
        if h.flags&hashWriting != 0 {
@@ -724,20 +733,20 @@ search:
                        }
                        k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
                        k2 := k
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                k2 = *((*unsafe.Pointer)(k2))
                        }
                        if !alg.equal(key, k2) {
                                continue
                        }
                        // Only clear key if there are pointers in it.
-                       if t.indirectkey {
+                       if t.indirectkey() {
                                *(*unsafe.Pointer)(k) = nil
                        } else if t.key.kind&kindNoPointers == 0 {
                                memclrHasPointers(k, t.key.size)
                        }
                        v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
-                       if t.indirectvalue {
+                       if t.indirectvalue() {
                                *(*unsafe.Pointer)(v) = nil
                        } else if t.elem.kind&kindNoPointers == 0 {
                                memclrHasPointers(v, t.elem.size)
@@ -897,7 +906,7 @@ next:
                        continue
                }
                k := add(unsafe.Pointer(b), dataOffset+uintptr(offi)*uintptr(t.keysize))
-               if t.indirectkey {
+               if t.indirectkey() {
                        k = *((*unsafe.Pointer)(k))
                }
                v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+uintptr(offi)*uintptr(t.valuesize))
@@ -909,7 +918,7 @@ next:
                        // through the oldbucket, skipping any keys that will go
                        // to the other new bucket (each oldbucket expands to two
                        // buckets during a grow).
-                       if t.reflexivekey || alg.equal(k, k) {
+                       if t.reflexivekey() || alg.equal(k, k) {
                                // If the item in the oldbucket is not destined for
                                // the current new bucket in the iteration, skip it.
                                hash := alg.hash(k, uintptr(h.hash0))
@@ -930,13 +939,13 @@ next:
                        }
                }
                if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) ||
-                       !(t.reflexivekey || alg.equal(k, k)) {
+                       !(t.reflexivekey() || alg.equal(k, k)) {
                        // This is the golden data, we can return it.
                        // OR
                        // key!=key, so the entry can't be deleted or updated, so we can just return it.
                        // That's lucky for us because when key!=key we can't look it up successfully.
                        it.key = k
-                       if t.indirectvalue {
+                       if t.indirectvalue() {
                                v = *((*unsafe.Pointer)(v))
                        }
                        it.value = v
@@ -1160,7 +1169,7 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
                                        throw("bad map state")
                                }
                                k2 := k
-                               if t.indirectkey {
+                               if t.indirectkey() {
                                        k2 = *((*unsafe.Pointer)(k2))
                                }
                                var useY uint8
@@ -1168,7 +1177,7 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
                                        // Compute hash to make our evacuation decision (whether we need
                                        // to send this key/value to bucket x or bucket y).
                                        hash := t.key.alg.hash(k2, uintptr(h.hash0))
-                                       if h.flags&iterator != 0 && !t.reflexivekey && !t.key.alg.equal(k2, k2) {
+                                       if h.flags&iterator != 0 && !t.reflexivekey() && !t.key.alg.equal(k2, k2) {
                                                // If key != key (NaNs), then the hash could be (and probably
                                                // will be) entirely different from the old hash. Moreover,
                                                // it isn't reproducible. Reproducibility is required in the
@@ -1203,12 +1212,12 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
                                        dst.v = add(dst.k, bucketCnt*uintptr(t.keysize))
                                }
                                dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
-                               if t.indirectkey {
+                               if t.indirectkey() {
                                        *(*unsafe.Pointer)(dst.k) = k2 // copy pointer
                                } else {
                                        typedmemmove(t.key, dst.k, k) // copy value
                                }
-                               if t.indirectvalue {
+                               if t.indirectvalue() {
                                        *(*unsafe.Pointer)(dst.v) = *(*unsafe.Pointer)(v)
                                } else {
                                        typedmemmove(t.elem, dst.v, v)
@@ -1274,12 +1283,12 @@ func reflect_makemap(t *maptype, cap int) *hmap {
        if !ismapkey(t.key) {
                throw("runtime.reflect_makemap: unsupported map key type")
        }
-       if t.key.size > maxKeySize && (!t.indirectkey || t.keysize != uint8(sys.PtrSize)) ||
-               t.key.size <= maxKeySize && (t.indirectkey || t.keysize != uint8(t.key.size)) {
+       if t.key.size > maxKeySize && (!t.indirectkey() || t.keysize != uint8(sys.PtrSize)) ||
+               t.key.size <= maxKeySize && (t.indirectkey() || t.keysize != uint8(t.key.size)) {
                throw("key size wrong")
        }
-       if t.elem.size > maxValueSize && (!t.indirectvalue || t.valuesize != uint8(sys.PtrSize)) ||
-               t.elem.size <= maxValueSize && (t.indirectvalue || t.valuesize != uint8(t.elem.size)) {
+       if t.elem.size > maxValueSize && (!t.indirectvalue() || t.valuesize != uint8(sys.PtrSize)) ||
+               t.elem.size <= maxValueSize && (t.indirectvalue() || t.valuesize != uint8(t.elem.size)) {
                throw("value size wrong")
        }
        if t.key.align > bucketCnt {
index 88a44a37ed3da17bfe4935cb69ca5b449354f0e2..f7f99924eaf759bf63cacd087eb16b0e508f317f 100644 (file)
@@ -361,17 +361,32 @@ type interfacetype struct {
 }
 
 type maptype struct {
-       typ           _type
-       key           *_type
-       elem          *_type
-       bucket        *_type // internal type representing a hash bucket
-       keysize       uint8  // size of key slot
-       indirectkey   bool   // store ptr to key instead of key itself
-       valuesize     uint8  // size of value slot
-       indirectvalue bool   // store ptr to value instead of value itself
-       bucketsize    uint16 // size of bucket
-       reflexivekey  bool   // true if k==k for all keys
-       needkeyupdate bool   // true if we need to update key on an overwrite
+       typ        _type
+       key        *_type
+       elem       *_type
+       bucket     *_type // internal type representing a hash bucket
+       keysize    uint8  // size of key slot
+       valuesize  uint8  // size of value slot
+       bucketsize uint16 // size of bucket
+       flags      uint32
+}
+
+// Note: flag values must match those used in the TMAP case
+// in ../cmd/compile/internal/gc/reflect.go:dtypesym.
+func (mt *maptype) indirectkey() bool { // store ptr to key instead of key itself
+       return mt.flags&1 != 0
+}
+func (mt *maptype) indirectvalue() bool { // store ptr to value instead of value itself
+       return mt.flags&2 != 0
+}
+func (mt *maptype) reflexivekey() bool { // true if k==k for all keys
+       return mt.flags&4 != 0
+}
+func (mt *maptype) needkeyupdate() bool { // true if we need to update key on an overwrite
+       return mt.flags&8 != 0
+}
+func (mt *maptype) hashMightPanic() bool { // true if hash function might panic
+       return mt.flags&16 != 0
 }
 
 type arraytype struct {
diff --git a/test/fixedbugs/issue23734.go b/test/fixedbugs/issue23734.go
new file mode 100644 (file)
index 0000000..dd5869b
--- /dev/null
@@ -0,0 +1,32 @@
+// run
+
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func main() {
+       m := map[interface{}]int{}
+       k := []int{}
+
+       mustPanic(func() {
+               _ = m[k]
+       })
+       mustPanic(func() {
+               _, _ = m[k]
+       })
+       mustPanic(func() {
+               delete(m, k)
+       })
+}
+
+func mustPanic(f func()) {
+       defer func() {
+               r := recover()
+               if r == nil {
+                       panic("didn't panic")
+               }
+       }()
+       f()
+}