]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: reorg equality functions a bit
authorKeith Randall <khr@golang.org>
Mon, 1 Dec 2025 23:58:57 +0000 (15:58 -0800)
committerKeith Randall <khr@golang.org>
Sat, 24 Jan 2026 04:58:03 +0000 (20:58 -0800)
Use signature for closure name instead of type.
Use signature instead of type to decide to use a runtime builtin comparator.
Remove trailing skips from signatures.

Change-Id: I73b2dcd3c6e2f1b2857985e14c24b290941b3ca3
Reviewed-on: https://go-review.googlesource.com/c/go/+/725604
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Keith Randall <khr@google.com>
src/cmd/compile/internal/reflectdata/alg.go

index 2eee55790174260b679f3d21c783f83a52eee075..bf52fb4b718a39a9e7656f12239c12e12fddd204 100644 (file)
@@ -299,61 +299,83 @@ func geneq(t *types.Type) *obj.LSym {
                // The runtime will panic if it tries to compare
                // a type with a nil equality function.
                return nil
-       case types.AMEM0:
+       }
+       return geneqSig(eqSignature(t))
+}
+
+// geneqSig returns a symbol which is the closure used to compute
+// equality for two objects with equality signature sig.
+func geneqSig(sig string) *obj.LSym {
+       align := int64(types.PtrSize)
+       if len(sig) > 0 && sig[0] == sigAlign {
+               align, sig = parseNum(sig[1:])
+       }
+       if base.Ctxt.Arch.CanMergeLoads {
+               align = 8
+       }
+       switch sig {
+       case "":
                return sysClosure("memequal0")
-       case types.AMEM8:
+       case string(sigMemory) + "1":
                return sysClosure("memequal8")
-       case types.AMEM16:
-               return sysClosure("memequal16")
-       case types.AMEM32:
-               return sysClosure("memequal32")
-       case types.AMEM64:
-               return sysClosure("memequal64")
-       case types.AMEM128:
-               return sysClosure("memequal128")
-       case types.ASTRING:
+       case string(sigMemory) + "2":
+               if align >= 2 {
+                       return sysClosure("memequal16")
+               }
+       case string(sigMemory) + "4":
+               if align >= 4 {
+                       return sysClosure("memequal32")
+               }
+       case string(sigMemory) + "8":
+               if align >= 8 {
+                       return sysClosure("memequal64")
+               }
+       case string(sigMemory) + "16":
+               if align >= 8 {
+                       return sysClosure("memequal128")
+               }
+       case string(sigString):
                return sysClosure("strequal")
-       case types.AINTER:
+       case string(sigIface):
                return sysClosure("interequal")
-       case types.ANILINTER:
+       case string(sigEface):
                return sysClosure("nilinterequal")
-       case types.AFLOAT32:
+       case string(sigFloat32):
                return sysClosure("f32equal")
-       case types.AFLOAT64:
+       case string(sigFloat64):
                return sysClosure("f64equal")
-       case types.ACPLX64:
+       case string(sigFloat32) + string(sigFloat32):
                return sysClosure("c64equal")
-       case types.ACPLX128:
+       case string(sigFloat64) + string(sigFloat64):
                return sysClosure("c128equal")
-       case types.AMEM:
-               // make equality closure. The size of the type
-               // is encoded in the closure.
-               closure := TypeLinksymLookup(fmt.Sprintf(".eqfunc%d", t.Size()))
-               if len(closure.P) != 0 {
-                       return closure
-               }
-               if memequalvarlen == nil {
-                       memequalvarlen = typecheck.LookupRuntimeFunc("memequal_varlen")
-               }
-               ot := 0
-               ot = objw.SymPtr(closure, ot, memequalvarlen, 0)
-               ot = objw.Uintptr(closure, ot, uint64(t.Size()))
-               objw.Global(closure, int32(ot), obj.DUPOK|obj.RODATA)
-               return closure
-       case types.ASPECIAL:
-               break
        }
 
-       closure := TypeLinksymPrefix(".eqfunc", t)
+       closure := TypeLinksymLookup(".eqfunc." + sig)
        if len(closure.P) > 0 { // already generated
                return closure
        }
 
+       if sig[0] == sigMemory {
+               n, rest := parseNum(sig[1:])
+               if rest == "" {
+                       // Just M%d. We can make a memequal_varlen closure.
+                       // The size of the memory region to compare is encoded in the closure.
+                       if memequalvarlen == nil {
+                               memequalvarlen = typecheck.LookupRuntimeFunc("memequal_varlen")
+                       }
+                       ot := 0
+                       ot = objw.SymPtr(closure, ot, memequalvarlen, 0)
+                       ot = objw.Uintptr(closure, ot, uint64(n))
+                       objw.Global(closure, int32(ot), obj.DUPOK|obj.RODATA)
+                       return closure
+               }
+       }
+
        if base.Flag.LowerR != 0 {
-               fmt.Printf("geneq %v\n", t)
+               fmt.Printf("geneqSig %s\n", sig)
        }
 
-       fn := eqFunc(eqSignature(t))
+       fn := eqFunc(sig)
 
        // Generate a closure which points at the function we just generated.
        objw.SymPtr(closure, 0, fn.Linksym(), 0)
@@ -572,7 +594,7 @@ func eqFunc(sig string) *ir.Func {
                        //     for i := off; i < off + N*elemSize; i += elemSize {
                        //         if !eqfn(p+i, q+i) { goto neq }
                        //     }
-                       elemFn := eqFunc(elemSig).Nname
+                       elemFn := eqFunc(sigTrimSkip(elemSig)).Nname
                        idx := typecheck.TempAt(pos, ir.CurFunc, types.Types[types.TUINTPTR])
                        init := ir.NewAssignStmt(pos, idx, ir.NewInt(pos, off))
                        cond := ir.NewBinaryExpr(pos, ir.OLT, idx, ir.NewInt(pos, off+n*elemSize))
@@ -702,6 +724,7 @@ func hashmem(t *types.Type) ir.Node {
 // An alignment directive is only needed on platforms that can't do
 // unaligned loads.
 // If an alignment directive is present, it must be first.
+// Signatures can end early; a K%d is not required at the end.
 func eqSignature(t *types.Type) string {
        var e eqSigBuilder
        if !base.Ctxt.Arch.CanMergeLoads { // alignment only matters if we can't use unaligned loads
@@ -710,7 +733,7 @@ func eqSignature(t *types.Type) string {
                }
        }
        e.build(t)
-       e.flush()
+       e.flush(true)
        return e.r.String()
 }
 
@@ -733,46 +756,48 @@ type eqSigBuilder struct {
        skipMem int64 // queued up region of memory to skip
 }
 
-func (e *eqSigBuilder) flush() {
+func (e *eqSigBuilder) flush(atEnd bool) {
        if e.regMem > 0 {
                e.r.WriteString(fmt.Sprintf("%c%d", sigMemory, e.regMem))
                e.regMem = 0
        }
        if e.skipMem > 0 {
-               e.r.WriteString(fmt.Sprintf("%c%d", sigSkip, e.skipMem))
+               if !atEnd {
+                       e.r.WriteString(fmt.Sprintf("%c%d", sigSkip, e.skipMem))
+               }
                e.skipMem = 0
        }
 }
 func (e *eqSigBuilder) regular(n int64) {
        if e.regMem == 0 {
-               e.flush()
+               e.flush(false)
        }
        e.regMem += n
 }
 func (e *eqSigBuilder) skip(n int64) {
        if e.skipMem == 0 {
-               e.flush()
+               e.flush(false)
        }
        e.skipMem += n
 }
 func (e *eqSigBuilder) float32() {
-       e.flush()
+       e.flush(false)
        e.r.WriteByte(sigFloat32)
 }
 func (e *eqSigBuilder) float64() {
-       e.flush()
+       e.flush(false)
        e.r.WriteByte(sigFloat64)
 }
 func (e *eqSigBuilder) string() {
-       e.flush()
+       e.flush(false)
        e.r.WriteByte(sigString)
 }
 func (e *eqSigBuilder) eface() {
-       e.flush()
+       e.flush(false)
        e.r.WriteByte(sigEface)
 }
 func (e *eqSigBuilder) iface() {
-       e.flush()
+       e.flush(false)
        e.r.WriteByte(sigIface)
 }
 
@@ -865,12 +890,12 @@ func (e *eqSigBuilder) build(t *types.Type) {
                                }
                                break
                        }
-                       e.flush()
+                       e.flush(false)
                        e.r.WriteString(fmt.Sprintf("%c%d", sigArrayStart, n/unroll))
                        for range unroll {
                                e.build(et)
                        }
-                       e.flush()
+                       e.flush(false)
                        e.r.WriteByte(sigArrayEnd)
                }
        default:
@@ -937,3 +962,16 @@ func sigSize(sig string) int64 {
        }
        return size
 }
+
+func sigTrimSkip(s string) string {
+       i := strings.LastIndexByte(s, sigSkip)
+       if i < 0 {
+               return s
+       }
+       for j := i + 1; j < len(s); j++ {
+               if s[j] < '0' || s[j] > '9' {
+                       return s
+               }
+       }
+       return s[:i]
+}