From 98e719f677e401e65e0cf41a630fc859ae0b28b8 Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Tue, 19 Nov 2024 09:15:19 -0800 Subject: [PATCH] cmd/compile: let MADD/MSUB combination happen more often on arm64 We have a single-instruction x+y*z op. Unfortunately x can't be a constant, so the rule that builds them doesn't apply in that case. This CL handles x+(c+y*z) by reordering to c+(x+y*z) so x is in the right place. Change-Id: Ibed621607d49da70474128e20991e0c4630ebfad Reviewed-on: https://go-review.googlesource.com/c/go/+/629858 LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase Reviewed-by: Keith Randall --- src/cmd/compile/internal/ssa/_gen/ARM64.rules | 6 + src/cmd/compile/internal/ssa/rewriteARM64.go | 361 ++++++++++++++++++ 2 files changed, 367 insertions(+) diff --git a/src/cmd/compile/internal/ssa/_gen/ARM64.rules b/src/cmd/compile/internal/ssa/_gen/ARM64.rules index 29dc258d9e..070329a539 100644 --- a/src/cmd/compile/internal/ssa/_gen/ARM64.rules +++ b/src/cmd/compile/internal/ssa/_gen/ARM64.rules @@ -1147,6 +1147,12 @@ (ADD a l:(MNEGW x y)) && v.Type.Size() <= 4 && l.Uses==1 && clobber(l) => (MSUBW a x y) (SUB a l:(MNEGW x y)) && v.Type.Size() <= 4 && l.Uses==1 && clobber(l) => (MADDW a x y) +// madd/msub can't take constant arguments, so do a bit of reordering if a non-constant is available. +(ADD a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (ADD a m)) +(ADD a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (ADD a m)) +(SUB a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (SUB a m)) +(SUB a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (SUB a m)) + // optimize ADCSflags, SBCSflags and friends (ADCSflags x y (Select1 (ADDSconstflags [-1] (ADCzerocarry c)))) => (ADCSflags x y c) (ADCSflags x y (Select1 (ADDSconstflags [-1] (MOVDconst [0])))) => (ADDSflags x y) diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go index edb17cedf3..ab838e6635 100644 --- a/src/cmd/compile/internal/ssa/rewriteARM64.go +++ b/src/cmd/compile/internal/ssa/rewriteARM64.go @@ -1225,6 +1225,7 @@ func rewriteValueARM64_OpARM64ADCSflags(v *Value) bool { func rewriteValueARM64_OpARM64ADD(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (ADD x (MOVDconst [c])) // cond: !t.IsPtr() // result: (ADDconst [c] x) @@ -1330,6 +1331,198 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool { } break } + // match: (ADD a p:(ADDconst [c] m:(MUL _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(ADDconst [c] m:(MULW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(ADDconst [c] m:(MNEG _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(ADDconst [c] m:(MNEGW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(SUBconst [c] m:(MUL _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(SUBconst [c] m:(MULW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(SUBconst [c] m:(MNEG _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } + // match: (ADD a p:(SUBconst [c] m:(MNEGW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (ADD a m)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + continue + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) { + continue + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + break + } // match: (ADD x (NEG y)) // result: (SUB x y) for { @@ -16411,6 +16604,174 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool { v.AddArg3(a, x, y) return true } + // match: (SUB a p:(ADDconst [c] m:(MUL _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(ADDconst [c] m:(MULW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(ADDconst [c] m:(MNEG _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(ADDconst [c] m:(MNEGW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (ADDconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64ADDconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64ADDconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(SUBconst [c] m:(MUL _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(SUBconst [c] m:(MULW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(SUBconst [c] m:(MNEG _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } + // match: (SUB a p:(SUBconst [c] m:(MNEGW _ _))) + // cond: p.Uses==1 && m.Uses==1 + // result: (SUBconst [c] (SUB a m)) + for { + a := v_0 + p := v_1 + if p.Op != OpARM64SUBconst { + break + } + c := auxIntToInt64(p.AuxInt) + m := p.Args[0] + if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) { + break + } + v.reset(OpARM64SUBconst) + v.AuxInt = int64ToAuxInt(c) + v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type) + v0.AddArg2(a, m) + v.AddArg(v0) + return true + } // match: (SUB x x) // result: (MOVDconst [0]) for { -- 2.48.1