From: Joel Sing Date: Sun, 23 Feb 2025 11:17:53 +0000 (+1100) Subject: cmd/compile: intrinsify math/bits.TrailingZeros on riscv64 X-Git-Tag: go1.25rc1~694 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=6fb7bdc96d0398fab313586fba6fdc89cc14c679;p=gostls13.git cmd/compile: intrinsify math/bits.TrailingZeros on riscv64 For riscv64/rva22u64 and above, we can intrinsify math/bits.TrailingZeros using the CTZ/CTZW machine instructions. On a StarFive VisionFive 2 with GORISCV64=rva22u64: │ ctz.b.1 │ ctz.b.2 │ │ sec/op │ sec/op vs base │ TrailingZeros-4 25.500n ± 0% 8.052n ± 0% -68.42% (p=0.000 n=10) TrailingZeros8-4 14.76n ± 0% 10.74n ± 0% -27.24% (p=0.000 n=10) TrailingZeros16-4 26.84n ± 0% 10.74n ± 0% -59.99% (p=0.000 n=10) TrailingZeros32-4 25.500n ± 0% 8.052n ± 0% -68.42% (p=0.000 n=10) TrailingZeros64-4 25.500n ± 0% 8.052n ± 0% -68.42% (p=0.000 n=10) geomean 23.09n 9.035n -60.88% Change-Id: I71edf2b988acb7a68e797afda4ee66d7a57d587e Reviewed-on: https://go-review.googlesource.com/c/go/+/652320 Reviewed-by: Cherry Mui Reviewed-by: Mark Ryan Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Reviewed-by: Meng Zhuo --- diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 636ef44d68..4392081f6e 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -419,7 +419,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { ssa.OpRISCV64FMVSX, ssa.OpRISCV64FMVDX, ssa.OpRISCV64FCVTSW, ssa.OpRISCV64FCVTSL, ssa.OpRISCV64FCVTWS, ssa.OpRISCV64FCVTLS, ssa.OpRISCV64FCVTDW, ssa.OpRISCV64FCVTDL, ssa.OpRISCV64FCVTWD, ssa.OpRISCV64FCVTLD, ssa.OpRISCV64FCVTDS, ssa.OpRISCV64FCVTSD, - ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW: + ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW, ssa.OpRISCV64CTZ, ssa.OpRISCV64CTZW: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_REG p.From.Reg = v.Args[0].Reg() diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 770a9095f6..016eb53f04 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -218,6 +218,13 @@ (RotateLeft32 ...) => (ROLW ...) (RotateLeft64 ...) => (ROL ...) +// Count trailing zeros (note that these will only be emitted for rva22u64 and above). +(Ctz(64|32|16|8)NonZero ...) => (Ctz64 ...) +(Ctz64 ...) => (CTZ ...) +(Ctz32 ...) => (CTZW ...) +(Ctz16 x) => (CTZW (ORI [1<<16] x)) +(Ctz8 x) => (CTZW (ORI [1<<8] x)) + (Less64 ...) => (SLT ...) (Less32 x y) => (SLT (SignExt32to64 x) (SignExt32to64 y)) (Less16 x y) => (SLT (SignExt16to64 x) (SignExt16to64 y)) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 85c74b4676..85e9e47e82 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -229,6 +229,8 @@ func init() { {name: "AND", argLength: 2, reg: gp21, asm: "AND", commutative: true}, // arg0 & arg1 {name: "ANDN", argLength: 2, reg: gp21, asm: "ANDN"}, // ^arg0 & arg1 {name: "ANDI", argLength: 1, reg: gp11, asm: "ANDI", aux: "Int64"}, // arg0 & auxint + {name: "CTZ", argLength: 1, reg: gp11, asm: "CTZ"}, // count trailing zeros + {name: "CTZW", argLength: 1, reg: gp11, asm: "CTZW"}, // count trailing zeros of least significant word {name: "NOT", argLength: 1, reg: gp11, asm: "NOT"}, // ^arg0 {name: "OR", argLength: 2, reg: gp21, asm: "OR", commutative: true}, // arg0 | arg1 {name: "ORN", argLength: 2, reg: gp21, asm: "ORN"}, // ^arg0 | arg1 diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 8ceff3f449..3fd5b310ac 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2523,6 +2523,8 @@ const ( OpRISCV64AND OpRISCV64ANDN OpRISCV64ANDI + OpRISCV64CTZ + OpRISCV64CTZW OpRISCV64NOT OpRISCV64OR OpRISCV64ORN @@ -34002,6 +34004,32 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "CTZ", + argLen: 1, + asm: riscv.ACTZ, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "CTZW", + argLen: 1, + asm: riscv.ACTZW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "NOT", argLen: 1, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index e19e28ea23..ab93309680 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -136,6 +136,28 @@ func rewriteValueRISCV64(v *Value) bool { case OpCopysign: v.Op = OpRISCV64FSGNJD return true + case OpCtz16: + return rewriteValueRISCV64_OpCtz16(v) + case OpCtz16NonZero: + v.Op = OpCtz64 + return true + case OpCtz32: + v.Op = OpRISCV64CTZW + return true + case OpCtz32NonZero: + v.Op = OpCtz64 + return true + case OpCtz64: + v.Op = OpRISCV64CTZ + return true + case OpCtz64NonZero: + v.Op = OpCtz64 + return true + case OpCtz8: + return rewriteValueRISCV64_OpCtz8(v) + case OpCtz8NonZero: + v.Op = OpCtz64 + return true case OpCvt32Fto32: v.Op = OpRISCV64FCVTWS return true @@ -993,6 +1015,38 @@ func rewriteValueRISCV64_OpConstNil(v *Value) bool { return true } } +func rewriteValueRISCV64_OpCtz16(v *Value) bool { + v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Ctz16 x) + // result: (CTZW (ORI [1<<16] x)) + for { + x := v_0 + v.reset(OpRISCV64CTZW) + v0 := b.NewValue0(v.Pos, OpRISCV64ORI, typ.UInt32) + v0.AuxInt = int64ToAuxInt(1 << 16) + v0.AddArg(x) + v.AddArg(v0) + return true + } +} +func rewriteValueRISCV64_OpCtz8(v *Value) bool { + v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Ctz8 x) + // result: (CTZW (ORI [1<<8] x)) + for { + x := v_0 + v.reset(OpRISCV64CTZW) + v0 := b.NewValue0(v.Pos, OpRISCV64ORI, typ.UInt32) + v0.AuxInt = int64ToAuxInt(1 << 8) + v0.AddArg(x) + v.AddArg(v0) + return true + } +} func rewriteValueRISCV64_OpDiv16(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/src/cmd/compile/internal/ssagen/intrinsics.go b/src/cmd/compile/internal/ssagen/intrinsics.go index abb63a99eb..e1e4649555 100644 --- a/src/cmd/compile/internal/ssagen/intrinsics.go +++ b/src/cmd/compile/internal/ssagen/intrinsics.go @@ -900,6 +900,30 @@ func initIntrinsics(cfg *intrinsicBuildConfig) { return s.newValue1(ssa.OpCtz8, types.Types[types.TINT], args[0]) }, sys.AMD64, sys.ARM, sys.ARM64, sys.I386, sys.MIPS, sys.Loong64, sys.PPC64, sys.S390X, sys.Wasm) + + if cfg.goriscv64 >= 22 { + addF("math/bits", "TrailingZeros64", + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + return s.newValue1(ssa.OpCtz64, types.Types[types.TINT], args[0]) + }, + sys.RISCV64) + addF("math/bits", "TrailingZeros32", + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + return s.newValue1(ssa.OpCtz32, types.Types[types.TINT], args[0]) + }, + sys.RISCV64) + addF("math/bits", "TrailingZeros16", + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + return s.newValue1(ssa.OpCtz16, types.Types[types.TINT], args[0]) + }, + sys.RISCV64) + addF("math/bits", "TrailingZeros8", + func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { + return s.newValue1(ssa.OpCtz8, types.Types[types.TINT], args[0]) + }, + sys.RISCV64) + } + alias("math/bits", "ReverseBytes64", "internal/runtime/sys", "Bswap64", all...) alias("math/bits", "ReverseBytes32", "internal/runtime/sys", "Bswap32", all...) addF("math/bits", "ReverseBytes16", diff --git a/src/cmd/compile/internal/ssagen/intrinsics_test.go b/src/cmd/compile/internal/ssagen/intrinsics_test.go index 0f8a8a83b4..192f91c183 100644 --- a/src/cmd/compile/internal/ssagen/intrinsics_test.go +++ b/src/cmd/compile/internal/ssagen/intrinsics_test.go @@ -1106,6 +1106,9 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{ {"riscv64", "internal/runtime/sys", "GetCallerPC"}: struct{}{}, {"riscv64", "internal/runtime/sys", "GetCallerSP"}: struct{}{}, {"riscv64", "internal/runtime/sys", "GetClosurePtr"}: struct{}{}, + {"riscv64", "internal/runtime/sys", "TrailingZeros32"}: struct{}{}, + {"riscv64", "internal/runtime/sys", "TrailingZeros64"}: struct{}{}, + {"riscv64", "internal/runtime/sys", "TrailingZeros8"}: struct{}{}, {"riscv64", "math", "Abs"}: struct{}{}, {"riscv64", "math", "Copysign"}: struct{}{}, {"riscv64", "math", "FMA"}: struct{}{}, @@ -1122,6 +1125,10 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{ {"riscv64", "math/bits", "RotateLeft8"}: struct{}{}, {"riscv64", "math/bits", "Sub"}: struct{}{}, {"riscv64", "math/bits", "Sub64"}: struct{}{}, + {"riscv64", "math/bits", "TrailingZeros16"}: struct{}{}, + {"riscv64", "math/bits", "TrailingZeros32"}: struct{}{}, + {"riscv64", "math/bits", "TrailingZeros64"}: struct{}{}, + {"riscv64", "math/bits", "TrailingZeros8"}: struct{}{}, {"riscv64", "runtime", "KeepAlive"}: struct{}{}, {"riscv64", "runtime", "publicationBarrier"}: struct{}{}, {"riscv64", "runtime", "slicebytetostringtmp"}: struct{}{}, @@ -1308,7 +1315,8 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{ func TestIntrinsics(t *testing.T) { cfg := &intrinsicBuildConfig{ - goppc64: 10, + goppc64: 10, + goriscv64: 23, } initIntrinsics(cfg) diff --git a/test/codegen/mathbits.go b/test/codegen/mathbits.go index 1cee39283d..786fad3bd9 100644 --- a/test/codegen/mathbits.go +++ b/test/codegen/mathbits.go @@ -356,28 +356,30 @@ func RotateLeftVariable32(n uint32, m int) uint32 { // ------------------------ // func TrailingZeros(n uint) int { + // 386:"BSFL" // amd64/v1,amd64/v2:"BSFQ","MOVL\t\\$64","CMOVQEQ" // amd64/v3:"TZCNTQ" - // 386:"BSFL" // arm:"CLZ" // arm64:"RBIT","CLZ" // loong64:"CTZV" - // s390x:"FLOGR" // ppc64x/power8:"ANDN","POPCNTD" // ppc64x/power9: "CNTTZD" + // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t" + // s390x:"FLOGR" // wasm:"I64Ctz" return bits.TrailingZeros(n) } func TrailingZeros64(n uint64) int { + // 386:"BSFL","JNE" // amd64/v1,amd64/v2:"BSFQ","MOVL\t\\$64","CMOVQEQ" // amd64/v3:"TZCNTQ" - // 386:"BSFL","JNE" // arm64:"RBIT","CLZ" // loong64:"CTZV" - // s390x:"FLOGR" // ppc64x/power8:"ANDN","POPCNTD" // ppc64x/power9: "CNTTZD" + // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t" + // s390x:"FLOGR" // wasm:"I64Ctz" return bits.TrailingZeros64(n) } @@ -389,40 +391,43 @@ func TrailingZeros64Subtract(n uint64) int { } func TrailingZeros32(n uint32) int { + // 386:"BSFL" // amd64/v1,amd64/v2:"BTSQ\\t\\$32","BSFQ" // amd64/v3:"TZCNTL" - // 386:"BSFL" // arm:"CLZ" // arm64:"RBITW","CLZW" // loong64:"CTZW" - // s390x:"FLOGR","MOVWZ" // ppc64x/power8:"ANDN","POPCNTW" // ppc64x/power9: "CNTTZW" + // riscv64/rva22u64,riscv64/rva23u64: "CTZW" + // s390x:"FLOGR","MOVWZ" // wasm:"I64Ctz" return bits.TrailingZeros32(n) } func TrailingZeros16(n uint16) int { - // amd64:"BSFL","ORL\\t\\$65536" // 386:"BSFL\t" + // amd64:"BSFL","ORL\\t\\$65536" // arm:"ORR\t\\$65536","CLZ",-"MOVHU\tR" // arm64:"ORR\t\\$65536","RBITW","CLZW",-"MOVHU\tR",-"RBIT\t",-"CLZ\t" // loong64:"CTZV" - // s390x:"FLOGR","OR\t\\$65536" // ppc64x/power8:"POPCNTW","ADD\t\\$-1" // ppc64x/power9:"CNTTZD","ORIS\\t\\$1" + // riscv64/rva22u64,riscv64/rva23u64: "ORI\t\\$65536","CTZW" + // s390x:"FLOGR","OR\t\\$65536" // wasm:"I64Ctz" return bits.TrailingZeros16(n) } func TrailingZeros8(n uint8) int { - // amd64:"BSFL","ORL\\t\\$256" // 386:"BSFL" + // amd64:"BSFL","ORL\\t\\$256" // arm:"ORR\t\\$256","CLZ",-"MOVBU\tR" // arm64:"ORR\t\\$256","RBITW","CLZW",-"MOVBU\tR",-"RBIT\t",-"CLZ\t" // loong64:"CTZV" // ppc64x/power8:"POPCNTB","ADD\t\\$-1" // ppc64x/power9:"CNTTZD","OR\t\\$256" + // riscv64/rva22u64,riscv64/rva23u64: "ORI\t\\$256","CTZW" // s390x:"FLOGR","OR\t\\$256" // wasm:"I64Ctz" return bits.TrailingZeros8(n) @@ -469,6 +474,7 @@ func IterateBits16(n uint16) int { // amd64/v1,amd64/v2:"BSFL",-"BTSL" // amd64/v3:"TZCNTL" // arm64:"RBITW","CLZW",-"ORR" + // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t",-"ORR" i += bits.TrailingZeros16(n) n &= n - 1 } @@ -481,6 +487,7 @@ func IterateBits8(n uint8) int { // amd64/v1,amd64/v2:"BSFL",-"BTSL" // amd64/v3:"TZCNTL" // arm64:"RBITW","CLZW",-"ORR" + // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t",-"ORR" i += bits.TrailingZeros8(n) n &= n - 1 }