]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: intrinsify Add64 on riscv64
authorWayne Zuo <wdvxdr@golangcn.org>
Fri, 29 Jul 2022 06:24:26 +0000 (14:24 +0800)
committerJoel Sing <joel@sing.id.au>
Sat, 27 Aug 2022 05:43:32 +0000 (05:43 +0000)
According to RISCV instruction set manual v2.2 Sec 2.4, we can
implement overflowing check for unsigned addition cheaply using
SLTU instructions.

After this CL, the performance difference in crypto/elliptic
benchmarks on linux/riscv64 are:

name                 old time/op    new time/op    delta
ScalarBaseMult/P256    1.93ms ± 1%    1.64ms ± 1%  -14.96%  (p=0.008 n=5+5)
ScalarBaseMult/P224    1.80ms ± 2%    1.53ms ± 1%  -14.89%  (p=0.008 n=5+5)
ScalarBaseMult/P384    6.15ms ± 2%    5.12ms ± 2%  -16.73%  (p=0.008 n=5+5)
ScalarBaseMult/P521    25.9ms ± 1%    22.3ms ± 2%  -13.78%  (p=0.008 n=5+5)
ScalarMult/P256        5.59ms ± 1%    4.49ms ± 2%  -19.79%  (p=0.008 n=5+5)
ScalarMult/P224        5.42ms ± 1%    4.33ms ± 1%  -20.01%  (p=0.008 n=5+5)
ScalarMult/P384        19.9ms ± 2%    16.3ms ± 1%  -18.15%  (p=0.008 n=5+5)
ScalarMult/P521        97.3ms ± 1%   100.7ms ± 0%   +3.48%  (p=0.008 n=5+5)

Change-Id: Ic4c82ced4b072a4a6575343fa9f29dd09b0cabc4
Reviewed-on: https://go-review.googlesource.com/c/go/+/420094
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Wayne Zuo <wdvxdr@golangcn.org>
Reviewed-by: Joel Sing <joel@sing.id.au>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/cmd/compile/internal/ssa/gen/RISCV64.rules
src/cmd/compile/internal/ssa/rewriteRISCV64.go
src/cmd/compile/internal/ssagen/ssa.go
test/codegen/mathbits.go

index 5bc47ee1cc185c83220940adffca1abcddbdb8a7..9d2d785d0eab56c6a8d61224911c5ce9fe5ba8e5 100644 (file)
 (Hmul32 x y)  => (SRAI [32] (MUL  (SignExt32to64 x) (SignExt32to64 y)))
 (Hmul32u x y) => (SRLI [32] (MUL  (ZeroExt32to64 x) (ZeroExt32to64 y)))
 
+(Select0 (Add64carry x y c)) => (ADD (ADD <typ.UInt64> x y) c)
+(Select1 (Add64carry x y c)) =>
+       (OR (SLTU <typ.UInt64> s:(ADD <typ.UInt64> x y) x) (SLTU <typ.UInt64> (ADD <typ.UInt64> s c) s))
+
 // (x + y) / 2 => (x / 2) + (y / 2) + (x & y & 1)
 (Avg64u <t> x y) => (ADD (ADD <t> (SRLI <t> [1] x) (SRLI <t> [1] y)) (ANDI <t> [1] (AND <t> x y)))
 
 (SLTI  [x] (MOVDconst [y])) => (MOVDconst [b2i(int64(y) < int64(x))])
 (SLTIU [x] (MOVDconst [y])) => (MOVDconst [b2i(uint64(y) < uint64(x))])
 
+(SLT x x)  => (MOVDconst [0])
+(SLTU x x) => (MOVDconst [0])
+
 // deadcode for LoweredMuluhilo
 (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y)
 (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y)
index 9253d2d7296c6388c2f48471f504493eb0ff02a4..e4e4003f34e2416eb37491aedd8d1410dfd7148b 100644 (file)
@@ -509,10 +509,14 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpRISCV64SLL(v)
        case OpRISCV64SLLI:
                return rewriteValueRISCV64_OpRISCV64SLLI(v)
+       case OpRISCV64SLT:
+               return rewriteValueRISCV64_OpRISCV64SLT(v)
        case OpRISCV64SLTI:
                return rewriteValueRISCV64_OpRISCV64SLTI(v)
        case OpRISCV64SLTIU:
                return rewriteValueRISCV64_OpRISCV64SLTIU(v)
+       case OpRISCV64SLTU:
+               return rewriteValueRISCV64_OpRISCV64SLTU(v)
        case OpRISCV64SRA:
                return rewriteValueRISCV64_OpRISCV64SRA(v)
        case OpRISCV64SRAI:
@@ -4864,6 +4868,22 @@ func rewriteValueRISCV64_OpRISCV64SLLI(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64SLT(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (SLT x x)
+       // result: (MOVDconst [0])
+       for {
+               x := v_0
+               if x != v_1 {
+                       break
+               }
+               v.reset(OpRISCV64MOVDconst)
+               v.AuxInt = int64ToAuxInt(0)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64SLTI(v *Value) bool {
        v_0 := v.Args[0]
        // match: (SLTI [x] (MOVDconst [y]))
@@ -4896,6 +4916,22 @@ func rewriteValueRISCV64_OpRISCV64SLTIU(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64SLTU(v *Value) bool {
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (SLTU x x)
+       // result: (MOVDconst [0])
+       for {
+               x := v_0
+               if x != v_1 {
+                       break
+               }
+               v.reset(OpRISCV64MOVDconst)
+               v.AuxInt = int64ToAuxInt(0)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
@@ -6036,6 +6072,23 @@ func rewriteValueRISCV64_OpRsh8x8(v *Value) bool {
 }
 func rewriteValueRISCV64_OpSelect0(v *Value) bool {
        v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (Select0 (Add64carry x y c))
+       // result: (ADD (ADD <typ.UInt64> x y) c)
+       for {
+               if v_0.Op != OpAdd64carry {
+                       break
+               }
+               c := v_0.Args[2]
+               x := v_0.Args[0]
+               y := v_0.Args[1]
+               v.reset(OpRISCV64ADD)
+               v0 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64)
+               v0.AddArg2(x, y)
+               v.AddArg2(v0, c)
+               return true
+       }
        // match: (Select0 m:(LoweredMuluhilo x y))
        // cond: m.Uses == 1
        // result: (MULHU x y)
@@ -6057,6 +6110,29 @@ func rewriteValueRISCV64_OpSelect0(v *Value) bool {
 }
 func rewriteValueRISCV64_OpSelect1(v *Value) bool {
        v_0 := v.Args[0]
+       b := v.Block
+       typ := &b.Func.Config.Types
+       // match: (Select1 (Add64carry x y c))
+       // result: (OR (SLTU <typ.UInt64> s:(ADD <typ.UInt64> x y) x) (SLTU <typ.UInt64> (ADD <typ.UInt64> s c) s))
+       for {
+               if v_0.Op != OpAdd64carry {
+                       break
+               }
+               c := v_0.Args[2]
+               x := v_0.Args[0]
+               y := v_0.Args[1]
+               v.reset(OpRISCV64OR)
+               v0 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64)
+               s := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64)
+               s.AddArg2(x, y)
+               v0.AddArg2(s, x)
+               v2 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64)
+               v3 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64)
+               v3.AddArg2(s, c)
+               v2.AddArg2(v3, s)
+               v.AddArg2(v0, v2)
+               return true
+       }
        // match: (Select1 m:(LoweredMuluhilo x y))
        // cond: m.Uses == 1
        // result: (MUL x y)
index dda813518a5f94d3e0361f3bee58723503401163..107944170fcaaa6137e046483e10c6599d92788f 100644 (file)
@@ -4726,8 +4726,8 @@ func InitTables() {
                func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
                        return s.newValue3(ssa.OpAdd64carry, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2])
                },
-               sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X)
-       alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X)
+               sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X, sys.RISCV64)
+       alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X, sys.ArchRISCV64)
        addF("math/bits", "Sub64",
                func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
                        return s.newValue3(ssa.OpSub64borrow, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2])
index a507d32843d6333105c76ab95453bf756e5a1511..f36916ad03a37e6ab895c204ea0b3c65e999fe43 100644 (file)
@@ -442,6 +442,7 @@ func Add(x, y, ci uint) (r, co uint) {
        // ppc64: "ADDC", "ADDE", "ADDZE"
        // ppc64le: "ADDC", "ADDE", "ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add(x, y, ci)
 }
 
@@ -451,6 +452,7 @@ func AddC(x, ci uint) (r, co uint) {
        // ppc64: "ADDC", "ADDE", "ADDZE"
        // ppc64le: "ADDC", "ADDE", "ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add(x, 7, ci)
 }
 
@@ -460,6 +462,7 @@ func AddZ(x, y uint) (r, co uint) {
        // ppc64: "ADDC", -"ADDE", "ADDZE"
        // ppc64le: "ADDC", -"ADDE", "ADDZE"
        // s390x:"ADDC",-"ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add(x, y, 0)
 }
 
@@ -469,6 +472,7 @@ func AddR(x, y, ci uint) uint {
        // ppc64: "ADDC", "ADDE", -"ADDZE"
        // ppc64le: "ADDC", "ADDE", -"ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD",-"SLTU"
        r, _ := bits.Add(x, y, ci)
        return r
 }
@@ -489,6 +493,7 @@ func Add64(x, y, ci uint64) (r, co uint64) {
        // ppc64: "ADDC", "ADDE", "ADDZE"
        // ppc64le: "ADDC", "ADDE", "ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add64(x, y, ci)
 }
 
@@ -498,6 +503,7 @@ func Add64C(x, ci uint64) (r, co uint64) {
        // ppc64: "ADDC", "ADDE", "ADDZE"
        // ppc64le: "ADDC", "ADDE", "ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add64(x, 7, ci)
 }
 
@@ -507,6 +513,7 @@ func Add64Z(x, y uint64) (r, co uint64) {
        // ppc64: "ADDC", -"ADDE", "ADDZE"
        // ppc64le: "ADDC", -"ADDE", "ADDZE"
        // s390x:"ADDC",-"ADDC\t[$]-1,"
+       // riscv64: "ADD","SLTU"
        return bits.Add64(x, y, 0)
 }
 
@@ -516,6 +523,7 @@ func Add64R(x, y, ci uint64) uint64 {
        // ppc64: "ADDC", "ADDE", -"ADDZE"
        // ppc64le: "ADDC", "ADDE", -"ADDZE"
        // s390x:"ADDE","ADDC\t[$]-1,"
+       // riscv64: "ADD",-"SLTU"
        r, _ := bits.Add64(x, y, ci)
        return r
 }