]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: let MADD/MSUB combination happen more often on arm64
authorKeith Randall <khr@golang.org>
Tue, 19 Nov 2024 17:15:19 +0000 (09:15 -0800)
committerKeith Randall <khr@golang.org>
Wed, 20 Nov 2024 16:30:01 +0000 (16:30 +0000)
We have a single-instruction x+y*z op. Unfortunately x can't be
a constant, so the rule that builds them doesn't apply in that case.

This CL handles x+(c+y*z) by reordering to c+(x+y*z) so x is
in the right place.

Change-Id: Ibed621607d49da70474128e20991e0c4630ebfad
Reviewed-on: https://go-review.googlesource.com/c/go/+/629858
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@google.com>
src/cmd/compile/internal/ssa/_gen/ARM64.rules
src/cmd/compile/internal/ssa/rewriteARM64.go

index 29dc258d9e37674a7bc13701601c66715dad2544..070329a53989ef9de7cde4d7fa0d20871c91d0e2 100644 (file)
 (ADD a l:(MNEGW x y)) && v.Type.Size() <= 4 && l.Uses==1 && clobber(l) => (MSUBW a x y)
 (SUB a l:(MNEGW x y)) && v.Type.Size() <= 4 && l.Uses==1 && clobber(l) => (MADDW a x y)
 
+// madd/msub can't take constant arguments, so do a bit of reordering if a non-constant is available.
+(ADD a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (ADD <v.Type> a m))
+(ADD a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (ADD <v.Type> a m))
+(SUB a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (SUB <v.Type> a m))
+(SUB a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (SUB <v.Type> a m))
+
 // optimize ADCSflags, SBCSflags and friends
 (ADCSflags x y (Select1 <types.TypeFlags> (ADDSconstflags [-1] (ADCzerocarry <typ.UInt64> c)))) => (ADCSflags x y c)
 (ADCSflags x y (Select1 <types.TypeFlags> (ADDSconstflags [-1] (MOVDconst [0])))) => (ADDSflags x y)
index edb17cedf3875fd46cfe538dadd473f50e386d30..ab838e66350bffdbda2b5b095b4a20a0fcaf02c2 100644 (file)
@@ -1225,6 +1225,7 @@ func rewriteValueARM64_OpARM64ADCSflags(v *Value) bool {
 func rewriteValueARM64_OpARM64ADD(v *Value) bool {
        v_1 := v.Args[1]
        v_0 := v.Args[0]
+       b := v.Block
        // match: (ADD x (MOVDconst <t> [c]))
        // cond: !t.IsPtr()
        // result: (ADDconst [c] x)
@@ -1330,6 +1331,198 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
+       // match: (ADD a p:(ADDconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64ADDconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64ADDconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(ADDconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64ADDconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64ADDconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(ADDconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64ADDconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64ADDconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(ADDconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64ADDconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64ADDconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(SUBconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64SUBconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64SUBconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(SUBconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64SUBconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64SUBconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(SUBconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64SUBconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64SUBconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
+       // match: (ADD a p:(SUBconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (ADD <v.Type> a m))
+       for {
+               for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+                       a := v_0
+                       p := v_1
+                       if p.Op != OpARM64SUBconst {
+                               continue
+                       }
+                       c := auxIntToInt64(p.AuxInt)
+                       m := p.Args[0]
+                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                               continue
+                       }
+                       v.reset(OpARM64SUBconst)
+                       v.AuxInt = int64ToAuxInt(c)
+                       v0 := b.NewValue0(v.Pos, OpARM64ADD, v.Type)
+                       v0.AddArg2(a, m)
+                       v.AddArg(v0)
+                       return true
+               }
+               break
+       }
        // match: (ADD x (NEG y))
        // result: (SUB x y)
        for {
@@ -16411,6 +16604,174 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg3(a, x, y)
                return true
        }
+       // match: (SUB a p:(ADDconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64ADDconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64ADDconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(ADDconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64ADDconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64ADDconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(ADDconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64ADDconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64ADDconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(ADDconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (ADDconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64ADDconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64ADDconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(SUBconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64SUBconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64SUBconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(SUBconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64SUBconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64SUBconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(SUBconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64SUBconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64SUBconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
+       // match: (SUB a p:(SUBconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1
+       // result: (SUBconst [c] (SUB <v.Type> a m))
+       for {
+               a := v_0
+               p := v_1
+               if p.Op != OpARM64SUBconst {
+                       break
+               }
+               c := auxIntToInt64(p.AuxInt)
+               m := p.Args[0]
+               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                       break
+               }
+               v.reset(OpARM64SUBconst)
+               v.AuxInt = int64ToAuxInt(c)
+               v0 := b.NewValue0(v.Pos, OpARM64SUB, v.Type)
+               v0.AddArg2(a, m)
+               v.AddArg(v0)
+               return true
+       }
        // match: (SUB x x)
        // result: (MOVDconst [0])
        for {