From 997636760e2d981bb2f5ba486e0702e60a07ba16 Mon Sep 17 00:00:00 2001 From: Joel Sing Date: Thu, 8 Feb 2024 13:54:10 +1100 Subject: [PATCH] cmd/compile,cmd/internal/obj: provide rotation pseudo-instructions for riscv64 Provide and use rotation pseudo-instructions for riscv64. The RISC-V bitmanip extension adds support for hardware rotation instructions in the form of ROL, ROLW, ROR, RORI, RORIW and RORW. These are easily implemented in the assembler as pseudo-instructions for CPUs that do not support the bitmanip extension. This approach provides a number of advantages, including reducing the rewrite rules needed in the compiler, simplifying codegen tests and most importantly, allowing these instructions to be used in assembly (for example, riscv64 optimised versions of SHA-256 and SHA-512). When bitmanip support is added, these instruction sequences can simply be replaced with a single instruction if permitted by the GORISCV64 profile. Change-Id: Ia23402e1a82f211ac760690deb063386056ae1fa Reviewed-on: https://go-review.googlesource.com/c/go/+/565015 TryBot-Result: Gopher Robot Reviewed-by: Michael Knyszek Reviewed-by: M Zhuo Reviewed-by: Carlos Amedee LUCI-TryBot-Result: Go LUCI Run-TryBot: Joel Sing --- src/cmd/asm/internal/asm/testdata/riscv64.s | 18 ++ src/cmd/compile/internal/riscv64/ssa.go | 4 +- .../compile/internal/ssa/_gen/RISCV64.rules | 15 +- .../compile/internal/ssa/_gen/RISCV64Ops.go | 14 +- src/cmd/compile/internal/ssa/opGen.go | 124 ++++++++++-- .../compile/internal/ssa/rewriteRISCV64.go | 185 +++++++++++++----- src/cmd/internal/obj/riscv/anames.go | 6 + src/cmd/internal/obj/riscv/cpu.go | 6 + src/cmd/internal/obj/riscv/obj.go | 51 ++++- src/crypto/sha512/sha512block_riscv64.s | 25 +-- test/codegen/rotate.go | 24 +-- 11 files changed, 376 insertions(+), 96 deletions(-) diff --git a/src/cmd/asm/internal/asm/testdata/riscv64.s b/src/cmd/asm/internal/asm/testdata/riscv64.s index a5ab254eaa..f944072c6e 100644 --- a/src/cmd/asm/internal/asm/testdata/riscv64.s +++ b/src/cmd/asm/internal/asm/testdata/riscv64.s @@ -417,6 +417,24 @@ start: NEGW X5 // bb025040 NEGW X5, X6 // 3b035040 + // Bitwise rotation pseudo-instructions + ROL X5, X6, X7 // b30f5040b35ff301b3135300b3e37f00 + ROL X5, X6 // b30f5040b35ff3013313530033e36f00 + ROLW X5, X6, X7 // b30f5040bb5ff301bb135300b3e37f00 + ROLW X5, X6 // b30f5040bb5ff3013b13530033e36f00 + ROR X5, X6, X7 // b30f5040b31ff301b3535300b3e37f00 + ROR X5, X6 // b30f5040b31ff3013353530033e36f00 + RORW X5, X6, X7 // b30f5040bb1ff301bb535300b3e37f00 + RORW X5, X6 // b30f5040bb1ff3013b53530033e36f00 + RORI $5, X6, X7 // 935f53009313b303b3e37f00 + RORI $5, X6 // 935f53001313b30333e36f00 + RORIW $5, X6, X7 // 9b5f53009b13b301b3e37f00 + RORIW $5, X6 // 9b5f53001b13b30133e36f00 + ROR $5, X6, X7 // 935f53009313b303b3e37f00 + ROR $5, X6 // 935f53001313b30333e36f00 + RORW $5, X6, X7 // 9b5f53009b13b301b3e37f00 + RORW $5, X6 // 9b5f53001b13b30133e36f00 + // This jumps to the second instruction in the function (the // first instruction is an invisible stack pointer adjustment). JMP start // JMP 2 diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 17f0d98532..c9e75b2180 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -283,6 +283,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { ssa.OpRISCV64MULHU, ssa.OpRISCV64DIV, ssa.OpRISCV64DIVU, ssa.OpRISCV64DIVW, ssa.OpRISCV64DIVUW, ssa.OpRISCV64REM, ssa.OpRISCV64REMU, ssa.OpRISCV64REMW, ssa.OpRISCV64REMUW, + ssa.OpRISCV64ROL, ssa.OpRISCV64ROLW, ssa.OpRISCV64ROR, ssa.OpRISCV64RORW, ssa.OpRISCV64FADDS, ssa.OpRISCV64FSUBS, ssa.OpRISCV64FMULS, ssa.OpRISCV64FDIVS, ssa.OpRISCV64FEQS, ssa.OpRISCV64FNES, ssa.OpRISCV64FLTS, ssa.OpRISCV64FLES, ssa.OpRISCV64FADDD, ssa.OpRISCV64FSUBD, ssa.OpRISCV64FMULD, ssa.OpRISCV64FDIVD, @@ -423,7 +424,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Reg = v.Reg() case ssa.OpRISCV64ADDI, ssa.OpRISCV64ADDIW, ssa.OpRISCV64XORI, ssa.OpRISCV64ORI, ssa.OpRISCV64ANDI, ssa.OpRISCV64SLLI, ssa.OpRISCV64SLLIW, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRAIW, - ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, ssa.OpRISCV64SLTIU: + ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, ssa.OpRISCV64SLTIU, + ssa.OpRISCV64RORI, ssa.OpRISCV64RORIW: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_CONST p.From.Offset = v.AuxInt diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 135d70bc47..c2df433315 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -216,8 +216,8 @@ // Rotates. (RotateLeft8 x y) => (OR (SLL x (ANDI [7] y)) (SRL (ZeroExt8to64 x) (ANDI [7] (NEG y)))) (RotateLeft16 x y) => (OR (SLL x (ANDI [15] y)) (SRL (ZeroExt16to64 x) (ANDI [15] (NEG y)))) -(RotateLeft32 x y) => (OR (SLLW x y) (SRLW x (NEG y))) -(RotateLeft64 x y) => (OR (SLL x y) (SRL x (NEG y))) +(RotateLeft32 ...) => (ROLW ...) +(RotateLeft64 ...) => (ROL ...) (Less64 ...) => (SLT ...) (Less32 x y) => (SLT (SignExt32to64 x) (SignExt32to64 y)) @@ -665,6 +665,9 @@ (MOVWreg x:(DIVUW _ _)) => (MOVDreg x) (MOVWreg x:(REMW _ _)) => (MOVDreg x) (MOVWreg x:(REMUW _ _)) => (MOVDreg x) +(MOVWreg x:(ROLW _ _)) => (MOVDreg x) +(MOVWreg x:(RORW _ _)) => (MOVDreg x) +(MOVWreg x:(RORIW _)) => (MOVDreg x) // Fold double extensions. (MOVBreg x:(MOVBreg _)) => (MOVDreg x) @@ -731,6 +734,10 @@ (AND (MOVDconst [val]) x) && is32Bit(val) => (ANDI [val] x) (OR (MOVDconst [val]) x) && is32Bit(val) => (ORI [val] x) (XOR (MOVDconst [val]) x) && is32Bit(val) => (XORI [val] x) +(ROL x (MOVDconst [val])) => (RORI [int64(int8(-val)&63)] x) +(ROLW x (MOVDconst [val])) => (RORIW [int64(int8(-val)&31)] x) +(ROR x (MOVDconst [val])) => (RORI [int64(val&63)] x) +(RORW x (MOVDconst [val])) => (RORIW [int64(val&31)] x) (SLL x (MOVDconst [val])) => (SLLI [int64(val&63)] x) (SRL x (MOVDconst [val])) => (SRLI [int64(val&63)] x) (SLLW x (MOVDconst [val])) => (SLLIW [int64(val&31)] x) @@ -740,6 +747,10 @@ (SLT x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTI [val] x) (SLTU x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTIU [val] x) +// Replace negated left rotation with right rotation. +(ROL x (NEG y)) => (ROR x y) +(ROLW x (NEG y)) => (RORW x y) + // Convert const subtraction into ADDI with negative immediate, where possible. (SUB x (MOVDconst [val])) && is32Bit(-val) => (ADDI [-val] x) (SUB (MOVDconst [val]) y) && is32Bit(-val) => (NEG (ADDI [-val] y)) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index e9f1df0d58..13fa91864b 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -221,13 +221,19 @@ func init() { {name: "SRLIW", argLength: 1, reg: gp11, asm: "SRLIW", aux: "Int64"}, // arg0 >> auxint, shift amount 0-31, logical right shift of 32 bit value, sign extended to 64 bits // Bitwise ops - {name: "XOR", argLength: 2, reg: gp21, asm: "XOR", commutative: true}, // arg0 ^ arg1 - {name: "XORI", argLength: 1, reg: gp11, asm: "XORI", aux: "Int64"}, // arg0 ^ auxint - {name: "OR", argLength: 2, reg: gp21, asm: "OR", commutative: true}, // arg0 | arg1 - {name: "ORI", argLength: 1, reg: gp11, asm: "ORI", aux: "Int64"}, // arg0 | auxint {name: "AND", argLength: 2, reg: gp21, asm: "AND", commutative: true}, // arg0 & arg1 {name: "ANDI", argLength: 1, reg: gp11, asm: "ANDI", aux: "Int64"}, // arg0 & auxint {name: "NOT", argLength: 1, reg: gp11, asm: "NOT"}, // ^arg0 + {name: "OR", argLength: 2, reg: gp21, asm: "OR", commutative: true}, // arg0 | arg1 + {name: "ORI", argLength: 1, reg: gp11, asm: "ORI", aux: "Int64"}, // arg0 | auxint + {name: "ROL", argLength: 2, reg: gp21, asm: "ROL"}, // rotate left arg0 by (arg1 & 63) + {name: "ROLW", argLength: 2, reg: gp21, asm: "ROLW"}, // rotate left least significant word of arg0 by (arg1 & 31), sign extended + {name: "ROR", argLength: 2, reg: gp21, asm: "ROR"}, // rotate right arg0 by (arg1 & 63) + {name: "RORI", argLength: 1, reg: gp11, asm: "RORI", aux: "Int64"}, // rotate right arg0 by auxint, shift amount 0-63 + {name: "RORIW", argLength: 1, reg: gp11, asm: "RORIW", aux: "Int64"}, // rotate right least significant word of arg0 by auxint, shift amount 0-31, sign extended + {name: "RORW", argLength: 2, reg: gp21, asm: "RORW"}, // rotate right least significant word of arg0 by (arg1 & 31), sign extended + {name: "XOR", argLength: 2, reg: gp21, asm: "XOR", commutative: true}, // arg0 ^ arg1 + {name: "XORI", argLength: 1, reg: gp11, asm: "XORI", aux: "Int64"}, // arg0 ^ auxint // Generate boolean values {name: "SEQZ", argLength: 1, reg: gp11, asm: "SEQZ"}, // arg0 == 0, result is 0 or 1 diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 2378c7abc2..aa896784f3 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2399,13 +2399,19 @@ const ( OpRISCV64SRAIW OpRISCV64SRLI OpRISCV64SRLIW - OpRISCV64XOR - OpRISCV64XORI - OpRISCV64OR - OpRISCV64ORI OpRISCV64AND OpRISCV64ANDI OpRISCV64NOT + OpRISCV64OR + OpRISCV64ORI + OpRISCV64ROL + OpRISCV64ROLW + OpRISCV64ROR + OpRISCV64RORI + OpRISCV64RORIW + OpRISCV64RORW + OpRISCV64XOR + OpRISCV64XORI OpRISCV64SEQZ OpRISCV64SNEZ OpRISCV64SLT @@ -32202,10 +32208,10 @@ var opcodeTable = [...]opInfo{ }, }, { - name: "XOR", + name: "AND", argLen: 2, commutative: true, - asm: riscv.AXOR, + asm: riscv.AAND, reg: regInfo{ inputs: []inputInfo{ {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 @@ -32217,10 +32223,23 @@ var opcodeTable = [...]opInfo{ }, }, { - name: "XORI", + name: "ANDI", auxType: auxInt64, argLen: 1, - asm: riscv.AXORI, + asm: riscv.AANDI, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "NOT", + argLen: 1, + asm: riscv.ANOT, reg: regInfo{ inputs: []inputInfo{ {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 @@ -32260,10 +32279,9 @@ var opcodeTable = [...]opInfo{ }, }, { - name: "AND", - argLen: 2, - commutative: true, - asm: riscv.AAND, + name: "ROL", + argLen: 2, + asm: riscv.AROL, reg: regInfo{ inputs: []inputInfo{ {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 @@ -32275,10 +32293,38 @@ var opcodeTable = [...]opInfo{ }, }, { - name: "ANDI", + name: "ROLW", + argLen: 2, + asm: riscv.AROLW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "ROR", + argLen: 2, + asm: riscv.AROR, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "RORI", auxType: auxInt64, argLen: 1, - asm: riscv.AANDI, + asm: riscv.ARORI, reg: regInfo{ inputs: []inputInfo{ {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 @@ -32289,9 +32335,53 @@ var opcodeTable = [...]opInfo{ }, }, { - name: "NOT", - argLen: 1, - asm: riscv.ANOT, + name: "RORIW", + auxType: auxInt64, + argLen: 1, + asm: riscv.ARORIW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "RORW", + argLen: 2, + asm: riscv.ARORW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "XOR", + argLen: 2, + commutative: true, + asm: riscv.AXOR, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, + { + name: "XORI", + auxType: auxInt64, + argLen: 1, + asm: riscv.AXORI, reg: regInfo{ inputs: []inputInfo{ {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 9b81676001..f033b25bdd 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -530,6 +530,14 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64OR(v) case OpRISCV64ORI: return rewriteValueRISCV64_OpRISCV64ORI(v) + case OpRISCV64ROL: + return rewriteValueRISCV64_OpRISCV64ROL(v) + case OpRISCV64ROLW: + return rewriteValueRISCV64_OpRISCV64ROLW(v) + case OpRISCV64ROR: + return rewriteValueRISCV64_OpRISCV64ROR(v) + case OpRISCV64RORW: + return rewriteValueRISCV64_OpRISCV64RORW(v) case OpRISCV64SEQZ: return rewriteValueRISCV64_OpRISCV64SEQZ(v) case OpRISCV64SLL: @@ -569,9 +577,11 @@ func rewriteValueRISCV64(v *Value) bool { case OpRotateLeft16: return rewriteValueRISCV64_OpRotateLeft16(v) case OpRotateLeft32: - return rewriteValueRISCV64_OpRotateLeft32(v) + v.Op = OpRISCV64ROLW + return true case OpRotateLeft64: - return rewriteValueRISCV64_OpRotateLeft64(v) + v.Op = OpRISCV64ROL + return true case OpRotateLeft8: return rewriteValueRISCV64_OpRotateLeft8(v) case OpRound32F: @@ -5624,6 +5634,39 @@ func rewriteValueRISCV64_OpRISCV64MOVWreg(v *Value) bool { v.AddArg(x) return true } + // match: (MOVWreg x:(ROLW _ _)) + // result: (MOVDreg x) + for { + x := v_0 + if x.Op != OpRISCV64ROLW { + break + } + v.reset(OpRISCV64MOVDreg) + v.AddArg(x) + return true + } + // match: (MOVWreg x:(RORW _ _)) + // result: (MOVDreg x) + for { + x := v_0 + if x.Op != OpRISCV64RORW { + break + } + v.reset(OpRISCV64MOVDreg) + v.AddArg(x) + return true + } + // match: (MOVWreg x:(RORIW _)) + // result: (MOVDreg x) + for { + x := v_0 + if x.Op != OpRISCV64RORIW { + break + } + v.reset(OpRISCV64MOVDreg) + v.AddArg(x) + return true + } // match: (MOVWreg x:(MOVBreg _)) // result: (MOVDreg x) for { @@ -5997,6 +6040,102 @@ func rewriteValueRISCV64_OpRISCV64ORI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64ROL(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (ROL x (MOVDconst [val])) + // result: (RORI [int64(int8(-val)&63)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64RORI) + v.AuxInt = int64ToAuxInt(int64(int8(-val) & 63)) + v.AddArg(x) + return true + } + // match: (ROL x (NEG y)) + // result: (ROR x y) + for { + x := v_0 + if v_1.Op != OpRISCV64NEG { + break + } + y := v_1.Args[0] + v.reset(OpRISCV64ROR) + v.AddArg2(x, y) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64ROLW(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (ROLW x (MOVDconst [val])) + // result: (RORIW [int64(int8(-val)&31)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64RORIW) + v.AuxInt = int64ToAuxInt(int64(int8(-val) & 31)) + v.AddArg(x) + return true + } + // match: (ROLW x (NEG y)) + // result: (RORW x y) + for { + x := v_0 + if v_1.Op != OpRISCV64NEG { + break + } + y := v_1.Args[0] + v.reset(OpRISCV64RORW) + v.AddArg2(x, y) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64ROR(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (ROR x (MOVDconst [val])) + // result: (RORI [int64(val&63)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64RORI) + v.AuxInt = int64ToAuxInt(int64(val & 63)) + v.AddArg(x) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64RORW(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (RORW x (MOVDconst [val])) + // result: (RORIW [int64(val&31)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64RORIW) + v.AuxInt = int64ToAuxInt(int64(val & 31)) + v.AddArg(x) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SEQZ(v *Value) bool { v_0 := v.Args[0] // match: (SEQZ (NEG x)) @@ -6689,48 +6828,6 @@ func rewriteValueRISCV64_OpRotateLeft16(v *Value) bool { return true } } -func rewriteValueRISCV64_OpRotateLeft32(v *Value) bool { - v_1 := v.Args[1] - v_0 := v.Args[0] - b := v.Block - // match: (RotateLeft32 x y) - // result: (OR (SLLW x y) (SRLW x (NEG y))) - for { - t := v.Type - x := v_0 - y := v_1 - v.reset(OpRISCV64OR) - v0 := b.NewValue0(v.Pos, OpRISCV64SLLW, t) - v0.AddArg2(x, y) - v1 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) - v2 := b.NewValue0(v.Pos, OpRISCV64NEG, y.Type) - v2.AddArg(y) - v1.AddArg2(x, v2) - v.AddArg2(v0, v1) - return true - } -} -func rewriteValueRISCV64_OpRotateLeft64(v *Value) bool { - v_1 := v.Args[1] - v_0 := v.Args[0] - b := v.Block - // match: (RotateLeft64 x y) - // result: (OR (SLL x y) (SRL x (NEG y))) - for { - t := v.Type - x := v_0 - y := v_1 - v.reset(OpRISCV64OR) - v0 := b.NewValue0(v.Pos, OpRISCV64SLL, t) - v0.AddArg2(x, y) - v1 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v2 := b.NewValue0(v.Pos, OpRISCV64NEG, y.Type) - v2.AddArg(y) - v1.AddArg2(x, v2) - v.AddArg2(v0, v1) - return true - } -} func rewriteValueRISCV64_OpRotateLeft8(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/src/cmd/internal/obj/riscv/anames.go b/src/cmd/internal/obj/riscv/anames.go index d2c41971b8..e547c6d5e9 100644 --- a/src/cmd/internal/obj/riscv/anames.go +++ b/src/cmd/internal/obj/riscv/anames.go @@ -246,6 +246,12 @@ var Anames = []string{ "NEG", "NEGW", "NOT", + "ROL", + "ROLW", + "ROR", + "RORI", + "RORIW", + "RORW", "SEQZ", "SNEZ", "LAST", diff --git a/src/cmd/internal/obj/riscv/cpu.go b/src/cmd/internal/obj/riscv/cpu.go index 919f07b1a9..00513a4a79 100644 --- a/src/cmd/internal/obj/riscv/cpu.go +++ b/src/cmd/internal/obj/riscv/cpu.go @@ -605,6 +605,12 @@ const ( ANEG ANEGW ANOT + AROL + AROLW + AROR + ARORI + ARORIW + ARORW ASEQZ ASNEZ diff --git a/src/cmd/internal/obj/riscv/obj.go b/src/cmd/internal/obj/riscv/obj.go index 3ec740f85a..f508adafed 100644 --- a/src/cmd/internal/obj/riscv/obj.go +++ b/src/cmd/internal/obj/riscv/obj.go @@ -59,7 +59,8 @@ func progedit(ctxt *obj.Link, p *obj.Prog, newprog obj.ProgAlloc) { AADDIW, ASLLIW, ASRLIW, ASRAIW, AADDW, ASUBW, ASLLW, ASRLW, ASRAW, AADD, AAND, AOR, AXOR, ASLL, ASRL, ASUB, ASRA, AMUL, AMULH, AMULHU, AMULHSU, AMULW, ADIV, ADIVU, ADIVW, ADIVUW, - AREM, AREMU, AREMW, AREMUW: + AREM, AREMU, AREMW, AREMUW, + AROL, AROLW, AROR, ARORW, ARORI, ARORIW: p.Reg = p.To.Reg } } @@ -90,6 +91,10 @@ func progedit(ctxt *obj.Link, p *obj.Prog, newprog obj.ProgAlloc) { p.As = ASRAI case AADDW: p.As = AADDIW + case AROR: + p.As = ARORI + case ARORW: + p.As = ARORIW case ASUBW: p.As, p.From.Offset = AADDIW, -p.From.Offset case ASLLW: @@ -2193,6 +2198,47 @@ func instructionsForMOV(p *obj.Prog) []*instruction { return inss } +// instructionsForRotate returns the machine instructions for a bitwise rotation. +func instructionsForRotate(p *obj.Prog, ins *instruction) []*instruction { + switch ins.as { + case AROL, AROLW, AROR, ARORW: + // ROL -> OR (SLL x y) (SRL x (NEG y)) + // ROR -> OR (SRL x y) (SLL x (NEG y)) + sllOp, srlOp := ASLL, ASRL + if ins.as == AROLW || ins.as == ARORW { + sllOp, srlOp = ASLLW, ASRLW + } + shift1, shift2 := sllOp, srlOp + if ins.as == AROR || ins.as == ARORW { + shift1, shift2 = shift2, shift1 + } + return []*instruction{ + &instruction{as: ASUB, rs1: REG_ZERO, rs2: ins.rs2, rd: REG_TMP}, + &instruction{as: shift2, rs1: ins.rs1, rs2: REG_TMP, rd: REG_TMP}, + &instruction{as: shift1, rs1: ins.rs1, rs2: ins.rs2, rd: ins.rd}, + &instruction{as: AOR, rs1: REG_TMP, rs2: ins.rd, rd: ins.rd}, + } + + case ARORI, ARORIW: + // ROR -> OR (SLLI -x y) (SRLI x y) + sllOp, srlOp := ASLLI, ASRLI + sllImm := int64(int8(-ins.imm) & 63) + if ins.as == ARORIW { + sllOp, srlOp = ASLLIW, ASRLIW + sllImm = int64(int8(-ins.imm) & 31) + } + return []*instruction{ + &instruction{as: srlOp, rs1: ins.rs1, rd: REG_TMP, imm: ins.imm}, + &instruction{as: sllOp, rs1: ins.rs1, rd: ins.rd, imm: sllImm}, + &instruction{as: AOR, rs1: REG_TMP, rs2: ins.rd, rd: ins.rd}, + } + + default: + p.Ctxt.Diag("%v: unknown rotation", p) + return nil + } +} + // instructionsForProg returns the machine instructions for an *obj.Prog. func instructionsForProg(p *obj.Prog) []*instruction { ins := instructionForProg(p) @@ -2363,6 +2409,9 @@ func instructionsForProg(p *obj.Prog) []*instruction { ins.as = AFSGNJND ins.rs1 = uint32(p.From.Reg) + case AROL, AROLW, AROR, ARORW, ARORI, ARORIW: + inss = instructionsForRotate(p, ins) + case ASLLI, ASRLI, ASRAI: if ins.imm < 0 || ins.imm > 63 { p.Ctxt.Diag("%v: shift amount out of range 0 to 63", p) diff --git a/src/crypto/sha512/sha512block_riscv64.s b/src/crypto/sha512/sha512block_riscv64.s index 6fbd524a31..7dcb0f80d0 100644 --- a/src/crypto/sha512/sha512block_riscv64.s +++ b/src/crypto/sha512/sha512block_riscv64.s @@ -46,11 +46,6 @@ // H6 = g + H6 // H7 = h + H7 -#define ROR(s, r, d, t1, t2) \ - SLL $(64-s), r, t1; \ - SRL $(s), r, t2; \ - OR t1, t2, d - // Wt = Mt; for 0 <= t <= 15 #define MSGSCHEDULE0(index) \ MOVBU ((index*8)+0)(X29), X5; \ @@ -85,14 +80,14 @@ MOV (((index-15)&0xf)*8)(X19), X6; \ MOV (((index-7)&0xf)*8)(X19), X9; \ MOV (((index-16)&0xf)*8)(X19), X21; \ - ROR(19, X5, X7, X23, X24); \ - ROR(61, X5, X8, X23, X24); \ + ROR $19, X5, X7; \ + ROR $61, X5, X8; \ SRL $6, X5; \ XOR X7, X5; \ XOR X8, X5; \ ADD X9, X5; \ - ROR(1, X6, X7, X23, X24); \ - ROR(8, X6, X8, X23, X24); \ + ROR $1, X6, X7; \ + ROR $8, X6, X8; \ SRL $7, X6; \ XOR X7, X6; \ XOR X8, X6; \ @@ -108,11 +103,11 @@ #define SHA512T1(index, e, f, g, h) \ MOV (index*8)(X18), X8; \ ADD X5, h; \ - ROR(14, e, X6, X23, X24); \ + ROR $14, e, X6; \ ADD X8, h; \ - ROR(18, e, X7, X23, X24); \ + ROR $18, e, X7; \ XOR X7, X6; \ - ROR(41, e, X8, X23, X24); \ + ROR $41, e, X8; \ XOR X8, X6; \ ADD X6, h; \ AND e, f, X5; \ @@ -126,10 +121,10 @@ // BIGSIGMA0(x) = ROTR(28,x) XOR ROTR(34,x) XOR ROTR(39,x) // Maj(x, y, z) = (x AND y) XOR (x AND z) XOR (y AND z) #define SHA512T2(a, b, c) \ - ROR(28, a, X6, X23, X24); \ - ROR(34, a, X7, X23, X24); \ + ROR $28, a, X6; \ + ROR $34, a, X7; \ XOR X7, X6; \ - ROR(39, a, X8, X23, X24); \ + ROR $39, a, X8; \ XOR X8, X6; \ AND a, b, X7; \ AND a, c, X8; \ diff --git a/test/codegen/rotate.go b/test/codegen/rotate.go index 109e55763c..121ce4cc0a 100644 --- a/test/codegen/rotate.go +++ b/test/codegen/rotate.go @@ -18,7 +18,7 @@ func rot64(x uint64) uint64 { // amd64:"ROLQ\t[$]7" // ppc64x:"ROTL\t[$]7" // loong64: "ROTRV\t[$]57" - // riscv64: "OR","SLLI","SRLI",-"AND" + // riscv64: "RORI\t[$]57" a += x<<7 | x>>57 // amd64:"ROLQ\t[$]8" @@ -26,7 +26,7 @@ func rot64(x uint64) uint64 { // s390x:"RISBGZ\t[$]0, [$]63, [$]8, " // ppc64x:"ROTL\t[$]8" // loong64: "ROTRV\t[$]56" - // riscv64: "OR","SLLI","SRLI",-"AND" + // riscv64: "RORI\t[$]56" a += x<<8 + x>>56 // amd64:"ROLQ\t[$]9" @@ -34,7 +34,7 @@ func rot64(x uint64) uint64 { // s390x:"RISBGZ\t[$]0, [$]63, [$]9, " // ppc64x:"ROTL\t[$]9" // loong64: "ROTRV\t[$]55" - // riscv64: "OR","SLLI","SRLI",-"AND" + // riscv64: "RORI\t[$]55" a += x<<9 ^ x>>55 // amd64:"ROLQ\t[$]10" @@ -44,7 +44,7 @@ func rot64(x uint64) uint64 { // arm64:"ROR\t[$]54" // s390x:"RISBGZ\t[$]0, [$]63, [$]10, " // loong64: "ROTRV\t[$]54" - // riscv64: "OR","SLLI","SRLI",-"AND" + // riscv64: "RORI\t[$]54" a += bits.RotateLeft64(x, 10) return a @@ -57,7 +57,7 @@ func rot32(x uint32) uint32 { // arm:"MOVW\tR\\d+@>25" // ppc64x:"ROTLW\t[$]7" // loong64: "ROTR\t[$]25" - // riscv64: "OR","SLLIW","SRLIW",-"AND" + // riscv64: "RORIW\t[$]25" a += x<<7 | x>>25 // amd64:`ROLL\t[$]8` @@ -66,7 +66,7 @@ func rot32(x uint32) uint32 { // s390x:"RLL\t[$]8" // ppc64x:"ROTLW\t[$]8" // loong64: "ROTR\t[$]24" - // riscv64: "OR","SLLIW","SRLIW",-"AND" + // riscv64: "RORIW\t[$]24" a += x<<8 + x>>24 // amd64:"ROLL\t[$]9" @@ -75,7 +75,7 @@ func rot32(x uint32) uint32 { // s390x:"RLL\t[$]9" // ppc64x:"ROTLW\t[$]9" // loong64: "ROTR\t[$]23" - // riscv64: "OR","SLLIW","SRLIW",-"AND" + // riscv64: "RORIW\t[$]23" a += x<<9 ^ x>>23 // amd64:"ROLL\t[$]10" @@ -86,7 +86,7 @@ func rot32(x uint32) uint32 { // arm64:"RORW\t[$]22" // s390x:"RLL\t[$]10" // loong64: "ROTR\t[$]22" - // riscv64: "OR","SLLIW","SRLIW",-"AND" + // riscv64: "RORIW\t[$]22" a += bits.RotateLeft32(x, 10) return a @@ -141,14 +141,14 @@ func rot64nc(x uint64, z uint) uint64 { // arm64:"ROR","NEG",-"AND" // ppc64x:"ROTL",-"NEG",-"AND" // loong64: "ROTRV", -"AND" - // riscv64: "OR","SLL","SRL",-"AND" + // riscv64: "ROL",-"AND" a += x<>(64-z) // amd64:"RORQ",-"AND" // arm64:"ROR",-"NEG",-"AND" // ppc64x:"ROTL","NEG",-"AND" // loong64: "ROTRV", -"AND" - // riscv64: "OR","SLL","SRL",-"AND" + // riscv64: "ROR",-"AND" a += x>>z | x<<(64-z) return a @@ -163,14 +163,14 @@ func rot32nc(x uint32, z uint) uint32 { // arm64:"ROR","NEG",-"AND" // ppc64x:"ROTLW",-"NEG",-"AND" // loong64: "ROTR", -"AND" - // riscv64: "OR","SLLW","SRLW",-"AND" + // riscv64: "ROLW",-"AND" a += x<>(32-z) // amd64:"RORL",-"AND" // arm64:"ROR",-"NEG",-"AND" // ppc64x:"ROTLW","NEG",-"AND" // loong64: "ROTR", -"AND" - // riscv64: "OR","SLLW","SRLW",-"AND" + // riscv64: "RORW",-"AND" a += x>>z | x<<(32-z) return a -- 2.48.1