From: Meng Zhuo Date: Wed, 28 Jun 2023 08:45:07 +0000 (+0800) Subject: cmd/compile: add single-precision FMA code generation for riscv64 X-Git-Tag: go1.22rc1~1158 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=63ab68ddc5f1307e552cf27ae7a6f0dfda2bb962;p=gostls13.git cmd/compile: add single-precision FMA code generation for riscv64 This CL adds FMADDS,FMSUBS,FNMADDS,FNMSUBS SSA support for riscv Change-Id: I1e7dd322b46b9e0f4923dbba256303d69ed12066 Reviewed-on: https://go-review.googlesource.com/c/go/+/506616 Reviewed-by: Joel Sing Reviewed-by: David Chase TryBot-Result: Gopher Robot Reviewed-by: Keith Randall Run-TryBot: M Zhuo --- diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 143e7c525a..f8cf786920 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -332,7 +332,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p2.From.Reg = v.Reg1() p2.To.Type = obj.TYPE_REG p2.To.Reg = v.Reg1() - case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD: + case ssa.OpRISCV64FMADDD, ssa.OpRISCV64FMSUBD, ssa.OpRISCV64FNMADDD, ssa.OpRISCV64FNMSUBD, + ssa.OpRISCV64FMADDS, ssa.OpRISCV64FMSUBS, ssa.OpRISCV64FNMADDS, ssa.OpRISCV64FNMSUBS: r := v.Reg() r1 := v.Args[0].Reg() r2 := v.Args[1].Reg() diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index ac68dfed76..e0bf00d45d 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -780,9 +780,10 @@ (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y) (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y) -(FADDD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FMADDD x y a) -(FSUBD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FNMSUBD x y a) -(FSUBD (FMULD x y) a) && a.Block.Func.useFMA(v) => (FMSUBD x y a) +(FADD(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FMADD(S|D) x y a) +(FSUB(S|D) a (FMUL(S|D) x y)) && a.Block.Func.useFMA(v) => (FNMSUB(S|D) x y a) +(FSUB(S|D) (FMUL(S|D) x y) a) && a.Block.Func.useFMA(v) => (FMSUB(S|D) x y a) + // Merge negation into fused multiply-add and multiply-subtract. // // Key: @@ -793,5 +794,7 @@ // D B // // Note: multiplication commutativity handled by rule generator. +(F(MADD|NMADD|MSUB|NMSUB)S neg:(FNEGS x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)S x y z) +(F(MADD|NMADD|MSUB|NMSUB)S x y neg:(FNEGS z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)S x y z) (F(MADD|NMADD|MSUB|NMSUB)D neg:(FNEGD x) y z) && neg.Uses == 1 => (F(NMSUB|MSUB|NMADD|MADD)D x y z) (F(MADD|NMADD|MSUB|NMSUB)D x y neg:(FNEGD z)) && neg.Uses == 1 => (F(MSUB|NMSUB|MADD|NMADD)D x y z) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 69f2950a88..317e9150c9 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -411,6 +411,10 @@ func init() { {name: "FSUBS", argLength: 2, reg: fp21, asm: "FSUBS", commutative: false, typ: "Float32"}, // arg0 - arg1 {name: "FMULS", argLength: 2, reg: fp21, asm: "FMULS", commutative: true, typ: "Float32"}, // arg0 * arg1 {name: "FDIVS", argLength: 2, reg: fp21, asm: "FDIVS", commutative: false, typ: "Float32"}, // arg0 / arg1 + {name: "FMADDS", argLength: 3, reg: fp31, asm: "FMADDS", commutative: true, typ: "Float32"}, // (arg0 * arg1) + arg2 + {name: "FMSUBS", argLength: 3, reg: fp31, asm: "FMSUBS", commutative: true, typ: "Float32"}, // (arg0 * arg1) - arg2 + {name: "FNMADDS", argLength: 3, reg: fp31, asm: "FNMADDS", commutative: true, typ: "Float32"}, // -(arg0 * arg1) + arg2 + {name: "FNMSUBS", argLength: 3, reg: fp31, asm: "FNMSUBS", commutative: true, typ: "Float32"}, // -(arg0 * arg1) - arg2 {name: "FSQRTS", argLength: 1, reg: fp11, asm: "FSQRTS", typ: "Float32"}, // sqrt(arg0) {name: "FNEGS", argLength: 1, reg: fp11, asm: "FNEGS", typ: "Float32"}, // -arg0 {name: "FMVSX", argLength: 1, reg: gpfp, asm: "FMVSX", typ: "Float32"}, // reinterpret arg0 as float diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 12d8214ae1..11a6138357 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2436,6 +2436,10 @@ const ( OpRISCV64FSUBS OpRISCV64FMULS OpRISCV64FDIVS + OpRISCV64FMADDS + OpRISCV64FMSUBS + OpRISCV64FNMADDS + OpRISCV64FNMSUBS OpRISCV64FSQRTS OpRISCV64FNEGS OpRISCV64FMVSX @@ -32673,6 +32677,70 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "FMADDS", + argLen: 3, + commutative: true, + asm: riscv.AFMADDS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FMSUBS", + argLen: 3, + commutative: true, + asm: riscv.AFMSUBS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FNMADDS", + argLen: 3, + commutative: true, + asm: riscv.AFNMADDS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "FNMSUBS", + argLen: 3, + commutative: true, + asm: riscv.AFNMSUBS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + {2, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, { name: "FSQRTS", argLen: 1, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 17af023db3..0ad6433bf4 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -442,16 +442,28 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64ANDI(v) case OpRISCV64FADDD: return rewriteValueRISCV64_OpRISCV64FADDD(v) + case OpRISCV64FADDS: + return rewriteValueRISCV64_OpRISCV64FADDS(v) case OpRISCV64FMADDD: return rewriteValueRISCV64_OpRISCV64FMADDD(v) + case OpRISCV64FMADDS: + return rewriteValueRISCV64_OpRISCV64FMADDS(v) case OpRISCV64FMSUBD: return rewriteValueRISCV64_OpRISCV64FMSUBD(v) + case OpRISCV64FMSUBS: + return rewriteValueRISCV64_OpRISCV64FMSUBS(v) case OpRISCV64FNMADDD: return rewriteValueRISCV64_OpRISCV64FNMADDD(v) + case OpRISCV64FNMADDS: + return rewriteValueRISCV64_OpRISCV64FNMADDS(v) case OpRISCV64FNMSUBD: return rewriteValueRISCV64_OpRISCV64FNMSUBD(v) + case OpRISCV64FNMSUBS: + return rewriteValueRISCV64_OpRISCV64FNMSUBS(v) case OpRISCV64FSUBD: return rewriteValueRISCV64_OpRISCV64FSUBD(v) + case OpRISCV64FSUBS: + return rewriteValueRISCV64_OpRISCV64FSUBS(v) case OpRISCV64MOVBUload: return rewriteValueRISCV64_OpRISCV64MOVBUload(v) case OpRISCV64MOVBUreg: @@ -3364,6 +3376,31 @@ func rewriteValueRISCV64_OpRISCV64FADDD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FADDS(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FADDS a (FMULS x y)) + // cond: a.Block.Func.useFMA(v) + // result: (FMADDS x y a) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + if v_1.Op != OpRISCV64FMULS { + continue + } + y := v_1.Args[1] + x := v_1.Args[0] + if !(a.Block.Func.useFMA(v)) { + continue + } + v.reset(OpRISCV64FMADDS) + v.AddArg3(x, y, a) + return true + } + break + } + return false +} func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -3409,6 +3446,51 @@ func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FMADDS(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FMADDS neg:(FNEGS x) y z) + // cond: neg.Uses == 1 + // result: (FNMSUBS x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGS { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FNMSUBS) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FMADDS x y neg:(FNEGS z)) + // cond: neg.Uses == 1 + // result: (FMSUBS x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGS { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FMSUBS) + v.AddArg3(x, y, z) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -3454,6 +3536,51 @@ func rewriteValueRISCV64_OpRISCV64FMSUBD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FMSUBS(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FMSUBS neg:(FNEGS x) y z) + // cond: neg.Uses == 1 + // result: (FNMADDS x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGS { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FNMADDS) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FMSUBS x y neg:(FNEGS z)) + // cond: neg.Uses == 1 + // result: (FMADDS x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGS { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FMADDS) + v.AddArg3(x, y, z) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -3499,6 +3626,51 @@ func rewriteValueRISCV64_OpRISCV64FNMADDD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FNMADDS(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FNMADDS neg:(FNEGS x) y z) + // cond: neg.Uses == 1 + // result: (FMSUBS x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGS { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FMSUBS) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FNMADDS x y neg:(FNEGS z)) + // cond: neg.Uses == 1 + // result: (FNMSUBS x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGS { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FNMSUBS) + v.AddArg3(x, y, z) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -3544,6 +3716,51 @@ func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FNMSUBS(v *Value) bool { + v_2 := v.Args[2] + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FNMSUBS neg:(FNEGS x) y z) + // cond: neg.Uses == 1 + // result: (FMADDS x y z) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + neg := v_0 + if neg.Op != OpRISCV64FNEGS { + continue + } + x := neg.Args[0] + y := v_1 + z := v_2 + if !(neg.Uses == 1) { + continue + } + v.reset(OpRISCV64FMADDS) + v.AddArg3(x, y, z) + return true + } + break + } + // match: (FNMSUBS x y neg:(FNEGS z)) + // cond: neg.Uses == 1 + // result: (FNMADDS x y z) + for { + x := v_0 + y := v_1 + neg := v_2 + if neg.Op != OpRISCV64FNEGS { + break + } + z := neg.Args[0] + if !(neg.Uses == 1) { + break + } + v.reset(OpRISCV64FNMADDS) + v.AddArg3(x, y, z) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -3583,6 +3800,45 @@ func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FSUBS(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FSUBS a (FMULS x y)) + // cond: a.Block.Func.useFMA(v) + // result: (FNMSUBS x y a) + for { + a := v_0 + if v_1.Op != OpRISCV64FMULS { + break + } + y := v_1.Args[1] + x := v_1.Args[0] + if !(a.Block.Func.useFMA(v)) { + break + } + v.reset(OpRISCV64FNMSUBS) + v.AddArg3(x, y, a) + return true + } + // match: (FSUBS (FMULS x y) a) + // cond: a.Block.Func.useFMA(v) + // result: (FMSUBS x y a) + for { + if v_0.Op != OpRISCV64FMULS { + break + } + y := v_0.Args[1] + x := v_0.Args[0] + a := v_1 + if !(a.Block.Func.useFMA(v)) { + break + } + v.reset(OpRISCV64FMSUBS) + v.AddArg3(x, y, a) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/test/codegen/floats.go b/test/codegen/floats.go index 1c5fc8a31a..7991174b66 100644 --- a/test/codegen/floats.go +++ b/test/codegen/floats.go @@ -70,17 +70,20 @@ func FusedAdd32(x, y, z float32) float32 { // s390x:"FMADDS\t" // ppc64x:"FMADDS\t" // arm64:"FMADDS" + // riscv64:"FMADDS\t" return x*y + z } func FusedSub32_a(x, y, z float32) float32 { // s390x:"FMSUBS\t" // ppc64x:"FMSUBS\t" + // riscv64:"FMSUBS\t" return x*y - z } func FusedSub32_b(x, y, z float32) float32 { // arm64:"FMSUBS" + // riscv64:"FNMSUBS\t" return z - x*y }