]> Cypherpunks repositories - gostls13.git/commitdiff
internal/runtime/maps: initial swiss table map implementation
authorMichael Pratt <mpratt@google.com>
Mon, 22 Apr 2024 19:48:57 +0000 (15:48 -0400)
committerGopher Robot <gobot@golang.org>
Tue, 8 Oct 2024 16:43:52 +0000 (16:43 +0000)
Add a new package that will contain a new "Swiss Table"
(https://abseil.io/about/design/swisstables) map implementation, which
is intended to eventually replace the existing runtime map
implementation.

This implementation is based on the fabulous
github.com/cockroachdb/swiss package contributed by Peter Mattis.

This CL adds an hash map implementation. It supports all the core
operations, but does not have incremental growth.

For #54766.

Change-Id: I52cf371448c3817d471ddb1f5a78f3513565db41
Reviewed-on: https://go-review.googlesource.com/c/go/+/582415
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Auto-Submit: Michael Pratt <mpratt@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
13 files changed:
src/go/build/deps_test.go
src/internal/runtime/maps/export_test.go [new file with mode: 0644]
src/internal/runtime/maps/fuzz_test.go [new file with mode: 0644]
src/internal/runtime/maps/group.go [new file with mode: 0644]
src/internal/runtime/maps/internal/abi/map_swiss.go [new file with mode: 0644]
src/internal/runtime/maps/map.go [new file with mode: 0644]
src/internal/runtime/maps/map_test.go [new file with mode: 0644]
src/internal/runtime/maps/runtime.go [new file with mode: 0644]
src/internal/runtime/maps/table.go [new file with mode: 0644]
src/internal/runtime/maps/table_debug.go [new file with mode: 0644]
src/runtime/malloc.go
src/runtime/mbarrier.go
src/runtime/rand.go

index 3adc26ae2b6e299d01765e05668f1a9b514556ee..894cf1bd2c0e04b84afa833263479ed2112bc398 100644 (file)
@@ -87,6 +87,8 @@ var depsRules = `
        < internal/runtime/syscall
        < internal/runtime/atomic
        < internal/runtime/exithook
+       < internal/runtime/maps/internal/abi
+       < internal/runtime/maps
        < internal/runtime/math
        < runtime
        < sync/atomic
diff --git a/src/internal/runtime/maps/export_test.go b/src/internal/runtime/maps/export_test.go
new file mode 100644 (file)
index 0000000..e2512d3
--- /dev/null
@@ -0,0 +1,56 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package maps
+
+import (
+       "internal/abi"
+       sabi "internal/runtime/maps/internal/abi"
+       "unsafe"
+)
+
+type CtrlGroup = ctrlGroup
+
+const DebugLog = debugLog
+
+var AlignUpPow2 = alignUpPow2
+
+type instantiatedGroup[K comparable, V any] struct {
+       ctrls ctrlGroup
+       slots [sabi.SwissMapGroupSlots]instantiatedSlot[K, V]
+}
+
+type instantiatedSlot[K comparable, V any] struct {
+       key  K
+       elem V
+}
+
+func NewTestTable[K comparable, V any](length uint64) *table {
+       var m map[K]V
+       mTyp := abi.TypeOf(m)
+       omt := (*abi.OldMapType)(unsafe.Pointer(mTyp))
+
+       var grp instantiatedGroup[K, V]
+       var slot instantiatedSlot[K, V]
+
+       mt := &sabi.SwissMapType{
+               Key:      omt.Key,
+               Elem:     omt.Elem,
+               Group:    abi.TypeOf(grp),
+               Hasher:   omt.Hasher,
+               SlotSize: unsafe.Sizeof(slot),
+               ElemOff:  unsafe.Offsetof(slot.elem),
+       }
+       if omt.NeedKeyUpdate() {
+               mt.Flags |= sabi.SwissMapNeedKeyUpdate
+       }
+       if omt.HashMightPanic() {
+               mt.Flags |= sabi.SwissMapHashMightPanic
+       }
+       return newTable(mt, length)
+}
+
+func (t *table) Type() *sabi.SwissMapType {
+       return t.typ
+}
diff --git a/src/internal/runtime/maps/fuzz_test.go b/src/internal/runtime/maps/fuzz_test.go
new file mode 100644 (file)
index 0000000..40a5010
--- /dev/null
@@ -0,0 +1,212 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package maps implements Go's builtin map type.
+package maps_test
+
+import (
+       "bytes"
+       "encoding/binary"
+       "fmt"
+       "internal/runtime/maps"
+       "reflect"
+       "testing"
+       "unsafe"
+)
+
+// The input to FuzzTable is a binary-encoded array of fuzzCommand structs.
+//
+// Each fuzz call begins with an empty table[uint16, uint32].
+//
+// Each command is then executed on the table in sequence. Operations with
+// output (e.g., Get) are verified against a reference map.
+type fuzzCommand struct {
+       Op fuzzOp
+
+       // Used for Get, Put, Delete.
+       Key uint16
+
+       // Used for Put.
+       Elem uint32
+}
+
+// Encoded size of fuzzCommand.
+var fuzzCommandSize = binary.Size(fuzzCommand{})
+
+type fuzzOp uint8
+
+const (
+       fuzzOpGet fuzzOp = iota
+       fuzzOpPut
+       fuzzOpDelete
+)
+
+func encode(fc []fuzzCommand) []byte {
+       var buf bytes.Buffer
+       if err := binary.Write(&buf, binary.LittleEndian, fc); err != nil {
+               panic(fmt.Sprintf("error writing %v: %v", fc, err))
+       }
+       return buf.Bytes()
+}
+
+func decode(b []byte) []fuzzCommand {
+       // Round b down to a multiple of fuzzCommand size. i.e., ignore extra
+       // bytes of input.
+       entries := len(b) / fuzzCommandSize
+       usefulSize := entries * fuzzCommandSize
+       b = b[:usefulSize]
+
+       fc := make([]fuzzCommand, entries)
+       buf := bytes.NewReader(b)
+       if err := binary.Read(buf, binary.LittleEndian, &fc); err != nil {
+               panic(fmt.Sprintf("error reading %v: %v", b, err))
+       }
+
+       return fc
+}
+
+func TestEncodeDecode(t *testing.T) {
+       fc := []fuzzCommand{
+               {
+                       Op:   fuzzOpPut,
+                       Key:  123,
+                       Elem: 456,
+               },
+               {
+                       Op:  fuzzOpGet,
+                       Key: 123,
+               },
+       }
+
+       b := encode(fc)
+       got := decode(b)
+       if !reflect.DeepEqual(fc, got) {
+               t.Errorf("encode-decode roundtrip got %+v want %+v", got, fc)
+       }
+
+       // Extra trailing bytes ignored.
+       b = append(b, 42)
+       got = decode(b)
+       if !reflect.DeepEqual(fc, got) {
+               t.Errorf("encode-decode (extra byte) roundtrip got %+v want %+v", got, fc)
+       }
+}
+
+func FuzzTable(f *testing.F) {
+       // All of the ops.
+       f.Add(encode([]fuzzCommand{
+               {
+                       Op:   fuzzOpPut,
+                       Key:  123,
+                       Elem: 456,
+               },
+               {
+                       Op:  fuzzOpDelete,
+                       Key: 123,
+               },
+               {
+                       Op:  fuzzOpGet,
+                       Key: 123,
+               },
+       }))
+
+       // Add enough times to trigger grow.
+       f.Add(encode([]fuzzCommand{
+               {
+                       Op:   fuzzOpPut,
+                       Key:  1,
+                       Elem: 101,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  2,
+                       Elem: 102,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  3,
+                       Elem: 103,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  4,
+                       Elem: 104,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  5,
+                       Elem: 105,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  6,
+                       Elem: 106,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  7,
+                       Elem: 107,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  8,
+                       Elem: 108,
+               },
+               {
+                       Op:  fuzzOpGet,
+                       Key: 1,
+               },
+               {
+                       Op:  fuzzOpDelete,
+                       Key: 2,
+               },
+               {
+                       Op:   fuzzOpPut,
+                       Key:  2,
+                       Elem: 42,
+               },
+               {
+                       Op:  fuzzOpGet,
+                       Key: 2,
+               },
+       }))
+
+       f.Fuzz(func(t *testing.T, in []byte) {
+               fc := decode(in)
+               if len(fc) == 0 {
+                       return
+               }
+
+               tab := maps.NewTestTable[uint16, uint32](8)
+               ref := make(map[uint16]uint32)
+               for _, c := range fc {
+                       switch c.Op {
+                       case fuzzOpGet:
+                               elemPtr, ok := tab.Get(unsafe.Pointer(&c.Key))
+                               refElem, refOK := ref[c.Key]
+
+                               if ok != refOK {
+                                       t.Errorf("Get(%d) got ok %v want ok %v", c.Key, ok, refOK)
+                               }
+                               if !ok {
+                                       continue
+                               }
+                               gotElem := *(*uint32)(elemPtr)
+                               if gotElem != refElem {
+                                       t.Errorf("Get(%d) got %d want %d", c.Key, gotElem, refElem)
+                               }
+                       case fuzzOpPut:
+                               tab.Put(unsafe.Pointer(&c.Key), unsafe.Pointer(&c.Elem))
+                               ref[c.Key] = c.Elem
+                       case fuzzOpDelete:
+                               tab.Delete(unsafe.Pointer(&c.Key))
+                               delete(ref, c.Key)
+                       default:
+                               // Just skip this command to keep the fuzzer
+                               // less constrained.
+                               continue
+                       }
+               }
+       })
+}
diff --git a/src/internal/runtime/maps/group.go b/src/internal/runtime/maps/group.go
new file mode 100644 (file)
index 0000000..d6e0630
--- /dev/null
@@ -0,0 +1,249 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package maps
+
+import (
+       "internal/runtime/maps/internal/abi"
+       "internal/runtime/sys"
+       "unsafe"
+)
+
+const (
+       // Maximum load factor prior to growing.
+       //
+       // 7/8 is the same load factor used by Abseil, but Abseil defaults to
+       // 16 slots per group, so they get two empty slots vs our one empty
+       // slot. We may want to reevaluate if this is best for us.
+       maxAvgGroupLoad = 7
+
+       ctrlEmpty   ctrl = 0b10000000
+       ctrlDeleted ctrl = 0b11111110
+
+       bitsetLSB     = 0x0101010101010101
+       bitsetMSB     = 0x8080808080808080
+       bitsetEmpty   = bitsetLSB * uint64(ctrlEmpty)
+       bitsetDeleted = bitsetLSB * uint64(ctrlDeleted)
+)
+
+// bitset represents a set of slots within a group.
+//
+// The underlying representation uses one byte per slot, where each byte is
+// either 0x80 if the slot is part of the set or 0x00 otherwise. This makes it
+// convenient to calculate for an entire group at once (e.g. see matchEmpty).
+type bitset uint64
+
+// first assumes that only the MSB of each control byte can be set (e.g. bitset
+// is the result of matchEmpty or similar) and returns the relative index of the
+// first control byte in the group that has the MSB set.
+//
+// Returns abi.SwissMapGroupSlots if the bitset is empty.
+func (b bitset) first() uint32 {
+       return uint32(sys.TrailingZeros64(uint64(b))) >> 3
+}
+
+// removeFirst removes the first set bit (that is, resets the least significant set bit to 0).
+func (b bitset) removeFirst() bitset {
+       return b & (b - 1)
+}
+
+// Each slot in the hash table has a control byte which can have one of three
+// states: empty, deleted, and full. They have the following bit patterns:
+//
+//       empty: 1 0 0 0 0 0 0 0
+//     deleted: 1 1 1 1 1 1 1 0
+//        full: 0 h h h h h h h  // h represents the H1 hash bits
+//
+// TODO(prattmic): Consider inverting the top bit so that the zero value is empty.
+type ctrl uint8
+
+// ctrlGroup is a fixed size array of abi.SwissMapGroupSlots control bytes
+// stored in a uint64.
+type ctrlGroup uint64
+
+// get returns the i-th control byte.
+func (g *ctrlGroup) get(i uint32) ctrl {
+       return *(*ctrl)(unsafe.Add(unsafe.Pointer(g), i))
+}
+
+// set sets the i-th control byte.
+func (g *ctrlGroup) set(i uint32, c ctrl) {
+       *(*ctrl)(unsafe.Add(unsafe.Pointer(g), i)) = c
+}
+
+// setEmpty sets all the control bytes to empty.
+func (g *ctrlGroup) setEmpty() {
+       *g = ctrlGroup(bitsetEmpty)
+}
+
+// matchH2 returns the set of slots which are full and for which the 7-bit hash
+// matches the given value. May return false positives.
+func (g ctrlGroup) matchH2(h uintptr) bitset {
+       // NB: This generic matching routine produces false positive matches when
+       // h is 2^N and the control bytes have a seq of 2^N followed by 2^N+1. For
+       // example: if ctrls==0x0302 and h=02, we'll compute v as 0x0100. When we
+       // subtract off 0x0101 the first 2 bytes we'll become 0xffff and both be
+       // considered matches of h. The false positive matches are not a problem,
+       // just a rare inefficiency. Note that they only occur if there is a real
+       // match and never occur on ctrlEmpty, or ctrlDeleted. The subsequent key
+       // comparisons ensure that there is no correctness issue.
+       v := uint64(g) ^ (bitsetLSB * uint64(h))
+       return bitset(((v - bitsetLSB) &^ v) & bitsetMSB)
+}
+
+// matchEmpty returns the set of slots in the group that are empty.
+func (g ctrlGroup) matchEmpty() bitset {
+       // An empty slot is   1000 0000
+       // A deleted slot is  1111 1110
+       // A full slot is     0??? ????
+       //
+       // A slot is empty iff bit 7 is set and bit 1 is not. We could select any
+       // of the other bits here (e.g. v << 1 would also work).
+       v := uint64(g)
+       return bitset((v &^ (v << 6)) & bitsetMSB)
+}
+
+// matchEmptyOrDeleted returns the set of slots in the group that are empty or
+// deleted.
+func (g ctrlGroup) matchEmptyOrDeleted() bitset {
+       // An empty slot is  1000 0000
+       // A deleted slot is 1111 1110
+       // A full slot is    0??? ????
+       //
+       // A slot is empty or deleted iff bit 7 is set and bit 0 is not.
+       v := uint64(g)
+       return bitset((v &^ (v << 7)) & bitsetMSB)
+}
+
+// convertNonFullToEmptyAndFullToDeleted converts deleted control bytes in a
+// group to empty control bytes, and control bytes indicating full slots to
+// deleted control bytes.
+func (g *ctrlGroup) convertNonFullToEmptyAndFullToDeleted() {
+       // An empty slot is     1000 0000
+       // A deleted slot is    1111 1110
+       // A full slot is       0??? ????
+       //
+       // We select the MSB, invert, add 1 if the MSB was set and zero out the low
+       // bit.
+       //
+       //  - if the MSB was set (i.e. slot was empty, or deleted):
+       //     v:             1000 0000
+       //     ^v:            0111 1111
+       //     ^v + (v >> 7): 1000 0000
+       //     &^ bitsetLSB:  1000 0000 = empty slot.
+       //
+       // - if the MSB was not set (i.e. full slot):
+       //     v:             0000 0000
+       //     ^v:            1111 1111
+       //     ^v + (v >> 7): 1111 1111
+       //     &^ bitsetLSB:  1111 1110 = deleted slot.
+       //
+       v := uint64(*g) & bitsetMSB
+       *g = ctrlGroup((^v + (v >> 7)) &^ bitsetLSB)
+}
+
+// groupReference is a wrapper type representing a single slot group stored at
+// data.
+//
+// A group holds abi.SwissMapGroupSlots slots (key/elem pairs) plus their
+// control word.
+type groupReference struct {
+       typ *abi.SwissMapType
+
+       // data points to the group, which is described by typ.Group and has
+       // layout:
+       //
+       // type group struct {
+       //      ctrls ctrlGroup
+       //      slots [abi.SwissMapGroupSlots]slot
+       // }
+       //
+       // type slot struct {
+       //      key  typ.Key
+       //      elem typ.Elem
+       // }
+       data unsafe.Pointer // data *typ.Group
+}
+
+const (
+       ctrlGroupsSize   = unsafe.Sizeof(ctrlGroup(0))
+       groupSlotsOffset = ctrlGroupsSize
+)
+
+// alignUp rounds n up to a multiple of a. a must be a power of 2.
+func alignUp(n, a uintptr) uintptr {
+       return (n + a - 1) &^ (a - 1)
+}
+
+// alignUpPow2 rounds n up to the next power of 2.
+//
+// Returns true if round up causes overflow.
+func alignUpPow2(n uint64) (uint64, bool) {
+       if n == 0 {
+               return 0, false
+       }
+       v := (uint64(1) << sys.Len64(n-1))
+       if v == 0 {
+               return 0, true
+       }
+       return v, false
+}
+
+// ctrls returns the group control word.
+func (g *groupReference) ctrls() *ctrlGroup {
+       return (*ctrlGroup)(g.data)
+}
+
+// key returns a pointer to the key at index i.
+func (g *groupReference) key(i uint32) unsafe.Pointer {
+       offset := groupSlotsOffset + uintptr(i)*g.typ.SlotSize
+
+       return unsafe.Pointer(uintptr(g.data) + offset)
+}
+
+// elem returns a pointer to the element at index i.
+func (g *groupReference) elem(i uint32) unsafe.Pointer {
+       offset := groupSlotsOffset + uintptr(i)*g.typ.SlotSize + g.typ.ElemOff
+
+       return unsafe.Pointer(uintptr(g.data) + offset)
+}
+
+// groupsReference is a wrapper type describing an array of groups stored at
+// data.
+type groupsReference struct {
+       typ *abi.SwissMapType
+
+       // data points to an array of groups. See groupReference above for the
+       // definition of group.
+       data unsafe.Pointer // data *[length]typ.Group
+
+       // lengthMask is the number of groups in data minus one (note that
+       // length must be a power of two). This allows computing i%length
+       // quickly using bitwise AND.
+       lengthMask uint64
+}
+
+// newGroups allocates a new array of length groups.
+//
+// Length must be a power of two.
+func newGroups(typ *abi.SwissMapType, length uint64) groupsReference {
+       return groupsReference{
+               typ: typ,
+               // TODO: make the length type the same throughout.
+               data:       newarray(typ.Group, int(length)),
+               lengthMask: length - 1,
+       }
+}
+
+// group returns the group at index i.
+func (g *groupsReference) group(i uint64) groupReference {
+       // TODO(prattmic): Do something here about truncation on cast to
+       // uintptr on 32-bit systems?
+       offset := uintptr(i) * g.typ.Group.Size_
+
+       return groupReference{
+               typ:  g.typ,
+               data: unsafe.Pointer(uintptr(g.data) + offset),
+       }
+}
diff --git a/src/internal/runtime/maps/internal/abi/map_swiss.go b/src/internal/runtime/maps/internal/abi/map_swiss.go
new file mode 100644 (file)
index 0000000..caa0827
--- /dev/null
@@ -0,0 +1,44 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package abi is a temporary copy of the swissmap abi. It will be eliminated
+// once swissmaps are integrated into the runtime.
+package abi
+
+import (
+       "internal/abi"
+       "unsafe"
+)
+
+// Map constants common to several packages
+// runtime/runtime-gdb.py:MapTypePrinter contains its own copy
+const (
+       // Number of slots in a group.
+       SwissMapGroupSlots = 8
+)
+
+type SwissMapType struct {
+       abi.Type
+       Key   *abi.Type
+       Elem  *abi.Type
+       Group *abi.Type // internal type representing a slot group
+       // function for hashing keys (ptr to key, seed) -> hash
+       Hasher     func(unsafe.Pointer, uintptr) uintptr
+       SlotSize   uintptr // size of key/elem slot
+       ElemOff    uintptr // offset of elem in key/elem slot
+       Flags      uint32
+}
+
+// Flag values
+const (
+       SwissMapNeedKeyUpdate = 1 << iota
+       SwissMapHashMightPanic
+)
+
+func (mt *SwissMapType) NeedKeyUpdate() bool { // true if we need to update key on an overwrite
+       return mt.Flags&SwissMapNeedKeyUpdate != 0
+}
+func (mt *SwissMapType) HashMightPanic() bool { // true if hash function might panic
+       return mt.Flags&SwissMapHashMightPanic != 0
+}
diff --git a/src/internal/runtime/maps/map.go b/src/internal/runtime/maps/map.go
new file mode 100644 (file)
index 0000000..0309e12
--- /dev/null
@@ -0,0 +1,128 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package maps implements Go's builtin map type.
+package maps
+
+// This package contains the implementation of Go's builtin map type.
+//
+// The map design is based on Abseil's "Swiss Table" map design
+// (https://abseil.io/about/design/swisstables), with additional modifications
+// to cover Go's additional requirements, discussed below.
+//
+// Terminology:
+// - Slot: A storage location of a single key/element pair.
+// - Group: A group of abi.SwissMapGroupSlots (8) slots, plus a control word.
+// - Control word: An 8-byte word which denotes whether each slot is empty,
+//   deleted, or used. If a slot is used, its control byte also contains the
+//   lower 7 bits of the hash (H2).
+// - H1: Upper 57 bits of a hash.
+// - H2: Lower 7 bits of a hash.
+// - Table: A complete "Swiss Table" hash table. A table consists of one or
+//   more groups for storage plus metadata to handle operation and determining
+//   when to grow.
+//
+// At its core, the table design is similar to a traditional open-addressed
+// hash table. Storage consists of an array of groups, which effectively means
+// an array of key/elem slots with some control words interspersed. Lookup uses
+// the hash to determine an initial group to check. If, due to collisions, this
+// group contains no match, the probe sequence selects the next group to check
+// (see below for more detail about the probe sequence).
+//
+// The key difference occurs within a group. In a standard open-addressed
+// linear probed hash table, we would check each slot one at a time to find a
+// match. A swiss table utilizes the extra control word to check all 8 slots in
+// parallel.
+//
+// Each byte in the control word corresponds to one of the slots in the group.
+// In each byte, 1 bit is used to indicate whether the slot is in use, or if it
+// is empty/deleted. The other 7 bits contain the lower 7 bits of the hash for
+// the key in that slot. See [ctrl] for the exact encoding.
+//
+// During lookup, we can use some clever bitwise manipulation to compare all 8
+// 7-bit hashes against the input hash in parallel (see [ctrlGroup.matchH2]).
+// That is, we effectively perform 8 steps of probing in a single operation.
+// With SIMD instructions, this could be extended to 16 slots with a 16-byte
+// control word.
+//
+// Since we only use 7 bits of the 64 bit hash, there is a 1 in 128 (~0.7%)
+// probability of false positive on each slot, but that's fine: we always need
+// double check each match with a standard key comparison regardless.
+//
+// Probing
+//
+// Probing is done using the upper 57 bits (H1) of the hash as an index into
+// the groups array. Probing walks through the groups using quadratic probing
+// until it finds a group with a match or a group with an empty slot. See
+// [probeSeq] for specifics about the probe sequence. Note the probe
+// invariants: the number of groups must be a power of two, and the end of a
+// probe sequence must be a group with an empty slot (the table can never be
+// 100% full).
+//
+// Deletion
+//
+// Probing stops when it finds a group with an empty slot. This affects
+// deletion: when deleting from a completely full group, we must not mark the
+// slot as empty, as there could be more slots used later in a probe sequence
+// and this deletion would cause probing to stop too early. Instead, we mark
+// such slots as "deleted" with a tombstone. If the group still has an empty
+// slot, we don't need a tombstone and directly mark the slot empty. Currently,
+// tombstone are only cleared during grow, as an in-place cleanup complicates
+// iteration.
+//
+// Growth
+//
+// When the table reaches the maximum load factor, it grows by allocating a new
+// groups array twice as big as before and reinserting all keys (the probe
+// sequence will differ with a larger array).
+// NOTE: Spoiler alert: A later CL supporting incremental growth will make each
+// table instance have an immutable group count. Growth will allocate a
+// completely new (bigger) table instance.
+//
+// Iteration
+//
+// Iteration is the most complex part of the map due to Go's generous iteration
+// semantics. A summary of semantics from the spec:
+// 1. Adding and/or deleting entries during iteration MUST NOT cause iteration
+//    to return the same entry more than once.
+// 2. Entries added during iteration MAY be returned by iteration.
+// 3. Entries modified during iteration MUST return their latest value.
+// 4. Entries deleted during iteration MUST NOT be returned by iteration.
+// 5. Iteration order is unspecified. In the implementation, it is explicitly
+//    randomized.
+//
+// If the map never grows, these semantics are straightforward: just iterate
+// over every group and every slot and these semantics all land as expected.
+//
+// If the map grows during iteration, things complicate significantly. First
+// and foremost, we need to track which entries we already returned to satisfy
+// (1), but the larger table has a completely different probe sequence and thus
+// different entry layout.
+//
+// We handle that by having the iterator keep a reference to the original table
+// groups array even after the table grows. We keep iterating over the original
+// groups to maintain the iteration order and avoid violating (1). Any new
+// entries added only to the new groups will be skipped (allowed by (2)). To
+// avoid violating (3) or (4), while we use the original groups to select the
+// keys, we must look them up again in the new groups to determine if they have
+// been modified or deleted. There is yet another layer of complexity if the
+// key does not compare equal itself. See [Iter.Next] for the gory details.
+//
+// NOTE: Spoiler alert: A later CL supporting incremental growth will make this
+// even more complicated. Yay!
+
+// Extracts the H1 portion of a hash: the 57 upper bits.
+// TODO(prattmic): what about 32-bit systems?
+func h1(h uintptr) uintptr {
+       return h >> 7
+}
+
+// Extracts the H2 portion of a hash: the 7 bits not used for h1.
+//
+// These are used as an occupied control byte.
+func h2(h uintptr) uintptr {
+       return h & 0x7f
+}
+
+type Map = table
diff --git a/src/internal/runtime/maps/map_test.go b/src/internal/runtime/maps/map_test.go
new file mode 100644 (file)
index 0000000..53b4a62
--- /dev/null
@@ -0,0 +1,447 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package maps_test
+
+import (
+       "fmt"
+       "internal/runtime/maps"
+       "internal/runtime/maps/internal/abi"
+       "math"
+       "testing"
+       "unsafe"
+)
+
+func TestCtrlSize(t *testing.T) {
+       cs := unsafe.Sizeof(maps.CtrlGroup(0))
+       if cs != abi.SwissMapGroupSlots {
+               t.Errorf("ctrlGroup size got %d want abi.SwissMapGroupSlots %d", cs, abi.SwissMapGroupSlots)
+       }
+}
+
+func TestTablePut(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](8)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               got, ok := tab.Get(unsafe.Pointer(&key))
+               if !ok {
+                       t.Errorf("Get(%d) got ok false want true", key)
+               }
+               gotElem := *(*uint64)(got)
+               if gotElem != elem {
+                       t.Errorf("Get(%d) got elem %d want %d", key, gotElem, elem)
+               }
+       }
+}
+
+func TestTableDelete(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](32)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               tab.Delete(unsafe.Pointer(&key))
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               _, ok := tab.Get(unsafe.Pointer(&key))
+               if ok {
+                       t.Errorf("Get(%d) got ok true want false", key)
+               }
+       }
+}
+
+func TestTableClear(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](32)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       tab.Clear()
+
+       if tab.Used() != 0 {
+               t.Errorf("Clear() used got %d want 0", tab.Used())
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               _, ok := tab.Get(unsafe.Pointer(&key))
+               if ok {
+                       t.Errorf("Get(%d) got ok true want false", key)
+               }
+       }
+}
+
+// +0.0 and -0.0 compare equal, but we must still must update the key slot when
+// overwriting.
+func TestTableKeyUpdate(t *testing.T) {
+       tab := maps.NewTestTable[float64, uint64](8)
+
+       zero := float64(0.0)
+       negZero := math.Copysign(zero, -1.0)
+       elem := uint64(0)
+
+       tab.Put(unsafe.Pointer(&zero), unsafe.Pointer(&elem))
+       if maps.DebugLog {
+               fmt.Printf("After put %f: %v\n", zero, tab)
+       }
+
+       elem = 1
+       tab.Put(unsafe.Pointer(&negZero), unsafe.Pointer(&elem))
+       if maps.DebugLog {
+               fmt.Printf("After put %f: %v\n", negZero, tab)
+       }
+
+       if tab.Used() != 1 {
+               t.Errorf("Used() used got %d want 1", tab.Used())
+       }
+
+       it := new(maps.Iter)
+       it.Init(tab.Type(), tab)
+       it.Next()
+       keyPtr, elemPtr := it.Key(), it.Elem()
+       if keyPtr == nil {
+               t.Fatal("it.Key() got nil want key")
+       }
+
+       key := *(*float64)(keyPtr)
+       elem = *(*uint64)(elemPtr)
+       if math.Copysign(1.0, key) > 0 {
+               t.Errorf("map key %f has positive sign", key)
+       }
+       if elem != 1 {
+               t.Errorf("map elem got %d want 1", elem)
+       }
+}
+
+func TestTableIteration(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](8)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       got := make(map[uint32]uint64)
+
+       it := new(maps.Iter)
+       it.Init(tab.Type(), tab)
+       for {
+               it.Next()
+               keyPtr, elemPtr := it.Key(), it.Elem()
+               if keyPtr == nil {
+                       break
+               }
+
+               key := *(*uint32)(keyPtr)
+               elem := *(*uint64)(elemPtr)
+               got[key] = elem
+       }
+
+       if len(got) != 31 {
+               t.Errorf("Iteration got %d entries, want 31: %+v", len(got), got)
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               gotElem, ok := got[key]
+               if !ok {
+                       t.Errorf("Iteration missing key %d", key)
+                       continue
+               }
+               if gotElem != elem {
+                       t.Errorf("Iteration key %d got elem %d want %d", key, gotElem, elem)
+               }
+       }
+}
+
+// Deleted keys shouldn't be visible in iteration.
+func TestTableIterationDelete(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](8)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       got := make(map[uint32]uint64)
+       first := true
+       deletedKey := uint32(1)
+       it := new(maps.Iter)
+       it.Init(tab.Type(), tab)
+       for {
+               it.Next()
+               keyPtr, elemPtr := it.Key(), it.Elem()
+               if keyPtr == nil {
+                       break
+               }
+
+               key := *(*uint32)(keyPtr)
+               elem := *(*uint64)(elemPtr)
+               got[key] = elem
+
+               if first {
+                       first = false
+
+                       // If the key we intended to delete was the one we just
+                       // saw, pick another to delete.
+                       if key == deletedKey {
+                               deletedKey++
+                       }
+                       tab.Delete(unsafe.Pointer(&deletedKey))
+               }
+       }
+
+       if len(got) != 30 {
+               t.Errorf("Iteration got %d entries, want 30: %+v", len(got), got)
+       }
+
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+
+               wantOK := true
+               if key == deletedKey {
+                       wantOK = false
+               }
+
+               gotElem, gotOK := got[key]
+               if gotOK != wantOK {
+                       t.Errorf("Iteration key %d got ok %v want ok %v", key, gotOK, wantOK)
+                       continue
+               }
+               if wantOK && gotElem != elem {
+                       t.Errorf("Iteration key %d got elem %d want %d", key, gotElem, elem)
+               }
+       }
+}
+
+// Deleted keys shouldn't be visible in iteration even after a grow.
+func TestTableIterationGrowDelete(t *testing.T) {
+       tab := maps.NewTestTable[uint32, uint64](8)
+
+       key := uint32(0)
+       elem := uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+               if maps.DebugLog {
+                       fmt.Printf("After put %d: %v\n", key, tab)
+               }
+       }
+
+       got := make(map[uint32]uint64)
+       first := true
+       deletedKey := uint32(1)
+       it := new(maps.Iter)
+       it.Init(tab.Type(), tab)
+       for {
+               it.Next()
+               keyPtr, elemPtr := it.Key(), it.Elem()
+               if keyPtr == nil {
+                       break
+               }
+
+               key := *(*uint32)(keyPtr)
+               elem := *(*uint64)(elemPtr)
+               got[key] = elem
+
+               if first {
+                       first = false
+
+                       // If the key we intended to delete was the one we just
+                       // saw, pick another to delete.
+                       if key == deletedKey {
+                               deletedKey++
+                       }
+
+                       // Double the number of elements to force a grow.
+                       key := uint32(32)
+                       elem := uint64(256 + 32)
+
+                       for i := 0; i < 31; i++ {
+                               key += 1
+                               elem += 1
+                               tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+                               if maps.DebugLog {
+                                       fmt.Printf("After put %d: %v\n", key, tab)
+                               }
+                       }
+
+                       // Then delete from the grown map.
+                       tab.Delete(unsafe.Pointer(&deletedKey))
+               }
+       }
+
+       // Don't check length: the number of new elements we'll see is
+       // unspecified.
+
+       // Check values only of the original pre-iteration entries.
+       key = uint32(0)
+       elem = uint64(256 + 0)
+
+       for i := 0; i < 31; i++ {
+               key += 1
+               elem += 1
+
+               wantOK := true
+               if key == deletedKey {
+                       wantOK = false
+               }
+
+               gotElem, gotOK := got[key]
+               if gotOK != wantOK {
+                       t.Errorf("Iteration key %d got ok %v want ok %v", key, gotOK, wantOK)
+                       continue
+               }
+               if wantOK && gotElem != elem {
+                       t.Errorf("Iteration key %d got elem %d want %d", key, gotElem, elem)
+               }
+       }
+}
+
+func TestAlignUpPow2(t *testing.T) {
+       tests := []struct {
+               in       uint64
+               want     uint64
+               overflow bool
+       }{
+               {
+                       in:   0,
+                       want: 0,
+               },
+               {
+                       in:   3,
+                       want: 4,
+               },
+               {
+                       in:   4,
+                       want: 4,
+               },
+               {
+                       in:   1 << 63,
+                       want: 1 << 63,
+               },
+               {
+                       in:   (1 << 63) - 1,
+                       want: 1 << 63,
+               },
+               {
+                       in:       (1 << 63) + 1,
+                       overflow: true,
+               },
+       }
+
+       for _, tc := range tests {
+               got, overflow := maps.AlignUpPow2(tc.in)
+               if got != tc.want {
+                       t.Errorf("alignUpPow2(%d) got %d, want %d", tc.in, got, tc.want)
+               }
+               if overflow != tc.overflow {
+                       t.Errorf("alignUpPow2(%d) got overflow %v, want %v", tc.in, overflow, tc.overflow)
+               }
+       }
+}
+
+// Verify that a table with zero-size slot is safe to use.
+func TestTableZeroSizeSlot(t *testing.T) {
+       tab := maps.NewTestTable[struct{}, struct{}](8)
+
+       key := struct{}{}
+       elem := struct{}{}
+
+       tab.Put(unsafe.Pointer(&key), unsafe.Pointer(&elem))
+
+       if maps.DebugLog {
+               fmt.Printf("After put %d: %v\n", key, tab)
+       }
+
+       got, ok := tab.Get(unsafe.Pointer(&key))
+       if !ok {
+               t.Errorf("Get(%d) got ok false want true", key)
+       }
+       gotElem := *(*struct{})(got)
+       if gotElem != elem {
+               t.Errorf("Get(%d) got elem %d want %d", key, gotElem, elem)
+       }
+}
diff --git a/src/internal/runtime/maps/runtime.go b/src/internal/runtime/maps/runtime.go
new file mode 100644 (file)
index 0000000..9ebfb34
--- /dev/null
@@ -0,0 +1,24 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package maps
+
+import (
+       "internal/abi"
+       "unsafe"
+)
+
+// Functions below pushed from runtime.
+
+//go:linkname rand
+func rand() uint64
+
+//go:linkname typedmemmove
+func typedmemmove(typ *abi.Type, dst, src unsafe.Pointer)
+
+//go:linkname typedmemclr
+func typedmemclr(typ *abi.Type, ptr unsafe.Pointer)
+
+//go:linkname newarray
+func newarray(typ *abi.Type, n int) unsafe.Pointer
diff --git a/src/internal/runtime/maps/table.go b/src/internal/runtime/maps/table.go
new file mode 100644 (file)
index 0000000..3516b92
--- /dev/null
@@ -0,0 +1,669 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package maps implements Go's builtin map type.
+package maps
+
+import (
+       "internal/runtime/maps/internal/abi"
+       "unsafe"
+)
+
+// table is a Swiss table hash table structure.
+//
+// Each table is a complete hash table implementation.
+type table struct {
+       // The number of filled slots (i.e. the number of elements in the table).
+       used uint64
+
+       // TODO(prattmic): Old maps pass this into every call instead of
+       // keeping a reference in the map header. This is probably more
+       // efficient and arguably more robust (crafty users can't reach into to
+       // the map to change its type), but I leave it here for now for
+       // simplicity.
+       typ *abi.SwissMapType
+
+       // seed is the hash seed, computed as a unique random number per table.
+       // TODO(prattmic): Populate this on table initialization.
+       seed uintptr
+
+       // groups is an array of slot groups. Each group holds abi.SwissMapGroupSlots
+       // key/elem slots and their control bytes.
+       //
+       // TODO(prattmic): keys and elements are interleaved to maximize
+       // locality, but it comes at the expense of wasted space for some types
+       // (consider uint8 key, uint64 element). Consider placing all keys
+       // together in these cases to save space.
+       //
+       // TODO(prattmic): Support indirect keys/values? This means storing
+       // keys/values as pointers rather than inline in the slot. This avoid
+       // bloating the table size if either type is very large.
+       groups groupsReference
+
+       // The total number of slots (always 2^N). Equal to
+       // `(groups.lengthMask+1)*abi.SwissMapGroupSlots`.
+       capacity uint64
+
+       // The number of slots we can still fill without needing to rehash.
+       //
+       // We rehash when used + tombstones > loadFactor*capacity, including
+       // tombstones so the table doesn't overfill with tombstones. This field
+       // counts down remaining empty slots before the next rehash.
+       growthLeft uint64
+
+       // clearSeq is a sequence counter of calls to Clear. It is used to
+       // detect map clears during iteration.
+       clearSeq uint64
+}
+
+func NewTable(mt *abi.SwissMapType, capacity uint64) *table {
+       return newTable(mt, capacity)
+}
+
+func newTable(mt *abi.SwissMapType, capacity uint64) *table {
+       if capacity < abi.SwissMapGroupSlots {
+               // TODO: temporary until we have a real map type.
+               capacity = abi.SwissMapGroupSlots
+       }
+
+       t := &table{
+               typ: mt,
+       }
+
+       // N.B. group count must be a power of two for probeSeq to visit every
+       // group.
+       capacity, overflow := alignUpPow2(capacity)
+       if overflow {
+               panic("rounded-up capacity overflows uint64")
+       }
+
+       t.reset(capacity)
+
+       return t
+}
+
+// reset resets the table with new, empty groups with the specified new total
+// capacity.
+func (t *table) reset(capacity uint64) {
+       ac, overflow := alignUpPow2(capacity)
+       if capacity != ac || overflow {
+               panic("capacity must be a power of two")
+       }
+
+       groupCount := capacity / abi.SwissMapGroupSlots
+       t.groups = newGroups(t.typ, groupCount)
+       t.capacity = capacity
+       t.resetGrowthLeft()
+
+       for i := uint64(0); i <= t.groups.lengthMask; i++ {
+               g := t.groups.group(i)
+               g.ctrls().setEmpty()
+       }
+}
+
+// Preconditions: table must be empty.
+func (t *table) resetGrowthLeft() {
+       var growthLeft uint64
+       if t.capacity == 0 {
+               // No real reason to support zero capacity table, since an
+               // empty Map simply won't have a table.
+               panic("table must have positive capacity")
+       } else if t.capacity <= abi.SwissMapGroupSlots {
+               // If the map fits in a single group then we're able to fill all of
+               // the slots except 1 (an empty slot is needed to terminate find
+               // operations).
+               //
+               // TODO(go.dev/issue/54766): With a special case in probing for
+               // single-group tables, we could fill all slots.
+               growthLeft = t.capacity - 1
+       } else {
+               if t.capacity*maxAvgGroupLoad < t.capacity {
+                       // TODO(prattmic): Do something cleaner.
+                       panic("overflow")
+               }
+               growthLeft = (t.capacity * maxAvgGroupLoad) / abi.SwissMapGroupSlots
+       }
+       t.growthLeft = growthLeft
+}
+
+func (t *table) Used() uint64 {
+       return t.used
+}
+
+// Get performs a lookup of the key that key points to. It returns a pointer to
+// the element, or false if the key doesn't exist.
+func (t *table) Get(key unsafe.Pointer) (unsafe.Pointer, bool) {
+       _, elem, ok := t.getWithKey(key)
+       return elem, ok
+}
+
+// getWithKey performs a lookup of key, returning a pointer to the version of
+// the key in the map in addition to the element.
+//
+// This is relevant when multiple different key values compare equal (e.g.,
+// +0.0 and -0.0). When a grow occurs during iteration, iteration perform a
+// lookup of keys from the old group in the new group in order to correctly
+// expose updated elements. For NeedsKeyUpdate keys, iteration also must return
+// the new key value, not the old key value.
+func (t *table) getWithKey(key unsafe.Pointer) (unsafe.Pointer, unsafe.Pointer, bool) {
+       // TODO(prattmic): We could avoid hashing in a variety of special
+       // cases.
+       //
+       // - One group maps with simple keys could iterate over all keys and
+       //   compare them directly.
+       // - One entry maps could just directly compare the single entry
+       //   without hashing.
+       // - String keys could do quick checks of a few bytes before hashing.
+       hash := t.typ.Hasher(key, t.seed)
+
+       // To find the location of a key in the table, we compute hash(key). From
+       // h1(hash(key)) and the capacity, we construct a probeSeq that visits
+       // every group of slots in some interesting order. See [probeSeq].
+       //
+       // We walk through these indices. At each index, we select the entire
+       // group starting with that index and extract potential candidates:
+       // occupied slots with a control byte equal to h2(hash(key)). The key
+       // at candidate slot i is compared with key; if key == g.slot(i).key
+       // we are done and return the slot; if there is an empty slot in the
+       // group, we stop and return an error; otherwise we continue to the
+       // next probe index. Tombstones (ctrlDeleted) effectively behave like
+       // full slots that never match the value we're looking for.
+       //
+       // The h2 bits ensure when we compare a key we are likely to have
+       // actually found the object. That is, the chance is low that keys
+       // compare false. Thus, when we search for an object, we are unlikely
+       // to call Equal many times. This likelihood can be analyzed as follows
+       // (assuming that h2 is a random enough hash function).
+       //
+       // Let's assume that there are k "wrong" objects that must be examined
+       // in a probe sequence. For example, when doing a find on an object
+       // that is in the table, k is the number of objects between the start
+       // of the probe sequence and the final found object (not including the
+       // final found object). The expected number of objects with an h2 match
+       // is then k/128. Measurements and analysis indicate that even at high
+       // load factors, k is less than 32, meaning that the number of false
+       // positive comparisons we must perform is less than 1/8 per find.
+       seq := makeProbeSeq(h1(hash), t.groups.lengthMask)
+       for ; ; seq = seq.next() {
+               g := t.groups.group(seq.offset)
+
+               match := g.ctrls().matchH2(h2(hash))
+
+               for match != 0 {
+                       i := match.first()
+
+                       slotKey := g.key(i)
+                       if t.typ.Key.Equal(key, slotKey) {
+                               return slotKey, g.elem(i), true
+                       }
+                       match = match.removeFirst()
+               }
+
+               match = g.ctrls().matchEmpty()
+               if match != 0 {
+                       // Finding an empty slot means we've reached the end of
+                       // the probe sequence.
+                       return nil, nil, false
+               }
+       }
+}
+
+func (t *table) Put(key, elem unsafe.Pointer) {
+       slotElem := t.PutSlot(key)
+       typedmemmove(t.typ.Elem, slotElem, elem)
+}
+
+// PutSlot returns a pointer to the element slot where an inserted element
+// should be written.
+//
+// PutSlot never returns nil.
+func (t *table) PutSlot(key unsafe.Pointer) unsafe.Pointer {
+       hash := t.typ.Hasher(key, t.seed)
+
+       seq := makeProbeSeq(h1(hash), t.groups.lengthMask)
+
+       for ; ; seq = seq.next() {
+               g := t.groups.group(seq.offset)
+               match := g.ctrls().matchH2(h2(hash))
+
+               // Look for an existing slot containing this key.
+               for match != 0 {
+                       i := match.first()
+
+                       slotKey := g.key(i)
+                       if t.typ.Key.Equal(key, slotKey) {
+                               if t.typ.NeedKeyUpdate() {
+                                       typedmemmove(t.typ.Key, slotKey, key)
+                               }
+
+                               slotElem := g.elem(i)
+
+                               t.checkInvariants()
+                               return slotElem
+                       }
+                       match = match.removeFirst()
+               }
+
+               match = g.ctrls().matchEmpty()
+               if match != 0 {
+                       // Finding an empty slot means we've reached the end of
+                       // the probe sequence.
+
+                       // If there is room left to grow, just insert the new entry.
+                       if t.growthLeft > 0 {
+                               i := match.first()
+
+                               slotKey := g.key(i)
+                               typedmemmove(t.typ.Key, slotKey, key)
+                               slotElem := g.elem(i)
+
+                               g.ctrls().set(i, ctrl(h2(hash)))
+                               t.growthLeft--
+                               t.used++
+
+                               t.checkInvariants()
+                               return slotElem
+                       }
+
+                       // TODO(prattmic): While searching the probe sequence,
+                       // we may have passed deleted slots which we could use
+                       // for this entry.
+                       //
+                       // At the moment, we leave this behind for
+                       // rehash to free up.
+                       //
+                       // cockroachlabs/swiss restarts search of the probe
+                       // sequence for a deleted slot.
+                       //
+                       // TODO(go.dev/issue/54766): We want this optimization
+                       // back. We could search for the first deleted slot
+                       // during the main search, but only use it if we don't
+                       // find an existing entry.
+
+                       t.rehash()
+
+                       // Note that we don't have to restart the entire Put process as we
+                       // know the key doesn't exist in the map.
+                       slotElem := t.uncheckedPutSlot(hash, key)
+                       t.used++
+                       t.checkInvariants()
+                       return slotElem
+               }
+       }
+}
+
+// uncheckedPutSlot inserts an entry known not to be in the table, returning an
+// entry to the element slot where the element should be written. Used by
+// PutSlot after it has failed to find an existing entry to overwrite duration
+// insertion.
+//
+// Updates growthLeft if necessary, but does not update used.
+//
+// Requires that the entry does not exist in the table, and that the table has
+// room for another element without rehashing.
+//
+// Never returns nil.
+func (t *table) uncheckedPutSlot(hash uintptr, key unsafe.Pointer) unsafe.Pointer {
+       if t.growthLeft == 0 {
+               panic("invariant failed: growthLeft is unexpectedly 0")
+       }
+
+       // Given key and its hash hash(key), to insert it, we construct a
+       // probeSeq, and use it to find the first group with an unoccupied (empty
+       // or deleted) slot. We place the key/value into the first such slot in
+       // the group and mark it as full with key's H2.
+       seq := makeProbeSeq(h1(hash), t.groups.lengthMask)
+       for ; ; seq = seq.next() {
+               g := t.groups.group(seq.offset)
+
+               match := g.ctrls().matchEmpty()
+               if match != 0 {
+                       i := match.first()
+
+                       slotKey := g.key(i)
+                       typedmemmove(t.typ.Key, slotKey, key)
+                       slotElem := g.elem(i)
+
+                       if g.ctrls().get(i) == ctrlEmpty {
+                               t.growthLeft--
+                       }
+                       g.ctrls().set(i, ctrl(h2(hash)))
+                       return slotElem
+               }
+       }
+}
+
+func (t *table) Delete(key unsafe.Pointer) {
+       hash := t.typ.Hasher(key, t.seed)
+
+       seq := makeProbeSeq(h1(hash), t.groups.lengthMask)
+       for ; ; seq = seq.next() {
+               g := t.groups.group(seq.offset)
+               match := g.ctrls().matchH2(h2(hash))
+
+               for match != 0 {
+                       i := match.first()
+                       slotKey := g.key(i)
+                       if t.typ.Key.Equal(key, slotKey) {
+                               t.used--
+
+                               typedmemclr(t.typ.Key, slotKey)
+                               typedmemclr(t.typ.Elem, g.elem(i))
+
+                               // Only a full group can appear in the middle
+                               // of a probe sequence (a group with at least
+                               // one empty slot terminates probing). Once a
+                               // group becomes full, it stays full until
+                               // rehashing/resizing. So if the group isn't
+                               // full now, we can simply remove the element.
+                               // Otherwise, we create a tombstone to mark the
+                               // slot as deleted.
+                               if g.ctrls().matchEmpty() != 0 {
+                                       g.ctrls().set(i, ctrlEmpty)
+                                       t.growthLeft++
+                               } else {
+                                       g.ctrls().set(i, ctrlDeleted)
+                               }
+
+                               t.checkInvariants()
+                               return
+                       }
+                       match = match.removeFirst()
+               }
+
+               match = g.ctrls().matchEmpty()
+               if match != 0 {
+                       // Finding an empty slot means we've reached the end of
+                       // the probe sequence.
+                       return
+               }
+       }
+}
+
+// tombstones returns the number of deleted (tombstone) entries in the table. A
+// tombstone is a slot that has been deleted but is still considered occupied
+// so as not to violate the probing invariant.
+func (t *table) tombstones() uint64 {
+       return (t.capacity*maxAvgGroupLoad)/abi.SwissMapGroupSlots - t.used - t.growthLeft
+}
+
+// Clear deletes all entries from the map resulting in an empty map.
+func (t *table) Clear() {
+       for i := uint64(0); i <= t.groups.lengthMask; i++ {
+               g := t.groups.group(i)
+               typedmemclr(t.typ.Group, g.data)
+               g.ctrls().setEmpty()
+       }
+
+       t.clearSeq++
+       t.used = 0
+       t.resetGrowthLeft()
+
+       // Reset the hash seed to make it more difficult for attackers to
+       // repeatedly trigger hash collisions. See issue
+       // https://github.com/golang/go/issues/25237.
+       // TODO
+       //t.seed = uintptr(rand())
+}
+
+type Iter struct {
+       key  unsafe.Pointer // Must be in first position.  Write nil to indicate iteration end (see cmd/compile/internal/walk/range.go).
+       elem unsafe.Pointer // Must be in second position (see cmd/compile/internal/walk/range.go).
+       typ  *abi.SwissMapType
+       tab  *table
+
+       // Snapshot of the groups at iteration initialization time. If the
+       // table resizes during iteration, we continue to iterate over the old
+       // groups.
+       //
+       // If the table grows we must consult the updated table to observe
+       // changes, though we continue to use the snapshot to determine order
+       // and avoid duplicating results.
+       groups groupsReference
+
+       // Copy of Table.clearSeq at iteration initialization time. Used to
+       // detect clear during iteration.
+       clearSeq uint64
+
+       // Randomize iteration order by starting iteration at a random slot
+       // offset.
+       offset uint64
+
+       // TODO: these could be merged into a single counter (and pre-offset
+       // with offset).
+       groupIdx uint64
+       slotIdx  uint32
+
+       // 4 bytes of padding on 64-bit arches.
+}
+
+// Init initializes Iter for iteration.
+func (it *Iter) Init(typ *abi.SwissMapType, t *table) {
+       it.typ = typ
+       if t == nil || t.used == 0 {
+               return
+       }
+
+       it.typ = t.typ
+       it.tab = t
+       it.offset = rand()
+       it.groups = t.groups
+       it.clearSeq = t.clearSeq
+}
+
+func (it *Iter) Initialized() bool {
+       return it.typ != nil
+}
+
+// Map returns the map this iterator is iterating over.
+func (it *Iter) Map() *Map {
+       return it.tab
+}
+
+// Key returns a pointer to the current key. nil indicates end of iteration.
+//
+// Must not be called prior to Next.
+func (it *Iter) Key() unsafe.Pointer {
+       return it.key
+}
+
+// Key returns a pointer to the current element. nil indicates end of
+// iteration.
+//
+// Must not be called prior to Next.
+func (it *Iter) Elem() unsafe.Pointer {
+       return it.elem
+}
+
+// Next proceeds to the next element in iteration, which can be accessed via
+// the Key and Elem methods.
+//
+// The table can be mutated during iteration, though there is no guarantee that
+// the mutations will be visible to the iteration.
+//
+// Init must be called prior to Next.
+func (it *Iter) Next() {
+       if it.tab == nil {
+               // Map was empty at Iter.Init.
+               it.key = nil
+               it.elem = nil
+               return
+       }
+
+       // Continue iteration until we find a full slot.
+       for ; it.groupIdx <= it.groups.lengthMask; it.groupIdx++ {
+               g := it.groups.group((it.groupIdx + it.offset) & it.groups.lengthMask)
+
+               // TODO(prattmic): Skip over groups that are composed of only empty
+               // or deleted slots using matchEmptyOrDeleted() and counting the
+               // number of bits set.
+               for ; it.slotIdx < abi.SwissMapGroupSlots; it.slotIdx++ {
+                       k := (it.slotIdx + uint32(it.offset)) % abi.SwissMapGroupSlots
+
+                       if (g.ctrls().get(k) & ctrlEmpty) == ctrlEmpty {
+                               // Empty or deleted.
+                               continue
+                       }
+
+                       key := g.key(k)
+
+                       // If groups.data has changed, then the table
+                       // has grown. If the table has grown, then
+                       // further mutations (changes to key->elem or
+                       // deletions) will not be visible in our
+                       // snapshot of groups. Instead we must consult
+                       // the new groups by doing a full lookup.
+                       //
+                       // We still use our old snapshot of groups to
+                       // decide which keys to lookup in order to
+                       // avoid returning the same key twice.
+                       //
+                       // TODO(prattmic): Rather than growing t.groups
+                       // directly, a cleaner design may be to always
+                       // create a new table on grow or split, leaving
+                       // behind 1 or 2 forwarding pointers. This lets
+                       // us handle this update after grow problem the
+                       // same way both within a single table and
+                       // across split.
+                       grown := it.groups.data != it.tab.groups.data
+                       var elem unsafe.Pointer
+                       if grown {
+                               var ok bool
+                               newKey, newElem, ok := it.tab.getWithKey(key)
+                               if !ok {
+                                       // Key has likely been deleted, and
+                                       // should be skipped.
+                                       //
+                                       // One exception is keys that don't
+                                       // compare equal to themselves (e.g.,
+                                       // NaN). These keys cannot be looked
+                                       // up, so getWithKey will fail even if
+                                       // the key exists.
+                                       //
+                                       // However, we are in luck because such
+                                       // keys cannot be updated and they
+                                       // cannot be deleted except with clear.
+                                       // Thus if no clear has occurted, the
+                                       // key/elem must still exist exactly as
+                                       // in the old groups, so we can return
+                                       // them from there.
+                                       //
+                                       // TODO(prattmic): Consider checking
+                                       // clearSeq early. If a clear occurred,
+                                       // Next could always return
+                                       // immediately, as iteration doesn't
+                                       // need to return anything added after
+                                       // clear.
+                                       if it.clearSeq == it.tab.clearSeq && !it.tab.typ.Key.Equal(key, key) {
+                                               elem = g.elem(k)
+                                       } else {
+                                               continue
+                                       }
+                               } else {
+                                       key = newKey
+                                       elem = newElem
+                               }
+                       } else {
+                               elem = g.elem(k)
+                       }
+
+                       it.slotIdx++
+                       if it.slotIdx >= abi.SwissMapGroupSlots {
+                               it.groupIdx++
+                               it.slotIdx = 0
+                       }
+                       it.key = key
+                       it.elem = elem
+                       return
+               }
+               it.slotIdx = 0
+       }
+
+       it.key = nil
+       it.elem = nil
+       return
+}
+
+func (t *table) rehash() {
+       // TODO(prattmic): SwissTables typically perform a "rehash in place"
+       // operation which recovers capacity consumed by tombstones without growing
+       // the table by reordering slots as necessary to maintain the probe
+       // invariant while eliminating all tombstones.
+       //
+       // However, it is unclear how to make rehash in place work with
+       // iteration. Since iteration simply walks through all slots in order
+       // (with random start offset), reordering the slots would break
+       // iteration.
+       //
+       // As an alternative, we could do a "resize" to new groups allocation
+       // of the same size. This would eliminate the tombstones, but using a
+       // new allocation, so the existing grow support in iteration would
+       // continue to work.
+
+       // TODO(prattmic): split table
+       // TODO(prattmic): Avoid overflow (splitting the table will achieve this)
+
+       newCapacity := 2 * t.capacity
+       t.resize(newCapacity)
+}
+
+// resize the capacity of the table by allocating a bigger array and
+// uncheckedPutting each element of the table into the new array (we know that
+// no insertion here will Put an already-present value), and discard the old
+// backing array.
+func (t *table) resize(newCapacity uint64) {
+       oldGroups := t.groups
+       oldCapacity := t.capacity
+       t.reset(newCapacity)
+
+       if oldCapacity > 0 {
+               for i := uint64(0); i <= oldGroups.lengthMask; i++ {
+                       g := oldGroups.group(i)
+                       for j := uint32(0); j < abi.SwissMapGroupSlots; j++ {
+                               if (g.ctrls().get(j) & ctrlEmpty) == ctrlEmpty {
+                                       // Empty or deleted
+                                       continue
+                               }
+                               key := g.key(j)
+                               elem := g.elem(j)
+                               hash := t.typ.Hasher(key, t.seed)
+                               slotElem := t.uncheckedPutSlot(hash, key)
+                               typedmemmove(t.typ.Elem, slotElem, elem)
+                       }
+               }
+       }
+
+       t.checkInvariants()
+}
+
+// probeSeq maintains the state for a probe sequence that iterates through the
+// groups in a table. The sequence is a triangular progression of the form
+//
+//     p(i) := (i^2 + i)/2 + hash (mod mask+1)
+//
+// The sequence effectively outputs the indexes of *groups*. The group
+// machinery allows us to check an entire group with minimal branching.
+//
+// It turns out that this probe sequence visits every group exactly once if
+// the number of groups is a power of two, since (i^2+i)/2 is a bijection in
+// Z/(2^m). See https://en.wikipedia.org/wiki/Quadratic_probing
+type probeSeq struct {
+       mask   uint64
+       offset uint64
+       index  uint64
+}
+
+func makeProbeSeq(hash uintptr, mask uint64) probeSeq {
+       return probeSeq{
+               mask:   mask,
+               offset: uint64(hash) & mask,
+               index:  0,
+       }
+}
+
+func (s probeSeq) next() probeSeq {
+       s.index++
+       s.offset = (s.offset + s.index) & s.mask
+       return s
+}
diff --git a/src/internal/runtime/maps/table_debug.go b/src/internal/runtime/maps/table_debug.go
new file mode 100644 (file)
index 0000000..7170fb6
--- /dev/null
@@ -0,0 +1,128 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package maps implements Go's builtin map type.
+package maps
+
+import (
+       sabi "internal/runtime/maps/internal/abi"
+       "unsafe"
+)
+
+const debugLog = false
+
+func (t *table) checkInvariants() {
+       if !debugLog {
+               return
+       }
+
+       // For every non-empty slot, verify we can retrieve the key using Get.
+       // Count the number of used and deleted slots.
+       var used uint64
+       var deleted uint64
+       var empty uint64
+       for i := uint64(0); i <= t.groups.lengthMask; i++ {
+               g := t.groups.group(i)
+               for j := uint32(0); j < sabi.SwissMapGroupSlots; j++ {
+                       c := g.ctrls().get(j)
+                       switch {
+                       case c == ctrlDeleted:
+                               deleted++
+                       case c == ctrlEmpty:
+                               empty++
+                       default:
+                               used++
+
+                               key := g.key(j)
+
+                               // Can't lookup keys that don't compare equal
+                               // to themselves (e.g., NaN).
+                               if !t.typ.Key.Equal(key, key) {
+                                       continue
+                               }
+
+                               if _, ok := t.Get(key); !ok {
+                                       hash := t.typ.Hasher(key, t.seed)
+                                       print("invariant failed: slot(", i, "/", j, "): key ")
+                                       dump(key, t.typ.Key.Size_)
+                                       print(" not found [hash=", hash, ", h2=", h2(hash), " h1=", h1(hash), "]\n")
+                                       t.Print()
+                                       panic("invariant failed: slot: key not found")
+                               }
+                       }
+               }
+       }
+
+       if used != t.used {
+               print("invariant failed: found ", used, " used slots, but used count is ", t.used, "\n")
+               t.Print()
+               panic("invariant failed: found mismatched used slot count")
+       }
+
+       growthLeft := (t.capacity*maxAvgGroupLoad)/sabi.SwissMapGroupSlots - t.used - deleted
+       if growthLeft != t.growthLeft {
+               print("invariant failed: found ", t.growthLeft, " growthLeft, but expected ", growthLeft, "\n")
+               t.Print()
+               panic("invariant failed: found mismatched growthLeft")
+       }
+       if deleted != t.tombstones() {
+               print("invariant failed: found ", deleted, " tombstones, but expected ", t.tombstones(), "\n")
+               t.Print()
+               panic("invariant failed: found mismatched tombstones")
+       }
+
+       if empty == 0 {
+               print("invariant failed: found no empty slots (violates probe invariant)\n")
+               t.Print()
+               panic("invariant failed: found no empty slots (violates probe invariant)")
+       }
+}
+
+func (t *table) Print() {
+       print(`table{
+       seed: `, t.seed, `
+       capacity: `, t.capacity, `
+       used: `, t.used, `
+       growthLeft: `, t.growthLeft, `
+       groups:
+`)
+
+       for i := uint64(0); i <= t.groups.lengthMask; i++ {
+               print("\t\tgroup ", i, "\n")
+
+               g := t.groups.group(i)
+               ctrls := g.ctrls()
+               for j := uint32(0); j < sabi.SwissMapGroupSlots; j++ {
+                       print("\t\t\tslot ", j, "\n")
+
+                       c := ctrls.get(j)
+                       print("\t\t\t\tctrl ", c)
+                       switch c {
+                       case ctrlEmpty:
+                               print(" (empty)\n")
+                       case ctrlDeleted:
+                               print(" (deleted)\n")
+                       default:
+                               print("\n")
+                       }
+
+                       print("\t\t\t\tkey  ")
+                       dump(g.key(j), t.typ.Key.Size_)
+                       println("")
+                       print("\t\t\t\telem ")
+                       dump(g.elem(j), t.typ.Elem.Size_)
+                       println("")
+               }
+       }
+}
+
+// TODO(prattmic): not in hex because print doesn't have a way to print in hex
+// outside the runtime.
+func dump(ptr unsafe.Pointer, size uintptr) {
+       for size > 0 {
+               print(*(*byte)(ptr), " ")
+               ptr = unsafe.Pointer(uintptr(ptr) + 1)
+               size--
+       }
+}
index a35f806aa3e2c9c0beea4d0e24f3cedd0834294f..7076ced4537202fe0e6606a2812b93106d8bf726 100644 (file)
@@ -1450,6 +1450,11 @@ func reflect_unsafe_NewArray(typ *_type, n int) unsafe.Pointer {
        return newarray(typ, n)
 }
 
+//go:linkname maps_newarray internal/runtime/maps.newarray
+func maps_newarray(typ *_type, n int) unsafe.Pointer {
+       return newarray(typ, n)
+}
+
 func profilealloc(mp *m, x unsafe.Pointer, size uintptr) {
        c := getMCache(mp)
        if c == nil {
index 054d493f35ae6a159c49a54e78c330084a282462..dd99bf3a6aa7e6e8194fcdda72f12f5157e5547c 100644 (file)
@@ -244,6 +244,11 @@ func reflectlite_typedmemmove(typ *_type, dst, src unsafe.Pointer) {
        reflect_typedmemmove(typ, dst, src)
 }
 
+//go:linkname maps_typedmemmove internal/runtime/maps.typedmemmove
+func maps_typedmemmove(typ *_type, dst, src unsafe.Pointer) {
+       typedmemmove(typ, dst, src)
+}
+
 // reflectcallmove is invoked by reflectcall to copy the return values
 // out of the stack and into the heap, invoking the necessary write
 // barriers. dst, src, and size describe the return value area to
@@ -389,6 +394,11 @@ func reflect_typedmemclr(typ *_type, ptr unsafe.Pointer) {
        typedmemclr(typ, ptr)
 }
 
+//go:linkname maps_typedmemclr internal/runtime/maps.typedmemclr
+func maps_typedmemclr(typ *_type, ptr unsafe.Pointer) {
+       typedmemclr(typ, ptr)
+}
+
 //go:linkname reflect_typedmemclrpartial reflect.typedmemclrpartial
 func reflect_typedmemclrpartial(typ *_type, ptr unsafe.Pointer, off, size uintptr) {
        if writeBarrier.enabled && typ.Pointers() {
index 2e44858ee21d24e7bbcdb7a9d2ef331e4b6d6733..0d1d2fe5ba74ce3805a983988d13ebc79c487632 100644 (file)
@@ -177,6 +177,11 @@ func rand() uint64 {
        }
 }
 
+//go:linkname maps_rand internal/runtime/maps.rand
+func maps_rand() uint64 {
+       return rand()
+}
+
 // mrandinit initializes the random state of an m.
 func mrandinit(mp *m) {
        var seed [4]uint64