From 39263f34a307814a74823b280a313829dad374e5 Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Wed, 6 Sep 2023 13:59:35 -0700 Subject: [PATCH] cmd/compile: add a cache to interface type switches MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit That way we don't need to call into the runtime when the type being switched on has been seen many times before. The cache is just a hash table of a sample of all the concrete types that have been switched on at that source location. We record the matching case number and the resulting itab for each concrete input type. The caches seldom get large. The only two in a run of all.bash that get more than 100 entries, even with the sampling rate set to 1, are test/fixedbugs/issue29264.go, with 101 test/fixedbugs/issue29312.go, with 254 Both happen at the type switch in fmt.(*pp).handleMethods, perhaps unsurprisingly. name old time/op new time/op delta SwitchInterfaceTypePredictable-24 25.8ns ± 2% 2.5ns ± 3% -90.43% (p=0.000 n=10+10) SwitchInterfaceTypeUnpredictable-24 37.5ns ± 2% 11.2ns ± 1% -70.02% (p=0.000 n=10+10) Change-Id: I4961ac9547b7f15b03be6f55cdcb972d176955eb Reviewed-on: https://go-review.googlesource.com/c/go/+/526658 Reviewed-by: Cuong Manh Le LUCI-TryBot-Result: Go LUCI Reviewed-by: Matthew Dempsky Reviewed-by: Keith Randall --- src/cmd/compile/internal/ssagen/ssa.go | 104 +++++++++++++++++++++- src/cmd/compile/internal/test/inl_test.go | 3 + src/cmd/compile/internal/walk/switch.go | 9 +- src/internal/abi/switch.go | 31 +++++++ src/runtime/iface.go | 98 +++++++++++++++++++- test/codegen/switch.go | 4 +- 6 files changed, 239 insertions(+), 10 deletions(-) diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index aa2c962de0..df6a5357f2 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -980,6 +980,7 @@ var ( typVar = ssaMarker("typ") okVar = ssaMarker("ok") deferBitsVar = ssaMarker("deferBits") + hashVar = ssaMarker("hash") ) // startBlock sets the current block we're generating code in to b. @@ -2020,13 +2021,112 @@ func (s *state) stmt(n ir.Node) { case ir.OINTERFACESWITCH: n := n.(*ir.InterfaceSwitchStmt) + typs := s.f.Config.Types t := s.expr(n.RuntimeType) - d := s.newValue1A(ssa.OpAddr, s.f.Config.Types.BytePtr, n.Descriptor, s.sb) - r := s.rtcall(ir.Syms.InterfaceSwitch, true, []*types.Type{s.f.Config.Types.Int, s.f.Config.Types.BytePtr}, d, t) + d := s.newValue1A(ssa.OpAddr, typs.BytePtr, n.Descriptor, s.sb) + + // Check the cache first. + var merge *ssa.Block + if base.Flag.N == 0 && rtabi.UseInterfaceSwitchCache(Arch.LinkArch.Name) { + // Note: we can only use the cache if we have the right atomic load instruction. + // Double-check that here. + if _, ok := intrinsics[intrinsicKey{Arch.LinkArch.Arch, "runtime/internal/atomic", "Loadp"}]; !ok { + s.Fatalf("atomic load not available") + } + merge = s.f.NewBlock(ssa.BlockPlain) + cacheHit := s.f.NewBlock(ssa.BlockPlain) + cacheMiss := s.f.NewBlock(ssa.BlockPlain) + loopHead := s.f.NewBlock(ssa.BlockPlain) + loopBody := s.f.NewBlock(ssa.BlockPlain) + + // Pick right size ops. + var mul, and, add, zext ssa.Op + if s.config.PtrSize == 4 { + mul = ssa.OpMul32 + and = ssa.OpAnd32 + add = ssa.OpAdd32 + zext = ssa.OpCopy + } else { + mul = ssa.OpMul64 + and = ssa.OpAnd64 + add = ssa.OpAdd64 + zext = ssa.OpZeroExt32to64 + } + + // Load cache pointer out of descriptor, with an atomic load so + // we ensure that we see a fully written cache. + atomicLoad := s.newValue2(ssa.OpAtomicLoadPtr, types.NewTuple(typs.BytePtr, types.TypeMem), d, s.mem()) + cache := s.newValue1(ssa.OpSelect0, typs.BytePtr, atomicLoad) + s.vars[memVar] = s.newValue1(ssa.OpSelect1, types.TypeMem, atomicLoad) + + // Load hash from type. + hash := s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, t), s.mem()) + hash = s.newValue1(zext, typs.Uintptr, hash) + s.vars[hashVar] = hash + // Load mask from cache. + mask := s.newValue2(ssa.OpLoad, typs.Uintptr, cache, s.mem()) + // Jump to loop head. + b := s.endBlock() + b.AddEdgeTo(loopHead) + + // At loop head, get pointer to the cache entry. + // e := &cache.Entries[hash&mask] + s.startBlock(loopHead) + entries := s.newValue2(ssa.OpAddPtr, typs.UintptrPtr, cache, s.uintptrConstant(uint64(s.config.PtrSize))) + idx := s.newValue2(and, typs.Uintptr, s.variable(hashVar, typs.Uintptr), mask) + idx = s.newValue2(mul, typs.Uintptr, idx, s.uintptrConstant(uint64(3*s.config.PtrSize))) + e := s.newValue2(ssa.OpAddPtr, typs.UintptrPtr, entries, idx) + // hash++ + s.vars[hashVar] = s.newValue2(add, typs.Uintptr, s.variable(hashVar, typs.Uintptr), s.uintptrConstant(1)) + + // Look for a cache hit. + // if e.Typ == t { goto hit } + eTyp := s.newValue2(ssa.OpLoad, typs.Uintptr, e, s.mem()) + cmp1 := s.newValue2(ssa.OpEqPtr, typs.Bool, t, eTyp) + b = s.endBlock() + b.Kind = ssa.BlockIf + b.SetControl(cmp1) + b.AddEdgeTo(cacheHit) + b.AddEdgeTo(loopBody) + + // Look for an empty entry, the tombstone for this hash table. + // if e.Typ == nil { goto miss } + s.startBlock(loopBody) + cmp2 := s.newValue2(ssa.OpEqPtr, typs.Bool, eTyp, s.constNil(typs.BytePtr)) + b = s.endBlock() + b.Kind = ssa.BlockIf + b.SetControl(cmp2) + b.AddEdgeTo(cacheMiss) + b.AddEdgeTo(loopHead) + + // On a hit, load the data fields of the cache entry. + // Case = e.Case + // Itab = e.Itab + s.startBlock(cacheHit) + eCase := s.newValue2(ssa.OpLoad, typs.Int, s.newValue1I(ssa.OpOffPtr, typs.IntPtr, s.config.PtrSize, e), s.mem()) + eItab := s.newValue2(ssa.OpLoad, typs.BytePtr, s.newValue1I(ssa.OpOffPtr, typs.BytePtrPtr, 2*s.config.PtrSize, e), s.mem()) + s.assign(n.Case, eCase, false, 0) + s.assign(n.Itab, eItab, false, 0) + b = s.endBlock() + b.AddEdgeTo(merge) + + // On a miss, call into the runtime to get the answer. + s.startBlock(cacheMiss) + } + + r := s.rtcall(ir.Syms.InterfaceSwitch, true, []*types.Type{typs.Int, typs.BytePtr}, d, t) s.assign(n.Case, r[0], false, 0) s.assign(n.Itab, r[1], false, 0) + if merge != nil { + // Cache hits merge in here. + b := s.endBlock() + b.Kind = ssa.BlockPlain + b.AddEdgeTo(merge) + s.startBlock(merge) + } + case ir.OCHECKNIL: n := n.(*ir.UnaryExpr) p := s.expr(n.X) diff --git a/src/cmd/compile/internal/test/inl_test.go b/src/cmd/compile/internal/test/inl_test.go index 4e34631d9b..f93d23de8b 100644 --- a/src/cmd/compile/internal/test/inl_test.go +++ b/src/cmd/compile/internal/test/inl_test.go @@ -108,6 +108,9 @@ func TestIntendedInlining(t *testing.T) { "(*Buffer).UnreadByte", "(*Buffer).tryGrowByReslice", }, + "internal/abi": { + "UseInterfaceSwitchCache", + }, "compress/flate": { "byLiteral.Len", "byLiteral.Less", diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index 2f7eb5486c..80c956f654 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -527,12 +527,15 @@ func walkSwitchType(sw *ir.SwitchStmt) { lsym := types.LocalPkg.Lookup(fmt.Sprintf(".interfaceSwitch.%d", interfaceSwitchGen)).LinksymABI(obj.ABI0) interfaceSwitchGen++ off := 0 + off = objw.SymPtr(lsym, off, typecheck.LookupRuntimeVar("emptyInterfaceSwitchCache"), 0) off = objw.Uintptr(lsym, off, uint64(len(interfaceCases))) for _, c := range interfaceCases { off = objw.SymPtr(lsym, off, reflectdata.TypeSym(c.typ.Type()).Linksym(), 0) } - // Note: it has pointers, just not ones the GC cares about. - objw.Global(lsym, int32(off), obj.LOCAL|obj.NOPTR) + objw.Global(lsym, int32(off), obj.LOCAL) + // Set the type to be just a single pointer, as the cache pointer is the + // only one that GC needs to see. + lsym.Gotype = reflectdata.TypeLinksym(types.Types[types.TUINT8].PtrTo()) // Call runtime to do switch // case, itab = runtime.interfaceSwitch(&descriptor, typeof(arg)) @@ -546,7 +549,7 @@ func walkSwitchType(sw *ir.SwitchStmt) { isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, lsym) sw.Compiled.Append(isw) - // Switch on the result of the call. + // Switch on the result of the call (or cache lookup). var newCases []*ir.CaseClause for i, c := range interfaceCases { newCases = append(newCases, &ir.CaseClause{ diff --git a/src/internal/abi/switch.go b/src/internal/abi/switch.go index 62d75852f1..5c1171c2f4 100644 --- a/src/internal/abi/switch.go +++ b/src/internal/abi/switch.go @@ -5,9 +5,40 @@ package abi type InterfaceSwitch struct { + Cache *InterfaceSwitchCache NCases int // Array of NCases elements. // Each case must be a non-empty interface type. Cases [1]*InterfaceType } + +type InterfaceSwitchCache struct { + Mask uintptr // mask for index. Must be a power of 2 minus 1 + Entries [1]InterfaceSwitchCacheEntry // Mask+1 entries total +} + +type InterfaceSwitchCacheEntry struct { + // type of source value (a *Type) + Typ uintptr + // case # to dispatch to + Case int + // itab to use for resulting case variable (a *runtime.itab) + Itab uintptr +} + +const go122InterfaceSwitchCache = true + +func UseInterfaceSwitchCache(goarch string) bool { + if !go122InterfaceSwitchCache { + return false + } + // We need an atomic load instruction to make the cache multithreaded-safe. + // (AtomicLoadPtr needs to be implemented in cmd/compile/internal/ssa/_gen/ARCH.rules.) + switch goarch { + case "amd64", "arm64", "loong64", "mips", "mipsle", "mips64", "mips64le", "ppc64", "ppc64le", "riscv64", "s390x": + return true + default: + return false + } +} diff --git a/src/runtime/iface.go b/src/runtime/iface.go index ecf673aa93..99ac3eb461 100644 --- a/src/runtime/iface.go +++ b/src/runtime/iface.go @@ -8,6 +8,7 @@ import ( "internal/abi" "internal/goarch" "runtime/internal/atomic" + "runtime/internal/sys" "unsafe" ) @@ -475,15 +476,106 @@ func assertE2I2(inter *interfacetype, e eface) (r iface) { // of cases. func interfaceSwitch(s *abi.InterfaceSwitch, t *_type) (int, *itab) { cases := unsafe.Slice(&s.Cases[0], s.NCases) + + // Results if we don't find a match. + case_ := len(cases) + var tab *itab + + // Look through each case in order. for i, c := range cases { - tab := getitab(c, t, true) + tab = getitab(c, t, true) if tab != nil { - return i, tab + case_ = i + break } } - return len(cases), nil + + if !abi.UseInterfaceSwitchCache(GOARCH) { + return case_, tab + } + + // Maybe update the cache, so the next time the generated code + // doesn't need to call into the runtime. + if fastrand()&1023 != 0 { + // Only bother updating the cache ~1 in 1000 times. + // This ensures we don't waste memory on switches, or + // switch arguments, that only happen a few times. + return case_, tab + } + // Load the current cache. + oldC := (*abi.InterfaceSwitchCache)(atomic.Loadp(unsafe.Pointer(&s.Cache))) + + if fastrand()&uint32(oldC.Mask) != 0 { + // As cache gets larger, choose to update it less often + // so we can amortize the cost of building a new cache + // (that cost is linear in oldc.Mask). + return case_, tab + } + + // Make a new cache. + newC := buildInterfaceSwitchCache(oldC, t, case_, tab) + + // Update cache. Use compare-and-swap so if multiple threads + // are fighting to update the cache, at least one of their + // updates will stick. + atomic_casPointer((*unsafe.Pointer)(unsafe.Pointer(&s.Cache)), unsafe.Pointer(oldC), unsafe.Pointer(newC)) + + return case_, tab } +// buildInterfaceSwitchCache constructs a interface switch cache +// containing all the entries from oldC plus the new entry +// (typ,case_,tab). +func buildInterfaceSwitchCache(oldC *abi.InterfaceSwitchCache, typ *_type, case_ int, tab *itab) *abi.InterfaceSwitchCache { + oldEntries := unsafe.Slice(&oldC.Entries[0], oldC.Mask+1) + + // Count the number of entries we need. + n := 1 + for _, e := range oldEntries { + if e.Typ != 0 { + n++ + } + } + + // Figure out how big a table we need. + // We need at least one more slot than the number of entries + // so that we are guaranteed an empty slot (for termination). + newN := n * 2 // make it at most 50% full + newN = 1 << sys.Len64(uint64(newN-1)) // round up to a power of 2 + + // Allocate the new table. + newSize := unsafe.Sizeof(abi.InterfaceSwitchCache{}) + uintptr(newN-1)*unsafe.Sizeof(abi.InterfaceSwitchCacheEntry{}) + newC := (*abi.InterfaceSwitchCache)(mallocgc(newSize, nil, true)) + newC.Mask = uintptr(newN - 1) + newEntries := unsafe.Slice(&newC.Entries[0], newN) + + // Fill the new table. + addEntry := func(typ *_type, case_ int, tab *itab) { + h := int(typ.Hash) & (newN - 1) + for { + if newEntries[h].Typ == 0 { + newEntries[h].Typ = uintptr(unsafe.Pointer(typ)) + newEntries[h].Case = case_ + newEntries[h].Itab = uintptr(unsafe.Pointer(tab)) + return + } + h = (h + 1) & (newN - 1) + } + } + for _, e := range oldEntries { + if e.Typ != 0 { + addEntry((*_type)(unsafe.Pointer(e.Typ)), e.Case, (*itab)(unsafe.Pointer(e.Itab))) + } + } + addEntry(typ, case_, tab) + + return newC +} + +// Empty interface switch cache. Contains one entry with a nil Typ (which +// causes a cache lookup to fail immediately.) +var emptyInterfaceSwitchCache = abi.InterfaceSwitchCache{Mask: 0} + //go:linkname reflect_ifaceE2I reflect.ifaceE2I func reflect_ifaceE2I(inter *interfacetype, e eface, dst *iface) { *dst = iface{assertE2I(inter, e._type), e.data} diff --git a/test/codegen/switch.go b/test/codegen/switch.go index 63b0dce8a6..6778c65ab3 100644 --- a/test/codegen/switch.go +++ b/test/codegen/switch.go @@ -128,8 +128,8 @@ type J interface { // use a runtime call for type switches to interface types. func interfaceSwitch(x any) int { - // amd64:`CALL\truntime.interfaceSwitch` - // arm64:`CALL\truntime.interfaceSwitch` + // amd64:`CALL\truntime.interfaceSwitch`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*8)` + // arm64:`CALL\truntime.interfaceSwitch`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)` switch x.(type) { case I: return 1 -- 2.50.0