]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: intrinsify math/bits.TrailingZeros on riscv64
authorJoel Sing <joel@sing.id.au>
Sun, 23 Feb 2025 11:17:53 +0000 (22:17 +1100)
committerJoel Sing <joel@sing.id.au>
Sun, 16 Mar 2025 02:07:53 +0000 (19:07 -0700)
For riscv64/rva22u64 and above, we can intrinsify math/bits.TrailingZeros
using the CTZ/CTZW machine instructions.

On a StarFive VisionFive 2 with GORISCV64=rva22u64:

                  │   ctz.b.1    │               ctz.b.2               │
                  │    sec/op    │   sec/op     vs base                │
TrailingZeros-4     25.500n ± 0%   8.052n ± 0%  -68.42% (p=0.000 n=10)
TrailingZeros8-4     14.76n ± 0%   10.74n ± 0%  -27.24% (p=0.000 n=10)
TrailingZeros16-4    26.84n ± 0%   10.74n ± 0%  -59.99% (p=0.000 n=10)
TrailingZeros32-4   25.500n ± 0%   8.052n ± 0%  -68.42% (p=0.000 n=10)
TrailingZeros64-4   25.500n ± 0%   8.052n ± 0%  -68.42% (p=0.000 n=10)
geomean              23.09n        9.035n       -60.88%

Change-Id: I71edf2b988acb7a68e797afda4ee66d7a57d587e
Reviewed-on: https://go-review.googlesource.com/c/go/+/652320
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Mark Ryan <markdryan@rivosinc.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Meng Zhuo <mengzhuo1203@gmail.com>
src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/_gen/RISCV64.rules
src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
src/cmd/compile/internal/ssagen/intrinsics.go
src/cmd/compile/internal/ssagen/intrinsics_test.go
test/codegen/mathbits.go

index 636ef44d68e6a3edbcdba5dd604ff57c10994bb1..4392081f6e3db051c0cb1d879173f1a918d31a50 100644 (file)
@@ -419,7 +419,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                ssa.OpRISCV64FMVSX, ssa.OpRISCV64FMVDX,
                ssa.OpRISCV64FCVTSW, ssa.OpRISCV64FCVTSL, ssa.OpRISCV64FCVTWS, ssa.OpRISCV64FCVTLS,
                ssa.OpRISCV64FCVTDW, ssa.OpRISCV64FCVTDL, ssa.OpRISCV64FCVTWD, ssa.OpRISCV64FCVTLD, ssa.OpRISCV64FCVTDS, ssa.OpRISCV64FCVTSD,
-               ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW:
+               ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW, ssa.OpRISCV64CTZ, ssa.OpRISCV64CTZW:
                p := s.Prog(v.Op.Asm())
                p.From.Type = obj.TYPE_REG
                p.From.Reg = v.Args[0].Reg()
index 770a9095f6d1c3f063bbb77fd39467f01168b3f6..016eb53f04909d734216adfbd08cf6674573c679 100644 (file)
 (RotateLeft32 ...) => (ROLW ...)
 (RotateLeft64 ...) => (ROL  ...)
 
+// Count trailing zeros (note that these will only be emitted for rva22u64 and above).
+(Ctz(64|32|16|8)NonZero ...) => (Ctz64 ...)
+(Ctz64 ...) => (CTZ  ...)
+(Ctz32 ...) => (CTZW ...)
+(Ctz16 x) => (CTZW (ORI <typ.UInt32> [1<<16] x))
+(Ctz8  x) => (CTZW (ORI <typ.UInt32> [1<<8]  x))
+
 (Less64  ...) => (SLT  ...)
 (Less32  x y) => (SLT  (SignExt32to64 x) (SignExt32to64 y))
 (Less16  x y) => (SLT  (SignExt16to64 x) (SignExt16to64 y))
index 85c74b46769a1f02a0bdda2fe3cefb34d144cba1..85e9e47e822369d5ecaa34d6bf913bd94e3db7b4 100644 (file)
@@ -229,6 +229,8 @@ func init() {
                {name: "AND", argLength: 2, reg: gp21, asm: "AND", commutative: true},   // arg0 & arg1
                {name: "ANDN", argLength: 2, reg: gp21, asm: "ANDN"},                    // ^arg0 & arg1
                {name: "ANDI", argLength: 1, reg: gp11, asm: "ANDI", aux: "Int64"},      // arg0 & auxint
+               {name: "CTZ", argLength: 1, reg: gp11, asm: "CTZ"},                      // count trailing zeros
+               {name: "CTZW", argLength: 1, reg: gp11, asm: "CTZW"},                    // count trailing zeros of least significant word
                {name: "NOT", argLength: 1, reg: gp11, asm: "NOT"},                      // ^arg0
                {name: "OR", argLength: 2, reg: gp21, asm: "OR", commutative: true},     // arg0 | arg1
                {name: "ORN", argLength: 2, reg: gp21, asm: "ORN"},                      // ^arg0 | arg1
index 8ceff3f449b0bc2c8ac1f9bf52a3cdd3e74236c8..3fd5b310ac25fd48782b7b8788b1b46dbf56cee4 100644 (file)
@@ -2523,6 +2523,8 @@ const (
        OpRISCV64AND
        OpRISCV64ANDN
        OpRISCV64ANDI
+       OpRISCV64CTZ
+       OpRISCV64CTZW
        OpRISCV64NOT
        OpRISCV64OR
        OpRISCV64ORN
@@ -34002,6 +34004,32 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:   "CTZ",
+               argLen: 1,
+               asm:    riscv.ACTZ,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+                       outputs: []outputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+               },
+       },
+       {
+               name:   "CTZW",
+               argLen: 1,
+               asm:    riscv.ACTZW,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+                       outputs: []outputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+               },
+       },
        {
                name:   "NOT",
                argLen: 1,
index e19e28ea239b625cf0f0ee4bc4b23f725491c45c..ab93309680ca677dfd65e90394d2b38e98f23b7c 100644 (file)
@@ -136,6 +136,28 @@ func rewriteValueRISCV64(v *Value) bool {
        case OpCopysign:
                v.Op = OpRISCV64FSGNJD
                return true
+       case OpCtz16:
+               return rewriteValueRISCV64_OpCtz16(v)
+       case OpCtz16NonZero:
+               v.Op = OpCtz64
+               return true
+       case OpCtz32:
+               v.Op = OpRISCV64CTZW
+               return true
+       case OpCtz32NonZero:
+               v.Op = OpCtz64
+               return true
+       case OpCtz64:
+               v.Op = OpRISCV64CTZ
+               return true
+       case OpCtz64NonZero:
+               v.Op = OpCtz64
+               return true
+       case OpCtz8:
+               return rewriteValueRISCV64_OpCtz8(v)
+       case OpCtz8NonZero:
+               v.Op = OpCtz64
+               return true
        case OpCvt32Fto32:
                v.Op = OpRISCV64FCVTWS
                return true
@@ -993,6 +1015,38 @@ func rewriteValueRISCV64_OpConstNil(v *Value) bool {
                return true
        }
 }
+func rewriteValueRISCV64_OpCtz16(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (Ctz16 x)
+       // result: (CTZW (ORI <typ.UInt32> [1<<16] x))
+       for {
+               x := v_0
+               v.reset(OpRISCV64CTZW)
+               v0 := b.NewValue0(v.Pos, OpRISCV64ORI, typ.UInt32)
+               v0.AuxInt = int64ToAuxInt(1 << 16)
+               v0.AddArg(x)
+               v.AddArg(v0)
+               return true
+       }
+}
+func rewriteValueRISCV64_OpCtz8(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (Ctz8 x)
+       // result: (CTZW (ORI <typ.UInt32> [1<<8] x))
+       for {
+               x := v_0
+               v.reset(OpRISCV64CTZW)
+               v0 := b.NewValue0(v.Pos, OpRISCV64ORI, typ.UInt32)
+               v0.AuxInt = int64ToAuxInt(1 << 8)
+               v0.AddArg(x)
+               v.AddArg(v0)
+               return true
+       }
+}
 func rewriteValueRISCV64_OpDiv16(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
index abb63a99eb25924452853fca525579730153aa9a..e1e464955515f9059b51acf19d62b785331a2bfb 100644 (file)
@@ -900,6 +900,30 @@ func initIntrinsics(cfg *intrinsicBuildConfig) {
                        return s.newValue1(ssa.OpCtz8, types.Types[types.TINT], args[0])
                },
                sys.AMD64, sys.ARM, sys.ARM64, sys.I386, sys.MIPS, sys.Loong64, sys.PPC64, sys.S390X, sys.Wasm)
+
+       if cfg.goriscv64 >= 22 {
+               addF("math/bits", "TrailingZeros64",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpCtz64, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "TrailingZeros32",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpCtz32, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "TrailingZeros16",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpCtz16, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "TrailingZeros8",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpCtz8, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+       }
+
        alias("math/bits", "ReverseBytes64", "internal/runtime/sys", "Bswap64", all...)
        alias("math/bits", "ReverseBytes32", "internal/runtime/sys", "Bswap32", all...)
        addF("math/bits", "ReverseBytes16",
index 0f8a8a83b407ab00ddf0fec80c0dbeb9c3d1f184..192f91c1830cafdad7d561ade6555f47a06cdf4b 100644 (file)
@@ -1106,6 +1106,9 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{
        {"riscv64", "internal/runtime/sys", "GetCallerPC"}:                 struct{}{},
        {"riscv64", "internal/runtime/sys", "GetCallerSP"}:                 struct{}{},
        {"riscv64", "internal/runtime/sys", "GetClosurePtr"}:               struct{}{},
+       {"riscv64", "internal/runtime/sys", "TrailingZeros32"}:             struct{}{},
+       {"riscv64", "internal/runtime/sys", "TrailingZeros64"}:             struct{}{},
+       {"riscv64", "internal/runtime/sys", "TrailingZeros8"}:              struct{}{},
        {"riscv64", "math", "Abs"}:                                         struct{}{},
        {"riscv64", "math", "Copysign"}:                                    struct{}{},
        {"riscv64", "math", "FMA"}:                                         struct{}{},
@@ -1122,6 +1125,10 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{
        {"riscv64", "math/bits", "RotateLeft8"}:                            struct{}{},
        {"riscv64", "math/bits", "Sub"}:                                    struct{}{},
        {"riscv64", "math/bits", "Sub64"}:                                  struct{}{},
+       {"riscv64", "math/bits", "TrailingZeros16"}:                        struct{}{},
+       {"riscv64", "math/bits", "TrailingZeros32"}:                        struct{}{},
+       {"riscv64", "math/bits", "TrailingZeros64"}:                        struct{}{},
+       {"riscv64", "math/bits", "TrailingZeros8"}:                         struct{}{},
        {"riscv64", "runtime", "KeepAlive"}:                                struct{}{},
        {"riscv64", "runtime", "publicationBarrier"}:                       struct{}{},
        {"riscv64", "runtime", "slicebytetostringtmp"}:                     struct{}{},
@@ -1308,7 +1315,8 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{
 
 func TestIntrinsics(t *testing.T) {
        cfg := &intrinsicBuildConfig{
-               goppc64: 10,
+               goppc64:   10,
+               goriscv64: 23,
        }
        initIntrinsics(cfg)
 
index 1cee39283d73a1a12d7e82c96bcbcef99538700a..786fad3bd9e058951b64c4daab92a804b4877554 100644 (file)
@@ -356,28 +356,30 @@ func RotateLeftVariable32(n uint32, m int) uint32 {
 // ------------------------ //
 
 func TrailingZeros(n uint) int {
+       // 386:"BSFL"
        // amd64/v1,amd64/v2:"BSFQ","MOVL\t\\$64","CMOVQEQ"
        // amd64/v3:"TZCNTQ"
-       // 386:"BSFL"
        // arm:"CLZ"
        // arm64:"RBIT","CLZ"
        // loong64:"CTZV"
-       // s390x:"FLOGR"
        // ppc64x/power8:"ANDN","POPCNTD"
        // ppc64x/power9: "CNTTZD"
+       // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t"
+       // s390x:"FLOGR"
        // wasm:"I64Ctz"
        return bits.TrailingZeros(n)
 }
 
 func TrailingZeros64(n uint64) int {
+       // 386:"BSFL","JNE"
        // amd64/v1,amd64/v2:"BSFQ","MOVL\t\\$64","CMOVQEQ"
        // amd64/v3:"TZCNTQ"
-       // 386:"BSFL","JNE"
        // arm64:"RBIT","CLZ"
        // loong64:"CTZV"
-       // s390x:"FLOGR"
        // ppc64x/power8:"ANDN","POPCNTD"
        // ppc64x/power9: "CNTTZD"
+       // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t"
+       // s390x:"FLOGR"
        // wasm:"I64Ctz"
        return bits.TrailingZeros64(n)
 }
@@ -389,40 +391,43 @@ func TrailingZeros64Subtract(n uint64) int {
 }
 
 func TrailingZeros32(n uint32) int {
+       // 386:"BSFL"
        // amd64/v1,amd64/v2:"BTSQ\\t\\$32","BSFQ"
        // amd64/v3:"TZCNTL"
-       // 386:"BSFL"
        // arm:"CLZ"
        // arm64:"RBITW","CLZW"
        // loong64:"CTZW"
-       // s390x:"FLOGR","MOVWZ"
        // ppc64x/power8:"ANDN","POPCNTW"
        // ppc64x/power9: "CNTTZW"
+       // riscv64/rva22u64,riscv64/rva23u64: "CTZW"
+       // s390x:"FLOGR","MOVWZ"
        // wasm:"I64Ctz"
        return bits.TrailingZeros32(n)
 }
 
 func TrailingZeros16(n uint16) int {
-       // amd64:"BSFL","ORL\\t\\$65536"
        // 386:"BSFL\t"
+       // amd64:"BSFL","ORL\\t\\$65536"
        // arm:"ORR\t\\$65536","CLZ",-"MOVHU\tR"
        // arm64:"ORR\t\\$65536","RBITW","CLZW",-"MOVHU\tR",-"RBIT\t",-"CLZ\t"
        // loong64:"CTZV"
-       // s390x:"FLOGR","OR\t\\$65536"
        // ppc64x/power8:"POPCNTW","ADD\t\\$-1"
        // ppc64x/power9:"CNTTZD","ORIS\\t\\$1"
+       // riscv64/rva22u64,riscv64/rva23u64: "ORI\t\\$65536","CTZW"
+       // s390x:"FLOGR","OR\t\\$65536"
        // wasm:"I64Ctz"
        return bits.TrailingZeros16(n)
 }
 
 func TrailingZeros8(n uint8) int {
-       // amd64:"BSFL","ORL\\t\\$256"
        // 386:"BSFL"
+       // amd64:"BSFL","ORL\\t\\$256"
        // arm:"ORR\t\\$256","CLZ",-"MOVBU\tR"
        // arm64:"ORR\t\\$256","RBITW","CLZW",-"MOVBU\tR",-"RBIT\t",-"CLZ\t"
        // loong64:"CTZV"
        // ppc64x/power8:"POPCNTB","ADD\t\\$-1"
        // ppc64x/power9:"CNTTZD","OR\t\\$256"
+       // riscv64/rva22u64,riscv64/rva23u64: "ORI\t\\$256","CTZW"
        // s390x:"FLOGR","OR\t\\$256"
        // wasm:"I64Ctz"
        return bits.TrailingZeros8(n)
@@ -469,6 +474,7 @@ func IterateBits16(n uint16) int {
                // amd64/v1,amd64/v2:"BSFL",-"BTSL"
                // amd64/v3:"TZCNTL"
                // arm64:"RBITW","CLZW",-"ORR"
+               // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t",-"ORR"
                i += bits.TrailingZeros16(n)
                n &= n - 1
        }
@@ -481,6 +487,7 @@ func IterateBits8(n uint8) int {
                // amd64/v1,amd64/v2:"BSFL",-"BTSL"
                // amd64/v3:"TZCNTL"
                // arm64:"RBITW","CLZW",-"ORR"
+               // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t",-"ORR"
                i += bits.TrailingZeros8(n)
                n &= n - 1
        }