From 33425ab8dbb03c355b2263b8250a1829e260d66f Mon Sep 17 00:00:00 2001 From: smasher164 Date: Wed, 29 Aug 2018 20:57:33 -0400 Subject: [PATCH] cmd/compile: introduce generic ssa intrinsic for fused-multiply-add In order to make math.FMA a compiler intrinsic for ISAs like ARM64, PPC64[le], and S390X, a generic 3-argument opcode "Fma" is provided and rewritten as ARM64: (Fma x y z) -> (FMADDD z x y) PPC64: (Fma x y z) -> (FMADD x y z) S390X: (Fma x y z) -> (FMADD z x y) Updates #25819. Change-Id: Ie5bc628311e6feeb28ddf9adaa6e702c8c291efa Reviewed-on: https://go-review.googlesource.com/c/go/+/131959 Run-TryBot: Akhil Indurti TryBot-Result: Gobot Gobot Reviewed-by: Keith Randall --- src/cmd/compile/internal/gc/ssa.go | 5 +++++ src/cmd/compile/internal/ssa/gen/ARM64.rules | 1 + src/cmd/compile/internal/ssa/gen/PPC64.rules | 1 + src/cmd/compile/internal/ssa/gen/S390X.rules | 1 + src/cmd/compile/internal/ssa/gen/genericOps.go | 14 ++++++++++++++ src/cmd/compile/internal/ssa/opGen.go | 6 ++++++ src/cmd/compile/internal/ssa/rewriteARM64.go | 17 +++++++++++++++++ src/cmd/compile/internal/ssa/rewritePPC64.go | 17 +++++++++++++++++ src/cmd/compile/internal/ssa/rewriteS390X.go | 17 +++++++++++++++++ test/codegen/math.go | 8 ++++++++ 10 files changed, 87 insertions(+) diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go index dd8dacd149..0b76ad728c 100644 --- a/src/cmd/compile/internal/gc/ssa.go +++ b/src/cmd/compile/internal/gc/ssa.go @@ -3321,6 +3321,11 @@ func init() { return s.newValue2(ssa.OpCopysign, types.Types[TFLOAT64], args[0], args[1]) }, sys.PPC64, sys.Wasm) + addF("math", "Fma", + func(s *state, n *Node, args []*ssa.Value) *ssa.Value { + return s.newValue3(ssa.OpFma, types.Types[TFLOAT64], args[0], args[1], args[2]) + }, + sys.ARM64, sys.PPC64, sys.S390X) makeRoundAMD64 := func(op ssa.Op) func(s *state, n *Node, args []*ssa.Value) *ssa.Value { return func(s *state, n *Node, args []*ssa.Value) *ssa.Value { diff --git a/src/cmd/compile/internal/ssa/gen/ARM64.rules b/src/cmd/compile/internal/ssa/gen/ARM64.rules index 7edd19e9cc..26ae004572 100644 --- a/src/cmd/compile/internal/ssa/gen/ARM64.rules +++ b/src/cmd/compile/internal/ssa/gen/ARM64.rules @@ -90,6 +90,7 @@ (Round x) -> (FRINTAD x) (RoundToEven x) -> (FRINTND x) (Trunc x) -> (FRINTZD x) +(Fma x y z) -> (FMADDD z x y) // lowering rotates (RotateLeft8 x (MOVDconst [c])) -> (Or8 (Lsh8x64 x (MOVDconst [c&7])) (Rsh8Ux64 x (MOVDconst [-c&7]))) diff --git a/src/cmd/compile/internal/ssa/gen/PPC64.rules b/src/cmd/compile/internal/ssa/gen/PPC64.rules index 59cce4ed57..239414f01b 100644 --- a/src/cmd/compile/internal/ssa/gen/PPC64.rules +++ b/src/cmd/compile/internal/ssa/gen/PPC64.rules @@ -68,6 +68,7 @@ (Round x) -> (FROUND x) (Copysign x y) -> (FCPSGN y x) (Abs x) -> (FABS x) +(Fma x y z) -> (FMADD x y z) // Lowering constants (Const(64|32|16|8) [val]) -> (MOVDconst [val]) diff --git a/src/cmd/compile/internal/ssa/gen/S390X.rules b/src/cmd/compile/internal/ssa/gen/S390X.rules index 4e459043b1..d7cb972b81 100644 --- a/src/cmd/compile/internal/ssa/gen/S390X.rules +++ b/src/cmd/compile/internal/ssa/gen/S390X.rules @@ -139,6 +139,7 @@ (Trunc x) -> (FIDBR [5] x) (RoundToEven x) -> (FIDBR [4] x) (Round x) -> (FIDBR [1] x) +(Fma x y z) -> (FMADD z x y) // Atomic loads and stores. // The SYNC instruction (fast-BCR-serialization) prevents store-load diff --git a/src/cmd/compile/internal/ssa/gen/genericOps.go b/src/cmd/compile/internal/ssa/gen/genericOps.go index df0dd8cabc..7bd79312e3 100644 --- a/src/cmd/compile/internal/ssa/gen/genericOps.go +++ b/src/cmd/compile/internal/ssa/gen/genericOps.go @@ -302,6 +302,20 @@ var genericOps = []opData{ {name: "Abs", argLength: 1}, // absolute value arg0 {name: "Copysign", argLength: 2}, // copy sign from arg0 to arg1 + // 3-input opcode. + // Fused-multiply-add, float64 only. + // When a*b+c is exactly zero (before rounding), then the result is +0 or -0. + // The 0's sign is determined according to the standard rules for the + // addition (-0 if both a*b and c are -0, +0 otherwise). + // + // Otherwise, when a*b+c rounds to zero, then the resulting 0's sign is + // determined by the sign of the exact result a*b+c. + // See section 6.3 in ieee754. + // + // When the multiply is an infinity times a zero, the result is NaN. + // See section 7.2 in ieee754. + {name: "Fma", argLength: 3}, // compute (a*b)+c without intermediate rounding + // Data movement, max argument length for Phi is indefinite so just pick // a really large number {name: "Phi", argLength: -1, zeroWidth: true}, // select an argument based on which predecessor block we came from diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index b7e65174f9..7f9fb4e3ef 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2419,6 +2419,7 @@ const ( OpRoundToEven OpAbs OpCopysign + OpFma OpPhi OpCopy OpConvert @@ -30593,6 +30594,11 @@ var opcodeTable = [...]opInfo{ argLen: 2, generic: true, }, + { + name: "Fma", + argLen: 3, + generic: true, + }, { name: "Phi", argLen: -1, diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go index dfb5554f62..2aa38f574f 100644 --- a/src/cmd/compile/internal/ssa/rewriteARM64.go +++ b/src/cmd/compile/internal/ssa/rewriteARM64.go @@ -571,6 +571,8 @@ func rewriteValueARM64(v *Value) bool { return rewriteValueARM64_OpEqPtr_0(v) case OpFloor: return rewriteValueARM64_OpFloor_0(v) + case OpFma: + return rewriteValueARM64_OpFma_0(v) case OpGeq16: return rewriteValueARM64_OpGeq16_0(v) case OpGeq16U: @@ -28565,6 +28567,21 @@ func rewriteValueARM64_OpFloor_0(v *Value) bool { return true } } +func rewriteValueARM64_OpFma_0(v *Value) bool { + // match: (Fma x y z) + // cond: + // result: (FMADDD z x y) + for { + z := v.Args[2] + x := v.Args[0] + y := v.Args[1] + v.reset(OpARM64FMADDD) + v.AddArg(z) + v.AddArg(x) + v.AddArg(y) + return true + } +} func rewriteValueARM64_OpGeq16_0(v *Value) bool { b := v.Block typ := &b.Func.Config.Types diff --git a/src/cmd/compile/internal/ssa/rewritePPC64.go b/src/cmd/compile/internal/ssa/rewritePPC64.go index 7f49d98bd1..b09bd85ca1 100644 --- a/src/cmd/compile/internal/ssa/rewritePPC64.go +++ b/src/cmd/compile/internal/ssa/rewritePPC64.go @@ -181,6 +181,8 @@ func rewriteValuePPC64(v *Value) bool { return rewriteValuePPC64_OpEqPtr_0(v) case OpFloor: return rewriteValuePPC64_OpFloor_0(v) + case OpFma: + return rewriteValuePPC64_OpFma_0(v) case OpGeq16: return rewriteValuePPC64_OpGeq16_0(v) case OpGeq16U: @@ -1988,6 +1990,21 @@ func rewriteValuePPC64_OpFloor_0(v *Value) bool { return true } } +func rewriteValuePPC64_OpFma_0(v *Value) bool { + // match: (Fma x y z) + // cond: + // result: (FMADD x y z) + for { + z := v.Args[2] + x := v.Args[0] + y := v.Args[1] + v.reset(OpPPC64FMADD) + v.AddArg(x) + v.AddArg(y) + v.AddArg(z) + return true + } +} func rewriteValuePPC64_OpGeq16_0(v *Value) bool { b := v.Block typ := &b.Func.Config.Types diff --git a/src/cmd/compile/internal/ssa/rewriteS390X.go b/src/cmd/compile/internal/ssa/rewriteS390X.go index 72bbdc0e57..0c03fa2080 100644 --- a/src/cmd/compile/internal/ssa/rewriteS390X.go +++ b/src/cmd/compile/internal/ssa/rewriteS390X.go @@ -166,6 +166,8 @@ func rewriteValueS390X(v *Value) bool { return rewriteValueS390X_OpEqPtr_0(v) case OpFloor: return rewriteValueS390X_OpFloor_0(v) + case OpFma: + return rewriteValueS390X_OpFma_0(v) case OpGeq16: return rewriteValueS390X_OpGeq16_0(v) case OpGeq16U: @@ -1918,6 +1920,21 @@ func rewriteValueS390X_OpFloor_0(v *Value) bool { return true } } +func rewriteValueS390X_OpFma_0(v *Value) bool { + // match: (Fma x y z) + // cond: + // result: (FMADD z x y) + for { + z := v.Args[2] + x := v.Args[0] + y := v.Args[1] + v.reset(OpS390XFMADD) + v.AddArg(z) + v.AddArg(x) + v.AddArg(y) + return true + } +} func rewriteValueS390X_OpGeq16_0(v *Value) bool { b := v.Block typ := &b.Func.Config.Types diff --git a/test/codegen/math.go b/test/codegen/math.go index 8aa652595b..427f305c12 100644 --- a/test/codegen/math.go +++ b/test/codegen/math.go @@ -107,6 +107,14 @@ func copysign(a, b, c float64) { sink64[3] = math.Copysign(-1, c) } +func fma(x, y, z float64) float64 { + // arm64:"FMADDD" + // s390x:"FMADD" + // ppc64:"FMADD" + // ppc64le:"FMADD" + return math.Fma(x, y, z) +} + func fromFloat64(f64 float64) uint64 { // amd64:"MOVQ\tX.*, [^X].*" // arm64:"FMOVD\tF.*, R.*" -- 2.50.0