From ea51e223c28babc530df475497de0be4579b5e86 Mon Sep 17 00:00:00 2001 From: Michael Munday Date: Wed, 17 Feb 2021 15:00:34 +0000 Subject: [PATCH] cmd/{asm,compile}: add fused multiply-add support on riscv64 Add support to the assembler for F[N]M{ADD,SUB}[SD] instructions. Argument order is: OP RS1, RS2, RS3, RD Also, add support for the FMA intrinsic to the compiler. Automatic FMA matching is left to a future CL. Change-Id: I47166c7393b2ab6bfc2e42aa8c1a8997c3a071b3 Reviewed-on: https://go-review.googlesource.com/c/go/+/293030 Trust: Michael Munday Run-TryBot: Michael Munday TryBot-Result: Go Bot Reviewed-by: Joel Sing --- src/cmd/asm/internal/asm/asm.go | 7 + src/cmd/asm/internal/asm/testdata/riscv64.s | 8 + src/cmd/compile/internal/riscv64/ssa.go | 13 +- .../compile/internal/ssa/gen/RISCV64.rules | 15 ++ .../compile/internal/ssa/gen/RISCV64Ops.go | 5 + src/cmd/compile/internal/ssa/opGen.go | 68 +++++++ .../compile/internal/ssa/rewriteRISCV64.go | 191 ++++++++++++++++++ src/cmd/compile/internal/ssagen/ssa.go | 2 +- src/cmd/internal/obj/riscv/obj.go | 104 ++++++++-- test/codegen/math.go | 16 ++ 10 files changed, 406 insertions(+), 23 deletions(-) diff --git a/src/cmd/asm/internal/asm/asm.go b/src/cmd/asm/internal/asm/asm.go index cf0d1550f9..d0cb6328f1 100644 --- a/src/cmd/asm/internal/asm/asm.go +++ b/src/cmd/asm/internal/asm/asm.go @@ -793,6 +793,13 @@ func (p *Parser) asmInstruction(op obj.As, cond string, a []obj.Addr) { return } } + if p.arch.Family == sys.RISCV64 { + prog.From = a[0] + prog.Reg = p.getRegister(prog, op, &a[1]) + prog.SetRestArgs([]obj.Addr{a[2]}) + prog.To = a[3] + break + } if p.arch.Family == sys.S390X { if a[1].Type != obj.TYPE_REG { p.errorf("second operand must be a register in %s instruction", op) diff --git a/src/cmd/asm/internal/asm/testdata/riscv64.s b/src/cmd/asm/internal/asm/testdata/riscv64.s index 628a8d91cd..173c50f2e1 100644 --- a/src/cmd/asm/internal/asm/testdata/riscv64.s +++ b/src/cmd/asm/internal/asm/testdata/riscv64.s @@ -214,6 +214,10 @@ start: FMVSX X5, F0 // 538002f0 FMVXW F0, X5 // d30200e0 FMVWX X5, F0 // 538002f0 + FMADDS F1, F2, F3, F4 // 43822018 + FMSUBS F1, F2, F3, F4 // 47822018 + FNMSUBS F1, F2, F3, F4 // 4b822018 + FNMADDS F1, F2, F3, F4 // 4f822018 // 11.8: Single-Precision Floating-Point Compare Instructions FEQS F0, F1, X7 // d3a300a0 @@ -254,6 +258,10 @@ start: FSGNJXD F1, F0, F2 // 53211022 FMVXD F0, X5 // d30200e2 FMVDX X5, F0 // 538002f2 + FMADDD F1, F2, F3, F4 // 4382201a + FMSUBD F1, F2, F3, F4 // 4782201a + FNMSUBD F1, F2, F3, F4 // 4b82201a + FNMADDD F1, F2, F3, F4 // 4f82201a // 12.6: Double-Precision Floating-Point Classify Instruction FCLASSD F0, X5 // d31200e2 diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index d3cbb4ec24..30b6d96a89 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -317,7 +317,18 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p2.From.Reg = v.Reg1() p2.To.Type = obj.TYPE_REG p2.To.Reg = v.Reg1() - + case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD: + r := v.Reg() + r1 := v.Args[0].Reg() + r2 := v.Args[1].Reg() + r3 := v.Args[2].Reg() + p := s.Prog(v.Op.Asm()) + p.From.Type = obj.TYPE_REG + p.From.Reg = r2 + p.Reg = r1 + p.SetRestArgs([]obj.Addr{{Type: obj.TYPE_REG, Reg: r3}}) + p.To.Type = obj.TYPE_REG + p.To.Reg = r case ssa.OpRISCV64FSQRTS, ssa.OpRISCV64FNEGS, ssa.OpRISCV64FSQRTD, ssa.OpRISCV64FNEGD, ssa.OpRISCV64FMVSX, ssa.OpRISCV64FMVDX, ssa.OpRISCV64FCVTSW, ssa.OpRISCV64FCVTSL, ssa.OpRISCV64FCVTWS, ssa.OpRISCV64FCVTLS, diff --git a/src/cmd/compile/internal/ssa/gen/RISCV64.rules b/src/cmd/compile/internal/ssa/gen/RISCV64.rules index 4eb48e3928..b711550186 100644 --- a/src/cmd/compile/internal/ssa/gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/gen/RISCV64.rules @@ -96,6 +96,8 @@ (Sqrt ...) => (FSQRTD ...) (Sqrt32 ...) => (FSQRTS ...) +(FMA ...) => (FMADDD ...) + // Sign and zero extension. (SignExt8to16 ...) => (MOVBreg ...) @@ -713,3 +715,16 @@ // Addition of zero. (ADDI [0] x) => x + +// Merge negation into fused multiply-add and multiply-subtract. +// +// Key: +// +// [+ -](x * y) [+ -] z. +// _ N A S +// D U +// D B +// +// Note: multiplication commutativity handled by rule generator. +(F(MADD|NMADD|MSUB|NMSUB)D neg:(FNEGD x) y z) && neg.Uses == 1 => (F(NMADD|MADD|NMSUB|MSUB)D x y z) +(F(MADD|NMADD|MSUB|NMSUB)D x y neg:(FNEGD z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)D x y z) diff --git a/src/cmd/compile/internal/ssa/gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/gen/RISCV64Ops.go index d36daa8b83..de189e4c60 100644 --- a/src/cmd/compile/internal/ssa/gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/gen/RISCV64Ops.go @@ -132,6 +132,7 @@ func init() { fp11 = regInfo{inputs: []regMask{fpMask}, outputs: []regMask{fpMask}} fp21 = regInfo{inputs: []regMask{fpMask, fpMask}, outputs: []regMask{fpMask}} + fp31 = regInfo{inputs: []regMask{fpMask, fpMask, fpMask}, outputs: []regMask{fpMask}} gpfp = regInfo{inputs: []regMask{gpMask}, outputs: []regMask{fpMask}} fpgp = regInfo{inputs: []regMask{fpMask}, outputs: []regMask{gpMask}} fpstore = regInfo{inputs: []regMask{gpspsbMask, fpMask, 0}} @@ -425,6 +426,10 @@ func init() { {name: "FSUBD", argLength: 2, reg: fp21, asm: "FSUBD", commutative: false, typ: "Float64"}, // arg0 - arg1 {name: "FMULD", argLength: 2, reg: fp21, asm: "FMULD", commutative: true, typ: "Float64"}, // arg0 * arg1 {name: "FDIVD", argLength: 2, reg: fp21, asm: "FDIVD", commutative: false, typ: "Float64"}, // arg0 / arg1 + {name: "FMADDD", argLength: 3, reg: fp31, asm: "FMADDD", commutative: true, typ: "Float64"}, // (arg0 * arg1) + arg2 + {name: "FMSUBD", argLength: 3, reg: fp31, asm: "FMSUBD", commutative: true, typ: "Float64"}, // (arg0 * arg1) - arg2 + {name: "FNMADDD", argLength: 3, reg: fp31, asm: "FNMADDD", commutative: true, typ: "Float64"}, // -(arg0 * arg1) + arg2 + {name: "FNMSUBD", argLength: 3, reg: fp31, asm: "FNMSUBD", commutative: true, typ: "Float64"}, // -(arg0 * arg1) - arg2 {name: "FSQRTD", argLength: 1, reg: fp11, asm: "FSQRTD", typ: "Float64"}, // sqrt(arg0) {name: "FNEGD", argLength: 1, reg: fp11, asm: "FNEGD", typ: "Float64"}, // -arg0 {name: "FMVDX", argLength: 1, reg: gpfp, asm: "FMVDX", typ: "Float64"}, // reinterpret arg0 as float diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 737afc6087..672528aefe 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2174,6 +2174,10 @@ const ( OpRISCV64FSUBD OpRISCV64FMULD OpRISCV64FDIVD + OpRISCV64FMADDD + OpRISCV64FMSUBD + OpRISCV64FNMADDD + OpRISCV64FNMSUBD OpRISCV64FSQRTD OpRISCV64FNEGD OpRISCV64FMVDX @@ -29054,6 +29058,70 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "FMADDD", + argLen: 3, + commutative: true, + asm: riscv.AFMADDD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FMSUBD", + argLen: 3, + commutative: true, + asm: riscv.AFMSUBD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FNMADDD", + argLen: 3, + commutative: true, + asm: riscv.AFNMADDD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FNMSUBD", + argLen: 3, + commutative: true, + asm: riscv.AFNMSUBD, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, { name: "FSQRTD", argLen: 1, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 641be038db..743ff50b0c 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -209,6 +209,9 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpEqB(v) case OpEqPtr: return rewriteValueRISCV64_OpEqPtr(v) + case OpFMA: + v.Op = OpRISCV64FMADDD + return true case OpGetCallerPC: v.Op = OpRISCV64LoweredGetCallerPC return true @@ -432,6 +435,14 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64ADDI(v) case OpRISCV64AND: return rewriteValueRISCV64_OpRISCV64AND(v) + case OpRISCV64FMADDD: + return rewriteValueRISCV64_OpRISCV64FMADDD(v) + case OpRISCV64FMSUBD: + return rewriteValueRISCV64_OpRISCV64FMSUBD(v) + case OpRISCV64FNMADDD: + return rewriteValueRISCV64_OpRISCV64FNMADDD(v) + case OpRISCV64FNMSUBD: + return rewriteValueRISCV64_OpRISCV64FNMSUBD(v) case OpRISCV64MOVBUload: return rewriteValueRISCV64_OpRISCV64MOVBUload(v) case OpRISCV64MOVBUreg: @@ -2829,6 +2840,186 @@ func rewriteValueRISCV64_OpRISCV64AND(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FMADDD neg:(FNEGD x) y z) + // cond: neg.Uses == 1 + // result: (FNMADDD x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGD { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FNMADDD) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FMADDD x y neg:(FNEGD z)) + // cond: neg.Uses == 1 + // result: (FMSUBD x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGD { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FMSUBD) + v.AddArg3(x, y, z) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FMSUBD neg:(FNEGD x) y z) + // cond: neg.Uses == 1 + // result: (FNMSUBD x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGD { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FNMSUBD) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FMSUBD x y neg:(FNEGD z)) + // cond: neg.Uses == 1 + // result: (FMADDD x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGD { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FMADDD) + v.AddArg3(x, y, z) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FNMADDD neg:(FNEGD x) y z) + // cond: neg.Uses == 1 + // result: (FMADDD x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGD { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FMADDD) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FNMADDD x y neg:(FNEGD z)) + // cond: neg.Uses == 1 + // result: (FNMSUBD x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGD { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FNMSUBD) + v.AddArg3(x, y, z) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FNMSUBD neg:(FNEGD x) y z) + // cond: neg.Uses == 1 + // result: (FMSUBD x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGD { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FMSUBD) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FNMSUBD x y neg:(FNEGD z)) + // cond: neg.Uses == 1 + // result: (FNMADDD x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGD { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FNMADDD) + v.AddArg3(x, y, z) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index a64901305f..176e6438dc 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -4168,7 +4168,7 @@ func InitTables() { func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue3(ssa.OpFMA, types.Types[types.TFLOAT64], args[0], args[1], args[2]) }, - sys.ARM64, sys.PPC64, sys.S390X) + sys.ARM64, sys.PPC64, sys.RISCV64, sys.S390X) addF("math", "FMA", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { if !s.config.UseFMA { diff --git a/src/cmd/internal/obj/riscv/obj.go b/src/cmd/internal/obj/riscv/obj.go index f89e13d81c..73f62c007d 100644 --- a/src/cmd/internal/obj/riscv/obj.go +++ b/src/cmd/internal/obj/riscv/obj.go @@ -1214,60 +1214,77 @@ func validateRIII(ctxt *obj.Link, ins *instruction) { wantIntReg(ctxt, ins.as, "rd", ins.rd) wantIntReg(ctxt, ins.as, "rs1", ins.rs1) wantIntReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRFFF(ctxt *obj.Link, ins *instruction) { wantFloatReg(ctxt, ins.as, "rd", ins.rd) wantFloatReg(ctxt, ins.as, "rs1", ins.rs1) wantFloatReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) +} + +func validateRFFFF(ctxt *obj.Link, ins *instruction) { + wantFloatReg(ctxt, ins.as, "rd", ins.rd) + wantFloatReg(ctxt, ins.as, "rs1", ins.rs1) + wantFloatReg(ctxt, ins.as, "rs2", ins.rs2) + wantFloatReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRFFI(ctxt *obj.Link, ins *instruction) { wantIntReg(ctxt, ins.as, "rd", ins.rd) wantFloatReg(ctxt, ins.as, "rs1", ins.rs1) wantFloatReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRFI(ctxt *obj.Link, ins *instruction) { wantIntReg(ctxt, ins.as, "rd", ins.rd) wantNoneReg(ctxt, ins.as, "rs1", ins.rs1) wantFloatReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRIF(ctxt *obj.Link, ins *instruction) { wantFloatReg(ctxt, ins.as, "rd", ins.rd) wantNoneReg(ctxt, ins.as, "rs1", ins.rs1) wantIntReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRFF(ctxt *obj.Link, ins *instruction) { wantFloatReg(ctxt, ins.as, "rd", ins.rd) wantNoneReg(ctxt, ins.as, "rs1", ins.rs1) wantFloatReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateII(ctxt *obj.Link, ins *instruction) { wantImmI(ctxt, ins.as, ins.imm, 12) wantIntReg(ctxt, ins.as, "rd", ins.rd) wantIntReg(ctxt, ins.as, "rs1", ins.rs1) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateIF(ctxt *obj.Link, ins *instruction) { wantImmI(ctxt, ins.as, ins.imm, 12) wantFloatReg(ctxt, ins.as, "rd", ins.rd) wantIntReg(ctxt, ins.as, "rs1", ins.rs1) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateSI(ctxt *obj.Link, ins *instruction) { wantImmI(ctxt, ins.as, ins.imm, 12) wantIntReg(ctxt, ins.as, "rd", ins.rd) wantIntReg(ctxt, ins.as, "rs1", ins.rs1) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateSF(ctxt *obj.Link, ins *instruction) { wantImmI(ctxt, ins.as, ins.imm, 12) wantIntReg(ctxt, ins.as, "rd", ins.rd) wantFloatReg(ctxt, ins.as, "rs1", ins.rs1) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateB(ctxt *obj.Link, ins *instruction) { @@ -1278,6 +1295,7 @@ func validateB(ctxt *obj.Link, ins *instruction) { wantNoneReg(ctxt, ins.as, "rd", ins.rd) wantIntReg(ctxt, ins.as, "rs1", ins.rs1) wantIntReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateU(ctxt *obj.Link, ins *instruction) { @@ -1285,6 +1303,7 @@ func validateU(ctxt *obj.Link, ins *instruction) { wantIntReg(ctxt, ins.as, "rd", ins.rd) wantNoneReg(ctxt, ins.as, "rs1", ins.rs1) wantNoneReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateJ(ctxt *obj.Link, ins *instruction) { @@ -1295,6 +1314,7 @@ func validateJ(ctxt *obj.Link, ins *instruction) { wantIntReg(ctxt, ins.as, "rd", ins.rd) wantNoneReg(ctxt, ins.as, "rs1", ins.rs1) wantNoneReg(ctxt, ins.as, "rs2", ins.rs2) + wantNoneReg(ctxt, ins.as, "rs3", ins.rs3) } func validateRaw(ctxt *obj.Link, ins *instruction) { @@ -1317,6 +1337,22 @@ func encodeR(as obj.As, rs1, rs2, rd, funct3, funct7 uint32) uint32 { return funct7<<25 | enc.funct7<<25 | enc.rs2<<20 | rs2<<20 | rs1<<15 | enc.funct3<<12 | funct3<<12 | rd<<7 | enc.opcode } +// encodeR4 encodes an R4-type RISC-V instruction. +func encodeR4(as obj.As, rs1, rs2, rs3, rd, funct3, funct2 uint32) uint32 { + enc := encode(as) + if enc == nil { + panic("encodeR4: could not encode instruction") + } + if enc.rs2 != 0 { + panic("encodeR4: instruction uses rs2") + } + funct2 |= enc.funct7 + if funct2&^3 != 0 { + panic("encodeR4: funct2 requires more than 2 bits") + } + return rs3<<27 | funct2<<25 | rs2<<20 | rs1<<15 | enc.funct3<<12 | funct3<<12 | rd<<7 | enc.opcode +} + func encodeRIII(ins *instruction) uint32 { return encodeR(ins.as, regI(ins.rs1), regI(ins.rs2), regI(ins.rd), ins.funct3, ins.funct7) } @@ -1325,6 +1361,10 @@ func encodeRFFF(ins *instruction) uint32 { return encodeR(ins.as, regF(ins.rs1), regF(ins.rs2), regF(ins.rd), ins.funct3, ins.funct7) } +func encodeRFFFF(ins *instruction) uint32 { + return encodeR4(ins.as, regF(ins.rs1), regF(ins.rs2), regF(ins.rs3), regF(ins.rd), ins.funct3, ins.funct7) +} + func encodeRFFI(ins *instruction) uint32 { return encodeR(ins.as, regF(ins.rs1), regF(ins.rs2), regI(ins.rd), ins.funct3, ins.funct7) } @@ -1462,12 +1502,13 @@ var ( // integer register inputs and an integer register output; sFEncoding // indicates an S-type instruction with rs2 being a float register. - rIIIEncoding = encoding{encode: encodeRIII, validate: validateRIII, length: 4} - rFFFEncoding = encoding{encode: encodeRFFF, validate: validateRFFF, length: 4} - rFFIEncoding = encoding{encode: encodeRFFI, validate: validateRFFI, length: 4} - rFIEncoding = encoding{encode: encodeRFI, validate: validateRFI, length: 4} - rIFEncoding = encoding{encode: encodeRIF, validate: validateRIF, length: 4} - rFFEncoding = encoding{encode: encodeRFF, validate: validateRFF, length: 4} + rIIIEncoding = encoding{encode: encodeRIII, validate: validateRIII, length: 4} + rFFFEncoding = encoding{encode: encodeRFFF, validate: validateRFFF, length: 4} + rFFFFEncoding = encoding{encode: encodeRFFFF, validate: validateRFFFF, length: 4} + rFFIEncoding = encoding{encode: encodeRFFI, validate: validateRFFI, length: 4} + rFIEncoding = encoding{encode: encodeRFI, validate: validateRFI, length: 4} + rIFEncoding = encoding{encode: encodeRIF, validate: validateRIF, length: 4} + rFFEncoding = encoding{encode: encodeRFF, validate: validateRFF, length: 4} iIEncoding = encoding{encode: encodeII, validate: validateII, length: 4} iFEncoding = encoding{encode: encodeIF, validate: validateIF, length: 4} @@ -1609,13 +1650,17 @@ var encodings = [ALAST & obj.AMask]encoding{ AFSW & obj.AMask: sFEncoding, // 11.6: Single-Precision Floating-Point Computational Instructions - AFADDS & obj.AMask: rFFFEncoding, - AFSUBS & obj.AMask: rFFFEncoding, - AFMULS & obj.AMask: rFFFEncoding, - AFDIVS & obj.AMask: rFFFEncoding, - AFMINS & obj.AMask: rFFFEncoding, - AFMAXS & obj.AMask: rFFFEncoding, - AFSQRTS & obj.AMask: rFFFEncoding, + AFADDS & obj.AMask: rFFFEncoding, + AFSUBS & obj.AMask: rFFFEncoding, + AFMULS & obj.AMask: rFFFEncoding, + AFDIVS & obj.AMask: rFFFEncoding, + AFMINS & obj.AMask: rFFFEncoding, + AFMAXS & obj.AMask: rFFFEncoding, + AFSQRTS & obj.AMask: rFFFEncoding, + AFMADDS & obj.AMask: rFFFFEncoding, + AFMSUBS & obj.AMask: rFFFFEncoding, + AFNMSUBS & obj.AMask: rFFFFEncoding, + AFNMADDS & obj.AMask: rFFFFEncoding, // 11.7: Single-Precision Floating-Point Conversion and Move Instructions AFCVTWS & obj.AMask: rFIEncoding, @@ -1647,13 +1692,17 @@ var encodings = [ALAST & obj.AMask]encoding{ AFSD & obj.AMask: sFEncoding, // 12.4: Double-Precision Floating-Point Computational Instructions - AFADDD & obj.AMask: rFFFEncoding, - AFSUBD & obj.AMask: rFFFEncoding, - AFMULD & obj.AMask: rFFFEncoding, - AFDIVD & obj.AMask: rFFFEncoding, - AFMIND & obj.AMask: rFFFEncoding, - AFMAXD & obj.AMask: rFFFEncoding, - AFSQRTD & obj.AMask: rFFFEncoding, + AFADDD & obj.AMask: rFFFEncoding, + AFSUBD & obj.AMask: rFFFEncoding, + AFMULD & obj.AMask: rFFFEncoding, + AFDIVD & obj.AMask: rFFFEncoding, + AFMIND & obj.AMask: rFFFEncoding, + AFMAXD & obj.AMask: rFFFEncoding, + AFSQRTD & obj.AMask: rFFFEncoding, + AFMADDD & obj.AMask: rFFFFEncoding, + AFMSUBD & obj.AMask: rFFFFEncoding, + AFNMSUBD & obj.AMask: rFFFFEncoding, + AFNMADDD & obj.AMask: rFFFFEncoding, // 12.5: Double-Precision Floating-Point Conversion and Move Instructions AFCVTWD & obj.AMask: rFIEncoding, @@ -1719,9 +1768,10 @@ type instruction struct { rd uint32 // Destination register rs1 uint32 // Source register 1 rs2 uint32 // Source register 2 + rs3 uint32 // Source register 3 imm int64 // Immediate funct3 uint32 // Function 3 - funct7 uint32 // Function 7 + funct7 uint32 // Function 7 (or Function 2) } func (ins *instruction) encode() (uint32, error) { @@ -1762,6 +1812,12 @@ func instructionsForProg(p *obj.Prog) []*instruction { imm: p.From.Offset, } + if len(p.RestArgs) == 1 { + ins.rs3 = uint32(p.RestArgs[0].Reg) + } else if len(p.RestArgs) > 0 { + p.Ctxt.Diag("too many source registers") + } + inss := []*instruction{ins} switch ins.as { case AJAL, AJALR: @@ -1899,6 +1955,12 @@ func instructionsForProg(p *obj.Prog) []*instruction { ins.rs1 = uint32(p.From.Reg) ins.rs2 = REG_F0 + case AFMADDS, AFMSUBS, AFNMADDS, AFNMSUBS, + AFMADDD, AFMSUBD, AFNMADDD, AFNMSUBD: + // Swap the first two operands so that the operands are in the same + // order as they are in the specification: RS1, RS2, RS3, RD. + ins.rs1, ins.rs2 = ins.rs2, ins.rs1 + case ANEG, ANEGW: // NEG rs, rd -> SUB rs, X0, rd ins.as = ASUB diff --git a/test/codegen/math.go b/test/codegen/math.go index 04cb4e577d..cd573db7b3 100644 --- a/test/codegen/math.go +++ b/test/codegen/math.go @@ -125,9 +125,25 @@ func fma(x, y, z float64) float64 { // s390x:"FMADD" // ppc64:"FMADD" // ppc64le:"FMADD" + // riscv64:"FMADDD" return math.FMA(x, y, z) } +func fms(x, y, z float64) float64 { + // riscv64:"FMSUBD" + return math.FMA(x, y, -z) +} + +func fnma(x, y, z float64) float64 { + // riscv64:"FNMADDD" + return math.FMA(-x, y, z) +} + +func fnms(x, y, z float64) float64 { + // riscv64:"FNMSUBD" + return math.FMA(x, -y, -z) +} + func fromFloat64(f64 float64) uint64 { // amd64:"MOVQ\tX.*, [^X].*" // arm64:"FMOVD\tF.*, R.*" -- 2.50.0