From 2220fd36368c96da3dd833bdc2bbd13be291216a Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Tue, 17 Sep 2024 16:30:04 -0400 Subject: [PATCH] runtime: add concurrent write checks to swissmap This is the same design as existing maps. For #54766. Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest-swissmap Change-Id: I5f6ef5fea1e0f0616bcd90eaae7faee4cdac58c6 Reviewed-on: https://go-review.googlesource.com/c/go/+/616460 Reviewed-by: Keith Randall LUCI-TryBot-Result: Go LUCI Auto-Submit: Michael Pratt Reviewed-by: Keith Randall --- .../compile/internal/reflectdata/map_swiss.go | 5 +- src/internal/runtime/maps/map.go | 88 +++++++++++++++---- src/internal/runtime/maps/runtime_swiss.go | 38 ++++++-- src/internal/runtime/maps/table.go | 4 + src/runtime/map_swiss.go | 4 - 5 files changed, 111 insertions(+), 28 deletions(-) diff --git a/src/cmd/compile/internal/reflectdata/map_swiss.go b/src/cmd/compile/internal/reflectdata/map_swiss.go index 2b79b22235..b531d785d3 100644 --- a/src/cmd/compile/internal/reflectdata/map_swiss.go +++ b/src/cmd/compile/internal/reflectdata/map_swiss.go @@ -144,6 +144,8 @@ func SwissMapType() *types.Type { // // globalDepth uint8 // globalShift uint8 + // + // writing uint8 // // N.B Padding // // clearSeq uint64 @@ -156,6 +158,7 @@ func SwissMapType() *types.Type { makefield("dirLen", types.Types[types.TINT]), makefield("globalDepth", types.Types[types.TUINT8]), makefield("globalShift", types.Types[types.TUINT8]), + makefield("writing", types.Types[types.TUINT8]), makefield("clearSeq", types.Types[types.TUINT64]), } @@ -169,7 +172,7 @@ func SwissMapType() *types.Type { // The size of Map should be 48 bytes on 64 bit // and 32 bytes on 32 bit platforms. - if size := int64(2*8 + 4*types.PtrSize /* one extra for globalDepth + padding */); m.Size() != size { + if size := int64(2*8 + 4*types.PtrSize /* one extra for globalDepth/globalShift/writing + padding */); m.Size() != size { base.Fatalf("internal/runtime/maps.Map size not correct: got %d, want %d", m.Size(), size) } diff --git a/src/internal/runtime/maps/map.go b/src/internal/runtime/maps/map.go index ae8afc3ea7..543340f10c 100644 --- a/src/internal/runtime/maps/map.go +++ b/src/internal/runtime/maps/map.go @@ -229,6 +229,12 @@ type Map struct { // On 64-bit systems, this is 64 - globalDepth. globalShift uint8 + // writing is a flag that is toggled (XOR 1) while the map is being + // written. Normally it is set to 1 when writing, but if there are + // multiple concurrent writers, then toggling increases the probability + // that both sides will detect the race. + writing uint8 + // clearSeq is a sequence counter of calls to Clear. It is used to // detect map clears during iteration. clearSeq uint64 @@ -386,6 +392,10 @@ func (m *Map) getWithKey(typ *abi.SwissMapType, key unsafe.Pointer) (unsafe.Poin return nil, nil, false } + if m.writing != 0 { + fatal("concurrent map read and map write") + } + hash := typ.Hasher(key, m.seed) if m.dirLen == 0 { @@ -401,6 +411,10 @@ func (m *Map) getWithoutKey(typ *abi.SwissMapType, key unsafe.Pointer) (unsafe.P return nil, false } + if m.writing != 0 { + fatal("concurrent map read and map write") + } + hash := typ.Hasher(key, m.seed) if m.dirLen == 0 { @@ -446,15 +460,30 @@ func (m *Map) Put(typ *abi.SwissMapType, key, elem unsafe.Pointer) { // // PutSlot never returns nil. func (m *Map) PutSlot(typ *abi.SwissMapType, key unsafe.Pointer) unsafe.Pointer { + if m.writing != 0 { + fatal("concurrent map writes") + } + hash := typ.Hasher(key, m.seed) + // Set writing after calling Hasher, since Hasher may panic, in which + // case we have not actually done a write. + m.writing ^= 1 // toggle, see comment on writing + if m.dirPtr == nil { m.growToSmall(typ) } if m.dirLen == 0 { if m.used < abi.SwissMapGroupSlots { - return m.putSlotSmall(typ, hash, key) + elem := m.putSlotSmall(typ, hash, key) + + if m.writing == 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 + + return elem } // Can't fit another entry, grow to full size map. @@ -470,6 +499,12 @@ func (m *Map) PutSlot(typ *abi.SwissMapType, key unsafe.Pointer) unsafe.Pointer if !ok { continue } + + if m.writing == 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 + return elem } } @@ -563,15 +598,27 @@ func (m *Map) Delete(typ *abi.SwissMapType, key unsafe.Pointer) { return } + if m.writing != 0 { + fatal("concurrent map writes") + } + hash := typ.Hasher(key, m.seed) + // Set writing after calling Hasher, since Hasher may panic, in which + // case we have not actually done a write. + m.writing ^= 1 // toggle, see comment on writing + if m.dirLen == 0 { m.deleteSmall(typ, hash, key) - return + } else { + idx := m.directoryIndex(hash) + m.directoryAt(idx).Delete(typ, m, key) } - idx := m.directoryIndex(hash) - m.directoryAt(idx).Delete(typ, m, key) + if m.writing == 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 } func (m *Map) deleteSmall(typ *abi.SwissMapType, hash uintptr, key unsafe.Pointer) { @@ -605,23 +652,32 @@ func (m *Map) Clear(typ *abi.SwissMapType) { return } + if m.writing != 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 // toggle, see comment on writing + if m.dirLen == 0 { m.clearSmall(typ) - return + } else { + var lastTab *table + for i := range m.dirLen { + t := m.directoryAt(uintptr(i)) + if t == lastTab { + continue + } + t.Clear(typ) + lastTab = t + } + m.used = 0 + m.clearSeq++ + // TODO: shrink directory? } - var lastTab *table - for i := range m.dirLen { - t := m.directoryAt(uintptr(i)) - if t == lastTab { - continue - } - t.Clear(typ) - lastTab = t + if m.writing == 0 { + fatal("concurrent map writes") } - m.used = 0 - m.clearSeq++ - // TODO: shrink directory? + m.writing ^= 1 } func (m *Map) clearSmall(typ *abi.SwissMapType) { diff --git a/src/internal/runtime/maps/runtime_swiss.go b/src/internal/runtime/maps/runtime_swiss.go index b8bc8de0c3..88042500bc 100644 --- a/src/internal/runtime/maps/runtime_swiss.go +++ b/src/internal/runtime/maps/runtime_swiss.go @@ -41,7 +41,6 @@ var zeroVal [abi.ZeroValSize]byte // //go:linkname runtime_mapaccess1 runtime.mapaccess1 func runtime_mapaccess1(typ *abi.SwissMapType, m *Map, key unsafe.Pointer) unsafe.Pointer { - // TODO: concurrent checks. if race.Enabled && m != nil { callerpc := sys.GetCallerPC() pc := abi.FuncPCABIInternal(runtime_mapaccess1) @@ -62,6 +61,10 @@ func runtime_mapaccess1(typ *abi.SwissMapType, m *Map, key unsafe.Pointer) unsaf return unsafe.Pointer(&zeroVal[0]) } + if m.writing != 0 { + fatal("concurrent map read and map write") + } + hash := typ.Hasher(key, m.seed) if m.dirLen <= 0 { @@ -104,7 +107,6 @@ func runtime_mapaccess1(typ *abi.SwissMapType, m *Map, key unsafe.Pointer) unsaf //go:linkname runtime_mapassign runtime.mapassign func runtime_mapassign(typ *abi.SwissMapType, m *Map, key unsafe.Pointer) unsafe.Pointer { - // TODO: concurrent checks. if m == nil { panic(errNilAssign) } @@ -120,22 +122,37 @@ func runtime_mapassign(typ *abi.SwissMapType, m *Map, key unsafe.Pointer) unsafe if asan.Enabled { asan.Read(key, typ.Key.Size_) } + if m.writing != 0 { + fatal("concurrent map writes") + } hash := typ.Hasher(key, m.seed) + // Set writing after calling Hasher, since Hasher may panic, in which + // case we have not actually done a write. + m.writing ^= 1 // toggle, see comment on writing + if m.dirPtr == nil { m.growToSmall(typ) } if m.dirLen == 0 { if m.used < abi.SwissMapGroupSlots { - return m.putSlotSmall(typ, hash, key) + elem := m.putSlotSmall(typ, hash, key) + + if m.writing == 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 + + return elem } // Can't fit another entry, grow to full size map. m.growToTable(typ) } + var slotElem unsafe.Pointer outer: for { // Select table. @@ -164,10 +181,10 @@ outer: typedmemmove(typ.Key, slotKey, key) } - slotElem := g.elem(typ, i) + slotElem = g.elem(typ, i) t.checkInvariants(typ) - return slotElem + break outer } match = match.removeFirst() } @@ -196,7 +213,7 @@ outer: if t.growthLeft > 0 { slotKey := g.key(typ, i) typedmemmove(typ.Key, slotKey, key) - slotElem := g.elem(typ, i) + slotElem = g.elem(typ, i) g.ctrls().set(i, ctrl(h2(hash))) t.growthLeft-- @@ -204,7 +221,7 @@ outer: m.used++ t.checkInvariants(typ) - return slotElem + break outer } t.rehash(typ, m) @@ -227,4 +244,11 @@ outer: } } } + + if m.writing == 0 { + fatal("concurrent map writes") + } + m.writing ^= 1 + + return slotElem } diff --git a/src/internal/runtime/maps/table.go b/src/internal/runtime/maps/table.go index 797d510269..ac5271ea06 100644 --- a/src/internal/runtime/maps/table.go +++ b/src/internal/runtime/maps/table.go @@ -553,6 +553,10 @@ func (it *Iter) Next() { return } + if it.m.writing != 0 { + fatal("concurrent map iteration and map write") + } + if it.dirIdx < 0 { // Map was small at Init. g := it.groupSmall diff --git a/src/runtime/map_swiss.go b/src/runtime/map_swiss.go index 42b964da24..2f48d29ac6 100644 --- a/src/runtime/map_swiss.go +++ b/src/runtime/map_swiss.go @@ -70,7 +70,6 @@ func makemap(t *abi.SwissMapType, hint int, m *maps.Map) *maps.Map { func mapaccess1(t *abi.SwissMapType, m *maps.Map, key unsafe.Pointer) unsafe.Pointer func mapaccess2(t *abi.SwissMapType, m *maps.Map, key unsafe.Pointer) (unsafe.Pointer, bool) { - // TODO: concurrent checks. if raceenabled && m != nil { callerpc := sys.GetCallerPC() pc := abi.FuncPCABIInternal(mapaccess2) @@ -121,7 +120,6 @@ func mapaccess2_fat(t *abi.SwissMapType, m *maps.Map, key, zero unsafe.Pointer) func mapassign(t *abi.SwissMapType, m *maps.Map, key unsafe.Pointer) unsafe.Pointer func mapdelete(t *abi.SwissMapType, m *maps.Map, key unsafe.Pointer) { - // TODO: concurrent checks. if raceenabled && m != nil { callerpc := sys.GetCallerPC() pc := abi.FuncPCABIInternal(mapdelete) @@ -153,7 +151,6 @@ func mapiterinit(t *abi.SwissMapType, m *maps.Map, it *maps.Iter) { } func mapiternext(it *maps.Iter) { - // TODO: concurrent checks. if raceenabled { callerpc := sys.GetCallerPC() racereadpc(unsafe.Pointer(it.Map()), callerpc, abi.FuncPCABIInternal(mapiternext)) @@ -164,7 +161,6 @@ func mapiternext(it *maps.Iter) { // mapclear deletes all keys from a map. func mapclear(t *abi.SwissMapType, m *maps.Map) { - // TODO: concurrent checks. if raceenabled && m != nil { callerpc := sys.GetCallerPC() pc := abi.FuncPCABIInternal(mapclear) -- 2.48.1