]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: implement float min/max in hardware for riscv64
authorMeng Zhuo <mzh@golangcn.org>
Tue, 1 Aug 2023 14:17:02 +0000 (22:17 +0800)
committerM Zhuo <mzh@golangcn.org>
Fri, 26 Jan 2024 01:41:50 +0000 (01:41 +0000)
CL 514596 adds float min/max for amd64, this CL adds it for riscv64.

The behavior of the RISC-V FMIN/FMAX instructions almost match Go's
requirements.

However according to RISCV spec 8.3 "NaN Generation and Propagation"
>> if at least one input is a signaling NaN, or if both inputs are quiet
>> NaNs, the result is the canonical NaN. If one operand is a quiet NaN
>> and the other is not a NaN, the result is the non-NaN operand.

Go using quiet NaN as NaN and according to Go spec
>> if any argument is a NaN, the result is a NaN

This requires the float min/max implementation to check whether one
of operand is qNaN before float mix/max actually execute.

This CL also fix a typo in minmax test.

Benchmark on Visionfive2
goos: linux
goarch: riscv64
pkg: runtime
         │ float_minmax.old.bench │       float_minmax.new.bench        │
         │         sec/op         │   sec/op     vs base                │
MinFloat             158.20n ± 0%   28.13n ± 0%  -82.22% (p=0.000 n=10)
MaxFloat             158.10n ± 0%   28.12n ± 0%  -82.21% (p=0.000 n=10)
geomean               158.1n        28.12n       -82.22%

Update #59488

Change-Id: Iab48be6d32b8882044fb8c821438ca8840e5493d
Reviewed-on: https://go-review.googlesource.com/c/go/+/514775
Reviewed-by: Mauri de Souza Meneguzzo <mauri870@gmail.com>
Run-TryBot: M Zhuo <mengzhuo1203@gmail.com>
Reviewed-by: Joel Sing <joel@sing.id.au>
Reviewed-by: Cherry Mui <cherryyz@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
src/cmd/compile/internal/riscv64/ssa.go
src/cmd/compile/internal/ssa/_gen/RISCV64.rules
src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go
src/cmd/compile/internal/ssa/opGen.go
src/cmd/compile/internal/ssa/rewriteRISCV64.go
src/cmd/compile/internal/ssagen/ssa.go
src/runtime/minmax_test.go

index 22338188e5202fd8698adf396ed6875bb84d96a2..caca504d28460618b8c3a07cdadb5e633b2c0873 100644 (file)
@@ -297,6 +297,72 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
                p.Reg = r1
                p.To.Type = obj.TYPE_REG
                p.To.Reg = r
+
+       case ssa.OpRISCV64LoweredFMAXD, ssa.OpRISCV64LoweredFMIND, ssa.OpRISCV64LoweredFMAXS, ssa.OpRISCV64LoweredFMINS:
+               // Most of FMIN/FMAX result match Go's required behaviour, unless one of the
+               // inputs is a NaN. As such, we need to explicitly test for NaN
+               // before using FMIN/FMAX.
+
+               // FADD Rarg0, Rarg1, Rout // FADD is used to propagate a NaN to the result in these cases.
+               // FEQ  Rarg0, Rarg0, Rtmp
+               // BEQZ Rtmp, end
+               // FEQ  Rarg1, Rarg1, Rtmp
+               // BEQZ Rtmp, end
+               // F(MIN | MAX)
+
+               r0 := v.Args[0].Reg()
+               r1 := v.Args[1].Reg()
+               out := v.Reg()
+               add, feq := riscv.AFADDD, riscv.AFEQD
+               if v.Op == ssa.OpRISCV64LoweredFMAXS || v.Op == ssa.OpRISCV64LoweredFMINS {
+                       add = riscv.AFADDS
+                       feq = riscv.AFEQS
+               }
+
+               p1 := s.Prog(add)
+               p1.From.Type = obj.TYPE_REG
+               p1.From.Reg = r0
+               p1.Reg = r1
+               p1.To.Type = obj.TYPE_REG
+               p1.To.Reg = out
+
+               p2 := s.Prog(feq)
+               p2.From.Type = obj.TYPE_REG
+               p2.From.Reg = r0
+               p2.Reg = r0
+               p2.To.Type = obj.TYPE_REG
+               p2.To.Reg = riscv.REG_TMP
+
+               p3 := s.Prog(riscv.ABEQ)
+               p3.From.Type = obj.TYPE_REG
+               p3.From.Reg = riscv.REG_ZERO
+               p3.Reg = riscv.REG_TMP
+               p3.To.Type = obj.TYPE_BRANCH
+
+               p4 := s.Prog(feq)
+               p4.From.Type = obj.TYPE_REG
+               p4.From.Reg = r1
+               p4.Reg = r1
+               p4.To.Type = obj.TYPE_REG
+               p4.To.Reg = riscv.REG_TMP
+
+               p5 := s.Prog(riscv.ABEQ)
+               p5.From.Type = obj.TYPE_REG
+               p5.From.Reg = riscv.REG_ZERO
+               p5.Reg = riscv.REG_TMP
+               p5.To.Type = obj.TYPE_BRANCH
+
+               p6 := s.Prog(v.Op.Asm())
+               p6.From.Type = obj.TYPE_REG
+               p6.From.Reg = r1
+               p6.Reg = r0
+               p6.To.Type = obj.TYPE_REG
+               p6.To.Reg = out
+
+               nop := s.Prog(obj.ANOP)
+               p3.To.SetTarget(nop)
+               p5.To.SetTarget(nop)
+
        case ssa.OpRISCV64LoweredMuluhilo:
                r0 := v.Args[0].Reg()
                r1 := v.Args[1].Reg()
index fc206c42d3d7fe6d33f8ee5fbc2d4ca1eeb66140..4fef20a565137d66d4962d4efe9674c597a51752 100644 (file)
@@ -72,6 +72,9 @@
 
 (FMA ...) => (FMADDD ...)
 
+(Min(64|32)F ...) => (LoweredFMIN(D|S) ...)
+(Max(64|32)F ...) => (LoweredFMAX(D|S) ...)
+
 // Sign and zero extension.
 
 (SignExt8to16  ...) => (MOVBreg ...)
index 93f20f8a99a138f2ed0a47ac85d39bd60600558c..9ce6450166ead401a702da5c668cd8333c65326d 100644 (file)
@@ -429,6 +429,8 @@ func init() {
                {name: "FNES", argLength: 2, reg: fp2gp, asm: "FNES", commutative: true},                                                            // arg0 != arg1
                {name: "FLTS", argLength: 2, reg: fp2gp, asm: "FLTS"},                                                                               // arg0 < arg1
                {name: "FLES", argLength: 2, reg: fp2gp, asm: "FLES"},                                                                               // arg0 <= arg1
+               {name: "LoweredFMAXS", argLength: 2, reg: fp21, resultNotInArgs: true, asm: "FMAXS", commutative: true, typ: "Float32"},             // max(arg0, arg1)
+               {name: "LoweredFMINS", argLength: 2, reg: fp21, resultNotInArgs: true, asm: "FMINS", commutative: true, typ: "Float32"},             // min(arg0, arg1)
 
                // D extension.
                {name: "FADDD", argLength: 2, reg: fp21, asm: "FADDD", commutative: true, typ: "Float64"},                                           // arg0 + arg1
@@ -456,6 +458,8 @@ func init() {
                {name: "FNED", argLength: 2, reg: fp2gp, asm: "FNED", commutative: true},                                                            // arg0 != arg1
                {name: "FLTD", argLength: 2, reg: fp2gp, asm: "FLTD"},                                                                               // arg0 < arg1
                {name: "FLED", argLength: 2, reg: fp2gp, asm: "FLED"},                                                                               // arg0 <= arg1
+               {name: "LoweredFMIND", argLength: 2, reg: fp21, resultNotInArgs: true, asm: "FMIND", commutative: true, typ: "Float64"},             // min(arg0, arg1)
+               {name: "LoweredFMAXD", argLength: 2, reg: fp21, resultNotInArgs: true, asm: "FMAXD", commutative: true, typ: "Float64"},             // max(arg0, arg1)
        }
 
        RISCV64blocks := []blockData{
index c552832520ec436a82b1669f24591dd98879bbd5..5a2ca1a42472e5fa823a84a2dfe54c9f24d28479 100644 (file)
@@ -2464,6 +2464,8 @@ const (
        OpRISCV64FNES
        OpRISCV64FLTS
        OpRISCV64FLES
+       OpRISCV64LoweredFMAXS
+       OpRISCV64LoweredFMINS
        OpRISCV64FADDD
        OpRISCV64FSUBD
        OpRISCV64FMULD
@@ -2489,6 +2491,8 @@ const (
        OpRISCV64FNED
        OpRISCV64FLTD
        OpRISCV64FLED
+       OpRISCV64LoweredFMIND
+       OpRISCV64LoweredFMAXD
 
        OpS390XFADDS
        OpS390XFADD
@@ -33072,6 +33076,38 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:            "LoweredFMAXS",
+               argLen:          2,
+               commutative:     true,
+               resultNotInArgs: true,
+               asm:             riscv.AFMAXS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:            "LoweredFMINS",
+               argLen:          2,
+               commutative:     true,
+               resultNotInArgs: true,
+               asm:             riscv.AFMINS,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
        {
                name:        "FADDD",
                argLen:      2,
@@ -33426,6 +33462,38 @@ var opcodeTable = [...]opInfo{
                        },
                },
        },
+       {
+               name:            "LoweredFMIND",
+               argLen:          2,
+               commutative:     true,
+               resultNotInArgs: true,
+               asm:             riscv.AFMIND,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
+       {
+               name:            "LoweredFMAXD",
+               argLen:          2,
+               commutative:     true,
+               resultNotInArgs: true,
+               asm:             riscv.AFMAXD,
+               reg: regInfo{
+                       inputs: []inputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                               {1, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+                       outputs: []outputInfo{
+                               {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31
+                       },
+               },
+       },
 
        {
                name:         "FADDS",
index 52ddca1c7d5e9ed099b7c433d465d6804497135d..cf86572b8d67000ec472d1ee6200e430c258ea82 100644 (file)
@@ -326,6 +326,18 @@ func rewriteValueRISCV64(v *Value) bool {
                return rewriteValueRISCV64_OpLsh8x64(v)
        case OpLsh8x8:
                return rewriteValueRISCV64_OpLsh8x8(v)
+       case OpMax32F:
+               v.Op = OpRISCV64LoweredFMAXS
+               return true
+       case OpMax64F:
+               v.Op = OpRISCV64LoweredFMAXD
+               return true
+       case OpMin32F:
+               v.Op = OpRISCV64LoweredFMINS
+               return true
+       case OpMin64F:
+               v.Op = OpRISCV64LoweredFMIND
+               return true
        case OpMod16:
                return rewriteValueRISCV64_OpMod16(v)
        case OpMod16u:
index df933ec1cfa1fe23385afa39a702f70ec1d159d5..3e72a275542246deb253d4c5f80f8f04ebac85d2 100644 (file)
@@ -3700,7 +3700,7 @@ func (s *state) minMax(n *ir.CallExpr) *ssa.Value {
 
                if typ.IsFloat() {
                        switch Arch.LinkArch.Family {
-                       case sys.AMD64, sys.ARM64:
+                       case sys.AMD64, sys.ARM64, sys.RISCV64:
                                var op ssa.Op
                                switch {
                                case typ.Kind() == types.TFLOAT64 && n.Op() == ir.OMIN:
index e0bc28fbf62b608bc40eca80259ab7fa107a380a..1f815a84c313dfaef2b6250eb3c0e18d5686b31f 100644 (file)
@@ -66,10 +66,10 @@ func TestMaxFloat(t *testing.T) {
        }
        for _, x := range all {
                if z := max(nan, x); !math.IsNaN(z) {
-                       t.Errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
+                       t.Errorf("max(%v, %v) = %v, want %v", nan, x, z, nan)
                }
                if z := max(x, nan); !math.IsNaN(z) {
-                       t.Errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
+                       t.Errorf("max(%v, %v) = %v, want %v", nan, x, z, nan)
                }
        }
 }
@@ -127,3 +127,21 @@ func TestMinMaxStringTies(t *testing.T) {
        test(2, 0, 1)
        test(2, 1, 0)
 }
+
+func BenchmarkMinFloat(b *testing.B) {
+       var m float64 = 0
+       for i := 0; i < b.N; i++ {
+               for _, f := range all {
+                       m = min(m, f)
+               }
+       }
+}
+
+func BenchmarkMaxFloat(b *testing.B) {
+       var m float64 = 0
+       for i := 0; i < b.N; i++ {
+               for _, f := range all {
+                       m = max(m, f)
+               }
+       }
+}