From 43b7e670401401b2e7536b4931df8b29a25994c7 Mon Sep 17 00:00:00 2001 From: Jakub Ciolek Date: Sun, 2 Feb 2025 23:42:43 +0100 Subject: [PATCH] cmd/compile: lower x*z + y to FMA if FMA enabled MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit There is a generic opcode for FMA, but we don't use it in rewrite rules. This is maybe because some archs, like WASM and MIPS don't have a late lowering rule for it. Fixes #71204 Intel Alder Lake 12600k (GOAMD64=v3): math: name old time/op new time/op delta Acos-16 4.58ns ± 0% 3.36ns ± 0% -26.68% (p=0.008 n=5+5) Acosh-16 8.04ns ± 1% 6.44ns ± 0% -19.95% (p=0.008 n=5+5) Asin-16 4.28ns ± 0% 3.32ns ± 0% -22.24% (p=0.008 n=5+5) Asinh-16 9.92ns ± 0% 8.62ns ± 0% -13.13% (p=0.008 n=5+5) Atan-16 2.31ns ± 0% 1.84ns ± 0% -20.02% (p=0.008 n=5+5) Atanh-16 7.79ns ± 0% 7.03ns ± 0% -9.67% (p=0.008 n=5+5) Atan2-16 3.93ns ± 0% 3.52ns ± 0% -10.35% (p=0.000 n=5+4) Cbrt-16 4.62ns ± 0% 4.41ns ± 0% -4.57% (p=0.016 n=4+5) Ceil-16 0.14ns ± 1% 0.14ns ± 2% ~ (p=0.103 n=5+5) Copysign-16 0.33ns ± 0% 0.33ns ± 0% +0.03% (p=0.029 n=4+4) Cos-16 4.87ns ± 0% 4.75ns ± 0% -2.44% (p=0.016 n=5+4) Cosh-16 4.86ns ± 0% 4.86ns ± 0% ~ (p=0.317 n=5+5) Erf-16 2.71ns ± 0% 2.25ns ± 0% -16.69% (p=0.008 n=5+5) Erfc-16 3.06ns ± 0% 2.67ns ± 0% -13.00% (p=0.016 n=5+4) Erfinv-16 3.88ns ± 0% 2.84ns ± 3% -26.83% (p=0.008 n=5+5) Erfcinv-16 4.08ns ± 0% 3.01ns ± 1% -26.27% (p=0.008 n=5+5) Exp-16 3.29ns ± 0% 3.37ns ± 2% +2.64% (p=0.016 n=4+5) ExpGo-16 8.44ns ± 0% 7.48ns ± 1% -11.37% (p=0.008 n=5+5) Expm1-16 4.46ns ± 0% 3.69ns ± 2% -17.26% (p=0.016 n=4+5) Exp2-16 8.20ns ± 0% 7.39ns ± 2% -9.94% (p=0.008 n=5+5) Exp2Go-16 8.26ns ± 0% 7.23ns ± 0% -12.49% (p=0.016 n=4+5) Abs-16 0.26ns ± 3% 0.22ns ± 1% -16.34% (p=0.008 n=5+5) Dim-16 0.38ns ± 1% 0.40ns ± 2% +5.02% (p=0.008 n=5+5) Floor-16 0.11ns ± 1% 0.17ns ± 4% +54.99% (p=0.008 n=5+5) Max-16 1.24ns ± 0% 1.24ns ± 0% ~ (p=0.619 n=5+5) Min-16 1.24ns ± 0% 1.24ns ± 0% ~ (p=0.484 n=5+5) Mod-16 13.4ns ± 1% 12.8ns ± 0% -4.21% (p=0.016 n=5+4) Frexp-16 1.70ns ± 0% 1.71ns ± 0% +0.46% (p=0.008 n=5+5) Gamma-16 3.97ns ± 0% 3.97ns ± 0% ~ (p=0.643 n=5+5) Hypot-16 2.11ns ± 0% 2.11ns ± 0% ~ (p=0.762 n=5+5) HypotGo-16 2.48ns ± 4% 2.26ns ± 0% -8.94% (p=0.008 n=5+5) Ilogb-16 1.67ns ± 0% 1.67ns ± 0% -0.07% (p=0.048 n=5+5) J0-16 19.8ns ± 0% 19.3ns ± 0% ~ (p=0.079 n=4+5) J1-16 19.4ns ± 0% 18.9ns ± 0% -2.63% (p=0.000 n=5+4) Jn-16 41.5ns ± 0% 40.6ns ± 0% -2.32% (p=0.016 n=4+5) Ldexp-16 2.26ns ± 0% 2.26ns ± 0% ~ (p=0.683 n=5+5) Lgamma-16 4.40ns ± 0% 4.21ns ± 0% -4.21% (p=0.008 n=5+5) Log-16 4.05ns ± 0% 4.05ns ± 0% ~ (all equal) Logb-16 1.69ns ± 0% 1.69ns ± 0% ~ (p=0.429 n=5+5) Log1p-16 5.00ns ± 0% 3.99ns ± 0% -20.14% (p=0.008 n=5+5) Log10-16 4.22ns ± 0% 4.21ns ± 0% -0.15% (p=0.008 n=5+5) Log2-16 2.27ns ± 0% 2.25ns ± 0% -0.94% (p=0.008 n=5+5) Modf-16 1.44ns ± 0% 1.44ns ± 0% ~ (p=0.492 n=5+5) Nextafter32-16 2.09ns ± 0% 2.09ns ± 0% ~ (p=0.079 n=4+5) Nextafter64-16 2.09ns ± 0% 2.09ns ± 0% ~ (p=0.095 n=4+5) PowInt-16 10.8ns ± 0% 10.8ns ± 0% ~ (all equal) PowFrac-16 25.3ns ± 0% 25.3ns ± 0% -0.09% (p=0.000 n=5+4) Pow10Pos-16 0.52ns ± 1% 0.52ns ± 0% ~ (p=0.810 n=5+5) Pow10Neg-16 0.82ns ± 0% 0.82ns ± 0% ~ (p=0.381 n=5+5) Round-16 0.93ns ± 0% 0.93ns ± 0% ~ (p=0.056 n=5+5) RoundToEven-16 1.64ns ± 0% 1.64ns ± 0% ~ (all equal) Remainder-16 12.4ns ± 2% 12.0ns ± 0% -3.27% (p=0.008 n=5+5) Signbit-16 0.37ns ± 0% 0.37ns ± 0% -0.19% (p=0.008 n=5+5) Sin-16 4.04ns ± 0% 3.92ns ± 0% -3.13% (p=0.000 n=4+5) Sincos-16 5.99ns ± 0% 5.80ns ± 0% -3.03% (p=0.008 n=5+5) Sinh-16 5.22ns ± 0% 5.22ns ± 0% ~ (p=0.651 n=5+4) SqrtIndirect-16 0.41ns ± 0% 0.41ns ± 0% ~ (p=0.333 n=4+5) SqrtLatency-16 2.66ns ± 0% 2.66ns ± 0% ~ (p=0.079 n=4+5) SqrtIndirectLatency-16 2.66ns ± 0% 2.66ns ± 0% ~ (p=1.000 n=5+5) SqrtGoLatency-16 30.1ns ± 0% 28.6ns ± 1% -4.84% (p=0.008 n=5+5) SqrtPrime-16 645ns ± 0% 645ns ± 0% ~ (p=0.095 n=5+4) Tan-16 4.21ns ± 0% 4.09ns ± 0% -2.76% (p=0.029 n=4+4) Tanh-16 5.36ns ± 0% 5.36ns ± 0% ~ (p=0.444 n=5+5) Trunc-16 0.12ns ± 6% 0.11ns ± 1% -6.79% (p=0.008 n=5+5) Y0-16 19.2ns ± 0% 18.7ns ± 0% -2.52% (p=0.000 n=5+4) Y1-16 19.1ns ± 0% 18.4ns ± 0% ~ (p=0.079 n=4+5) Yn-16 40.7ns ± 0% 39.5ns ± 0% -2.82% (p=0.008 n=5+5) Float64bits-16 0.21ns ± 0% 0.21ns ± 0% ~ (p=0.603 n=5+5) Float64frombits-16 0.21ns ± 0% 0.21ns ± 0% ~ (p=0.984 n=4+5) Float32bits-16 0.21ns ± 0% 0.21ns ± 0% ~ (p=0.778 n=4+5) Float32frombits-16 0.21ns ± 0% 0.20ns ± 0% ~ (p=0.397 n=5+5) FMA-16 0.82ns ± 0% 0.82ns ± 0% +0.02% (p=0.029 n=4+4) [Geo mean] 2.87ns 2.74ns -4.61% math/cmplx: name old time/op new time/op delta Abs-16 2.07ns ± 0% 2.05ns ± 0% -0.70% (p=0.016 n=5+4) Acos-16 36.5ns ± 0% 35.7ns ± 0% -2.33% (p=0.029 n=4+4) Acosh-16 37.0ns ± 0% 36.2ns ± 0% -2.20% (p=0.008 n=5+5) Asin-16 36.5ns ± 0% 35.7ns ± 0% -2.29% (p=0.008 n=5+5) Asinh-16 33.5ns ± 0% 31.6ns ± 0% -5.51% (p=0.008 n=5+5) Atan-16 15.5ns ± 0% 13.9ns ± 0% -10.61% (p=0.008 n=5+5) Atanh-16 15.0ns ± 0% 13.6ns ± 0% -9.73% (p=0.008 n=5+5) Conj-16 0.11ns ± 5% 0.11ns ± 1% ~ (p=0.421 n=5+5) Cos-16 12.3ns ± 0% 12.2ns ± 0% -0.60% (p=0.000 n=4+5) Cosh-16 12.1ns ± 0% 12.0ns ± 0% ~ (p=0.079 n=4+5) Exp-16 10.0ns ± 0% 9.8ns ± 0% -1.77% (p=0.008 n=5+5) Log-16 14.5ns ± 0% 13.7ns ± 0% -5.67% (p=0.008 n=5+5) Log10-16 14.5ns ± 0% 13.7ns ± 0% -5.55% (p=0.000 n=5+4) Phase-16 5.11ns ± 0% 4.25ns ± 0% -16.90% (p=0.008 n=5+5) Polar-16 7.12ns ± 0% 6.35ns ± 0% -10.90% (p=0.008 n=5+5) Pow-16 64.3ns ± 0% 63.7ns ± 0% -0.97% (p=0.008 n=5+5) Rect-16 5.74ns ± 0% 5.58ns ± 0% -2.73% (p=0.016 n=4+5) Sin-16 12.2ns ± 0% 12.2ns ± 0% -0.54% (p=0.000 n=4+5) Sinh-16 12.1ns ± 0% 12.0ns ± 0% -0.58% (p=0.000 n=5+4) Sqrt-16 5.30ns ± 0% 5.18ns ± 0% -2.36% (p=0.008 n=5+5) Tan-16 22.7ns ± 0% 22.6ns ± 0% -0.33% (p=0.008 n=5+5) Tanh-16 21.2ns ± 0% 20.9ns ± 0% -1.32% (p=0.008 n=5+5) [Geo mean] 11.3ns 10.8ns -3.97% Change-Id: Idcc4b357ba68477929c126289e5095b27a827b1b Reviewed-on: https://go-review.googlesource.com/c/go/+/646335 LUCI-TryBot-Result: Go LUCI Auto-Submit: Keith Randall Reviewed-by: Keith Randall Reviewed-by: Dmitri Shuralyov Reviewed-by: Keith Randall --- src/cmd/compile/internal/amd64/ssa.go | 4 +- src/cmd/compile/internal/ssa/_gen/AMD64.rules | 5 +- src/cmd/compile/internal/ssa/_gen/AMD64Ops.go | 10 +++- src/cmd/compile/internal/ssa/opGen.go | 47 +++++++++++++++++++ src/cmd/compile/internal/ssa/rewriteAMD64.go | 44 ++++++++++++++++- test/codegen/floats.go | 2 + test/codegen/math.go | 6 ++- 7 files changed, 110 insertions(+), 8 deletions(-) diff --git a/src/cmd/compile/internal/amd64/ssa.go b/src/cmd/compile/internal/amd64/ssa.go index 493369af51..9eef71f760 100644 --- a/src/cmd/compile/internal/amd64/ssa.go +++ b/src/cmd/compile/internal/amd64/ssa.go @@ -202,7 +202,7 @@ func getgFromTLS(s *ssagen.State, r int16) { func ssaGenValue(s *ssagen.State, v *ssa.Value) { switch v.Op { - case ssa.OpAMD64VFMADD231SD: + case ssa.OpAMD64VFMADD231SD, ssa.OpAMD64VFMADD231SS: p := s.Prog(v.Op.Asm()) p.From = obj.Addr{Type: obj.TYPE_REG, Reg: v.Args[2].Reg()} p.To = obj.Addr{Type: obj.TYPE_REG, Reg: v.Reg()} @@ -1170,6 +1170,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { case ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD, ssa.OpAMD64SQRTSS: p.To.Reg = v.Reg() } + case ssa.OpAMD64LoweredRound32F, ssa.OpAMD64LoweredRound64F: + // input is already rounded case ssa.OpAMD64ROUNDSD: p := s.Prog(v.Op.Asm()) val := v.AuxInt diff --git a/src/cmd/compile/internal/ssa/_gen/AMD64.rules b/src/cmd/compile/internal/ssa/_gen/AMD64.rules index 0e429b5be7..9177067e52 100644 --- a/src/cmd/compile/internal/ssa/_gen/AMD64.rules +++ b/src/cmd/compile/internal/ssa/_gen/AMD64.rules @@ -170,7 +170,7 @@ (Cvt32Fto64F ...) => (CVTSS2SD ...) (Cvt64Fto32F ...) => (CVTSD2SS ...) -(Round(32|64)F ...) => (Copy ...) +(Round(32|64)F ...) => (LoweredRound(32|64)F ...) // Floating-point min is tricky, as the hardware op isn't right for various special // cases (-0 and NaN). We use two hardware ops organized just right to make the @@ -1589,6 +1589,9 @@ (MULSDload x [off] {sym} ptr (MOVQstore [off] {sym} ptr y _)) => (MULSD x (MOVQi2f y)) (MULSSload x [off] {sym} ptr (MOVLstore [off] {sym} ptr y _)) => (MULSS x (MOVLi2f y)) +// Detect FMA +(ADDS(S|D) (MULS(S|D) x y) z) && buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v) => (VFMADD231S(S|D) z x y) + // Redirect stores to use the other register set. (MOVQstore [off] {sym} ptr (MOVQf2i val) mem) => (MOVSDstore [off] {sym} ptr val mem) (MOVLstore [off] {sym} ptr (MOVLf2i val) mem) => (MOVSSstore [off] {sym} ptr val mem) diff --git a/src/cmd/compile/internal/ssa/_gen/AMD64Ops.go b/src/cmd/compile/internal/ssa/_gen/AMD64Ops.go index 53df7af305..1cce32eba3 100644 --- a/src/cmd/compile/internal/ssa/_gen/AMD64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/AMD64Ops.go @@ -692,9 +692,15 @@ func init() { // ROUNDSD instruction is only guaraneteed to be available if GOAMD64>=v2. // For GOAMD64=v3. + // x==S for float32, x==D for float64 + // arg0 + arg1*arg2, with no intermediate rounding. + {name: "VFMADD231SS", argLength: 3, reg: fp31, resultInArg0: true, asm: "VFMADD231SS"}, {name: "VFMADD231SD", argLength: 3, reg: fp31, resultInArg0: true, asm: "VFMADD231SD"}, // Note that these operations don't exactly match the semantics of Go's diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 13ec9dc9e3..f4f648c53b 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -924,6 +924,9 @@ const ( OpAMD64SQRTSD OpAMD64SQRTSS OpAMD64ROUNDSD + OpAMD64LoweredRound32F + OpAMD64LoweredRound64F + OpAMD64VFMADD231SS OpAMD64VFMADD231SD OpAMD64MINSD OpAMD64MINSS @@ -12060,6 +12063,50 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "LoweredRound32F", + argLen: 1, + resultInArg0: true, + zeroWidth: true, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, + { + name: "LoweredRound64F", + argLen: 1, + resultInArg0: true, + zeroWidth: true, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, + { + name: "VFMADD231SS", + argLen: 3, + resultInArg0: true, + asm: x86.AVFMADD231SS, + reg: regInfo{ + inputs: []inputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + {2, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + outputs: []outputInfo{ + {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 + }, + }, + }, { name: "VFMADD231SD", argLen: 3, diff --git a/src/cmd/compile/internal/ssa/rewriteAMD64.go b/src/cmd/compile/internal/ssa/rewriteAMD64.go index 9ea1114d45..63376dcb76 100644 --- a/src/cmd/compile/internal/ssa/rewriteAMD64.go +++ b/src/cmd/compile/internal/ssa/rewriteAMD64.go @@ -1007,10 +1007,10 @@ func rewriteValueAMD64(v *Value) bool { v.Op = OpAMD64ROLB return true case OpRound32F: - v.Op = OpCopy + v.Op = OpAMD64LoweredRound32F return true case OpRound64F: - v.Op = OpCopy + v.Op = OpAMD64LoweredRound64F return true case OpRoundToEven: return rewriteValueAMD64_OpRoundToEven(v) @@ -2430,6 +2430,26 @@ func rewriteValueAMD64_OpAMD64ADDSD(v *Value) bool { } break } + // match: (ADDSD (MULSD x y) z) + // cond: buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v) + // result: (VFMADD231SD z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64MULSD { + continue + } + y := v_0.Args[1] + x := v_0.Args[0] + z := v_1 + if !(buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)) { + continue + } + v.reset(OpAMD64VFMADD231SD) + v.AddArg3(z, x, y) + return true + } + break + } return false } func rewriteValueAMD64_OpAMD64ADDSDload(v *Value) bool { @@ -2533,6 +2553,26 @@ func rewriteValueAMD64_OpAMD64ADDSS(v *Value) bool { } break } + // match: (ADDSS (MULSS x y) z) + // cond: buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v) + // result: (VFMADD231SS z x y) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpAMD64MULSS { + continue + } + y := v_0.Args[1] + x := v_0.Args[0] + z := v_1 + if !(buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)) { + continue + } + v.reset(OpAMD64VFMADD231SS) + v.AddArg3(z, x, y) + return true + } + break + } return false } func rewriteValueAMD64_OpAMD64ADDSSload(v *Value) bool { diff --git a/test/codegen/floats.go b/test/codegen/floats.go index 1b85eba352..2a5cf39957 100644 --- a/test/codegen/floats.go +++ b/test/codegen/floats.go @@ -74,6 +74,7 @@ func FusedAdd32(x, y, z float32) float32 { // arm64:"FMADDS" // loong64:"FMADDF\t" // riscv64:"FMADDS\t" + // amd64/v3:"VFMADD231SS\t" return x*y + z } @@ -98,6 +99,7 @@ func FusedAdd64(x, y, z float64) float64 { // arm64:"FMADDD" // loong64:"FMADDD\t" // riscv64:"FMADDD\t" + // amd64/v3:"VFMADD231SD\t" return x*y + z } diff --git a/test/codegen/math.go b/test/codegen/math.go index 4ce5fa419d..87d9cd7b27 100644 --- a/test/codegen/math.go +++ b/test/codegen/math.go @@ -240,10 +240,11 @@ func nanGenerate64() float64 { // amd64:"DIVSD" z0 := zero / zero - // amd64:"MULSD" + // amd64/v1,amd64/v2:"MULSD" z1 := zero * inf // amd64:"SQRTSD" z2 := math.Sqrt(negone) + // amd64/v3:"VFMADD231SD" return z0 + z1 + z2 } @@ -254,7 +255,8 @@ func nanGenerate32() float32 { // amd64:"DIVSS" z0 := zero / zero - // amd64:"MULSS" + // amd64/v1,amd64/v2:"MULSS" z1 := zero * inf + // amd64/v3:"VFMADD231SS" return z0 + z1 } -- 2.50.0