]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/{asm,compile}: add fused multiply-add support on riscv64
authorMichael Munday <mike.munday@lowrisc.org>
Wed, 17 Feb 2021 15:00:34 +0000 (15:00 +0000)
committerMichael Munday <mike.munday@lowrisc.org>
Wed, 1 Sep 2021 21:17:04 +0000 (21:17 +0000)
Add support to the assembler for F[N]M{ADD,SUB}[SD] instructions.
Argument order is:

  OP RS1, RS2, RS3, RD

Also, add support for the FMA intrinsic to the compiler. Automatic
FMA matching is left to a future CL.

Change-Id: I47166c7393b2ab6bfc2e42aa8c1a8997c3a071b3
Reviewed-on: https://go-review.googlesource.com/c/go/+/293030
Trust: Michael Munday <mike.munday@lowrisc.org>
Run-TryBot: Michael Munday <mike.munday@lowrisc.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Joel Sing <joel@sing.id.au>
src/cmd/asm/internal/asm/asm.go
src/cmd/asm/internal/asm/testdata/riscv64.s
src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/gen/RISCV64.rules
src/cmd/compile/internal/ssa/gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
src/cmd/compile/internal/ssagen/ssa.go
src/cmd/internal/obj/riscv/obj.go
test/codegen/math.go

index cf0d1550f99f0e3d6f014b1b4782553692348bb2..d0cb6328f16b6e6936b826a7e11e09b914549b76 100644 (file)
@@ -793,6 +793,13 @@ func (p *Parser) asmInstruction(op obj.As, cond string, a []obj.Addr) {
                                return
                        }
                }
+               if p.arch.Family == sys.RISCV64 {
+                       prog.From = a[0]
+                       prog.Reg = p.getRegister(prog, op, &a[1])
+                       prog.SetRestArgs([]obj.Addr{a[2]})
+                       prog.To = a[3]
+                       break
+               }
                if p.arch.Family == sys.S390X {
                        if a[1].Type != obj.TYPE_REG {
                                p.errorf("second operand must be a register in %s instruction", op)
index 628a8d91cd8afeb096c2dae6e7f82094e8ed6257..173c50f2e18a0a7540776947d6574febfa790022 100644 (file)
@@ -214,6 +214,10 @@ start:
        FMVSX   X5, F0                                  // 538002f0
        FMVXW   F0, X5                                  // d30200e0
        FMVWX   X5, F0                                  // 538002f0
+       FMADDS  F1, F2, F3, F4                          // 43822018
+       FMSUBS  F1, F2, F3, F4                          // 47822018
+       FNMSUBS F1, F2, F3, F4                          // 4b822018
+       FNMADDS F1, F2, F3, F4                          // 4f822018
 
        // 11.8: Single-Precision Floating-Point Compare Instructions
        FEQS    F0, F1, X7                              // d3a300a0
@@ -254,6 +258,10 @@ start:
        FSGNJXD F1, F0, F2                              // 53211022
        FMVXD   F0, X5                                  // d30200e2
        FMVDX   X5, F0                                  // 538002f2
+       FMADDD  F1, F2, F3, F4                          // 4382201a
+       FMSUBD  F1, F2, F3, F4                          // 4782201a
+       FNMSUBD F1, F2, F3, F4                          // 4b82201a
+       FNMADDD F1, F2, F3, F4                          // 4f82201a
 
        // 12.6: Double-Precision Floating-Point Classify Instruction
        FCLASSD F0, X5                                  // d31200e2
index d3cbb4ec24b17530e49e656484f66e5cc86b322c..30b6d96a893d34fb7ea9526b78403ef91ec65067 100644 (file)
@@ -317,7 +317,18 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                p2.From.Reg = v.Reg1()
                p2.To.Type = obj.TYPE_REG
                p2.To.Reg = v.Reg1()
-
+       case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD:
+               r := v.Reg()
+               r1 := v.Args[0].Reg()
+               r2 := v.Args[1].Reg()
+               r3 := v.Args[2].Reg()
+               p := s.Prog(v.Op.Asm())
+               p.From.Type = obj.TYPE_REG
+               p.From.Reg = r2
+               p.Reg = r1
+               p.SetRestArgs([]obj.Addr{{Type: obj.TYPE_REG, Reg: r3}})
+               p.To.Type = obj.TYPE_REG
+               p.To.Reg = r
        case ssa.OpRISCV64FSQRTS, ssa.OpRISCV64FNEGS, ssa.OpRISCV64FSQRTD, ssa.OpRISCV64FNEGD,
                ssa.OpRISCV64FMVSX, ssa.OpRISCV64FMVDX,
                ssa.OpRISCV64FCVTSW, ssa.OpRISCV64FCVTSL, ssa.OpRISCV64FCVTWS, ssa.OpRISCV64FCVTLS,
index 4eb48e3928f587f5537fd984e779eb8da2aa497a..b711550186d596c0b1dde02a62ba0cfcd016f233 100644 (file)
@@ -96,6 +96,8 @@
 (Sqrt ...) => (FSQRTD ...)
 (Sqrt32 ...) => (FSQRTS ...)
 
+(FMA ...) => (FMADDD ...)
+
 // Sign and zero extension.
 
 (SignExt8to16  ...) => (MOVBreg ...)
 
 // Addition of zero.
 (ADDI [0] x) => x
+
+// Merge negation into fused multiply-add and multiply-subtract.
+//
+// Key:
+//
+//   [+ -](x * y) [+ -] z.
+//    _ N          A S
+//                 D U
+//                 D B
+//
+// Note: multiplication commutativity handled by rule generator.
+(F(MADD|NMADD|MSUB|NMSUB)D neg:(FNEGD x) y z) && neg.Uses == 1 => (F(NMADD|MADD|NMSUB|MSUB)D x y z)
+(F(MADD|NMADD|MSUB|NMSUB)D x y neg:(FNEGD z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)D x y z)
index d36daa8b839a6f0fb73021a0d0b856f886b4bfe4..de189e4c601e199d855c2c57a0821b06c3c806ed 100644 (file)
@@ -132,6 +132,7 @@ func init() {
 
                fp11    = regInfo{inputs: []regMask{fpMask}, outputs: []regMask{fpMask}}
                fp21    = regInfo{inputs: []regMask{fpMask, fpMask}, outputs: []regMask{fpMask}}
+               fp31    = regInfo{inputs: []regMask{fpMask, fpMask, fpMask}, outputs: []regMask{fpMask}}
                gpfp    = regInfo{inputs: []regMask{gpMask}, outputs: []regMask{fpMask}}
                fpgp    = regInfo{inputs: []regMask{fpMask}, outputs: []regMask{gpMask}}
                fpstore = regInfo{inputs: []regMask{gpspsbMask, fpMask, 0}}
@@ -425,6 +426,10 @@ func init() {
                {name: "FSUBD", argLength: 2, reg: fp21, asm: "FSUBD", commutative: false, typ: "Float64"},                                          // arg0 - arg1
                {name: "FMULD", argLength: 2, reg: fp21, asm: "FMULD", commutative: true, typ: "Float64"},                                           // arg0 * arg1
                {name: "FDIVD", argLength: 2, reg: fp21, asm: "FDIVD", commutative: false, typ: "Float64"},                                          // arg0 / arg1
+               {name: "FMADDD", argLength: 3, reg: fp31, asm: "FMADDD", commutative: true, typ: "Float64"},                                         // (arg0 * arg1) + arg2
+               {name: "FMSUBD", argLength: 3, reg: fp31, asm: "FMSUBD", commutative: true, typ: "Float64"},                                         // (arg0 * arg1) - arg2
+               {name: "FNMADDD", argLength: 3, reg: fp31, asm: "FNMADDD", commutative: true, typ: "Float64"},                                       // -(arg0 * arg1) + arg2
+               {name: "FNMSUBD", argLength: 3, reg: fp31, asm: "FNMSUBD", commutative: true, typ: "Float64"},                                       // -(arg0 * arg1) - arg2
                {name: "FSQRTD", argLength: 1, reg: fp11, asm: "FSQRTD", typ: "Float64"},                                                            // sqrt(arg0)
                {name: "FNEGD", argLength: 1, reg: fp11, asm: "FNEGD", typ: "Float64"},                                                              // -arg0
                {name: "FMVDX", argLength: 1, reg: gpfp, asm: "FMVDX", typ: "Float64"},                                                              // reinterpret arg0 as float
index 737afc6087f0afd7bb782e6c7beb90ac06cb3ab5..672528aefeb1462adc48b9622b9490b73c316acb 100644 (file)
@@ -2174,6 +2174,10 @@ const (
        OpRISCV64FSUBD
        OpRISCV64FMULD
        OpRISCV64FDIVD
+       OpRISCV64FMADDD
+       OpRISCV64FMSUBD
+       OpRISCV64FNMADDD
+       OpRISCV64FNMSUBD
        OpRISCV64FSQRTD
        OpRISCV64FNEGD
        OpRISCV64FMVDX
@@ -29054,6 +29058,70 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:        "FMADDD",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFMADDD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FMSUBD",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFMSUBD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FNMADDD",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFNMADDD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:        "FNMSUBD",
+               argLen:      3,
+               commutative: true,
+               asm:         riscv.AFNMSUBD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
        {
                name:   "FSQRTD",
                argLen: 1,
index 641be038dba02800393c8835afcacfbf5c25a805..743ff50b0cf5ea92d7f47253ac320709d3339726 100644 (file)
@@ -209,6 +209,9 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpEqB(v)
        case OpEqPtr:
                return rewriteValueRISCV64_OpEqPtr(v)
+       case OpFMA:
+               v.Op = OpRISCV64FMADDD
+               return true
        case OpGetCallerPC:
                v.Op = OpRISCV64LoweredGetCallerPC
                return true
@@ -432,6 +435,14 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpRISCV64ADDI(v)
        case OpRISCV64AND:
                return rewriteValueRISCV64_OpRISCV64AND(v)
+       case OpRISCV64FMADDD:
+               return rewriteValueRISCV64_OpRISCV64FMADDD(v)
+       case OpRISCV64FMSUBD:
+               return rewriteValueRISCV64_OpRISCV64FMSUBD(v)
+       case OpRISCV64FNMADDD:
+               return rewriteValueRISCV64_OpRISCV64FNMADDD(v)
+       case OpRISCV64FNMSUBD:
+               return rewriteValueRISCV64_OpRISCV64FNMSUBD(v)
        case OpRISCV64MOVBUload:
                return rewriteValueRISCV64_OpRISCV64MOVBUload(v)
        case OpRISCV64MOVBUreg:
@@ -2829,6 +2840,186 @@ func rewriteValueRISCV64_OpRISCV64AND(v *Value) bool {
        }
        return false
 }
+func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FMADDD neg:(FNEGD x) y z)
+       // cond: neg.Uses == 1
+       // result: (FNMADDD x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGD {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FNMADDD)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FMADDD x y neg:(FNEGD z))
+       // cond: neg.Uses == 1
+       // result: (FMSUBD x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGD {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FMSUBD)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
+func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FMSUBD neg:(FNEGD x) y z)
+       // cond: neg.Uses == 1
+       // result: (FNMSUBD x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGD {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FNMSUBD)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FMSUBD x y neg:(FNEGD z))
+       // cond: neg.Uses == 1
+       // result: (FMADDD x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGD {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FMADDD)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
+func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FNMADDD neg:(FNEGD x) y z)
+       // cond: neg.Uses == 1
+       // result: (FMADDD x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGD {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMADDD)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FNMADDD x y neg:(FNEGD z))
+       // cond: neg.Uses == 1
+       // result: (FNMSUBD x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGD {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FNMSUBD)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
+func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool {
+       v_2 := v.Args[2]
+       v_1 := v.Args[1]
+       v_0 := v.Args[0]
+       // match: (FNMSUBD neg:(FNEGD x) y z)
+       // cond: neg.Uses == 1
+       // result: (FMSUBD x y z)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       neg := v_0
+                       if neg.Op != OpRISCV64FNEGD {
+                               continue
+                       }
+                       x := neg.Args[0]
+                       y := v_1
+                       z := v_2
+                       if !(neg.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpRISCV64FMSUBD)
+                       v.AddArg3(x, y, z)
+                       return true
+               }
+               break
+       }
+       // match: (FNMSUBD x y neg:(FNEGD z))
+       // cond: neg.Uses == 1
+       // result: (FNMADDD x y z)
+       for {
+               x := v_0
+               y := v_1
+               neg := v_2
+               if neg.Op != OpRISCV64FNEGD {
+                       break
+               }
+               z := neg.Args[0]
+               if !(neg.Uses == 1) {
+                       break
+               }
+               v.reset(OpRISCV64FNMADDD)
+               v.AddArg3(x, y, z)
+               return true
+       }
+       return false
+}
 func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
index a64901305fe72587b444ab574256e9a28658427a..176e6438dc5c8dc2345a75787a09e5f28afb70da 100644 (file)
@@ -4168,7 +4168,7 @@ func InitTables() {
                func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
                        return s.newValue3(ssa.OpFMA, types.Types[types.TFLOAT64], args[0], args[1], args[2])
                },
-               sys.ARM64, sys.PPC64, sys.S390X)
+               sys.ARM64, sys.PPC64, sys.RISCV64, sys.S390X)
        addF("math", "FMA",
                func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
                        if !s.config.UseFMA {
index f89e13d81c95b817b39ea598bce3b6652766addc..73f62c007dd8ac7a4008f809f7877a2332b4a52e 100644 (file)
@@ -1214,60 +1214,77 @@ func validateRIII(ctxt *obj.Link, ins *instruction) {
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantIntReg(ctxt, ins.as, "rs1", ins.rs1)
        wantIntReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRFFF(ctxt *obj.Link, ins *instruction) {
        wantFloatReg(ctxt, ins.as, "rd", ins.rd)
        wantFloatReg(ctxt, ins.as, "rs1", ins.rs1)
        wantFloatReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
+}
+
+func validateRFFFF(ctxt *obj.Link, ins *instruction) {
+       wantFloatReg(ctxt, ins.as, "rd", ins.rd)
+       wantFloatReg(ctxt, ins.as, "rs1", ins.rs1)
+       wantFloatReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantFloatReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRFFI(ctxt *obj.Link, ins *instruction) {
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantFloatReg(ctxt, ins.as, "rs1", ins.rs1)
        wantFloatReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRFI(ctxt *obj.Link, ins *instruction) {
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantNoneReg(ctxt, ins.as, "rs1", ins.rs1)
        wantFloatReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRIF(ctxt *obj.Link, ins *instruction) {
        wantFloatReg(ctxt, ins.as, "rd", ins.rd)
        wantNoneReg(ctxt, ins.as, "rs1", ins.rs1)
        wantIntReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRFF(ctxt *obj.Link, ins *instruction) {
        wantFloatReg(ctxt, ins.as, "rd", ins.rd)
        wantNoneReg(ctxt, ins.as, "rs1", ins.rs1)
        wantFloatReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateII(ctxt *obj.Link, ins *instruction) {
        wantImmI(ctxt, ins.as, ins.imm, 12)
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantIntReg(ctxt, ins.as, "rs1", ins.rs1)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateIF(ctxt *obj.Link, ins *instruction) {
        wantImmI(ctxt, ins.as, ins.imm, 12)
        wantFloatReg(ctxt, ins.as, "rd", ins.rd)
        wantIntReg(ctxt, ins.as, "rs1", ins.rs1)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateSI(ctxt *obj.Link, ins *instruction) {
        wantImmI(ctxt, ins.as, ins.imm, 12)
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantIntReg(ctxt, ins.as, "rs1", ins.rs1)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateSF(ctxt *obj.Link, ins *instruction) {
        wantImmI(ctxt, ins.as, ins.imm, 12)
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantFloatReg(ctxt, ins.as, "rs1", ins.rs1)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateB(ctxt *obj.Link, ins *instruction) {
@@ -1278,6 +1295,7 @@ func validateB(ctxt *obj.Link, ins *instruction) {
        wantNoneReg(ctxt, ins.as, "rd", ins.rd)
        wantIntReg(ctxt, ins.as, "rs1", ins.rs1)
        wantIntReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateU(ctxt *obj.Link, ins *instruction) {
@@ -1285,6 +1303,7 @@ func validateU(ctxt *obj.Link, ins *instruction) {
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantNoneReg(ctxt, ins.as, "rs1", ins.rs1)
        wantNoneReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateJ(ctxt *obj.Link, ins *instruction) {
@@ -1295,6 +1314,7 @@ func validateJ(ctxt *obj.Link, ins *instruction) {
        wantIntReg(ctxt, ins.as, "rd", ins.rd)
        wantNoneReg(ctxt, ins.as, "rs1", ins.rs1)
        wantNoneReg(ctxt, ins.as, "rs2", ins.rs2)
+       wantNoneReg(ctxt, ins.as, "rs3", ins.rs3)
 }
 
 func validateRaw(ctxt *obj.Link, ins *instruction) {
@@ -1317,6 +1337,22 @@ func encodeR(as obj.As, rs1, rs2, rd, funct3, funct7 uint32) uint32 {
        return funct7<<25 | enc.funct7<<25 | enc.rs2<<20 | rs2<<20 | rs1<<15 | enc.funct3<<12 | funct3<<12 | rd<<7 | enc.opcode
 }
 
+// encodeR4 encodes an R4-type RISC-V instruction.
+func encodeR4(as obj.As, rs1, rs2, rs3, rd, funct3, funct2 uint32) uint32 {
+       enc := encode(as)
+       if enc == nil {
+               panic("encodeR4: could not encode instruction")
+       }
+       if enc.rs2 != 0 {
+               panic("encodeR4: instruction uses rs2")
+       }
+       funct2 |= enc.funct7
+       if funct2&^3 != 0 {
+               panic("encodeR4: funct2 requires more than 2 bits")
+       }
+       return rs3<<27 | funct2<<25 | rs2<<20 | rs1<<15 | enc.funct3<<12 | funct3<<12 | rd<<7 | enc.opcode
+}
+
 func encodeRIII(ins *instruction) uint32 {
        return encodeR(ins.as, regI(ins.rs1), regI(ins.rs2), regI(ins.rd), ins.funct3, ins.funct7)
 }
@@ -1325,6 +1361,10 @@ func encodeRFFF(ins *instruction) uint32 {
        return encodeR(ins.as, regF(ins.rs1), regF(ins.rs2), regF(ins.rd), ins.funct3, ins.funct7)
 }
 
+func encodeRFFFF(ins *instruction) uint32 {
+       return encodeR4(ins.as, regF(ins.rs1), regF(ins.rs2), regF(ins.rs3), regF(ins.rd), ins.funct3, ins.funct7)
+}
+
 func encodeRFFI(ins *instruction) uint32 {
        return encodeR(ins.as, regF(ins.rs1), regF(ins.rs2), regI(ins.rd), ins.funct3, ins.funct7)
 }
@@ -1462,12 +1502,13 @@ var (
        // integer register inputs and an integer register output; sFEncoding
        // indicates an S-type instruction with rs2 being a float register.
 
-       rIIIEncoding = encoding{encode: encodeRIII, validate: validateRIII, length: 4}
-       rFFFEncoding = encoding{encode: encodeRFFF, validate: validateRFFF, length: 4}
-       rFFIEncoding = encoding{encode: encodeRFFI, validate: validateRFFI, length: 4}
-       rFIEncoding  = encoding{encode: encodeRFI, validate: validateRFI, length: 4}
-       rIFEncoding  = encoding{encode: encodeRIF, validate: validateRIF, length: 4}
-       rFFEncoding  = encoding{encode: encodeRFF, validate: validateRFF, length: 4}
+       rIIIEncoding  = encoding{encode: encodeRIII, validate: validateRIII, length: 4}
+       rFFFEncoding  = encoding{encode: encodeRFFF, validate: validateRFFF, length: 4}
+       rFFFFEncoding = encoding{encode: encodeRFFFF, validate: validateRFFFF, length: 4}
+       rFFIEncoding  = encoding{encode: encodeRFFI, validate: validateRFFI, length: 4}
+       rFIEncoding   = encoding{encode: encodeRFI, validate: validateRFI, length: 4}
+       rIFEncoding   = encoding{encode: encodeRIF, validate: validateRIF, length: 4}
+       rFFEncoding   = encoding{encode: encodeRFF, validate: validateRFF, length: 4}
 
        iIEncoding = encoding{encode: encodeII, validate: validateII, length: 4}
        iFEncoding = encoding{encode: encodeIF, validate: validateIF, length: 4}
@@ -1609,13 +1650,17 @@ var encodings = [ALAST & obj.AMask]encoding{
        AFSW & obj.AMask: sFEncoding,
 
        // 11.6: Single-Precision Floating-Point Computational Instructions
-       AFADDS & obj.AMask:  rFFFEncoding,
-       AFSUBS & obj.AMask:  rFFFEncoding,
-       AFMULS & obj.AMask:  rFFFEncoding,
-       AFDIVS & obj.AMask:  rFFFEncoding,
-       AFMINS & obj.AMask:  rFFFEncoding,
-       AFMAXS & obj.AMask:  rFFFEncoding,
-       AFSQRTS & obj.AMask: rFFFEncoding,
+       AFADDS & obj.AMask:   rFFFEncoding,
+       AFSUBS & obj.AMask:   rFFFEncoding,
+       AFMULS & obj.AMask:   rFFFEncoding,
+       AFDIVS & obj.AMask:   rFFFEncoding,
+       AFMINS & obj.AMask:   rFFFEncoding,
+       AFMAXS & obj.AMask:   rFFFEncoding,
+       AFSQRTS & obj.AMask:  rFFFEncoding,
+       AFMADDS & obj.AMask:  rFFFFEncoding,
+       AFMSUBS & obj.AMask:  rFFFFEncoding,
+       AFNMSUBS & obj.AMask: rFFFFEncoding,
+       AFNMADDS & obj.AMask: rFFFFEncoding,
 
        // 11.7: Single-Precision Floating-Point Conversion and Move Instructions
        AFCVTWS & obj.AMask:  rFIEncoding,
@@ -1647,13 +1692,17 @@ var encodings = [ALAST & obj.AMask]encoding{
        AFSD & obj.AMask: sFEncoding,
 
        // 12.4: Double-Precision Floating-Point Computational Instructions
-       AFADDD & obj.AMask:  rFFFEncoding,
-       AFSUBD & obj.AMask:  rFFFEncoding,
-       AFMULD & obj.AMask:  rFFFEncoding,
-       AFDIVD & obj.AMask:  rFFFEncoding,
-       AFMIND & obj.AMask:  rFFFEncoding,
-       AFMAXD & obj.AMask:  rFFFEncoding,
-       AFSQRTD & obj.AMask: rFFFEncoding,
+       AFADDD & obj.AMask:   rFFFEncoding,
+       AFSUBD & obj.AMask:   rFFFEncoding,
+       AFMULD & obj.AMask:   rFFFEncoding,
+       AFDIVD & obj.AMask:   rFFFEncoding,
+       AFMIND & obj.AMask:   rFFFEncoding,
+       AFMAXD & obj.AMask:   rFFFEncoding,
+       AFSQRTD & obj.AMask:  rFFFEncoding,
+       AFMADDD & obj.AMask:  rFFFFEncoding,
+       AFMSUBD & obj.AMask:  rFFFFEncoding,
+       AFNMSUBD & obj.AMask: rFFFFEncoding,
+       AFNMADDD & obj.AMask: rFFFFEncoding,
 
        // 12.5: Double-Precision Floating-Point Conversion and Move Instructions
        AFCVTWD & obj.AMask:  rFIEncoding,
@@ -1719,9 +1768,10 @@ type instruction struct {
        rd     uint32 // Destination register
        rs1    uint32 // Source register 1
        rs2    uint32 // Source register 2
+       rs3    uint32 // Source register 3
        imm    int64  // Immediate
        funct3 uint32 // Function 3
-       funct7 uint32 // Function 7
+       funct7 uint32 // Function 7 (or Function 2)
 }
 
 func (ins *instruction) encode() (uint32, error) {
@@ -1762,6 +1812,12 @@ func instructionsForProg(p *obj.Prog) []*instruction {
                imm: p.From.Offset,
        }
 
+       if len(p.RestArgs) == 1 {
+               ins.rs3 = uint32(p.RestArgs[0].Reg)
+       } else if len(p.RestArgs) > 0 {
+               p.Ctxt.Diag("too many source registers")
+       }
+
        inss := []*instruction{ins}
        switch ins.as {
        case AJAL, AJALR:
@@ -1899,6 +1955,12 @@ func instructionsForProg(p *obj.Prog) []*instruction {
                ins.rs1 = uint32(p.From.Reg)
                ins.rs2 = REG_F0
 
+       case AFMADDS, AFMSUBS, AFNMADDS, AFNMSUBS,
+               AFMADDD, AFMSUBD, AFNMADDD, AFNMSUBD:
+               // Swap the first two operands so that the operands are in the same
+               // order as they are in the specification: RS1, RS2, RS3, RD.
+               ins.rs1, ins.rs2 = ins.rs2, ins.rs1
+
        case ANEG, ANEGW:
                // NEG rs, rd -> SUB rs, X0, rd
                ins.as = ASUB
index 04cb4e577da6acd6fe0cbd540a6ac93eca0aee7a..cd573db7b32e879ca24898f8597b0e2101a17b29 100644 (file)
@@ -125,9 +125,25 @@ func fma(x, y, z float64) float64 {
        // s390x:"FMADD"
        // ppc64:"FMADD"
        // ppc64le:"FMADD"
+       // riscv64:"FMADDD"
        return math.FMA(x, y, z)
 }
 
+func fms(x, y, z float64) float64 {
+       // riscv64:"FMSUBD"
+       return math.FMA(x, y, -z)
+}
+
+func fnma(x, y, z float64) float64 {
+       // riscv64:"FNMADDD"
+       return math.FMA(-x, y, z)
+}
+
+func fnms(x, y, z float64) float64 {
+       // riscv64:"FNMSUBD"
+       return math.FMA(x, -y, -z)
+}
+
 func fromFloat64(f64 float64) uint64 {
        // amd64:"MOVQ\tX.*, [^X].*"
        // arm64:"FMOVD\tF.*, R.*"