From af668c689c66588f8adb9f5cd6db812706536338 Mon Sep 17 00:00:00 2001 From: Wayne Zuo Date: Tue, 6 Sep 2022 11:43:28 +0800 Subject: [PATCH] cmd/compile: fold constant shift with extension on riscv64 For example: movb a0, a0 srai $1, a0, a0 the assembler will expand to: slli $56, a0, a0 srai $56, a0, a0 srai $1, a0, a0 this CL optimize to: slli $56, a0, a0 srai $57, a0, a0 Remove 270+ instructions from Go binary on linux/riscv64. Change-Id: I375e19f9d3bd54f2781791d8cbe5970191297dc8 Reviewed-on: https://go-review.googlesource.com/c/go/+/428496 Reviewed-by: Keith Randall Run-TryBot: Wayne Zuo Reviewed-by: Joel Sing Reviewed-by: Cherry Mui TryBot-Result: Gopher Robot --- .../internal/ssa/_gen/RISCV64latelower.rules | 19 ++ src/cmd/compile/internal/ssa/config.go | 1 + .../internal/ssa/rewriteRISCV64latelower.go | 253 ++++++++++++++++++ test/codegen/shift.go | 6 +- 4 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules create mode 100644 src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules new file mode 100644 index 0000000000..c44a837bbf --- /dev/null +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64latelower.rules @@ -0,0 +1,19 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fold constant shift with extension. +(SRAI [c] (MOVBreg x)) && c < 8 => (SRAI [56+c] (SLLI [56] x)) +(SRAI [c] (MOVHreg x)) && c < 16 => (SRAI [48+c] (SLLI [48] x)) +(SRAI [c] (MOVWreg x)) && c < 32 => (SRAI [32+c] (SLLI [32] x)) +(SRLI [c] (MOVBUreg x)) && c < 8 => (SRLI [56+c] (SLLI [56] x)) +(SRLI [c] (MOVHUreg x)) && c < 16 => (SRLI [48+c] (SLLI [48] x)) +(SRLI [c] (MOVWUreg x)) && c < 32 => (SRLI [32+c] (SLLI [32] x)) +(SLLI [c] (MOVBUreg x)) && c <= 56 => (SRLI [56-c] (SLLI [56] x)) +(SLLI [c] (MOVHUreg x)) && c <= 48 => (SRLI [48-c] (SLLI [48] x)) +(SLLI [c] (MOVWUreg x)) && c <= 32 => (SRLI [32-c] (SLLI [32] x)) + +// Shift by zero. +(SRAI [0] x) => x +(SRLI [0] x) => x +(SLLI [0] x) => x diff --git a/src/cmd/compile/internal/ssa/config.go b/src/cmd/compile/internal/ssa/config.go index 0ad2d94dce..5f39a6dfb3 100644 --- a/src/cmd/compile/internal/ssa/config.go +++ b/src/cmd/compile/internal/ssa/config.go @@ -309,6 +309,7 @@ func NewConfig(arch string, types Types, ctxt *obj.Link, optimize, softfloat boo c.RegSize = 8 c.lowerBlock = rewriteBlockRISCV64 c.lowerValue = rewriteValueRISCV64 + c.lateLowerValue = rewriteValueRISCV64latelower c.registers = registersRISCV64[:] c.gpRegMask = gpRegMaskRISCV64 c.fpRegMask = fpRegMaskRISCV64 diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go b/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go new file mode 100644 index 0000000000..bde0164644 --- /dev/null +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64latelower.go @@ -0,0 +1,253 @@ +// Code generated from gen/RISCV64latelower.rules; DO NOT EDIT. +// generated with: cd gen; go run *.go + +package ssa + +func rewriteValueRISCV64latelower(v *Value) bool { + switch v.Op { + case OpRISCV64SLLI: + return rewriteValueRISCV64latelower_OpRISCV64SLLI(v) + case OpRISCV64SRAI: + return rewriteValueRISCV64latelower_OpRISCV64SRAI(v) + case OpRISCV64SRLI: + return rewriteValueRISCV64latelower_OpRISCV64SRLI(v) + } + return false +} +func rewriteValueRISCV64latelower_OpRISCV64SLLI(v *Value) bool { + v_0 := v.Args[0] + b := v.Block + // match: (SLLI [c] (MOVBUreg x)) + // cond: c <= 56 + // result: (SRLI [56-c] (SLLI [56] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBUreg { + break + } + x := v_0.Args[0] + if !(c <= 56) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(56 - c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(56) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SLLI [c] (MOVHUreg x)) + // cond: c <= 48 + // result: (SRLI [48-c] (SLLI [48] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHUreg { + break + } + x := v_0.Args[0] + if !(c <= 48) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(48 - c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(48) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SLLI [c] (MOVWUreg x)) + // cond: c <= 32 + // result: (SRLI [32-c] (SLLI [32] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + x := v_0.Args[0] + if !(c <= 32) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(32 - c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(32) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SLLI [0] x) + // result: x + for { + if auxIntToInt64(v.AuxInt) != 0 { + break + } + x := v_0 + v.copyOf(x) + return true + } + return false +} +func rewriteValueRISCV64latelower_OpRISCV64SRAI(v *Value) bool { + v_0 := v.Args[0] + b := v.Block + // match: (SRAI [c] (MOVBreg x)) + // cond: c < 8 + // result: (SRAI [56+c] (SLLI [56] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBreg { + break + } + x := v_0.Args[0] + if !(c < 8) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(56 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(56) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRAI [c] (MOVHreg x)) + // cond: c < 16 + // result: (SRAI [48+c] (SLLI [48] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHreg { + break + } + x := v_0.Args[0] + if !(c < 16) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(48 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(48) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRAI [c] (MOVWreg x)) + // cond: c < 32 + // result: (SRAI [32+c] (SLLI [32] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWreg { + break + } + x := v_0.Args[0] + if !(c < 32) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(32 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(32) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRAI [0] x) + // result: x + for { + if auxIntToInt64(v.AuxInt) != 0 { + break + } + x := v_0 + v.copyOf(x) + return true + } + return false +} +func rewriteValueRISCV64latelower_OpRISCV64SRLI(v *Value) bool { + v_0 := v.Args[0] + b := v.Block + // match: (SRLI [c] (MOVBUreg x)) + // cond: c < 8 + // result: (SRLI [56+c] (SLLI [56] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBUreg { + break + } + x := v_0.Args[0] + if !(c < 8) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(56 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(56) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRLI [c] (MOVHUreg x)) + // cond: c < 16 + // result: (SRLI [48+c] (SLLI [48] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHUreg { + break + } + x := v_0.Args[0] + if !(c < 16) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(48 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(48) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRLI [c] (MOVWUreg x)) + // cond: c < 32 + // result: (SRLI [32+c] (SLLI [32] x)) + for { + t := v.Type + c := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + x := v_0.Args[0] + if !(c < 32) { + break + } + v.reset(OpRISCV64SRLI) + v.AuxInt = int64ToAuxInt(32 + c) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(32) + v0.AddArg(x) + v.AddArg(v0) + return true + } + // match: (SRLI [0] x) + // result: x + for { + if auxIntToInt64(v.AuxInt) != 0 { + break + } + x := v_0 + v.copyOf(x) + return true + } + return false +} +func rewriteBlockRISCV64latelower(b *Block) bool { + return false +} diff --git a/test/codegen/shift.go b/test/codegen/shift.go index c82566bb10..4a9f5d4356 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -34,21 +34,21 @@ func rshConst64x64(v int64) int64 { func lshConst32x64(v int32) int32 { // ppc64:"SLW" // ppc64le:"SLW" - // riscv64:"SLLI",-"AND",-"SLTIU" + // riscv64:"SLLI",-"AND",-"SLTIU", -"MOVW" return v << uint64(29) } func rshConst32Ux64(v uint32) uint32 { // ppc64:"SRW" // ppc64le:"SRW" - // riscv64:"SRLI",-"AND",-"SLTIU" + // riscv64:"SRLI",-"AND",-"SLTIU", -"MOVW" return v >> uint64(29) } func rshConst32x64(v int32) int32 { // ppc64:"SRAW" // ppc64le:"SRAW" - // riscv64:"SRAI",-"OR",-"SLTIU" + // riscv64:"SRAI",-"OR",-"SLTIU", -"MOVW" return v >> uint64(29) } -- 2.50.0