]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: intrinsify math/bits.Len on riscv64
authorJoel Sing <joel@sing.id.au>
Sun, 23 Feb 2025 13:27:34 +0000 (00:27 +1100)
committerJoel Sing <joel@sing.id.au>
Sat, 22 Mar 2025 01:21:44 +0000 (18:21 -0700)
For riscv64/rva22u64 and above, we can intrinsify math/bits.Len using the
CLZ/CLZW machine instructions.

On a StarFive VisionFive 2 with GORISCV64=rva22u64:

                 │   clz.b.1   │               clz.b.2               │
                 │   sec/op    │   sec/op     vs base                │
LeadingZeros-4     28.89n ± 0%   12.08n ± 0%  -58.19% (p=0.000 n=10)
LeadingZeros8-4    18.79n ± 0%   14.76n ± 0%  -21.45% (p=0.000 n=10)
LeadingZeros16-4   25.27n ± 0%   14.76n ± 0%  -41.59% (p=0.000 n=10)
LeadingZeros32-4   25.12n ± 0%   12.08n ± 0%  -51.92% (p=0.000 n=10)
LeadingZeros64-4   25.89n ± 0%   12.08n ± 0%  -53.35% (p=0.000 n=10)
geomean            24.55n        13.09n       -46.70%

Change-Id: I0dda684713dbdf5336af393f5ccbdae861c4f694
Reviewed-on: https://go-review.googlesource.com/c/go/+/652321
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Meng Zhuo <mengzhuo1203@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Mark Ryan <markdryan@rivosinc.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/_gen/RISCV64.rules
src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
src/cmd/compile/internal/ssagen/intrinsics.go
src/cmd/compile/internal/ssagen/intrinsics_test.go
test/codegen/mathbits.go

index 4392081f6e3db051c0cb1d879173f1a918d31a50..952a2050a08e08e95743753e879ee27475b64301 100644 (file)
@@ -419,7 +419,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                ssa.OpRISCV64FMVSX, ssa.OpRISCV64FMVDX,
                ssa.OpRISCV64FCVTSW, ssa.OpRISCV64FCVTSL, ssa.OpRISCV64FCVTWS, ssa.OpRISCV64FCVTLS,
                ssa.OpRISCV64FCVTDW, ssa.OpRISCV64FCVTDL, ssa.OpRISCV64FCVTWD, ssa.OpRISCV64FCVTLD, ssa.OpRISCV64FCVTDS, ssa.OpRISCV64FCVTSD,
-               ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW, ssa.OpRISCV64CTZ, ssa.OpRISCV64CTZW:
+               ssa.OpRISCV64NOT, ssa.OpRISCV64NEG, ssa.OpRISCV64NEGW, ssa.OpRISCV64CLZ, ssa.OpRISCV64CLZW, ssa.OpRISCV64CTZ, ssa.OpRISCV64CTZW:
                p := s.Prog(v.Op.Asm())
                p.From.Type = obj.TYPE_REG
                p.From.Reg = v.Args[0].Reg()
index 016eb53f04909d734216adfbd08cf6674573c679..96b9b11cf9d711bee700f7babe08c1950731ac3e 100644 (file)
 (Ctz16 x) => (CTZW (ORI <typ.UInt32> [1<<16] x))
 (Ctz8  x) => (CTZW (ORI <typ.UInt32> [1<<8]  x))
 
+// Bit length (note that these will only be emitted for rva22u64 and above).
+(BitLen64 <t> x) => (SUB (MOVDconst [64]) (CLZ  <t> x))
+(BitLen32 <t> x) => (SUB (MOVDconst [32]) (CLZW <t> x))
+(BitLen16 x) => (BitLen64 (ZeroExt16to64 x))
+(BitLen8  x) => (BitLen64 (ZeroExt8to64 x))
+
 (Less64  ...) => (SLT  ...)
 (Less32  x y) => (SLT  (SignExt32to64 x) (SignExt32to64 y))
 (Less16  x y) => (SLT  (SignExt16to64 x) (SignExt16to64 y))
index 85e9e47e822369d5ecaa34d6bf913bd94e3db7b4..cc2302ff374fa57084c93590600b3cdd70c18399 100644 (file)
@@ -229,6 +229,8 @@ func init() {
                {name: "AND", argLength: 2, reg: gp21, asm: "AND", commutative: true},   // arg0 & arg1
                {name: "ANDN", argLength: 2, reg: gp21, asm: "ANDN"},                    // ^arg0 & arg1
                {name: "ANDI", argLength: 1, reg: gp11, asm: "ANDI", aux: "Int64"},      // arg0 & auxint
+               {name: "CLZ", argLength: 1, reg: gp11, asm: "CLZ"},                      // count leading zeros
+               {name: "CLZW", argLength: 1, reg: gp11, asm: "CLZW"},                    // count leading zeros of least significant word
                {name: "CTZ", argLength: 1, reg: gp11, asm: "CTZ"},                      // count trailing zeros
                {name: "CTZW", argLength: 1, reg: gp11, asm: "CTZW"},                    // count trailing zeros of least significant word
                {name: "NOT", argLength: 1, reg: gp11, asm: "NOT"},                      // ^arg0
index 3fd5b310ac25fd48782b7b8788b1b46dbf56cee4..0ae88ab24699eb901c9199a4ec7f7b5b627403b7 100644 (file)
@@ -2523,6 +2523,8 @@ const (
        OpRISCV64AND
        OpRISCV64ANDN
        OpRISCV64ANDI
+       OpRISCV64CLZ
+       OpRISCV64CLZW
        OpRISCV64CTZ
        OpRISCV64CTZW
        OpRISCV64NOT
@@ -34004,6 +34006,32 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:   "CLZ",
+               argLen: 1,
+               asm:    riscv.ACLZ,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+                       outputs: []outputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+               },
+       },
+       {
+               name:   "CLZW",
+               argLen: 1,
+               asm:    riscv.ACLZW,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+                       outputs: []outputInfo{
+                               {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30
+                       },
+               },
+       },
        {
                name:   "CTZ",
                argLen: 1,
index ab93309680ca677dfd65e90394d2b38e98f23b7c..b2318e711b272f6802380d55b5e70f4c53f1c638 100644 (file)
@@ -102,6 +102,14 @@ func rewriteValueRISCV64(v *Value) bool {
                return true
        case OpAvg64u:
                return rewriteValueRISCV64_OpAvg64u(v)
+       case OpBitLen16:
+               return rewriteValueRISCV64_OpBitLen16(v)
+       case OpBitLen32:
+               return rewriteValueRISCV64_OpBitLen32(v)
+       case OpBitLen64:
+               return rewriteValueRISCV64_OpBitLen64(v)
+       case OpBitLen8:
+               return rewriteValueRISCV64_OpBitLen8(v)
        case OpClosureCall:
                v.Op = OpRISCV64CALLclosure
                return true
@@ -928,6 +936,72 @@ func rewriteValueRISCV64_OpAvg64u(v *Value) bool {
                return true
        }
 }
+func rewriteValueRISCV64_OpBitLen16(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (BitLen16 x)
+       // result: (BitLen64 (ZeroExt16to64 x))
+       for {
+               x := v_0
+               v.reset(OpBitLen64)
+               v0 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64)
+               v0.AddArg(x)
+               v.AddArg(v0)
+               return true
+       }
+}
+func rewriteValueRISCV64_OpBitLen32(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (BitLen32 <t> x)
+       // result: (SUB (MOVDconst [32]) (CLZW <t> x))
+       for {
+               t := v.Type
+               x := v_0
+               v.reset(OpRISCV64SUB)
+               v0 := b.NewValue0(v.Pos, OpRISCV64MOVDconst, typ.UInt64)
+               v0.AuxInt = int64ToAuxInt(32)
+               v1 := b.NewValue0(v.Pos, OpRISCV64CLZW, t)
+               v1.AddArg(x)
+               v.AddArg2(v0, v1)
+               return true
+       }
+}
+func rewriteValueRISCV64_OpBitLen64(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (BitLen64 <t> x)
+       // result: (SUB (MOVDconst [64]) (CLZ <t> x))
+       for {
+               t := v.Type
+               x := v_0
+               v.reset(OpRISCV64SUB)
+               v0 := b.NewValue0(v.Pos, OpRISCV64MOVDconst, typ.UInt64)
+               v0.AuxInt = int64ToAuxInt(64)
+               v1 := b.NewValue0(v.Pos, OpRISCV64CLZ, t)
+               v1.AddArg(x)
+               v.AddArg2(v0, v1)
+               return true
+       }
+}
+func rewriteValueRISCV64_OpBitLen8(v *Value) bool {
+       v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (BitLen8 x)
+       // result: (BitLen64 (ZeroExt8to64 x))
+       for {
+               x := v_0
+               v.reset(OpBitLen64)
+               v0 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64)
+               v0.AddArg(x)
+               v.AddArg(v0)
+               return true
+       }
+}
 func rewriteValueRISCV64_OpConst16(v *Value) bool {
        // match: (Const16 [val])
        // result: (MOVDconst [int64(val)])
index f2b13045eb2d7dd73b46ffa7bc0e1ddbb03d0794..eaced0b2775e3a545faccb852374587b5da25ff3 100644 (file)
@@ -962,6 +962,30 @@ func initIntrinsics(cfg *intrinsicBuildConfig) {
                        return s.newValue1(ssa.OpBitLen8, types.Types[types.TINT], args[0])
                },
                sys.AMD64, sys.ARM, sys.ARM64, sys.Loong64, sys.MIPS, sys.PPC64, sys.S390X, sys.Wasm)
+
+       if cfg.goriscv64 >= 22 {
+               addF("math/bits", "Len64",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpBitLen64, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "Len32",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpBitLen32, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "Len16",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpBitLen16, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+               addF("math/bits", "Len8",
+                       func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
+                               return s.newValue1(ssa.OpBitLen8, types.Types[types.TINT], args[0])
+                       },
+                       sys.RISCV64)
+       }
+
        alias("math/bits", "Len", "math/bits", "Len64", p8...)
        alias("math/bits", "Len", "math/bits", "Len32", p4...)
 
index a06fdeedb2a6c59908c22273fc58f140cf380d3e..230a7bdf67d28e05630c3c1f1ced92f6742be602 100644 (file)
@@ -1110,6 +1110,8 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{
        {"riscv64", "internal/runtime/sys", "GetCallerPC"}:                 struct{}{},
        {"riscv64", "internal/runtime/sys", "GetCallerSP"}:                 struct{}{},
        {"riscv64", "internal/runtime/sys", "GetClosurePtr"}:               struct{}{},
+       {"riscv64", "internal/runtime/sys", "Len64"}:                       struct{}{},
+       {"riscv64", "internal/runtime/sys", "Len8"}:                        struct{}{},
        {"riscv64", "internal/runtime/sys", "TrailingZeros32"}:             struct{}{},
        {"riscv64", "internal/runtime/sys", "TrailingZeros64"}:             struct{}{},
        {"riscv64", "internal/runtime/sys", "TrailingZeros8"}:              struct{}{},
@@ -1120,6 +1122,11 @@ var wantIntrinsics = map[testIntrinsicKey]struct{}{
        {"riscv64", "math/big", "mulWW"}:                                   struct{}{},
        {"riscv64", "math/bits", "Add"}:                                    struct{}{},
        {"riscv64", "math/bits", "Add64"}:                                  struct{}{},
+       {"riscv64", "math/bits", "Len"}:                                    struct{}{},
+       {"riscv64", "math/bits", "Len16"}:                                  struct{}{},
+       {"riscv64", "math/bits", "Len32"}:                                  struct{}{},
+       {"riscv64", "math/bits", "Len64"}:                                  struct{}{},
+       {"riscv64", "math/bits", "Len8"}:                                   struct{}{},
        {"riscv64", "math/bits", "Mul"}:                                    struct{}{},
        {"riscv64", "math/bits", "Mul64"}:                                  struct{}{},
        {"riscv64", "math/bits", "RotateLeft"}:                             struct{}{},
index 786fad3bd9e058951b64c4daab92a804b4877554..a9cf466780773973d0b701ff1f2819125d13f5e9 100644 (file)
@@ -15,60 +15,70 @@ import "math/bits"
 func LeadingZeros(n uint) int {
        // amd64/v1,amd64/v2:"BSRQ"
        // amd64/v3:"LZCNTQ", -"BSRQ"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV",-"SUB"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t",-"SUB"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.LeadingZeros(n)
 }
 
 func LeadingZeros64(n uint64) int {
        // amd64/v1,amd64/v2:"BSRQ"
        // amd64/v3:"LZCNTQ", -"BSRQ"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm:"CLZ"
+       // arm64:"CLZ"
        // loong64:"CLZV",-"SUB"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t",-"ADDI"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.LeadingZeros64(n)
 }
 
 func LeadingZeros32(n uint32) int {
        // amd64/v1,amd64/v2:"BSRQ","LEAQ",-"CMOVQEQ"
        // amd64/v3: "LZCNTL",- "BSRL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZW"
+       // arm:"CLZ"
+       // arm64:"CLZW"
        // loong64:"CLZW",-"SUB"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"CNTLZW"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZW",-"ADDI"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.LeadingZeros32(n)
 }
 
 func LeadingZeros16(n uint16) int {
        // amd64/v1,amd64/v2:"BSRL","LEAL",-"CMOVQEQ"
        // amd64/v3: "LZCNTL",- "BSRL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-48",-"NEG"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.LeadingZeros16(n)
 }
 
 func LeadingZeros8(n uint8) int {
        // amd64/v1,amd64/v2:"BSRL","LEAL",-"CMOVQEQ"
        // amd64/v3: "LZCNTL",- "BSRL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-56",-"NEG"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.LeadingZeros8(n)
 }
 
@@ -79,30 +89,35 @@ func LeadingZeros8(n uint8) int {
 func Len(n uint) int {
        // amd64/v1,amd64/v2:"BSRQ"
        // amd64/v3: "LZCNTQ"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"SUBC","CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-64"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.Len(n)
 }
 
 func Len64(n uint64) int {
        // amd64/v1,amd64/v2:"BSRQ"
        // amd64/v3: "LZCNTQ"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"SUBC","CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-64"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.Len64(n)
 }
 
 func SubFromLen64(n uint64) int {
        // loong64:"CLZV",-"ADD"
        // ppc64x:"CNTLZD",-"SUBC"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t",-"ADDI",-"NEG"
        return 64 - bits.Len64(n)
 }
 
@@ -114,36 +129,42 @@ func CompareWithLen64(n uint64) bool {
 func Len32(n uint32) int {
        // amd64/v1,amd64/v2:"BSRQ","LEAQ",-"CMOVQEQ"
        // amd64/v3: "LZCNTL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZW"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x: "CNTLZW"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZW","ADDI\t\\$-32"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.Len32(n)
 }
 
 func Len16(n uint16) int {
        // amd64/v1,amd64/v2:"BSRL","LEAL",-"CMOVQEQ"
        // amd64/v3: "LZCNTL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"SUBC","CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-64"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.Len16(n)
 }
 
 func Len8(n uint8) int {
        // amd64/v1,amd64/v2:"BSRL","LEAL",-"CMOVQEQ"
        // amd64/v3: "LZCNTL"
-       // s390x:"FLOGR"
-       // arm:"CLZ" arm64:"CLZ"
+       // arm64:"CLZ"
+       // arm:"CLZ"
        // loong64:"CLZV"
        // mips:"CLZ"
-       // wasm:"I64Clz"
        // ppc64x:"SUBC","CNTLZD"
+       // riscv64/rva22u64,riscv64/rva23u64:"CLZ\t","ADDI\t\\$-64"
+       // s390x:"FLOGR"
+       // wasm:"I64Clz"
        return bits.Len8(n)
 }
 
@@ -451,6 +472,7 @@ func IterateBits64(n uint64) int {
        for n != 0 {
                // amd64/v1,amd64/v2:"BSFQ",-"CMOVEQ"
                // amd64/v3:"TZCNTQ"
+               // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t"
                i += bits.TrailingZeros64(n)
                n &= n - 1
        }
@@ -462,6 +484,7 @@ func IterateBits32(n uint32) int {
        for n != 0 {
                // amd64/v1,amd64/v2:"BSFL",-"BTSQ"
                // amd64/v3:"TZCNTL"
+               // riscv64/rva22u64,riscv64/rva23u64: "CTZ\t"
                i += bits.TrailingZeros32(n)
                n &= n - 1
        }