From f6b47110952ea1c19cbdc040489c83f306c36e73 Mon Sep 17 00:00:00 2001 From: David Chase Date: Thu, 9 Oct 2025 15:12:47 -0400 Subject: [PATCH] [dev.simd] cmd/compile, simd: add rewrite to convert logical expression trees into TERNLOG instructions includes tests of both rewrite application and rewrite correctness Change-Id: I7983ccf87a8408af95bb6c447cb22f01beda9f61 Reviewed-on: https://go-review.googlesource.com/c/go/+/710697 LUCI-TryBot-Result: Go LUCI Reviewed-by: Junyang Shao --- src/cmd/compile/internal/ssa/compile.go | 1 + src/cmd/compile/internal/ssa/rewritetern.go | 292 +++++++++++++++++++ src/cmd/compile/internal/ssa/tern_helpers.go | 160 ++++++++++ src/simd/genfiles.go | 155 ++++++++++ src/simd/internal/simd_test/simd_test.go | 78 +++++ test/simd.go | 12 +- 6 files changed, 697 insertions(+), 1 deletion(-) create mode 100644 src/cmd/compile/internal/ssa/rewritetern.go create mode 100644 src/cmd/compile/internal/ssa/tern_helpers.go diff --git a/src/cmd/compile/internal/ssa/compile.go b/src/cmd/compile/internal/ssa/compile.go index be1a6f158e..372d238a1c 100644 --- a/src/cmd/compile/internal/ssa/compile.go +++ b/src/cmd/compile/internal/ssa/compile.go @@ -486,6 +486,7 @@ var passes = [...]pass{ {name: "insert resched checks", fn: insertLoopReschedChecks, disabled: !buildcfg.Experiment.PreemptibleLoops}, // insert resched checks in loops. {name: "cpufeatures", fn: cpufeatures, required: buildcfg.Experiment.SIMD, disabled: !buildcfg.Experiment.SIMD}, + {name: "rewrite tern", fn: rewriteTern, required: false, disabled: !buildcfg.Experiment.SIMD}, {name: "lower", fn: lower, required: true}, {name: "addressing modes", fn: addressingModes, required: false}, {name: "late lower", fn: lateLower, required: true}, diff --git a/src/cmd/compile/internal/ssa/rewritetern.go b/src/cmd/compile/internal/ssa/rewritetern.go new file mode 100644 index 0000000000..5493e5f109 --- /dev/null +++ b/src/cmd/compile/internal/ssa/rewritetern.go @@ -0,0 +1,292 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "fmt" + "internal/goarch" + "slices" +) + +var truthTableValues [3]uint8 = [3]uint8{0b1111_0000, 0b1100_1100, 0b1010_1010} + +func (slop SIMDLogicalOP) String() string { + if slop == sloInterior { + return "leaf" + } + interior := "" + if slop&sloInterior != 0 { + interior = "+interior" + } + switch slop &^ sloInterior { + case sloAnd: + return "and" + interior + case sloXor: + return "xor" + interior + case sloOr: + return "or" + interior + case sloAndNot: + return "andNot" + interior + case sloNot: + return "not" + interior + } + return "wrong" +} + +func rewriteTern(f *Func) { + if f.maxCPUFeatures == CPUNone { + return + } + + arch := f.Config.Ctxt().Arch.Family + // TODO there are other SIMD architectures + if arch != goarch.AMD64 { + return + } + + boolExprTrees := make(map[*Value]SIMDLogicalOP) + + // Find logical-expr expression trees, including leaves. + // interior nodes will be marked sloInterior, + // root nodes will not be marked sloInterior, + // leaf nodes are only marked sloInterior. + for _, b := range f.Blocks { + for _, v := range b.Values { + slo := classifyBooleanSIMD(v) + switch slo { + case sloOr, + sloAndNot, + sloXor, + sloAnd: + boolExprTrees[v.Args[1]] |= sloInterior + fallthrough + case sloNot: + boolExprTrees[v.Args[0]] |= sloInterior + boolExprTrees[v] |= slo + } + } + } + + // get a canonical sorted set of roots + var roots []*Value + for v, slo := range boolExprTrees { + if f.pass.debug > 1 { + f.Warnl(v.Pos, "%s has SLO %v", v.LongString(), slo) + } + + if slo&sloInterior == 0 && v.Block.CPUfeatures.hasFeature(CPUavx512) { + roots = append(roots, v) + } + } + slices.SortFunc(roots, func(u, v *Value) int { return int(u.ID - v.ID) }) // IDs are small enough to not care about overflow. + + // This rewrite works by iterating over the root set. + // For each boolean expression, it walks the expression + // bottom up accumulating sets of variables mentioned in + // subexpressions, lazy-greedily finding the largest subexpressions + // of 3 inputs that can be rewritten to use ternary-truth-table instructions. + + // rewrite recursively attempts to replace v and v's subexpressions with + // ternary-logic truth-table operations, returning a set of not more than 3 + // subexpressions within v that may be combined into a parent's replacement. + // V need not have the CPU features that allow a ternary-logic operation; + // in that case, v will not be rewritten. Replacements also require + // exactly 3 different variable inputs to a boolean expression. + // + // Given the CPU feature and 3 inputs, v is replaced in the following + // cases: + // + // 1) v is a root + // 2) u = NOT(v) and u lacks the CPU feature + // 3) u = OP(v, w) and u lacks the CPU feature + // 4) u = OP(v, w) and u has more than 3 variable inputs. var rewrite func(v *Value) [3]*Value + var rewrite func(v *Value) [3]*Value + + // computeTT returns the truth table for a boolean expression + // over the variables in vars, where vars[0] varies slowest in + // the truth table and vars[2] varies fastest. + // e.g. computeTT( "and(x, or(y, not(z)))", {x,y,z} ) returns + // (bit 0 first) 0 0 0 0 1 0 1 1 = (reversed) 1101_0000 = 0xD0 + // x: 0 0 0 0 1 1 1 1 + // y: 0 0 1 1 0 0 1 1 + // z: 0 1 0 1 0 1 0 1 + var computeTT func(v *Value, vars [3]*Value) uint8 + + // combine two sets of variables into one, returning ok/not + // if the two sets contained 3 or fewer elements. Combine + // ensures that the sets of Values never contain duplicates. + // (Duplicates would create less-efficient code, not incorrect code.) + combine := func(a, b [3]*Value) ([3]*Value, bool) { + var c [3]*Value + i := 0 + for _, v := range a { + if v == nil { + break + } + c[i] = v + i++ + } + bloop: + for _, v := range b { + if v == nil { + break + } + for _, u := range a { + if v == u { + continue bloop + } + } + if i == 3 { + return [3]*Value{}, false + } + c[i] = v + i++ + } + return c, true + } + + computeTT = func(v *Value, vars [3]*Value) uint8 { + i := 0 + for ; i < len(vars); i++ { + if vars[i] == v { + return truthTableValues[i] + } + } + slo := boolExprTrees[v] &^ sloInterior + a := computeTT(v.Args[0], vars) + switch slo { + case sloNot: + return ^a + case sloAnd: + return a & computeTT(v.Args[1], vars) + case sloXor: + return a ^ computeTT(v.Args[1], vars) + case sloOr: + return a | computeTT(v.Args[1], vars) + case sloAndNot: + return a & ^computeTT(v.Args[1], vars) + } + panic("switch should have covered all cases, or unknown var in logical expression") + } + + replace := func(a0 *Value, vars0 [3]*Value) { + imm := computeTT(a0, vars0) + op := ternOpForLogical(a0.Op) + if op == a0.Op { + panic(fmt.Errorf("should have mapped away from input op, a0 is %s", a0.LongString())) + } + if f.pass.debug > 0 { + f.Warnl(a0.Pos, "Rewriting %s into %v of 0b%b %v %v %v", a0.LongString(), op, imm, + vars0[0], vars0[1], vars0[2]) + } + a0.reset(op) + a0.SetArgs3(vars0[0], vars0[1], vars0[2]) + a0.AuxInt = int64(int8(imm)) + } + + // addOne ensures the no-duplicates addition of a single value + // to a set that is not full. It seems possible that a shared + // subexpression in tricky combination with blocks lacking the + // AVX512 feature might permit this. + addOne := func(vars [3]*Value, v *Value) [3]*Value { + if vars[2] != nil { + panic("rewriteTern.addOne, vars[2] should be nil") + } + if v == vars[0] || v == vars[1] { + return vars + } + if vars[1] == nil { + vars[1] = v + } else { + vars[2] = v + } + return vars + } + + rewrite = func(v *Value) [3]*Value { + slo := boolExprTrees[v] + if slo == sloInterior { // leaf node, i.e., a "variable" + return [3]*Value{v, nil, nil} + } + var vars [3]*Value + hasFeature := v.Block.CPUfeatures.hasFeature(CPUavx512) + if slo&sloNot == sloNot { + vars = rewrite(v.Args[0]) + if !hasFeature { + if vars[2] != nil { + replace(v.Args[0], vars) + return [3]*Value{v, nil, nil} + } + return vars + } + } else { + var ok bool + a0, a1 := v.Args[0], v.Args[1] + vars0 := rewrite(a0) + vars1 := rewrite(a1) + vars, ok = combine(vars0, vars1) + + if f.pass.debug > 1 { + f.Warnl(a0.Pos, "combine(%v, %v) -> %v, %v", vars0, vars1, vars, ok) + } + + if !(ok && v.Block.CPUfeatures.hasFeature(CPUavx512)) { + // too many variables, or cannot rewrite current values. + // rewrite one or both subtrees if possible + if vars0[2] != nil && a0.Block.CPUfeatures.hasFeature(CPUavx512) { + replace(a0, vars0) + } + if vars1[2] != nil && a1.Block.CPUfeatures.hasFeature(CPUavx512) { + replace(a1, vars1) + } + + // 3-element var arrays are either rewritten, or unable to be rewritten + // because of the features in effect in their block. Either way, they + // are treated as a "new var" if 3 elements are present. + + if vars0[2] == nil { + if vars1[2] == nil { + // both subtrees are 2-element and were not rewritten. + // + // TODO a clever person would look at subtrees of inputs, + // e.g. rewrite + // ((a AND b) XOR b) XOR (d XOR (c AND d)) + // to (((a AND b) XOR b) XOR d) XOR (c AND d) + // to v = TERNLOG(truthtable, a, b, d) XOR (c AND d) + // and return the variable set {v, c, d} + // + // But for now, just restart with a0 and a1. + return [3]*Value{a0, a1, nil} + } else { + // a1 (maybe) rewrote, a0 has room for another var + vars = addOne(vars0, a1) + } + } else if vars1[2] == nil { + // a0 (maybe) rewrote, a1 has room for another var + vars = addOne(vars1, a0) + } else if !ok { + // both (maybe) rewrote + // a0 and a1 are different because otherwise their variable + // sets would have combined "ok". + return [3]*Value{a0, a1, nil} + } + // continue with either the vars from "ok" or the updated set of vars. + } + } + // if root and 3 vars and hasFeature, rewrite. + if slo&sloInterior == 0 && vars[2] != nil && hasFeature { + replace(v, vars) + return [3]*Value{v, nil, nil} + } + return vars + } + + for _, v := range roots { + if f.pass.debug > 1 { + f.Warnl(v.Pos, "SLO root %s", v.LongString()) + } + rewrite(v) + } +} diff --git a/src/cmd/compile/internal/ssa/tern_helpers.go b/src/cmd/compile/internal/ssa/tern_helpers.go new file mode 100644 index 0000000000..3ffc980c33 --- /dev/null +++ b/src/cmd/compile/internal/ssa/tern_helpers.go @@ -0,0 +1,160 @@ +// Code generated by 'go run genfiles.go'; DO NOT EDIT. + +package ssa + +type SIMDLogicalOP uint8 + +const ( + // boolean simd operations, for reducing expression to VPTERNLOG* instructions + // sloInterior is set for non-root nodes in logical-op expression trees. + // the operations are even-numbered. + sloInterior SIMDLogicalOP = 1 + sloNone SIMDLogicalOP = 2 * iota + sloAnd + sloOr + sloAndNot + sloXor + sloNot +) + +func classifyBooleanSIMD(v *Value) SIMDLogicalOP { + switch v.Op { + case OpAndInt8x16, OpAndInt16x8, OpAndInt32x4, OpAndInt64x2, OpAndInt8x32, OpAndInt16x16, OpAndInt32x8, OpAndInt64x4, OpAndInt8x64, OpAndInt16x32, OpAndInt32x16, OpAndInt64x8: + return sloAnd + + case OpOrInt8x16, OpOrInt16x8, OpOrInt32x4, OpOrInt64x2, OpOrInt8x32, OpOrInt16x16, OpOrInt32x8, OpOrInt64x4, OpOrInt8x64, OpOrInt16x32, OpOrInt32x16, OpOrInt64x8: + return sloOr + + case OpAndNotInt8x16, OpAndNotInt16x8, OpAndNotInt32x4, OpAndNotInt64x2, OpAndNotInt8x32, OpAndNotInt16x16, OpAndNotInt32x8, OpAndNotInt64x4, OpAndNotInt8x64, OpAndNotInt16x32, OpAndNotInt32x16, OpAndNotInt64x8: + return sloAndNot + case OpXorInt8x16: + if y := v.Args[1]; y.Op == OpEqualInt8x16 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt16x8: + if y := v.Args[1]; y.Op == OpEqualInt16x8 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt32x4: + if y := v.Args[1]; y.Op == OpEqualInt32x4 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt64x2: + if y := v.Args[1]; y.Op == OpEqualInt64x2 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt8x32: + if y := v.Args[1]; y.Op == OpEqualInt8x32 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt16x16: + if y := v.Args[1]; y.Op == OpEqualInt16x16 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt32x8: + if y := v.Args[1]; y.Op == OpEqualInt32x8 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt64x4: + if y := v.Args[1]; y.Op == OpEqualInt64x4 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt8x64: + if y := v.Args[1]; y.Op == OpEqualInt8x64 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt16x32: + if y := v.Args[1]; y.Op == OpEqualInt16x32 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt32x16: + if y := v.Args[1]; y.Op == OpEqualInt32x16 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + case OpXorInt64x8: + if y := v.Args[1]; y.Op == OpEqualInt64x8 && + y.Args[0] == y.Args[1] { + return sloNot + } + return sloXor + + } + return sloNone +} + +func ternOpForLogical(op Op) Op { + switch op { + case OpAndInt8x16, OpOrInt8x16, OpXorInt8x16, OpAndNotInt8x16: + return OpternInt32x4 + case OpAndUint8x16, OpOrUint8x16, OpXorUint8x16, OpAndNotUint8x16: + return OpternUint32x4 + case OpAndInt16x8, OpOrInt16x8, OpXorInt16x8, OpAndNotInt16x8: + return OpternInt32x4 + case OpAndUint16x8, OpOrUint16x8, OpXorUint16x8, OpAndNotUint16x8: + return OpternUint32x4 + case OpAndInt32x4, OpOrInt32x4, OpXorInt32x4, OpAndNotInt32x4: + return OpternInt32x4 + case OpAndUint32x4, OpOrUint32x4, OpXorUint32x4, OpAndNotUint32x4: + return OpternUint32x4 + case OpAndInt64x2, OpOrInt64x2, OpXorInt64x2, OpAndNotInt64x2: + return OpternInt64x2 + case OpAndUint64x2, OpOrUint64x2, OpXorUint64x2, OpAndNotUint64x2: + return OpternUint64x2 + case OpAndInt8x32, OpOrInt8x32, OpXorInt8x32, OpAndNotInt8x32: + return OpternInt32x8 + case OpAndUint8x32, OpOrUint8x32, OpXorUint8x32, OpAndNotUint8x32: + return OpternUint32x8 + case OpAndInt16x16, OpOrInt16x16, OpXorInt16x16, OpAndNotInt16x16: + return OpternInt32x8 + case OpAndUint16x16, OpOrUint16x16, OpXorUint16x16, OpAndNotUint16x16: + return OpternUint32x8 + case OpAndInt32x8, OpOrInt32x8, OpXorInt32x8, OpAndNotInt32x8: + return OpternInt32x8 + case OpAndUint32x8, OpOrUint32x8, OpXorUint32x8, OpAndNotUint32x8: + return OpternUint32x8 + case OpAndInt64x4, OpOrInt64x4, OpXorInt64x4, OpAndNotInt64x4: + return OpternInt64x4 + case OpAndUint64x4, OpOrUint64x4, OpXorUint64x4, OpAndNotUint64x4: + return OpternUint64x4 + case OpAndInt8x64, OpOrInt8x64, OpXorInt8x64, OpAndNotInt8x64: + return OpternInt32x16 + case OpAndUint8x64, OpOrUint8x64, OpXorUint8x64, OpAndNotUint8x64: + return OpternUint32x16 + case OpAndInt16x32, OpOrInt16x32, OpXorInt16x32, OpAndNotInt16x32: + return OpternInt32x16 + case OpAndUint16x32, OpOrUint16x32, OpXorUint16x32, OpAndNotUint16x32: + return OpternUint32x16 + case OpAndInt32x16, OpOrInt32x16, OpXorInt32x16, OpAndNotInt32x16: + return OpternInt32x16 + case OpAndUint32x16, OpOrUint32x16, OpXorUint32x16, OpAndNotUint32x16: + return OpternUint32x16 + case OpAndInt64x8, OpOrInt64x8, OpXorInt64x8, OpAndNotInt64x8: + return OpternInt64x8 + case OpAndUint64x8, OpOrUint64x8, OpXorUint64x8, OpAndNotUint64x8: + return OpternUint64x8 + + } + return op +} diff --git a/src/simd/genfiles.go b/src/simd/genfiles.go index 80234ac9f8..be23b127c8 100644 --- a/src/simd/genfiles.go +++ b/src/simd/genfiles.go @@ -254,6 +254,15 @@ package simd `, s) } +func ssaPrologue(s string, out io.Writer) { + fmt.Fprintf(out, + `// Code generated by '%s'; DO NOT EDIT. + +package ssa + +`, s) +} + func unsafePrologue(s string, out io.Writer) { fmt.Fprintf(out, `// Code generated by '%s'; DO NOT EDIT. @@ -806,6 +815,7 @@ func (x {{.VType}}) String() string { `) const TD = "internal/simd_test/" +const SSA = "../cmd/compile/internal/ssa/" func main() { sl := flag.String("sl", "slice_gen_amd64.go", "file name for slice operations") @@ -867,6 +877,115 @@ func main() { if *cmh != "" { one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate) } + + nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical) + +} + +func ternOpForLogical(out io.Writer) { + fmt.Fprintf(out, ` +func ternOpForLogical(op Op) Op { + switch op { +`) + + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + wt, ct := w, c + if wt < 32 { + wt = 32 + ct = (w * c) / wt + } + fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct) + fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct) + }, out) + + fmt.Fprintf(out, ` + } + return op +} +`) + +} + +func classifyBooleanSIMD(out io.Writer) { + fmt.Fprintf(out, ` +type SIMDLogicalOP uint8 +const ( + // boolean simd operations, for reducing expression to VPTERNLOG* instructions + // sloInterior is set for non-root nodes in logical-op expression trees. + sloInterior SIMDLogicalOP = 1 + sloNone SIMDLogicalOP = 2 * iota + sloAnd + sloOr + sloAndNot + sloXor + sloNot +) +func classifyBooleanSIMD(v *Value) SIMDLogicalOP { + switch v.Op { + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "And" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloAnd + + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "Or" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloOr + + case `) + intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) { + op := "AndNot" + if seq > 0 { + fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c) + } else { + fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c) + } + seq++ + }, out) + + fmt.Fprintf(out, `: + return sloAndNot +`) + + // "Not" is encoded as x.Xor(x.Equal(x).AsInt8x16()) + // i.e. xor.Args[0] == x, xor.Args[1].Op == As... + // but AsInt8x16 is a pun/passthrough. + + intShapes.forAllShapes( + func(seq int, t, upperT string, w, c int, out io.Writer) { + fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c) + fmt.Fprintf(out, ` + if y := v.Args[1]; y.Op == OpEqual%s%dx%d && + y.Args[0] == y.Args[1] { + return sloNot + } + `, upperT, w, c) + fmt.Fprintf(out, "return sloXor\n") + }, out) + + fmt.Fprintf(out, ` + } + return sloNone +} +`) } // numberLines takes a slice of bytes, and returns a string where each line @@ -881,6 +1000,42 @@ func numberLines(data []byte) string { return buf.String() } +func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) { + if filename == "" { + return + } + + ofile := os.Stdout + + if filename != "-" { + var err error + ofile, err = os.Create(filename) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err) + os.Exit(1) + } + } + + out := new(bytes.Buffer) + + prologue("go run genfiles.go", out) + for _, rewrite := range rewrites { + rewrite(out) + } + + b, err := format.Source(out.Bytes()) + if err != nil { + fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err) + fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes())) + fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err) + os.Exit(1) + } else { + ofile.Write(b) + ofile.Close() + } + +} + func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) { if filename == "" { return diff --git a/src/simd/internal/simd_test/simd_test.go b/src/simd/internal/simd_test/simd_test.go index 295f7bf9ce..c64ac0fcfd 100644 --- a/src/simd/internal/simd_test/simd_test.go +++ b/src/simd/internal/simd_test/simd_test.go @@ -1030,3 +1030,81 @@ func TestString(t *testing.T) { t.Logf("y=%s", y) t.Logf("z=%s", z) } + +// a returns an slice of 16 int32 +func a() []int32 { + return make([]int32, 16, 16) +} + +// applyTo3 returns a 16-element slice of the results of +// applying f to the respective elements of vectors x, y, and z. +func applyTo3(x, y, z simd.Int32x16, f func(x, y, z int32) int32) []int32 { + ax, ay, az := a(), a(), a() + x.StoreSlice(ax) + y.StoreSlice(ay) + z.StoreSlice(az) + + r := a() + for i := range r { + r[i] = f(ax[i], ay[i], az[i]) + } + return r +} + +// applyTo3 returns a 16-element slice of the results of +// applying f to the respective elements of vectors x, y, z, and w. +func applyTo4(x, y, z, w simd.Int32x16, f func(x, y, z, w int32) int32) []int32 { + ax, ay, az, aw := a(), a(), a(), a() + x.StoreSlice(ax) + y.StoreSlice(ay) + z.StoreSlice(az) + w.StoreSlice(aw) + + r := make([]int32, len(ax), len(ax)) + for i := range r { + r[i] = f(ax[i], ay[i], az[i], aw[i]) + } + return r +} + +func TestSelectTernOptInt32x16(t *testing.T) { + if !simd.HasAVX512() { + t.Skip("Test requires HasAVX512, not available on this hardware") + return + } + ax := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1} + ay := []int32{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1} + az := []int32{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1} + aw := []int32{0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1} + am := []int32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + + x := simd.LoadInt32x16Slice(ax) + y := simd.LoadInt32x16Slice(ay) + z := simd.LoadInt32x16Slice(az) + w := simd.LoadInt32x16Slice(aw) + m := simd.LoadInt32x16Slice(am) + + foo := func(v simd.Int32x16, s []int32) { + r := make([]int32, 16, 16) + v.StoreSlice(r) + checkSlices[int32](t, r, s) + } + + t0 := w.Xor(y).Xor(z) + ft0 := func(w, y, z int32) int32 { + return w ^ y ^ z + } + foo(t0, applyTo3(w, y, z, ft0)) + + t1 := m.And(w.Xor(y).Xor(z.Not())) + ft1 := func(m, w, y, z int32) int32 { + return m & (w ^ y ^ ^z) + } + foo(t1, applyTo4(m, w, y, z, ft1)) + + t2 := x.Xor(y).Xor(z).And(x.Xor(y).Xor(z.Not())) + ft2 := func(x, y, z int32) int32 { + return (x ^ y ^ z) & (x ^ y ^ ^z) + } + foo(t2, applyTo3(x, y, z, ft2)) +} diff --git a/test/simd.go b/test/simd.go index b1695fa514..32ed70d39a 100644 --- a/test/simd.go +++ b/test/simd.go @@ -1,4 +1,4 @@ -// errorcheck -0 -d=ssa/cpufeatures/debug=1 +// errorcheck -0 -d=ssa/cpufeatures/debug=1,ssa/rewrite_tern/debug=1 //go:build goexperiment.simd && amd64 @@ -95,3 +95,13 @@ b: c: println("c") } + +func ternRewrite(m, w, x, y, z simd.Int32x16) (t0, t1, t2 simd.Int32x16) { + if !simd.HasAVX512() { // ERROR "has features avx[+]avx2[+]avx512$" + return // ERROR "has features avx[+]avx2[+]avx512$" // all blocks have it because of the vector size + } + t0 = w.Xor(y).Xor(z) // ERROR "Rewriting.*ternInt" + t1 = m.And(w.Xor(y).Xor(z.Not())) // ERROR "Rewriting.*ternInt" + t2 = x.Xor(y).Xor(z).And(x.Xor(y).Xor(z.Not())) // ERROR "Rewriting.*ternInt" + return // ERROR "has features avx[+]avx2[+]avx512$" +} -- 2.52.0