]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: lower x*z + y to FMA if FMA enabled
authorJakub Ciolek <jakub@ciolek.dev>
Sun, 2 Feb 2025 22:42:43 +0000 (23:42 +0100)
committerGopher Robot <gobot@golang.org>
Thu, 13 Feb 2025 20:34:33 +0000 (12:34 -0800)
There is a generic opcode for FMA, but we don't use it in rewrite rules.
This is maybe because some archs, like WASM and MIPS don't have a late
lowering rule for it.

Fixes #71204

Intel Alder Lake 12600k (GOAMD64=v3):

math:

name                    old time/op  new time/op  delta
Acos-16                 4.58ns ± 0%  3.36ns ± 0%  -26.68%  (p=0.008 n=5+5)
Acosh-16                8.04ns ± 1%  6.44ns ± 0%  -19.95%  (p=0.008 n=5+5)
Asin-16                 4.28ns ± 0%  3.32ns ± 0%  -22.24%  (p=0.008 n=5+5)
Asinh-16                9.92ns ± 0%  8.62ns ± 0%  -13.13%  (p=0.008 n=5+5)
Atan-16                 2.31ns ± 0%  1.84ns ± 0%  -20.02%  (p=0.008 n=5+5)
Atanh-16                7.79ns ± 0%  7.03ns ± 0%   -9.67%  (p=0.008 n=5+5)
Atan2-16                3.93ns ± 0%  3.52ns ± 0%  -10.35%  (p=0.000 n=5+4)
Cbrt-16                 4.62ns ± 0%  4.41ns ± 0%   -4.57%  (p=0.016 n=4+5)
Ceil-16                 0.14ns ± 1%  0.14ns ± 2%     ~     (p=0.103 n=5+5)
Copysign-16             0.33ns ± 0%  0.33ns ± 0%   +0.03%  (p=0.029 n=4+4)
Cos-16                  4.87ns ± 0%  4.75ns ± 0%   -2.44%  (p=0.016 n=5+4)
Cosh-16                 4.86ns ± 0%  4.86ns ± 0%     ~     (p=0.317 n=5+5)
Erf-16                  2.71ns ± 0%  2.25ns ± 0%  -16.69%  (p=0.008 n=5+5)
Erfc-16                 3.06ns ± 0%  2.67ns ± 0%  -13.00%  (p=0.016 n=5+4)
Erfinv-16               3.88ns ± 0%  2.84ns ± 3%  -26.83%  (p=0.008 n=5+5)
Erfcinv-16              4.08ns ± 0%  3.01ns ± 1%  -26.27%  (p=0.008 n=5+5)
Exp-16                  3.29ns ± 0%  3.37ns ± 2%   +2.64%  (p=0.016 n=4+5)
ExpGo-16                8.44ns ± 0%  7.48ns ± 1%  -11.37%  (p=0.008 n=5+5)
Expm1-16                4.46ns ± 0%  3.69ns ± 2%  -17.26%  (p=0.016 n=4+5)
Exp2-16                 8.20ns ± 0%  7.39ns ± 2%   -9.94%  (p=0.008 n=5+5)
Exp2Go-16               8.26ns ± 0%  7.23ns ± 0%  -12.49%  (p=0.016 n=4+5)
Abs-16                  0.26ns ± 3%  0.22ns ± 1%  -16.34%  (p=0.008 n=5+5)
Dim-16                  0.38ns ± 1%  0.40ns ± 2%   +5.02%  (p=0.008 n=5+5)
Floor-16                0.11ns ± 1%  0.17ns ± 4%  +54.99%  (p=0.008 n=5+5)
Max-16                  1.24ns ± 0%  1.24ns ± 0%     ~     (p=0.619 n=5+5)
Min-16                  1.24ns ± 0%  1.24ns ± 0%     ~     (p=0.484 n=5+5)
Mod-16                  13.4ns ± 1%  12.8ns ± 0%   -4.21%  (p=0.016 n=5+4)
Frexp-16                1.70ns ± 0%  1.71ns ± 0%   +0.46%  (p=0.008 n=5+5)
Gamma-16                3.97ns ± 0%  3.97ns ± 0%     ~     (p=0.643 n=5+5)
Hypot-16                2.11ns ± 0%  2.11ns ± 0%     ~     (p=0.762 n=5+5)
HypotGo-16              2.48ns ± 4%  2.26ns ± 0%   -8.94%  (p=0.008 n=5+5)
Ilogb-16                1.67ns ± 0%  1.67ns ± 0%   -0.07%  (p=0.048 n=5+5)
J0-16                   19.8ns ± 0%  19.3ns ± 0%     ~     (p=0.079 n=4+5)
J1-16                   19.4ns ± 0%  18.9ns ± 0%   -2.63%  (p=0.000 n=5+4)
Jn-16                   41.5ns ± 0%  40.6ns ± 0%   -2.32%  (p=0.016 n=4+5)
Ldexp-16                2.26ns ± 0%  2.26ns ± 0%     ~     (p=0.683 n=5+5)
Lgamma-16               4.40ns ± 0%  4.21ns ± 0%   -4.21%  (p=0.008 n=5+5)
Log-16                  4.05ns ± 0%  4.05ns ± 0%     ~     (all equal)
Logb-16                 1.69ns ± 0%  1.69ns ± 0%     ~     (p=0.429 n=5+5)
Log1p-16                5.00ns ± 0%  3.99ns ± 0%  -20.14%  (p=0.008 n=5+5)
Log10-16                4.22ns ± 0%  4.21ns ± 0%   -0.15%  (p=0.008 n=5+5)
Log2-16                 2.27ns ± 0%  2.25ns ± 0%   -0.94%  (p=0.008 n=5+5)
Modf-16                 1.44ns ± 0%  1.44ns ± 0%     ~     (p=0.492 n=5+5)
Nextafter32-16          2.09ns ± 0%  2.09ns ± 0%     ~     (p=0.079 n=4+5)
Nextafter64-16          2.09ns ± 0%  2.09ns ± 0%     ~     (p=0.095 n=4+5)
PowInt-16               10.8ns ± 0%  10.8ns ± 0%     ~     (all equal)
PowFrac-16              25.3ns ± 0%  25.3ns ± 0%   -0.09%  (p=0.000 n=5+4)
Pow10Pos-16             0.52ns ± 1%  0.52ns ± 0%     ~     (p=0.810 n=5+5)
Pow10Neg-16             0.82ns ± 0%  0.82ns ± 0%     ~     (p=0.381 n=5+5)
Round-16                0.93ns ± 0%  0.93ns ± 0%     ~     (p=0.056 n=5+5)
RoundToEven-16          1.64ns ± 0%  1.64ns ± 0%     ~     (all equal)
Remainder-16            12.4ns ± 2%  12.0ns ± 0%   -3.27%  (p=0.008 n=5+5)
Signbit-16              0.37ns ± 0%  0.37ns ± 0%   -0.19%  (p=0.008 n=5+5)
Sin-16                  4.04ns ± 0%  3.92ns ± 0%   -3.13%  (p=0.000 n=4+5)
Sincos-16               5.99ns ± 0%  5.80ns ± 0%   -3.03%  (p=0.008 n=5+5)
Sinh-16                 5.22ns ± 0%  5.22ns ± 0%     ~     (p=0.651 n=5+4)
SqrtIndirect-16         0.41ns ± 0%  0.41ns ± 0%     ~     (p=0.333 n=4+5)
SqrtLatency-16          2.66ns ± 0%  2.66ns ± 0%     ~     (p=0.079 n=4+5)
SqrtIndirectLatency-16  2.66ns ± 0%  2.66ns ± 0%     ~     (p=1.000 n=5+5)
SqrtGoLatency-16        30.1ns ± 0%  28.6ns ± 1%   -4.84%  (p=0.008 n=5+5)
SqrtPrime-16             645ns ± 0%   645ns ± 0%     ~     (p=0.095 n=5+4)
Tan-16                  4.21ns ± 0%  4.09ns ± 0%   -2.76%  (p=0.029 n=4+4)
Tanh-16                 5.36ns ± 0%  5.36ns ± 0%     ~     (p=0.444 n=5+5)
Trunc-16                0.12ns ± 6%  0.11ns ± 1%   -6.79%  (p=0.008 n=5+5)
Y0-16                   19.2ns ± 0%  18.7ns ± 0%   -2.52%  (p=0.000 n=5+4)
Y1-16                   19.1ns ± 0%  18.4ns ± 0%     ~     (p=0.079 n=4+5)
Yn-16                   40.7ns ± 0%  39.5ns ± 0%   -2.82%  (p=0.008 n=5+5)
Float64bits-16          0.21ns ± 0%  0.21ns ± 0%     ~     (p=0.603 n=5+5)
Float64frombits-16      0.21ns ± 0%  0.21ns ± 0%     ~     (p=0.984 n=4+5)
Float32bits-16          0.21ns ± 0%  0.21ns ± 0%     ~     (p=0.778 n=4+5)
Float32frombits-16      0.21ns ± 0%  0.20ns ± 0%     ~     (p=0.397 n=5+5)
FMA-16                  0.82ns ± 0%  0.82ns ± 0%   +0.02%  (p=0.029 n=4+4)
[Geo mean]              2.87ns       2.74ns        -4.61%

math/cmplx:

name        old time/op  new time/op  delta
Abs-16      2.07ns ± 0%  2.05ns ± 0%   -0.70%  (p=0.016 n=5+4)
Acos-16     36.5ns ± 0%  35.7ns ± 0%   -2.33%  (p=0.029 n=4+4)
Acosh-16    37.0ns ± 0%  36.2ns ± 0%   -2.20%  (p=0.008 n=5+5)
Asin-16     36.5ns ± 0%  35.7ns ± 0%   -2.29%  (p=0.008 n=5+5)
Asinh-16    33.5ns ± 0%  31.6ns ± 0%   -5.51%  (p=0.008 n=5+5)
Atan-16     15.5ns ± 0%  13.9ns ± 0%  -10.61%  (p=0.008 n=5+5)
Atanh-16    15.0ns ± 0%  13.6ns ± 0%   -9.73%  (p=0.008 n=5+5)
Conj-16     0.11ns ± 5%  0.11ns ± 1%     ~     (p=0.421 n=5+5)
Cos-16      12.3ns ± 0%  12.2ns ± 0%   -0.60%  (p=0.000 n=4+5)
Cosh-16     12.1ns ± 0%  12.0ns ± 0%     ~     (p=0.079 n=4+5)
Exp-16      10.0ns ± 0%   9.8ns ± 0%   -1.77%  (p=0.008 n=5+5)
Log-16      14.5ns ± 0%  13.7ns ± 0%   -5.67%  (p=0.008 n=5+5)
Log10-16    14.5ns ± 0%  13.7ns ± 0%   -5.55%  (p=0.000 n=5+4)
Phase-16    5.11ns ± 0%  4.25ns ± 0%  -16.90%  (p=0.008 n=5+5)
Polar-16    7.12ns ± 0%  6.35ns ± 0%  -10.90%  (p=0.008 n=5+5)
Pow-16      64.3ns ± 0%  63.7ns ± 0%   -0.97%  (p=0.008 n=5+5)
Rect-16     5.74ns ± 0%  5.58ns ± 0%   -2.73%  (p=0.016 n=4+5)
Sin-16      12.2ns ± 0%  12.2ns ± 0%   -0.54%  (p=0.000 n=4+5)
Sinh-16     12.1ns ± 0%  12.0ns ± 0%   -0.58%  (p=0.000 n=5+4)
Sqrt-16     5.30ns ± 0%  5.18ns ± 0%   -2.36%  (p=0.008 n=5+5)
Tan-16      22.7ns ± 0%  22.6ns ± 0%   -0.33%  (p=0.008 n=5+5)
Tanh-16     21.2ns ± 0%  20.9ns ± 0%   -1.32%  (p=0.008 n=5+5)
[Geo mean]  11.3ns       10.8ns        -3.97%

Change-Id: Idcc4b357ba68477929c126289e5095b27a827b1b
Reviewed-on: https://go-review.googlesource.com/c/go/+/646335
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
src/cmd/compile/internal/amd64/ssa.go
src/cmd/compile/internal/ssa/_gen/AMD64.rules
src/cmd/compile/internal/ssa/_gen/AMD64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteAMD64.go
test/codegen/floats.go
test/codegen/math.go

index 493369af51ce150cf6e59d5d8f6299cb7e7fd0fa..9eef71f760357a9a3502849d5baaf910fe17bbc5 100644 (file)
@@ -202,7 +202,7 @@ func getgFromTLS(s *ssagen.State, r int16) {
 
 func ssaGenValue(s *ssagen.State, v *ssa.Value) {
        switch v.Op {
-       case ssa.OpAMD64VFMADD231SD:
+       case ssa.OpAMD64VFMADD231SD, ssa.OpAMD64VFMADD231SS:
                p := s.Prog(v.Op.Asm())
                p.From = obj.Addr{Type: obj.TYPE_REG, Reg: v.Args[2].Reg()}
                p.To = obj.Addr{Type: obj.TYPE_REG, Reg: v.Reg()}
@@ -1170,6 +1170,8 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                case ssa.OpAMD64BSFL, ssa.OpAMD64BSRL, ssa.OpAMD64SQRTSD, ssa.OpAMD64SQRTSS:
                        p.To.Reg = v.Reg()
                }
+       case ssa.OpAMD64LoweredRound32F, ssa.OpAMD64LoweredRound64F:
+               // input is already rounded
        case ssa.OpAMD64ROUNDSD:
                p := s.Prog(v.Op.Asm())
                val := v.AuxInt
index 0e429b5be74dcb05315b9c10ee24a9fe9c904fbd..9177067e522206d4594f664eb6aed283d87a8b16 100644 (file)
 (Cvt32Fto64F ...) => (CVTSS2SD ...)
 (Cvt64Fto32F ...) => (CVTSD2SS ...)
 
-(Round(32|64)F ...) => (Copy ...)
+(Round(32|64)F ...) => (LoweredRound(32|64)F ...)
 
 // Floating-point min is tricky, as the hardware op isn't right for various special
 // cases (-0 and NaN). We use two hardware ops organized just right to make the
 (MULSDload x [off] {sym} ptr (MOVQstore [off] {sym} ptr y _)) => (MULSD x (MOVQi2f y))
 (MULSSload x [off] {sym} ptr (MOVLstore [off] {sym} ptr y _)) => (MULSS x (MOVLi2f y))
 
+// Detect FMA
+(ADDS(S|D) (MULS(S|D) x y) z) && buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v) => (VFMADD231S(S|D) z x y)
+
 // Redirect stores to use the other register set.
 (MOVQstore  [off] {sym} ptr (MOVQf2i val) mem) => (MOVSDstore [off] {sym} ptr val mem)
 (MOVLstore  [off] {sym} ptr (MOVLf2i val) mem) => (MOVSSstore [off] {sym} ptr val mem)
index 53df7af3059a522a8f19a175f01fc6597a508e89..1cce32eba32523a3a7c41973fb783c72f748d6e1 100644 (file)
@@ -692,9 +692,15 @@ func init() {
                // ROUNDSD instruction is only guaraneteed to be available if GOAMD64>=v2.
                // For GOAMD64<v2, any use must be preceded by a successful check of runtime.x86HasSSE41.
                {name: "ROUNDSD", argLength: 1, reg: fp11, aux: "Int8", asm: "ROUNDSD"},
+               // See why we need those in issue #71204
+               {name: "LoweredRound32F", argLength: 1, reg: fp11, resultInArg0: true, zeroWidth: true},
+               {name: "LoweredRound64F", argLength: 1, reg: fp11, resultInArg0: true, zeroWidth: true},
 
-               // VFMADD231SD only exists on platforms with the FMA3 instruction set.
-               // Any use must be preceded by a successful check of runtime.support_fma.
+               // VFMADD231Sx only exist on platforms with the FMA3 instruction set.
+               // Any use must be preceded by a successful check of runtime.x86HasFMA or a check of GOAMD64>=v3.
+               // x==S for float32, x==D for float64
+               // arg0 + arg1*arg2, with no intermediate rounding.
+               {name: "VFMADD231SS", argLength: 3, reg: fp31, resultInArg0: true, asm: "VFMADD231SS"},
                {name: "VFMADD231SD", argLength: 3, reg: fp31, resultInArg0: true, asm: "VFMADD231SD"},
 
                // Note that these operations don't exactly match the semantics of Go's
index 13ec9dc9e3fc3abec04c319a8dcbf29ec700470f..f4f648c53b5fa6968b1a5eb28560699c1832f525 100644 (file)
@@ -924,6 +924,9 @@ const (
        OpAMD64SQRTSD
        OpAMD64SQRTSS
        OpAMD64ROUNDSD
+       OpAMD64LoweredRound32F
+       OpAMD64LoweredRound64F
+       OpAMD64VFMADD231SS
        OpAMD64VFMADD231SD
        OpAMD64MINSD
        OpAMD64MINSS
@@ -12060,6 +12063,50 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:         "LoweredRound32F",
+               argLen:       1,
+               resultInArg0: true,
+               zeroWidth:    true,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
+       {
+               name:         "LoweredRound64F",
+               argLen:       1,
+               resultInArg0: true,
+               zeroWidth:    true,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
+       {
+               name:         "VFMADD231SS",
+               argLen:       3,
+               resultInArg0: true,
+               asm:          x86.AVFMADD231SS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                               {1, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                               {2, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+                       outputs: []outputInfo{
+                               {0, 2147418112}, // X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14
+                       },
+               },
+       },
        {
                name:         "VFMADD231SD",
                argLen:       3,
index 9ea1114d45e9b0cc8e4bf8df87680cb2c1bcb2de..63376dcb76f2d98668d8e4ad9aa3ac3ff75d08b1 100644 (file)
@@ -1007,10 +1007,10 @@ func rewriteValueAMD64(v *Value) bool {
                v.Op = OpAMD64ROLB
                return true
        case OpRound32F:
-               v.Op = OpCopy
+               v.Op = OpAMD64LoweredRound32F
                return true
        case OpRound64F:
-               v.Op = OpCopy
+               v.Op = OpAMD64LoweredRound64F
                return true
        case OpRoundToEven:
                return rewriteValueAMD64_OpRoundToEven(v)
@@ -2430,6 +2430,26 @@ func rewriteValueAMD64_OpAMD64ADDSD(v *Value) bool {
                }
                break
        }
+       // match: (ADDSD (MULSD x y) z)
+       // cond: buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)
+       // result: (VFMADD231SD z x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64MULSD {
+                               continue
+                       }
+                       y := v_0.Args[1]
+                       x := v_0.Args[0]
+                       z := v_1
+                       if !(buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)) {
+                               continue
+                       }
+                       v.reset(OpAMD64VFMADD231SD)
+                       v.AddArg3(z, x, y)
+                       return true
+               }
+               break
+       }
        return false
 }
 func rewriteValueAMD64_OpAMD64ADDSDload(v *Value) bool {
@@ -2533,6 +2553,26 @@ func rewriteValueAMD64_OpAMD64ADDSS(v *Value) bool {
                }
                break
        }
+       // match: (ADDSS (MULSS x y) z)
+       // cond: buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)
+       // result: (VFMADD231SS z x y)
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       if v_0.Op != OpAMD64MULSS {
+                               continue
+                       }
+                       y := v_0.Args[1]
+                       x := v_0.Args[0]
+                       z := v_1
+                       if !(buildcfg.GOAMD64 >= 3 && z.Block.Func.useFMA(v)) {
+                               continue
+                       }
+                       v.reset(OpAMD64VFMADD231SS)
+                       v.AddArg3(z, x, y)
+                       return true
+               }
+               break
+       }
        return false
 }
 func rewriteValueAMD64_OpAMD64ADDSSload(v *Value) bool {
index 1b85eba35249c95e09bb94029f01e6f7cf20c693..2a5cf3995781e61d885fdeea8f8e007eb5d1c91c 100644 (file)
@@ -74,6 +74,7 @@ func FusedAdd32(x, y, z float32) float32 {
        // arm64:"FMADDS"
        // loong64:"FMADDF\t"
        // riscv64:"FMADDS\t"
+       // amd64/v3:"VFMADD231SS\t"
        return x*y + z
 }
 
@@ -98,6 +99,7 @@ func FusedAdd64(x, y, z float64) float64 {
        // arm64:"FMADDD"
        // loong64:"FMADDD\t"
        // riscv64:"FMADDD\t"
+       // amd64/v3:"VFMADD231SD\t"
        return x*y + z
 }
 
index 4ce5fa419d2b79e84cd12bbbbd83db78c9615733..87d9cd7b2715ba32c2a8a05e82d33074148285aa 100644 (file)
@@ -240,10 +240,11 @@ func nanGenerate64() float64 {
 
        // amd64:"DIVSD"
        z0 := zero / zero
-       // amd64:"MULSD"
+       // amd64/v1,amd64/v2:"MULSD"
        z1 := zero * inf
        // amd64:"SQRTSD"
        z2 := math.Sqrt(negone)
+       // amd64/v3:"VFMADD231SD"
        return z0 + z1 + z2
 }
 
@@ -254,7 +255,8 @@ func nanGenerate32() float32 {
 
        // amd64:"DIVSS"
        z0 := zero / zero
-       // amd64:"MULSS"
+       // amd64/v1,amd64/v2:"MULSS"
        z1 := zero * inf
+       // amd64/v3:"VFMADD231SS"
        return z0 + z1
 }