--- /dev/null
+module simd/_gen
+
+go 1.24
+
+require (
+ golang.org/x/arch v0.20.0
+ gopkg.in/yaml.v3 v3.0.1
+)
--- /dev/null
+golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
+golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
--- /dev/null
+testdata/*
+.gemini/*
+.gemini*
--- /dev/null
+# Hand-written toy input like -xedPath would generate.
+# This input can be substituted for -xedPath.
+!sum
+- asm: ADDPS
+ goarch: amd64
+ feature: "SSE2"
+ in:
+ - asmPos: 0
+ class: vreg
+ base: float
+ elemBits: 32
+ bits: 128
+ - asmPos: 1
+ class: vreg
+ base: float
+ elemBits: 32
+ bits: 128
+ out:
+ - asmPos: 0
+ class: vreg
+ base: float
+ elemBits: 32
+ bits: 128
+
+- asm: ADDPD
+ goarch: amd64
+ feature: "SSE2"
+ in:
+ - asmPos: 0
+ class: vreg
+ base: float
+ elemBits: 64
+ bits: 128
+ - asmPos: 1
+ class: vreg
+ base: float
+ elemBits: 64
+ bits: 128
+ out:
+ - asmPos: 0
+ class: vreg
+ base: float
+ elemBits: 64
+ bits: 128
+
+- asm: PADDB
+ goarch: amd64
+ feature: "SSE2"
+ in:
+ - asmPos: 0
+ class: vreg
+ base: int|uint
+ elemBits: 32
+ bits: 128
+ - asmPos: 1
+ class: vreg
+ base: int|uint
+ elemBits: 32
+ bits: 128
+ out:
+ - asmPos: 0
+ class: vreg
+ base: int|uint
+ elemBits: 32
+ bits: 128
+
+- asm: VPADDB
+ goarch: amd64
+ feature: "AVX"
+ in:
+ - asmPos: 1
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 128
+ - asmPos: 2
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 128
+ out:
+ - asmPos: 0
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 128
+
+- asm: VPADDB
+ goarch: amd64
+ feature: "AVX2"
+ in:
+ - asmPos: 1
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 256
+ - asmPos: 2
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 256
+ out:
+ - asmPos: 0
+ class: vreg
+ base: int|uint
+ elemBits: 8
+ bits: 256
--- /dev/null
+!import ops/*/categories.yaml
--- /dev/null
+#!/bin/bash -x
+
+cat <<\\EOF
+
+This is an end-to-end test of Go SIMD. It checks out a fresh Go
+repository from the go.simd branch, then generates the SIMD input
+files and runs simdgen writing into the fresh repository.
+
+After that it generates the modified ssa pattern matching files, then
+builds the compiler.
+
+\EOF
+
+rm -rf go-test
+git clone https://go.googlesource.com/go -b dev.simd go-test
+go run . -xedPath xeddata -o godefs -goroot ./go-test go.yaml types.yaml categories.yaml
+(cd go-test/src/cmd/compile/internal/ssa/_gen ; go run *.go )
+(cd go-test/src ; GOEXPERIMENT=simd ./make.bash )
+(cd go-test/bin; b=`pwd` ; cd ../src/simd/testdata; GOARCH=amd64 $b/go run .)
+(cd go-test/bin; b=`pwd` ; cd ../src ;
+GOEXPERIMENT=simd GOARCH=amd64 $b/go test -v simd
+GOEXPERIMENT=simd $b/go test go/doc
+GOEXPERIMENT=simd $b/go test go/build
+GOEXPERIMENT=simd $b/go test cmd/api -v -check
+$b/go test go/doc
+$b/go test go/build
+$b/go test cmd/api -v -check
+
+$b/go test cmd/compile/internal/ssagen -simd=0
+GOEXPERIMENT=simd $b/go test cmd/compile/internal/ssagen -simd=0
+)
+
+# next, add some tests of SIMD itself
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+)
+
+const simdGenericOpsTmpl = `
+package main
+
+func simdGenericOps() []opData {
+ return []opData{
+{{- range .Ops }}
+ {name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}},
+{{- end }}
+{{- range .OpsImm }}
+ {name: "{{.OpName}}", argLength: {{.OpInLen}}, commutative: {{.Comm}}, aux: "UInt8"},
+{{- end }}
+ }
+}
+`
+
+// writeSIMDGenericOps generates the generic ops and writes it to simdAMD64ops.go
+// within the specified directory.
+func writeSIMDGenericOps(ops []Operation) *bytes.Buffer {
+ t := templateOf(simdGenericOpsTmpl, "simdgenericOps")
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(generatedHeader)
+
+ type genericOpsData struct {
+ OpName string
+ OpInLen int
+ Comm bool
+ }
+ type opData struct {
+ Ops []genericOpsData
+ OpsImm []genericOpsData
+ }
+ var opsData opData
+ for _, op := range ops {
+ if op.NoGenericOps != nil && *op.NoGenericOps == "true" {
+ continue
+ }
+ _, _, _, immType, gOp := op.shape()
+ gOpData := genericOpsData{gOp.GenericName(), len(gOp.In), op.Commutative}
+ if immType == VarImm || immType == ConstVarImm {
+ opsData.OpsImm = append(opsData.OpsImm, gOpData)
+ } else {
+ opsData.Ops = append(opsData.Ops, gOpData)
+ }
+ }
+ sort.Slice(opsData.Ops, func(i, j int) bool {
+ return compareNatural(opsData.Ops[i].OpName, opsData.Ops[j].OpName) < 0
+ })
+ sort.Slice(opsData.OpsImm, func(i, j int) bool {
+ return compareNatural(opsData.OpsImm[i].OpName, opsData.OpsImm[j].OpName) < 0
+ })
+
+ err := t.Execute(buffer, opsData)
+ if err != nil {
+ panic(fmt.Errorf("failed to execute template: %w", err))
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "slices"
+)
+
+const simdIntrinsicsTmpl = `
+{{define "header"}}
+package ssagen
+
+import (
+ "cmd/compile/internal/ir"
+ "cmd/compile/internal/ssa"
+ "cmd/compile/internal/types"
+ "cmd/internal/sys"
+)
+
+const simdPackage = "` + simdPackage + `"
+
+func simdIntrinsics(addF func(pkg, fn string, b intrinsicBuilder, archFamilies ...sys.ArchFamily)) {
+{{end}}
+
+{{define "op1"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen1(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op2"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op2_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op2_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op3"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op3_21"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op3_21Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_21(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op3_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op3_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen3_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op4"}} addF(simdPackage, "{{(index .In 0).Go}}.{{.Go}}", opLen4(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op4_231Type1"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4_231(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op4_31"}} addF(simdPackage, "{{(index .In 2).Go}}.{{.Go}}", opLen4_31(ssa.Op{{.GenericName}}, {{.SSAType}}), sys.AMD64)
+{{end}}
+{{define "op1Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen1Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+{{define "op2Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+{{define "op2Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen2Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+{{define "op3Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+{{define "op3Imm8_2I"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen3Imm8_2I(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+{{define "op4Imm8"}} addF(simdPackage, "{{(index .In 1).Go}}.{{.Go}}", opLen4Imm8(ssa.Op{{.GenericName}}, {{.SSAType}}, {{(index .In 0).ImmOffset}}), sys.AMD64)
+{{end}}
+
+{{define "vectorConversion"}} addF(simdPackage, "{{.Tsrc.Name}}.As{{.Tdst.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
+{{end}}
+
+{{define "loadStore"}} addF(simdPackage, "Load{{.Name}}", simdLoad(), sys.AMD64)
+ addF(simdPackage, "{{.Name}}.Store", simdStore(), sys.AMD64)
+{{end}}
+
+{{define "maskedLoadStore"}} addF(simdPackage, "LoadMasked{{.Name}}", simdMaskedLoad(ssa.OpLoadMasked{{.ElemBits}}), sys.AMD64)
+ addF(simdPackage, "{{.Name}}.StoreMasked", simdMaskedStore(ssa.OpStoreMasked{{.ElemBits}}), sys.AMD64)
+{{end}}
+
+{{define "mask"}} addF(simdPackage, "{{.Name}}.As{{.VectorCounterpart}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
+ addF(simdPackage, "{{.VectorCounterpart}}.As{{.Name}}", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return args[0] }, sys.AMD64)
+ addF(simdPackage, "{{.Name}}.And", opLen2(ssa.OpAnd{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
+ addF(simdPackage, "{{.Name}}.Or", opLen2(ssa.OpOr{{.ReshapedVectorWithAndOr}}, types.TypeVec{{.Size}}), sys.AMD64)
+ addF(simdPackage, "Load{{.Name}}FromBits", simdLoadMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
+ addF(simdPackage, "{{.Name}}.StoreToBits", simdStoreMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
+ addF(simdPackage, "{{.Name}}FromBits", simdCvtVToMask({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
+ addF(simdPackage, "{{.Name}}.ToBits", simdCvtMaskToV({{.ElemBits}}, {{.Lanes}}), sys.AMD64)
+{{end}}
+
+{{define "footer"}}}
+{{end}}
+`
+
+// writeSIMDIntrinsics generates the intrinsic mappings and writes it to simdintrinsics.go
+// within the specified directory.
+func writeSIMDIntrinsics(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
+ t := templateOf(simdIntrinsicsTmpl, "simdintrinsics")
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(generatedHeader)
+
+ if err := t.ExecuteTemplate(buffer, "header", nil); err != nil {
+ panic(fmt.Errorf("failed to execute header template: %w", err))
+ }
+
+ slices.SortFunc(ops, compareOperations)
+
+ for _, op := range ops {
+ if op.NoTypes != nil && *op.NoTypes == "true" {
+ continue
+ }
+ if s, op, err := classifyOp(op); err == nil {
+ if err := t.ExecuteTemplate(buffer, s, op); err != nil {
+ panic(fmt.Errorf("failed to execute template %s for op %s: %w", s, op.Go, err))
+ }
+
+ } else {
+ panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
+ }
+ }
+
+ for _, conv := range vConvertFromTypeMap(typeMap) {
+ if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
+ panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
+ }
+ }
+
+ for _, typ := range typesFromTypeMap(typeMap) {
+ if typ.Type != "mask" {
+ if err := t.ExecuteTemplate(buffer, "loadStore", typ); err != nil {
+ panic(fmt.Errorf("failed to execute loadStore template: %w", err))
+ }
+ }
+ }
+
+ for _, typ := range typesFromTypeMap(typeMap) {
+ if typ.MaskedLoadStoreFilter() {
+ if err := t.ExecuteTemplate(buffer, "maskedLoadStore", typ); err != nil {
+ panic(fmt.Errorf("failed to execute maskedLoadStore template: %w", err))
+ }
+ }
+ }
+
+ for _, mask := range masksFromTypeMap(typeMap) {
+ if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
+ panic(fmt.Errorf("failed to execute mask template: %w", err))
+ }
+ }
+
+ if err := t.ExecuteTemplate(buffer, "footer", nil); err != nil {
+ panic(fmt.Errorf("failed to execute footer template: %w", err))
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+)
+
+const simdMachineOpsTmpl = `
+package main
+
+func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw regInfo) []opData {
+ return []opData{
+{{- range .OpsData }}
+ {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
+{{- end }}
+{{- range .OpsDataImm }}
+ {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
+{{- end }}
+ }
+}
+`
+
+// writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
+// within the specified directory.
+func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
+ t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(generatedHeader)
+
+ type opData struct {
+ OpName string
+ Asm string
+ OpInLen int
+ RegInfo string
+ Comm bool
+ Type string
+ ResultInArg0 bool
+ }
+ type machineOpsData struct {
+ OpsData []opData
+ OpsDataImm []opData
+ }
+ seen := map[string]struct{}{}
+ regInfoSet := map[string]bool{
+ "v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
+ "w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true}
+ opsData := make([]opData, 0)
+ opsDataImm := make([]opData, 0)
+ for _, op := range ops {
+ shapeIn, shapeOut, maskType, _, gOp := op.shape()
+ asm := machineOpName(maskType, gOp)
+
+ // TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
+ // one here with a name suffix "Merging". The rewrite rules will need them.
+ if _, ok := seen[asm]; ok {
+ continue
+ }
+ seen[asm] = struct{}{}
+ regInfo, err := op.regShape()
+ if err != nil {
+ panic(err)
+ }
+ idx, err := checkVecAsScalar(op)
+ if err != nil {
+ panic(err)
+ }
+ if idx != -1 {
+ if regInfo == "v21" {
+ regInfo = "vfpv"
+ } else if regInfo == "v2kv" {
+ regInfo = "vfpkv"
+ } else {
+ panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regInfo, op))
+ }
+ }
+ // Makes AVX512 operations use upper registers
+ if strings.Contains(op.CPUFeature, "AVX512") {
+ regInfo = strings.ReplaceAll(regInfo, "v", "w")
+ }
+ if _, ok := regInfoSet[regInfo]; !ok {
+ panic(fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
+ }
+ var outType string
+ if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
+ // If class overwrite is happening, that's not really a mask but a vreg.
+ outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
+ } else if shapeOut == OneGregOut {
+ outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
+ } else if shapeOut == OneKmaskOut {
+ outType = "Mask"
+ } else {
+ panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
+ }
+ resultInArg0 := false
+ if shapeOut == OneVregOutAtIn {
+ resultInArg0 = true
+ }
+ if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
+ opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
+ } else {
+ opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
+ }
+ }
+ sort.Slice(opsData, func(i, j int) bool {
+ return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
+ })
+ sort.Slice(opsDataImm, func(i, j int) bool {
+ return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
+ })
+ err := t.Execute(buffer, machineOpsData{opsData, opsDataImm})
+ if err != nil {
+ panic(fmt.Errorf("failed to execute template: %w", err))
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "cmp"
+ "fmt"
+ "maps"
+ "slices"
+ "sort"
+ "strings"
+)
+
+type simdType struct {
+ Name string // The go type name of this simd type, for example Int32x4.
+ Lanes int // The number of elements in this vector/mask.
+ Base string // The element's type, like for Int32x4 it will be int32.
+ Fields string // The struct fields, it should be right formatted.
+ Type string // Either "mask" or "vreg"
+ VectorCounterpart string // For mask use only: just replacing the "Mask" in [simdType.Name] with "Int"
+ ReshapedVectorWithAndOr string // For mask use only: vector AND and OR are only available in some shape with element width 32.
+ Size int // The size of the vector type
+}
+
+func (x simdType) ElemBits() int {
+ return x.Size / x.Lanes
+}
+
+// LanesContainer returns the smallest int/uint bit size that is
+// large enough to hold one bit for each lane. E.g., Mask32x4
+// is 4 lanes, and a uint8 is the smallest uint that has 4 bits.
+func (x simdType) LanesContainer() int {
+ if x.Lanes > 64 {
+ panic("too many lanes")
+ }
+ if x.Lanes > 32 {
+ return 64
+ }
+ if x.Lanes > 16 {
+ return 32
+ }
+ if x.Lanes > 8 {
+ return 16
+ }
+ return 8
+}
+
+// MaskedLoadStoreFilter encodes which simd type type currently
+// get masked loads/stores generated, it is used in two places,
+// this forces coordination.
+func (x simdType) MaskedLoadStoreFilter() bool {
+ return x.Size == 512 || x.ElemBits() >= 32 && x.Type != "mask"
+}
+
+func (x simdType) IntelSizeSuffix() string {
+ switch x.ElemBits() {
+ case 8:
+ return "B"
+ case 16:
+ return "W"
+ case 32:
+ return "D"
+ case 64:
+ return "Q"
+ }
+ panic("oops")
+}
+
+func (x simdType) MaskedLoadDoc() string {
+ if x.Size == 512 || x.ElemBits() < 32 {
+ return fmt.Sprintf("// Asm: VMOVDQU%d.Z, CPU Feature: AVX512", x.ElemBits())
+ } else {
+ return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
+ }
+}
+
+func (x simdType) MaskedStoreDoc() string {
+ if x.Size == 512 || x.ElemBits() < 32 {
+ return fmt.Sprintf("// Asm: VMOVDQU%d, CPU Feature: AVX512", x.ElemBits())
+ } else {
+ return fmt.Sprintf("// Asm: VMASKMOV%s, CPU Feature: AVX2", x.IntelSizeSuffix())
+ }
+}
+
+func compareSimdTypes(x, y simdType) int {
+ // "vreg" then "mask"
+ if c := -compareNatural(x.Type, y.Type); c != 0 {
+ return c
+ }
+ // want "flo" < "int" < "uin" (and then 8 < 16 < 32 < 64),
+ // not "int16" < "int32" < "int64" < "int8")
+ // so limit comparison to first 3 bytes in string.
+ if c := compareNatural(x.Base[:3], y.Base[:3]); c != 0 {
+ return c
+ }
+ // base type size, 8 < 16 < 32 < 64
+ if c := x.ElemBits() - y.ElemBits(); c != 0 {
+ return c
+ }
+ // vector size last
+ return x.Size - y.Size
+}
+
+type simdTypeMap map[int][]simdType
+
+type simdTypePair struct {
+ Tsrc simdType
+ Tdst simdType
+}
+
+func compareSimdTypePairs(x, y simdTypePair) int {
+ c := compareSimdTypes(x.Tsrc, y.Tsrc)
+ if c != 0 {
+ return c
+ }
+ return compareSimdTypes(x.Tdst, y.Tdst)
+}
+
+const simdPackageHeader = generatedHeader + `
+//go:build goexperiment.simd
+
+package simd
+`
+
+const simdTypesTemplates = `
+{{define "sizeTmpl"}}
+// v{{.}} is a tag type that tells the compiler that this is really {{.}}-bit SIMD
+type v{{.}} struct {
+ _{{.}} struct{}
+}
+{{end}}
+
+{{define "typeTmpl"}}
+// {{.Name}} is a {{.Size}}-bit SIMD vector of {{.Lanes}} {{.Base}}
+type {{.Name}} struct {
+{{.Fields}}
+}
+
+{{end}}
+`
+
+const simdFeaturesTemplate = `
+import "internal/cpu"
+
+{{range .}}
+{{- if eq .Feature "AVX512"}}
+// Has{{.Feature}} returns whether the CPU supports the AVX512F+CD+BW+DQ+VL features.
+//
+// These five CPU features are bundled together, and no use of AVX-512
+// is allowed unless all of these features are supported together.
+// Nearly every CPU that has shipped with any support for AVX-512 has
+// supported all five of these features.
+{{- else -}}
+// Has{{.Feature}} returns whether the CPU supports the {{.Feature}} feature.
+{{- end}}
+//
+// Has{{.Feature}} is defined on all GOARCHes, but will only return true on
+// GOARCH {{.GoArch}}.
+func Has{{.Feature}}() bool {
+ return cpu.X86.Has{{.Feature}}
+}
+{{end}}
+`
+
+const simdLoadStoreTemplate = `
+// Len returns the number of elements in a {{.Name}}
+func (x {{.Name}}) Len() int { return {{.Lanes}} }
+
+// Load{{.Name}} loads a {{.Name}} from an array
+//
+//go:noescape
+func Load{{.Name}}(y *[{{.Lanes}}]{{.Base}}) {{.Name}}
+
+// Store stores a {{.Name}} to an array
+//
+//go:noescape
+func (x {{.Name}}) Store(y *[{{.Lanes}}]{{.Base}})
+`
+
+const simdMaskFromBitsTemplate = `
+// Load{{.Name}}FromBits constructs a {{.Name}} from a bitmap, where 1 means set for the indexed element, 0 means unset.
+// Only the lower {{.Lanes}} bits of y are used.
+//
+// CPU Features: AVX512
+//go:noescape
+func Load{{.Name}}FromBits(y *uint64) {{.Name}}
+
+// StoreToBits stores a {{.Name}} as a bitmap, where 1 means set for the indexed element, 0 means unset.
+// Only the lower {{.Lanes}} bits of y are used.
+//
+// CPU Features: AVX512
+//go:noescape
+func (x {{.Name}}) StoreToBits(y *uint64)
+`
+
+const simdMaskFromValTemplate = `
+// {{.Name}}FromBits constructs a {{.Name}} from a bitmap value, where 1 means set for the indexed element, 0 means unset.
+// Only the lower {{.Lanes}} bits of y are used.
+//
+// Asm: KMOV{{.IntelSizeSuffix}}, CPU Feature: AVX512
+func {{.Name}}FromBits(y uint{{.LanesContainer}}) {{.Name}}
+
+// ToBits constructs a bitmap from a {{.Name}}, where 1 means set for the indexed element, 0 means unset.
+// Only the lower {{.Lanes}} bits of y are used.
+//
+// Asm: KMOV{{.IntelSizeSuffix}}, CPU Features: AVX512
+func (x {{.Name}}) ToBits() uint{{.LanesContainer}}
+`
+
+const simdMaskedLoadStoreTemplate = `
+// LoadMasked{{.Name}} loads a {{.Name}} from an array,
+// at those elements enabled by mask
+//
+{{.MaskedLoadDoc}}
+//
+//go:noescape
+func LoadMasked{{.Name}}(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}}) {{.Name}}
+
+// StoreMasked stores a {{.Name}} to an array,
+// at those elements enabled by mask
+//
+{{.MaskedStoreDoc}}
+//
+//go:noescape
+func (x {{.Name}}) StoreMasked(y *[{{.Lanes}}]{{.Base}}, mask Mask{{.ElemBits}}x{{.Lanes}})
+`
+
+const simdStubsTmpl = `
+{{define "op1"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}() {{.GoType}}
+{{end}}
+
+{{define "op2"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}) {{.GoType}}
+{{end}}
+
+{{define "op2_21"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
+{{end}}
+
+{{define "op2_21Type1"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}) {{.GoType}}
+{{end}}
+
+{{define "op3"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op3_31"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op3_21"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op3_21Type1"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op0NameAndType "y"}}, {{.Op2NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op3_231Type1"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op2VecAsScalar"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}) {{(index .Out 0).Go}}
+{{end}}
+
+{{define "op3VecAsScalar"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}(y uint{{(index .In 1).TreatLikeAScalarOfSize}}, {{.Op2NameAndType "z"}}) {{(index .Out 0).Go}}
+{{end}}
+
+{{define "op4"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op0NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op2NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
+{{end}}
+
+{{define "op4_231Type1"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
+{{end}}
+
+{{define "op4_31"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op2NameAndType "x"}}) {{.Go}}({{.Op1NameAndType "y"}}, {{.Op0NameAndType "z"}}, {{.Op3NameAndType "u"}}) {{.GoType}}
+{{end}}
+
+{{define "op1Imm8"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8) {{.GoType}}
+{{end}}
+
+{{define "op2Imm8"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}) {{.GoType}}
+{{end}}
+
+{{define "op2Imm8_2I"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8) {{.GoType}}
+{{end}}
+
+
+{{define "op3Imm8"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+{{define "op3Imm8_2I"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.Op2NameAndType "y"}}, {{.ImmName}} uint8, {{.Op3NameAndType "z"}}) {{.GoType}}
+{{end}}
+
+
+{{define "op4Imm8"}}
+{{if .Documentation}}{{.Documentation}}
+//{{end}}
+// {{.ImmName}} results in better performance when it's a constant, a non-constant value will be translated into a jump table.
+//
+// Asm: {{.Asm}}, CPU Feature: {{.CPUFeature}}
+func ({{.Op1NameAndType "x"}}) {{.Go}}({{.ImmName}} uint8, {{.Op2NameAndType "y"}}, {{.Op3NameAndType "z"}}, {{.Op4NameAndType "u"}}) {{.GoType}}
+{{end}}
+
+{{define "vectorConversion"}}
+// {{.Tdst.Name}} converts from {{.Tsrc.Name}} to {{.Tdst.Name}}
+func (from {{.Tsrc.Name}}) As{{.Tdst.Name}}() (to {{.Tdst.Name}})
+{{end}}
+
+{{define "mask"}}
+// converts from {{.Name}} to {{.VectorCounterpart}}
+func (from {{.Name}}) As{{.VectorCounterpart}}() (to {{.VectorCounterpart}})
+
+// converts from {{.VectorCounterpart}} to {{.Name}}
+func (from {{.VectorCounterpart}}) As{{.Name}}() (to {{.Name}})
+
+func (x {{.Name}}) And(y {{.Name}}) {{.Name}}
+
+func (x {{.Name}}) Or(y {{.Name}}) {{.Name}}
+{{end}}
+`
+
+// parseSIMDTypes groups go simd types by their vector sizes, and
+// returns a map whose key is the vector size, value is the simd type.
+func parseSIMDTypes(ops []Operation) simdTypeMap {
+ // TODO: maybe instead of going over ops, let's try go over types.yaml.
+ ret := map[int][]simdType{}
+ seen := map[string]struct{}{}
+ processArg := func(arg Operand) {
+ if arg.Class == "immediate" || arg.Class == "greg" {
+ // Immediates are not encoded as vector types.
+ return
+ }
+ if _, ok := seen[*arg.Go]; ok {
+ return
+ }
+ seen[*arg.Go] = struct{}{}
+
+ lanes := *arg.Lanes
+ base := fmt.Sprintf("%s%d", *arg.Base, *arg.ElemBits)
+ tagFieldNameS := fmt.Sprintf("%sx%d", base, lanes)
+ tagFieldS := fmt.Sprintf("%s v%d", tagFieldNameS, *arg.Bits)
+ valFieldS := fmt.Sprintf("vals%s[%d]%s", strings.Repeat(" ", len(tagFieldNameS)-3), lanes, base)
+ fields := fmt.Sprintf("\t%s\n\t%s", tagFieldS, valFieldS)
+ if arg.Class == "mask" {
+ vectorCounterpart := strings.ReplaceAll(*arg.Go, "Mask", "Int")
+ reshapedVectorWithAndOr := fmt.Sprintf("Int32x%d", *arg.Bits/32)
+ ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, vectorCounterpart, reshapedVectorWithAndOr, *arg.Bits})
+ // In case the vector counterpart of a mask is not present, put its vector counterpart typedef into the map as well.
+ if _, ok := seen[vectorCounterpart]; !ok {
+ seen[vectorCounterpart] = struct{}{}
+ ret[*arg.Bits] = append(ret[*arg.Bits], simdType{vectorCounterpart, lanes, base, fields, "vreg", "", "", *arg.Bits})
+ }
+ } else {
+ ret[*arg.Bits] = append(ret[*arg.Bits], simdType{*arg.Go, lanes, base, fields, arg.Class, "", "", *arg.Bits})
+ }
+ }
+ for _, op := range ops {
+ for _, arg := range op.In {
+ processArg(arg)
+ }
+ for _, arg := range op.Out {
+ processArg(arg)
+ }
+ }
+ return ret
+}
+
+func vConvertFromTypeMap(typeMap simdTypeMap) []simdTypePair {
+ v := []simdTypePair{}
+ for _, ts := range typeMap {
+ for i, tsrc := range ts {
+ for j, tdst := range ts {
+ if i != j && tsrc.Type == tdst.Type && tsrc.Type == "vreg" &&
+ tsrc.Lanes > 1 && tdst.Lanes > 1 {
+ v = append(v, simdTypePair{tsrc, tdst})
+ }
+ }
+ }
+ }
+ slices.SortFunc(v, compareSimdTypePairs)
+ return v
+}
+
+func masksFromTypeMap(typeMap simdTypeMap) []simdType {
+ m := []simdType{}
+ for _, ts := range typeMap {
+ for _, tsrc := range ts {
+ if tsrc.Type == "mask" {
+ m = append(m, tsrc)
+ }
+ }
+ }
+ slices.SortFunc(m, compareSimdTypes)
+ return m
+}
+
+func typesFromTypeMap(typeMap simdTypeMap) []simdType {
+ m := []simdType{}
+ for _, ts := range typeMap {
+ for _, tsrc := range ts {
+ if tsrc.Lanes > 1 {
+ m = append(m, tsrc)
+ }
+ }
+ }
+ slices.SortFunc(m, compareSimdTypes)
+ return m
+}
+
+// writeSIMDTypes generates the simd vector types into a bytes.Buffer
+func writeSIMDTypes(typeMap simdTypeMap) *bytes.Buffer {
+ t := templateOf(simdTypesTemplates, "types_amd64")
+ loadStore := templateOf(simdLoadStoreTemplate, "loadstore_amd64")
+ maskedLoadStore := templateOf(simdMaskedLoadStoreTemplate, "maskedloadstore_amd64")
+ maskFromBits := templateOf(simdMaskFromBitsTemplate, "maskFromBits_amd64")
+ maskFromVal := templateOf(simdMaskFromValTemplate, "maskFromVal_amd64")
+
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(simdPackageHeader)
+
+ sizes := make([]int, 0, len(typeMap))
+ for size, types := range typeMap {
+ slices.SortFunc(types, compareSimdTypes)
+ sizes = append(sizes, size)
+ }
+ sort.Ints(sizes)
+
+ for _, size := range sizes {
+ if size <= 64 {
+ // these are scalar
+ continue
+ }
+ if err := t.ExecuteTemplate(buffer, "sizeTmpl", size); err != nil {
+ panic(fmt.Errorf("failed to execute size template for size %d: %w", size, err))
+ }
+ for _, typeDef := range typeMap[size] {
+ if typeDef.Lanes == 1 {
+ continue
+ }
+ if err := t.ExecuteTemplate(buffer, "typeTmpl", typeDef); err != nil {
+ panic(fmt.Errorf("failed to execute type template for type %s: %w", typeDef.Name, err))
+ }
+ if typeDef.Type != "mask" {
+ if err := loadStore.ExecuteTemplate(buffer, "loadstore_amd64", typeDef); err != nil {
+ panic(fmt.Errorf("failed to execute loadstore template for type %s: %w", typeDef.Name, err))
+ }
+ // restrict to AVX2 masked loads/stores first.
+ if typeDef.MaskedLoadStoreFilter() {
+ if err := maskedLoadStore.ExecuteTemplate(buffer, "maskedloadstore_amd64", typeDef); err != nil {
+ panic(fmt.Errorf("failed to execute maskedloadstore template for type %s: %w", typeDef.Name, err))
+ }
+ }
+ } else {
+ if err := maskFromBits.ExecuteTemplate(buffer, "maskFromBits_amd64", typeDef); err != nil {
+ panic(fmt.Errorf("failed to execute maskFromBits template for type %s: %w", typeDef.Name, err))
+ }
+ if err := maskFromVal.ExecuteTemplate(buffer, "maskFromVal_amd64", typeDef); err != nil {
+ panic(fmt.Errorf("failed to execute maskFromVal template for type %s: %w", typeDef.Name, err))
+ }
+ }
+ }
+ }
+
+ return buffer
+}
+
+func writeSIMDFeatures(ops []Operation) *bytes.Buffer {
+ // Gather all features
+ type featureKey struct {
+ GoArch string
+ Feature string
+ }
+ featureSet := make(map[featureKey]struct{})
+ for _, op := range ops {
+ featureSet[featureKey{op.GoArch, op.CPUFeature}] = struct{}{}
+ }
+ features := slices.SortedFunc(maps.Keys(featureSet), func(a, b featureKey) int {
+ if c := cmp.Compare(a.GoArch, b.GoArch); c != 0 {
+ return c
+ }
+ return compareNatural(a.Feature, b.Feature)
+ })
+
+ // If we ever have the same feature name on more than one GOARCH, we'll have
+ // to be more careful about this.
+ t := templateOf(simdFeaturesTemplate, "features")
+
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(simdPackageHeader)
+
+ if err := t.Execute(buffer, features); err != nil {
+ panic(fmt.Errorf("failed to execute features template: %w", err))
+ }
+
+ return buffer
+}
+
+// writeSIMDStubs generates the simd vector intrinsic stubs and writes it to ops_amd64.go and ops_internal_amd64.go
+// within the specified directory.
+func writeSIMDStubs(ops []Operation, typeMap simdTypeMap) *bytes.Buffer {
+ t := templateOf(simdStubsTmpl, "simdStubs")
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(simdPackageHeader)
+
+ slices.SortFunc(ops, compareOperations)
+
+ for i, op := range ops {
+ if op.NoTypes != nil && *op.NoTypes == "true" {
+ continue
+ }
+ idxVecAsScalar, err := checkVecAsScalar(op)
+ if err != nil {
+ panic(err)
+ }
+ if s, op, err := classifyOp(op); err == nil {
+ if idxVecAsScalar != -1 {
+ if s == "op2" || s == "op3" {
+ s += "VecAsScalar"
+ } else {
+ panic(fmt.Errorf("simdgen only supports op2 or op3 with TreatLikeAScalarOfSize"))
+ }
+ }
+ if i == 0 || op.Go != ops[i-1].Go {
+ fmt.Fprintf(buffer, "\n/* %s */\n", op.Go)
+ }
+ if err := t.ExecuteTemplate(buffer, s, op); err != nil {
+ panic(fmt.Errorf("failed to execute template %s for op %v: %w", s, op, err))
+ }
+ } else {
+ panic(fmt.Errorf("failed to classify op %v: %w", op.Go, err))
+ }
+ }
+
+ vectorConversions := vConvertFromTypeMap(typeMap)
+ for _, conv := range vectorConversions {
+ if err := t.ExecuteTemplate(buffer, "vectorConversion", conv); err != nil {
+ panic(fmt.Errorf("failed to execute vectorConversion template: %w", err))
+ }
+ }
+
+ masks := masksFromTypeMap(typeMap)
+ for _, mask := range masks {
+ if err := t.ExecuteTemplate(buffer, "mask", mask); err != nil {
+ panic(fmt.Errorf("failed to execute mask template for mask %s: %w", mask.Name, err))
+ }
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "slices"
+ "text/template"
+)
+
+type tplRuleData struct {
+ tplName string // e.g. "sftimm"
+ GoOp string // e.g. "ShiftAllLeft"
+ GoType string // e.g. "Uint32x8"
+ Args string // e.g. "x y"
+ Asm string // e.g. "VPSLLD256"
+ ArgsOut string // e.g. "x y"
+ MaskInConvert string // e.g. "VPMOVVec32x8ToM"
+ MaskOutConvert string // e.g. "VPMOVMToVec32x8"
+}
+
+var (
+ ruleTemplates = template.Must(template.New("simdRules").Parse(`
+{{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
+{{end}}
+{{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
+{{end}}
+{{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
+{{end}}
+{{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
+{{end}}
+{{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
+{{end}}
+{{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
+{{end}}
+`))
+)
+
+// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
+var tmplOrder = map[string]int{
+ "masksftimm": 0,
+ "sftimm": 1,
+ "maskInMaskOut": 2,
+ "maskOut": 3,
+ "maskIn": 4,
+ "pureVreg": 5,
+}
+
+func compareTplRuleData(x, y tplRuleData) int {
+ if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
+ return c
+ }
+ if c := compareNatural(x.GoType, y.GoType); c != 0 {
+ return c
+ }
+ if c := compareNatural(x.Args, y.Args); c != 0 {
+ return c
+ }
+ if x.tplName == y.tplName {
+ return 0
+ }
+ xo, xok := tmplOrder[x.tplName]
+ yo, yok := tmplOrder[y.tplName]
+ if !xok {
+ panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
+ }
+ if !yok {
+ panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
+ }
+ return xo - yo
+}
+
+// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
+// within the specified directory.
+func writeSIMDRules(ops []Operation) *bytes.Buffer {
+ buffer := new(bytes.Buffer)
+ buffer.WriteString(generatedHeader + "\n")
+
+ var allData []tplRuleData
+
+ for _, opr := range ops {
+ if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
+ continue
+ }
+ opInShape, opOutShape, maskType, immType, gOp := opr.shape()
+ asm := machineOpName(maskType, gOp)
+ vregInCnt := len(gOp.In)
+ if maskType == OneMask {
+ vregInCnt--
+ }
+
+ data := tplRuleData{
+ GoOp: gOp.Go,
+ Asm: asm,
+ }
+
+ if vregInCnt == 1 {
+ data.Args = "x"
+ data.ArgsOut = data.Args
+ } else if vregInCnt == 2 {
+ data.Args = "x y"
+ data.ArgsOut = data.Args
+ } else if vregInCnt == 3 {
+ data.Args = "x y z"
+ data.ArgsOut = data.Args
+ } else {
+ panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
+ }
+ if immType == ConstImm {
+ data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
+ } else if immType == VarImm {
+ data.Args = fmt.Sprintf("[a] %s", data.Args)
+ data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
+ } else if immType == ConstVarImm {
+ data.Args = fmt.Sprintf("[a] %s", data.Args)
+ data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
+ }
+
+ goType := func(op Operation) string {
+ if op.OperandOrder != nil {
+ switch *op.OperandOrder {
+ case "21Type1", "231Type1":
+ // Permute uses operand[1] for method receiver.
+ return *op.In[1].Go
+ }
+ }
+ return *op.In[0].Go
+ }
+ var tplName string
+ // If class overwrite is happening, that's not really a mask but a vreg.
+ if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
+ switch opInShape {
+ case OneImmIn:
+ tplName = "pureVreg"
+ data.GoType = goType(gOp)
+ case PureVregIn:
+ tplName = "pureVreg"
+ data.GoType = goType(gOp)
+ case OneKmaskImmIn:
+ fallthrough
+ case OneKmaskIn:
+ tplName = "maskIn"
+ data.GoType = goType(gOp)
+ rearIdx := len(gOp.In) - 1
+ // Mask is at the end.
+ data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
+ case PureKmaskIn:
+ panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
+ }
+ } else if opOutShape == OneGregOut {
+ tplName = "pureVreg" // TODO this will be wrong
+ data.GoType = goType(gOp)
+ } else {
+ // OneKmaskOut case
+ data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
+ switch opInShape {
+ case OneImmIn:
+ fallthrough
+ case PureVregIn:
+ tplName = "maskOut"
+ data.GoType = goType(gOp)
+ case OneKmaskImmIn:
+ fallthrough
+ case OneKmaskIn:
+ tplName = "maskInMaskOut"
+ data.GoType = goType(gOp)
+ rearIdx := len(gOp.In) - 1
+ data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
+ case PureKmaskIn:
+ panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
+ }
+ }
+
+ if gOp.SpecialLower != nil {
+ if *gOp.SpecialLower == "sftimm" {
+ if data.GoType[0] == 'I' {
+ // only do these for signed types, it is a duplicate rewrite for unsigned
+ sftImmData := data
+ if tplName == "maskIn" {
+ sftImmData.tplName = "masksftimm"
+ } else {
+ sftImmData.tplName = "sftimm"
+ }
+ allData = append(allData, sftImmData)
+ }
+ } else {
+ panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
+ }
+ }
+
+ if tplName == "pureVreg" && data.Args == data.ArgsOut {
+ data.Args = "..."
+ data.ArgsOut = "..."
+ }
+ data.tplName = tplName
+ allData = append(allData, data)
+ }
+
+ slices.SortFunc(allData, compareTplRuleData)
+
+ for _, data := range allData {
+ if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
+ panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
+ }
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+ "text/template"
+)
+
+var (
+ ssaTemplates = template.Must(template.New("simdSSA").Parse(`
+{{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
+
+package amd64
+
+import (
+ "cmd/compile/internal/ssa"
+ "cmd/compile/internal/ssagen"
+ "cmd/internal/obj"
+ "cmd/internal/obj/x86"
+)
+
+func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
+ var p *obj.Prog
+ switch v.Op {{"{"}}{{end}}
+{{define "case"}}
+ case {{.Cases}}:
+ p = {{.Helper}}(s, v)
+{{end}}
+{{define "footer"}}
+ default:
+ // Unknown reg shape
+ return false
+ }
+{{end}}
+{{define "zeroing"}}
+ // Masked operation are always compiled with zeroing.
+ switch v.Op {
+ case {{.}}:
+ x86.ParseSuffix(p, "Z")
+ }
+{{end}}
+{{define "ending"}}
+ return true
+}
+{{end}}`))
+)
+
+type tplSSAData struct {
+ Cases string
+ Helper string
+}
+
+// writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
+// within the specified directory.
+func writeSIMDSSA(ops []Operation) *bytes.Buffer {
+ var ZeroingMask []string
+ regInfoKeys := []string{
+ "v11",
+ "v21",
+ "v2k",
+ "v2kv",
+ "v2kk",
+ "vkv",
+ "v31",
+ "v3kv",
+ "v11Imm8",
+ "vkvImm8",
+ "v21Imm8",
+ "v2kImm8",
+ "v2kkImm8",
+ "v31ResultInArg0",
+ "v3kvResultInArg0",
+ "vfpv",
+ "vfpkv",
+ "vgpvImm8",
+ "vgpImm8",
+ "v2kvImm8",
+ }
+ regInfoSet := map[string][]string{}
+ for _, key := range regInfoKeys {
+ regInfoSet[key] = []string{}
+ }
+
+ seen := map[string]struct{}{}
+ allUnseen := make(map[string][]Operation)
+ for _, op := range ops {
+ shapeIn, shapeOut, maskType, _, gOp := op.shape()
+ asm := machineOpName(maskType, gOp)
+
+ if _, ok := seen[asm]; ok {
+ continue
+ }
+ seen[asm] = struct{}{}
+ caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
+ if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
+ if gOp.Zeroing == nil {
+ ZeroingMask = append(ZeroingMask, caseStr)
+ }
+ }
+ regShape, err := op.regShape()
+ if err != nil {
+ panic(err)
+ }
+ if shapeOut == OneVregOutAtIn {
+ regShape += "ResultInArg0"
+ }
+ if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
+ regShape += "Imm8"
+ }
+ idx, err := checkVecAsScalar(op)
+ if err != nil {
+ panic(err)
+ }
+ if idx != -1 {
+ if regShape == "v21" {
+ regShape = "vfpv"
+ } else if regShape == "v2kv" {
+ regShape = "vfpkv"
+ } else {
+ panic(fmt.Errorf("simdgen does not recognize uses of treatLikeAScalarOfSize with op regShape %s in op: %s", regShape, op))
+ }
+ }
+ if _, ok := regInfoSet[regShape]; !ok {
+ allUnseen[regShape] = append(allUnseen[regShape], op)
+ }
+ regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
+ }
+ if len(allUnseen) != 0 {
+ panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v", allUnseen))
+ }
+
+ buffer := new(bytes.Buffer)
+
+ if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
+ panic(fmt.Errorf("failed to execute header template: %w", err))
+ }
+
+ for _, regShape := range regInfoKeys {
+ // Stable traversal of regInfoSet
+ cases := regInfoSet[regShape]
+ if len(cases) == 0 {
+ continue
+ }
+ data := tplSSAData{
+ Cases: strings.Join(cases, ",\n\t\t"),
+ Helper: "simd" + capitalizeFirst(regShape),
+ }
+ if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
+ panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
+ }
+ }
+
+ if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
+ panic(fmt.Errorf("failed to execute footer template: %w", err))
+ }
+
+ if len(ZeroingMask) != 0 {
+ if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
+ panic(fmt.Errorf("failed to execute footer template: %w", err))
+ }
+ }
+
+ if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
+ panic(fmt.Errorf("failed to execute footer template: %w", err))
+ }
+
+ return buffer
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "go/format"
+ "log"
+ "os"
+ "path/filepath"
+ "reflect"
+ "slices"
+ "sort"
+ "strings"
+ "text/template"
+ "unicode"
+)
+
+func templateOf(temp, name string) *template.Template {
+ t, err := template.New(name).Parse(temp)
+ if err != nil {
+ panic(fmt.Errorf("failed to parse template %s: %w", name, err))
+ }
+ return t
+}
+
+func createPath(goroot string, file string) (*os.File, error) {
+ fp := filepath.Join(goroot, file)
+ dir := filepath.Dir(fp)
+ err := os.MkdirAll(dir, 0755)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
+ }
+ f, err := os.Create(fp)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create file %s: %w", fp, err)
+ }
+ return f, nil
+}
+
+func formatWriteAndClose(out *bytes.Buffer, goroot string, file string) {
+ b, err := format.Source(out.Bytes())
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ panic(err)
+ } else {
+ writeAndClose(b, goroot, file)
+ }
+}
+
+func writeAndClose(b []byte, goroot string, file string) {
+ ofile, err := createPath(goroot, file)
+ if err != nil {
+ panic(err)
+ }
+ ofile.Write(b)
+ ofile.Close()
+}
+
+// numberLines takes a slice of bytes, and returns a string where each line
+// is numbered, starting from 1.
+func numberLines(data []byte) string {
+ var buf bytes.Buffer
+ r := bytes.NewReader(data)
+ s := bufio.NewScanner(r)
+ for i := 1; s.Scan(); i++ {
+ fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
+ }
+ return buf.String()
+}
+
+type inShape uint8
+type outShape uint8
+type maskShape uint8
+type immShape uint8
+
+const (
+ InvalidIn inShape = iota
+ PureVregIn // vector register input only
+ OneKmaskIn // vector and kmask input
+ OneImmIn // vector and immediate input
+ OneKmaskImmIn // vector, kmask, and immediate inputs
+ PureKmaskIn // only mask inputs.
+)
+
+const (
+ InvalidOut outShape = iota
+ NoOut // no output
+ OneVregOut // (one) vector register output
+ OneGregOut // (one) general register output
+ OneKmaskOut // mask output
+ OneVregOutAtIn // the first input is also the output
+)
+
+const (
+ InvalidMask maskShape = iota
+ NoMask // no mask
+ OneMask // with mask (K1 to K7)
+ AllMasks // a K mask instruction (K0-K7)
+)
+
+const (
+ InvalidImm immShape = iota
+ NoImm // no immediate
+ ConstImm // const only immediate
+ VarImm // pure imm argument provided by the users
+ ConstVarImm // a combination of user arg and const
+)
+
+// opShape returns the several integers describing the shape of the operation,
+// and modified versions of the op:
+//
+// opNoImm is op with its inputs excluding the const imm.
+//
+// This function does not modify op.
+func (op *Operation) shape() (shapeIn inShape, shapeOut outShape, maskType maskShape, immType immShape,
+ opNoImm Operation) {
+ if len(op.Out) > 1 {
+ panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
+ }
+ var outputReg int
+ if len(op.Out) == 1 {
+ outputReg = op.Out[0].AsmPos
+ if op.Out[0].Class == "vreg" {
+ shapeOut = OneVregOut
+ } else if op.Out[0].Class == "greg" {
+ shapeOut = OneGregOut
+ } else if op.Out[0].Class == "mask" {
+ shapeOut = OneKmaskOut
+ } else {
+ panic(fmt.Errorf("simdgen only supports output of class vreg or mask: %s", op))
+ }
+ } else {
+ shapeOut = NoOut
+ // TODO: are these only Load/Stores?
+ // We manually supported two Load and Store, are those enough?
+ panic(fmt.Errorf("simdgen only supports 1 output: %s", op))
+ }
+ hasImm := false
+ maskCount := 0
+ hasVreg := false
+ for _, in := range op.In {
+ if in.AsmPos == outputReg {
+ if shapeOut != OneVregOutAtIn && in.AsmPos == 0 && in.Class == "vreg" {
+ shapeOut = OneVregOutAtIn
+ } else {
+ panic(fmt.Errorf("simdgen only support output and input sharing the same position case of \"the first input is vreg and the only output\": %s", op))
+ }
+ }
+ if in.Class == "immediate" {
+ // A manual check on XED data found that AMD64 SIMD instructions at most
+ // have 1 immediates. So we don't need to check this here.
+ if *in.Bits != 8 {
+ panic(fmt.Errorf("simdgen only supports immediates of 8 bits: %s", op))
+ }
+ hasImm = true
+ } else if in.Class == "mask" {
+ maskCount++
+ } else {
+ hasVreg = true
+ }
+ }
+ opNoImm = *op
+
+ removeImm := func(o *Operation) {
+ o.In = o.In[1:]
+ }
+ if hasImm {
+ removeImm(&opNoImm)
+ if op.In[0].Const != nil {
+ if op.In[0].ImmOffset != nil {
+ immType = ConstVarImm
+ } else {
+ immType = ConstImm
+ }
+ } else if op.In[0].ImmOffset != nil {
+ immType = VarImm
+ } else {
+ panic(fmt.Errorf("simdgen requires imm to have at least one of ImmOffset or Const set: %s", op))
+ }
+ } else {
+ immType = NoImm
+ }
+ if maskCount == 0 {
+ maskType = NoMask
+ } else {
+ maskType = OneMask
+ }
+ checkPureMask := func() bool {
+ if hasImm {
+ panic(fmt.Errorf("simdgen does not support immediates in pure mask operations: %s", op))
+ }
+ if hasVreg {
+ panic(fmt.Errorf("simdgen does not support more than 1 masks in non-pure mask operations: %s", op))
+ }
+ return false
+ }
+ if !hasImm && maskCount == 0 {
+ shapeIn = PureVregIn
+ } else if !hasImm && maskCount > 0 {
+ if maskCount == 1 {
+ shapeIn = OneKmaskIn
+ } else {
+ if checkPureMask() {
+ return
+ }
+ shapeIn = PureKmaskIn
+ maskType = AllMasks
+ }
+ } else if hasImm && maskCount == 0 {
+ shapeIn = OneImmIn
+ } else {
+ if maskCount == 1 {
+ shapeIn = OneKmaskImmIn
+ } else {
+ checkPureMask()
+ return
+ }
+ }
+ return
+}
+
+// regShape returns a string representation of the register shape.
+func (op *Operation) regShape() (string, error) {
+ _, _, _, _, gOp := op.shape()
+ var regInfo string
+ var vRegInCnt, gRegInCnt, kMaskInCnt, vRegOutCnt, gRegOutCnt, kMaskOutCnt int
+ for _, in := range gOp.In {
+ if in.Class == "vreg" {
+ vRegInCnt++
+ } else if in.Class == "greg" {
+ gRegInCnt++
+ } else if in.Class == "mask" {
+ kMaskInCnt++
+ }
+ }
+ for _, out := range gOp.Out {
+ // If class overwrite is happening, that's not really a mask but a vreg.
+ if out.Class == "vreg" || out.OverwriteClass != nil {
+ vRegOutCnt++
+ } else if out.Class == "greg" {
+ gRegOutCnt++
+ } else if out.Class == "mask" {
+ kMaskOutCnt++
+ }
+ }
+ var inRegs, inMasks, outRegs, outMasks string
+
+ rmAbbrev := func(s string, i int) string {
+ if i == 0 {
+ return ""
+ }
+ if i == 1 {
+ return s
+ }
+ return fmt.Sprintf("%s%d", s, i)
+
+ }
+
+ inRegs = rmAbbrev("v", vRegInCnt)
+ inRegs += rmAbbrev("gp", gRegInCnt)
+ inMasks = rmAbbrev("k", kMaskInCnt)
+
+ outRegs = rmAbbrev("v", vRegOutCnt)
+ outRegs += rmAbbrev("gp", gRegOutCnt)
+ outMasks = rmAbbrev("k", kMaskOutCnt)
+
+ if kMaskInCnt == 0 && kMaskOutCnt == 0 && gRegInCnt == 0 && gRegOutCnt == 0 {
+ // For pure v we can abbreviate it as v%d%d.
+ regInfo = fmt.Sprintf("v%d%d", vRegInCnt, vRegOutCnt)
+ } else if kMaskInCnt == 0 && kMaskOutCnt == 0 {
+ regInfo = fmt.Sprintf("%s%s", inRegs, outRegs)
+ } else {
+ regInfo = fmt.Sprintf("%s%s%s%s", inRegs, inMasks, outRegs, outMasks)
+ }
+ return regInfo, nil
+}
+
+// sortOperand sorts op.In by putting immediates first, then vreg, and mask the last.
+// TODO: verify that this is a safe assumption of the prog structure.
+// from my observation looks like in asm, imms are always the first,
+// masks are always the last, with vreg in between.
+func (op *Operation) sortOperand() {
+ priority := map[string]int{"immediate": 0, "vreg": 1, "greg": 1, "mask": 2}
+ sort.SliceStable(op.In, func(i, j int) bool {
+ pi := priority[op.In[i].Class]
+ pj := priority[op.In[j].Class]
+ if pi != pj {
+ return pi < pj
+ }
+ return op.In[i].AsmPos < op.In[j].AsmPos
+ })
+}
+
+// goNormalType returns the Go type name for the result of an Op that
+// does not return a vector, i.e., that returns a result in a general
+// register. Currently there's only one family of Ops in Go's simd library
+// that does this (GetElem), and so this is specialized to work for that,
+// but the problem (mismatch betwen hardware register width and Go type
+// width) seems likely to recur if there are any other cases.
+func (op Operation) goNormalType() string {
+ if op.Go == "GetElem" {
+ // GetElem returns an element of the vector into a general register
+ // but as far as the hardware is concerned, that result is either 32
+ // or 64 bits wide, no matter what the vector element width is.
+ // This is not "wrong" but it is not the right answer for Go source code.
+ // To get the Go type right, combine the base type ("int", "uint", "float"),
+ // with the input vector element width in bits (8,16,32,64).
+
+ at := 0 // proper value of at depends on whether immediate was stripped or not
+ if op.In[at].Class == "immediate" {
+ at++
+ }
+ return fmt.Sprintf("%s%d", *op.Out[0].Base, *op.In[at].ElemBits)
+ }
+ panic(fmt.Errorf("Implement goNormalType for %v", op))
+}
+
+// SSAType returns the string for the type reference in SSA generation,
+// for example in the intrinsics generating template.
+func (op Operation) SSAType() string {
+ if op.Out[0].Class == "greg" {
+ return fmt.Sprintf("types.Types[types.T%s]", strings.ToUpper(op.goNormalType()))
+ }
+ return fmt.Sprintf("types.TypeVec%d", *op.Out[0].Bits)
+}
+
+// GoType returns the Go type returned by this operation (relative to the simd package),
+// for example "int32" or "Int8x16". This is used in a template.
+func (op Operation) GoType() string {
+ if op.Out[0].Class == "greg" {
+ return op.goNormalType()
+ }
+ return *op.Out[0].Go
+}
+
+// ImmName returns the name to use for an operation's immediate operand.
+// This can be overriden in the yaml with "name" on an operand,
+// otherwise, for now, "constant"
+func (op Operation) ImmName() string {
+ return op.Op0Name("constant")
+}
+
+func (o Operand) OpName(s string) string {
+ if n := o.Name; n != nil {
+ return *n
+ }
+ if o.Class == "mask" {
+ return "mask"
+ }
+ return s
+}
+
+func (o Operand) OpNameAndType(s string) string {
+ return o.OpName(s) + " " + *o.Go
+}
+
+// GoExported returns [Go] with first character capitalized.
+func (op Operation) GoExported() string {
+ return capitalizeFirst(op.Go)
+}
+
+// DocumentationExported returns [Documentation] with method name capitalized.
+func (op Operation) DocumentationExported() string {
+ return strings.ReplaceAll(op.Documentation, op.Go, op.GoExported())
+}
+
+// Op0Name returns the name to use for the 0 operand,
+// if any is present, otherwise the parameter is used.
+func (op Operation) Op0Name(s string) string {
+ return op.In[0].OpName(s)
+}
+
+// Op1Name returns the name to use for the 1 operand,
+// if any is present, otherwise the parameter is used.
+func (op Operation) Op1Name(s string) string {
+ return op.In[1].OpName(s)
+}
+
+// Op2Name returns the name to use for the 2 operand,
+// if any is present, otherwise the parameter is used.
+func (op Operation) Op2Name(s string) string {
+ return op.In[2].OpName(s)
+}
+
+// Op3Name returns the name to use for the 3 operand,
+// if any is present, otherwise the parameter is used.
+func (op Operation) Op3Name(s string) string {
+ return op.In[3].OpName(s)
+}
+
+// Op0NameAndType returns the name and type to use for
+// the 0 operand, if a name is provided, otherwise
+// the parameter value is used as the default.
+func (op Operation) Op0NameAndType(s string) string {
+ return op.In[0].OpNameAndType(s)
+}
+
+// Op1NameAndType returns the name and type to use for
+// the 1 operand, if a name is provided, otherwise
+// the parameter value is used as the default.
+func (op Operation) Op1NameAndType(s string) string {
+ return op.In[1].OpNameAndType(s)
+}
+
+// Op2NameAndType returns the name and type to use for
+// the 2 operand, if a name is provided, otherwise
+// the parameter value is used as the default.
+func (op Operation) Op2NameAndType(s string) string {
+ return op.In[2].OpNameAndType(s)
+}
+
+// Op3NameAndType returns the name and type to use for
+// the 3 operand, if a name is provided, otherwise
+// the parameter value is used as the default.
+func (op Operation) Op3NameAndType(s string) string {
+ return op.In[3].OpNameAndType(s)
+}
+
+// Op4NameAndType returns the name and type to use for
+// the 4 operand, if a name is provided, otherwise
+// the parameter value is used as the default.
+func (op Operation) Op4NameAndType(s string) string {
+ return op.In[4].OpNameAndType(s)
+}
+
+var immClasses []string = []string{"BAD0Imm", "BAD1Imm", "op1Imm8", "op2Imm8", "op3Imm8", "op4Imm8"}
+var classes []string = []string{"BAD0", "op1", "op2", "op3", "op4"}
+
+// classifyOp returns a classification string, modified operation, and perhaps error based
+// on the stub and intrinsic shape for the operation.
+// The classification string is in the regular expression set "op[1234](Imm8)?(_<order>)?"
+// where the "<order>" suffix is optionally attached to the Operation in its input yaml.
+// The classification string is used to select a template or a clause of a template
+// for intrinsics declaration and the ssagen intrinisics glue code in the compiler.
+func classifyOp(op Operation) (string, Operation, error) {
+ _, _, _, immType, gOp := op.shape()
+
+ var class string
+
+ if immType == VarImm || immType == ConstVarImm {
+ switch l := len(op.In); l {
+ case 1:
+ return "", op, fmt.Errorf("simdgen does not recognize this operation of only immediate input: %s", op)
+ case 2, 3, 4, 5:
+ class = immClasses[l]
+ default:
+ return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
+ }
+ if order := op.OperandOrder; order != nil {
+ class += "_" + *order
+ }
+ return class, op, nil
+ } else {
+ switch l := len(gOp.In); l {
+ case 1, 2, 3, 4:
+ class = classes[l]
+ default:
+ return "", op, fmt.Errorf("simdgen does not recognize this operation of input length %d: %s", len(op.In), op)
+ }
+ if order := op.OperandOrder; order != nil {
+ class += "_" + *order
+ }
+ return class, gOp, nil
+ }
+}
+
+func checkVecAsScalar(op Operation) (idx int, err error) {
+ idx = -1
+ sSize := 0
+ for i, o := range op.In {
+ if o.TreatLikeAScalarOfSize != nil {
+ if idx == -1 {
+ idx = i
+ sSize = *o.TreatLikeAScalarOfSize
+ } else {
+ err = fmt.Errorf("simdgen only supports one TreatLikeAScalarOfSize in the arg list: %s", op)
+ return
+ }
+ }
+ }
+ if idx >= 0 {
+ if idx != 1 {
+ err = fmt.Errorf("simdgen only supports TreatLikeAScalarOfSize at the 2nd arg of the arg list: %s", op)
+ return
+ }
+ if sSize != 8 && sSize != 16 && sSize != 32 && sSize != 64 {
+ err = fmt.Errorf("simdgen does not recognize this uint size: %d, %s", sSize, op)
+ return
+ }
+ }
+ return
+}
+
+// dedup is deduping operations in the full structure level.
+func dedup(ops []Operation) (deduped []Operation) {
+ for _, op := range ops {
+ seen := false
+ for _, dop := range deduped {
+ if reflect.DeepEqual(op, dop) {
+ seen = true
+ break
+ }
+ }
+ if !seen {
+ deduped = append(deduped, op)
+ }
+ }
+ return
+}
+
+func (op Operation) GenericName() string {
+ if op.OperandOrder != nil {
+ switch *op.OperandOrder {
+ case "21Type1", "231Type1":
+ // Permute uses operand[1] for method receiver.
+ return op.Go + *op.In[1].Go
+ }
+ }
+ if op.In[0].Class == "immediate" {
+ return op.Go + *op.In[1].Go
+ }
+ return op.Go + *op.In[0].Go
+}
+
+// dedupGodef is deduping operations in [Op.Go]+[*Op.In[0].Go] level.
+// By deduping, it means picking the least advanced architecture that satisfy the requirement:
+// AVX512 will be least preferred.
+// If FlagNoDedup is set, it will report the duplicates to the console.
+func dedupGodef(ops []Operation) ([]Operation, error) {
+ seen := map[string][]Operation{}
+ for _, op := range ops {
+ _, _, _, _, gOp := op.shape()
+
+ gN := gOp.GenericName()
+ seen[gN] = append(seen[gN], op)
+ }
+ if *FlagReportDup {
+ for gName, dup := range seen {
+ if len(dup) > 1 {
+ log.Printf("Duplicate for %s:\n", gName)
+ for _, op := range dup {
+ log.Printf("%s\n", op)
+ }
+ }
+ }
+ return ops, nil
+ }
+ isAVX512 := func(op Operation) bool {
+ return strings.Contains(op.CPUFeature, "AVX512")
+ }
+ deduped := []Operation{}
+ for _, dup := range seen {
+ if len(dup) > 1 {
+ slices.SortFunc(dup, func(i, j Operation) int {
+ // Put non-AVX512 candidates at the beginning
+ if !isAVX512(i) && isAVX512(j) {
+ return -1
+ }
+ if isAVX512(i) && !isAVX512(j) {
+ return 1
+ }
+ return strings.Compare(i.CPUFeature, j.CPUFeature)
+ })
+ }
+ deduped = append(deduped, dup[0])
+ }
+ slices.SortFunc(deduped, compareOperations)
+ return deduped, nil
+}
+
+// Copy op.ConstImm to op.In[0].Const
+// This is a hack to reduce the size of defs we need for const imm operations.
+func copyConstImm(ops []Operation) error {
+ for _, op := range ops {
+ if op.ConstImm == nil {
+ continue
+ }
+ _, _, _, immType, _ := op.shape()
+
+ if immType == ConstImm || immType == ConstVarImm {
+ op.In[0].Const = op.ConstImm
+ }
+ // Otherwise, just not port it - e.g. {VPCMP[BWDQ] imm=0} and {VPCMPEQ[BWDQ]} are
+ // the same operations "Equal", [dedupgodef] should be able to distinguish them.
+ }
+ return nil
+}
+
+func capitalizeFirst(s string) string {
+ if s == "" {
+ return ""
+ }
+ // Convert the string to a slice of runes to handle multi-byte characters correctly.
+ r := []rune(s)
+ r[0] = unicode.ToUpper(r[0])
+ return string(r)
+}
+
+// overwrite corrects some errors due to:
+// - The XED data is wrong
+// - Go's SIMD API requirement, for example AVX2 compares should also produce masks.
+// This rewrite has strict constraints, please see the error message.
+// These constraints are also explointed in [writeSIMDRules], [writeSIMDMachineOps]
+// and [writeSIMDSSA], please be careful when updating these constraints.
+func overwrite(ops []Operation) error {
+ hasClassOverwrite := false
+ overwrite := func(op []Operand, idx int, o Operation) error {
+ if op[idx].OverwriteElementBits != nil {
+ if op[idx].ElemBits == nil {
+ panic(fmt.Errorf("ElemBits is nil at operand %d of %v", idx, o))
+ }
+ *op[idx].ElemBits = *op[idx].OverwriteElementBits
+ *op[idx].Lanes = *op[idx].Bits / *op[idx].ElemBits
+ *op[idx].Go = fmt.Sprintf("%s%dx%d", capitalizeFirst(*op[idx].Base), *op[idx].ElemBits, *op[idx].Lanes)
+ }
+ if op[idx].OverwriteClass != nil {
+ if op[idx].OverwriteBase == nil {
+ panic(fmt.Errorf("simdgen: [OverwriteClass] must be set together with [OverwriteBase]: %s", op[idx]))
+ }
+ oBase := *op[idx].OverwriteBase
+ oClass := *op[idx].OverwriteClass
+ if oClass != "mask" {
+ panic(fmt.Errorf("simdgen: [Class] overwrite only supports overwritting to mask: %s", op[idx]))
+ }
+ if oBase != "int" {
+ panic(fmt.Errorf("simdgen: [Class] overwrite must set [OverwriteBase] to int: %s", op[idx]))
+ }
+ if op[idx].Class != "vreg" {
+ panic(fmt.Errorf("simdgen: [Class] overwrite must be overwriting [Class] from vreg: %s", op[idx]))
+ }
+ hasClassOverwrite = true
+ *op[idx].Base = oBase
+ op[idx].Class = oClass
+ *op[idx].Go = fmt.Sprintf("Mask%dx%d", *op[idx].ElemBits, *op[idx].Lanes)
+ } else if op[idx].OverwriteBase != nil {
+ oBase := *op[idx].OverwriteBase
+ *op[idx].Go = strings.ReplaceAll(*op[idx].Go, capitalizeFirst(*op[idx].Base), capitalizeFirst(oBase))
+ if op[idx].Class == "greg" {
+ *op[idx].Go = strings.ReplaceAll(*op[idx].Go, *op[idx].Base, oBase)
+ }
+ *op[idx].Base = oBase
+ }
+ return nil
+ }
+ for i, o := range ops {
+ hasClassOverwrite = false
+ for j := range ops[i].In {
+ if err := overwrite(ops[i].In, j, o); err != nil {
+ return err
+ }
+ if hasClassOverwrite {
+ return fmt.Errorf("simdgen does not support [OverwriteClass] in inputs: %s", ops[i])
+ }
+ }
+ for j := range ops[i].Out {
+ if err := overwrite(ops[i].Out, j, o); err != nil {
+ return err
+ }
+ }
+ if hasClassOverwrite {
+ for _, in := range ops[i].In {
+ if in.Class == "mask" {
+ return fmt.Errorf("simdgen only supports [OverwriteClass] for operations without mask inputs")
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// reportXEDInconsistency reports potential XED inconsistencies.
+// We can add more fields to [Operation] to enable more checks and implement it here.
+// Supported checks:
+// [NameAndSizeCheck]: NAME[BWDQ] should set the elemBits accordingly.
+// This check is useful to find inconsistencies, then we can add overwrite fields to
+// those defs to correct them manually.
+func reportXEDInconsistency(ops []Operation) error {
+ for _, o := range ops {
+ if o.NameAndSizeCheck != nil {
+ suffixSizeMap := map[byte]int{'B': 8, 'W': 16, 'D': 32, 'Q': 64}
+ checkOperand := func(opr Operand) error {
+ if opr.ElemBits == nil {
+ return fmt.Errorf("simdgen expects elemBits to be set when performing NameAndSizeCheck")
+ }
+ if v, ok := suffixSizeMap[o.Asm[len(o.Asm)-1]]; !ok {
+ return fmt.Errorf("simdgen expects asm to end with [BWDQ] when performing NameAndSizeCheck")
+ } else {
+ if v != *opr.ElemBits {
+ return fmt.Errorf("simdgen finds NameAndSizeCheck inconsistency in def: %s", o)
+ }
+ }
+ return nil
+ }
+ for _, in := range o.In {
+ if in.Class != "vreg" && in.Class != "mask" {
+ continue
+ }
+ if in.TreatLikeAScalarOfSize != nil {
+ // This is an irregular operand, don't check it.
+ continue
+ }
+ if err := checkOperand(in); err != nil {
+ return err
+ }
+ }
+ for _, out := range o.Out {
+ if err := checkOperand(out); err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+func (o Operation) String() string {
+ return pprints(o)
+}
+
+func (op Operand) String() string {
+ return pprints(op)
+}
--- /dev/null
+!import ops/*/go.yaml
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "log"
+ "regexp"
+ "slices"
+ "strconv"
+ "strings"
+
+ "simd/_gen/unify"
+)
+
+type Operation struct {
+ rawOperation
+
+ // Go is the Go method name of this operation.
+ //
+ // It is derived from the raw Go method name by adding optional suffixes.
+ // Currently, "Masked" is the only suffix.
+ Go string
+
+ // Documentation is the doc string for this API.
+ //
+ // It is computed from the raw documentation:
+ //
+ // - "NAME" is replaced by the Go method name.
+ //
+ // - For masked operation, a sentence about masking is added.
+ Documentation string
+
+ // In is the sequence of parameters to the Go method.
+ //
+ // For masked operations, this will have the mask operand appended.
+ In []Operand
+}
+
+// rawOperation is the unifier representation of an [Operation]. It is
+// translated into a more parsed form after unifier decoding.
+type rawOperation struct {
+ Go string // Base Go method name
+
+ GoArch string // GOARCH for this definition
+ Asm string // Assembly mnemonic
+ OperandOrder *string // optional Operand order for better Go declarations
+ // Optional tag to indicate this operation is paired with special generic->machine ssa lowering rules.
+ // Should be paired with special templates in gen_simdrules.go
+ SpecialLower *string
+
+ In []Operand // Parameters
+ InVariant []Operand // Optional parameters
+ Out []Operand // Results
+ Commutative bool // Commutativity
+ CPUFeature string // CPUID/Has* feature name
+ Zeroing *bool // nil => use asm suffix ".Z"; false => do not use asm suffix ".Z"
+ Documentation *string // Documentation will be appended to the stubs comments.
+ // ConstMask is a hack to reduce the size of defs the user writes for const-immediate
+ // If present, it will be copied to [In[0].Const].
+ ConstImm *string
+ // NameAndSizeCheck is used to check [BWDQ] maps to (8|16|32|64) elemBits.
+ NameAndSizeCheck *bool
+ // If non-nil, all generation in gen_simdTypes.go and gen_intrinsics will be skipped.
+ NoTypes *string
+ // If non-nil, all generation in gen_simdGenericOps and gen_simdrules will be skipped.
+ NoGenericOps *string
+ // If non-nil, this string will be attached to the machine ssa op name.
+ SSAVariant *string
+}
+
+func (o *Operation) DecodeUnified(v *unify.Value) error {
+ if err := v.Decode(&o.rawOperation); err != nil {
+ return err
+ }
+
+ isMasked := false
+ if len(o.InVariant) == 0 {
+ // No variant
+ } else if len(o.InVariant) == 1 && o.InVariant[0].Class == "mask" {
+ isMasked = true
+ } else {
+ return fmt.Errorf("unknown inVariant")
+ }
+
+ // Compute full Go method name.
+ o.Go = o.rawOperation.Go
+ if isMasked {
+ o.Go += "Masked"
+ }
+
+ // Compute doc string.
+ if o.rawOperation.Documentation != nil {
+ o.Documentation = *o.rawOperation.Documentation
+ } else {
+ o.Documentation = "// UNDOCUMENTED"
+ }
+ o.Documentation = regexp.MustCompile(`\bNAME\b`).ReplaceAllString(o.Documentation, o.Go)
+ if isMasked {
+ o.Documentation += "\n//\n// This operation is applied selectively under a write mask."
+ }
+
+ o.In = append(o.rawOperation.In, o.rawOperation.InVariant...)
+
+ return nil
+}
+
+func (o *Operation) VectorWidth() int {
+ out := o.Out[0]
+ if out.Class == "vreg" {
+ return *out.Bits
+ } else if out.Class == "greg" || out.Class == "mask" {
+ for i := range o.In {
+ if o.In[i].Class == "vreg" {
+ return *o.In[i].Bits
+ }
+ }
+ }
+ panic(fmt.Errorf("Figure out what the vector width is for %v and implement it", *o))
+}
+
+func machineOpName(maskType maskShape, gOp Operation) string {
+ asm := gOp.Asm
+ if maskType == 2 {
+ asm += "Masked"
+ }
+ asm = fmt.Sprintf("%s%d", asm, gOp.VectorWidth())
+ if gOp.SSAVariant != nil {
+ asm += *gOp.SSAVariant
+ }
+ return asm
+}
+
+func compareStringPointers(x, y *string) int {
+ if x != nil && y != nil {
+ return compareNatural(*x, *y)
+ }
+ if x == nil && y == nil {
+ return 0
+ }
+ if x == nil {
+ return -1
+ }
+ return 1
+}
+
+func compareIntPointers(x, y *int) int {
+ if x != nil && y != nil {
+ return *x - *y
+ }
+ if x == nil && y == nil {
+ return 0
+ }
+ if x == nil {
+ return -1
+ }
+ return 1
+}
+
+func compareOperations(x, y Operation) int {
+ if c := compareNatural(x.Go, y.Go); c != 0 {
+ return c
+ }
+ xIn, yIn := x.In, y.In
+
+ if len(xIn) > len(yIn) && xIn[len(xIn)-1].Class == "mask" {
+ xIn = xIn[:len(xIn)-1]
+ } else if len(xIn) < len(yIn) && yIn[len(yIn)-1].Class == "mask" {
+ yIn = yIn[:len(yIn)-1]
+ }
+
+ if len(xIn) < len(yIn) {
+ return -1
+ }
+ if len(xIn) > len(yIn) {
+ return 1
+ }
+ if len(x.Out) < len(y.Out) {
+ return -1
+ }
+ if len(x.Out) > len(y.Out) {
+ return 1
+ }
+ for i := range xIn {
+ ox, oy := &xIn[i], &yIn[i]
+ if c := compareOperands(ox, oy); c != 0 {
+ return c
+ }
+ }
+ return 0
+}
+
+func compareOperands(x, y *Operand) int {
+ if c := compareNatural(x.Class, y.Class); c != 0 {
+ return c
+ }
+ if x.Class == "immediate" {
+ return compareStringPointers(x.ImmOffset, y.ImmOffset)
+ } else {
+ if c := compareStringPointers(x.Base, y.Base); c != 0 {
+ return c
+ }
+ if c := compareIntPointers(x.ElemBits, y.ElemBits); c != 0 {
+ return c
+ }
+ if c := compareIntPointers(x.Bits, y.Bits); c != 0 {
+ return c
+ }
+ return 0
+ }
+}
+
+type Operand struct {
+ Class string // One of "mask", "immediate", "vreg", "greg", and "mem"
+
+ Go *string // Go type of this operand
+ AsmPos int // Position of this operand in the assembly instruction
+
+ Base *string // Base Go type ("int", "uint", "float")
+ ElemBits *int // Element bit width
+ Bits *int // Total vector bit width
+
+ Const *string // Optional constant value for immediates.
+ // Optional immediate arg offsets. If this field is non-nil,
+ // This operand will be an immediate operand:
+ // The compiler will right-shift the user-passed value by ImmOffset and set it as the AuxInt
+ // field of the operation.
+ ImmOffset *string
+ Name *string // optional name in the Go intrinsic declaration
+ Lanes *int // *Lanes equals Bits/ElemBits except for scalars, when *Lanes == 1
+ // TreatLikeAScalarOfSize means only the lower $TreatLikeAScalarOfSize bits of the vector
+ // is used, so at the API level we can make it just a scalar value of this size; Then we
+ // can overwrite it to a vector of the right size during intrinsics stage.
+ TreatLikeAScalarOfSize *int
+ // If non-nil, it means the [Class] field is overwritten here, right now this is used to
+ // overwrite the results of AVX2 compares to masks.
+ OverwriteClass *string
+ // If non-nil, it means the [Base] field is overwritten here. This field exist solely
+ // because Intel's XED data is inconsistent. e.g. VANDNP[SD] marks its operand int.
+ OverwriteBase *string
+ // If non-nil, it means the [ElementBits] field is overwritten. This field exist solely
+ // because Intel's XED data is inconsistent. e.g. AVX512 VPMADDUBSW marks its operand
+ // elemBits 16, which should be 8.
+ OverwriteElementBits *int
+}
+
+// isDigit returns true if the byte is an ASCII digit.
+func isDigit(b byte) bool {
+ return b >= '0' && b <= '9'
+}
+
+// compareNatural performs a "natural sort" comparison of two strings.
+// It compares non-digit sections lexicographically and digit sections
+// numerically. In the case of string-unequal "equal" strings like
+// "a01b" and "a1b", strings.Compare breaks the tie.
+//
+// It returns:
+//
+// -1 if s1 < s2
+// 0 if s1 == s2
+// +1 if s1 > s2
+func compareNatural(s1, s2 string) int {
+ i, j := 0, 0
+ len1, len2 := len(s1), len(s2)
+
+ for i < len1 && j < len2 {
+ // Find a non-digit segment or a number segment in both strings.
+ if isDigit(s1[i]) && isDigit(s2[j]) {
+ // Number segment comparison.
+ numStart1 := i
+ for i < len1 && isDigit(s1[i]) {
+ i++
+ }
+ num1, _ := strconv.Atoi(s1[numStart1:i])
+
+ numStart2 := j
+ for j < len2 && isDigit(s2[j]) {
+ j++
+ }
+ num2, _ := strconv.Atoi(s2[numStart2:j])
+
+ if num1 < num2 {
+ return -1
+ }
+ if num1 > num2 {
+ return 1
+ }
+ // If numbers are equal, continue to the next segment.
+ } else {
+ // Non-digit comparison.
+ if s1[i] < s2[j] {
+ return -1
+ }
+ if s1[i] > s2[j] {
+ return 1
+ }
+ i++
+ j++
+ }
+ }
+
+ // deal with a01b vs a1b; there needs to be an order.
+ return strings.Compare(s1, s2)
+}
+
+const generatedHeader = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
+`
+
+func writeGoDefs(path string, cl unify.Closure) error {
+ // TODO: Merge operations with the same signature but multiple
+ // implementations (e.g., SSE vs AVX)
+ var ops []Operation
+ for def := range cl.All() {
+ var op Operation
+ if !def.Exact() {
+ continue
+ }
+ if err := def.Decode(&op); err != nil {
+ log.Println(err.Error())
+ log.Println(def)
+ continue
+ }
+ // TODO: verify that this is safe.
+ op.sortOperand()
+ ops = append(ops, op)
+ }
+ slices.SortFunc(ops, compareOperations)
+ // The parsed XED data might contain duplicates, like
+ // 512 bits VPADDP.
+ deduped := dedup(ops)
+ slices.SortFunc(deduped, compareOperations)
+
+ if *Verbose {
+ log.Printf("dedup len: %d\n", len(ops))
+ }
+ var err error
+ if err = overwrite(deduped); err != nil {
+ return err
+ }
+ if *Verbose {
+ log.Printf("dedup len: %d\n", len(deduped))
+ }
+ if *Verbose {
+ log.Printf("dedup len: %d\n", len(deduped))
+ }
+ if !*FlagNoDedup {
+ // TODO: This can hide mistakes in the API definitions, especially when
+ // multiple patterns result in the same API unintentionally. Make it stricter.
+ if deduped, err = dedupGodef(deduped); err != nil {
+ return err
+ }
+ }
+ if *Verbose {
+ log.Printf("dedup len: %d\n", len(deduped))
+ }
+ if !*FlagNoConstImmPorting {
+ if err = copyConstImm(deduped); err != nil {
+ return err
+ }
+ }
+ if *Verbose {
+ log.Printf("dedup len: %d\n", len(deduped))
+ }
+ reportXEDInconsistency(deduped)
+ typeMap := parseSIMDTypes(deduped)
+
+ formatWriteAndClose(writeSIMDTypes(typeMap), path, "src/"+simdPackage+"/types_amd64.go")
+ formatWriteAndClose(writeSIMDFeatures(deduped), path, "src/"+simdPackage+"/cpu.go")
+ formatWriteAndClose(writeSIMDStubs(deduped, typeMap), path, "src/"+simdPackage+"/ops_amd64.go")
+ formatWriteAndClose(writeSIMDIntrinsics(deduped, typeMap), path, "src/cmd/compile/internal/ssagen/simdintrinsics.go")
+ formatWriteAndClose(writeSIMDGenericOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdgenericOps.go")
+ formatWriteAndClose(writeSIMDMachineOps(deduped), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64ops.go")
+ formatWriteAndClose(writeSIMDSSA(deduped), path, "src/cmd/compile/internal/amd64/simdssa.go")
+ writeAndClose(writeSIMDRules(deduped).Bytes(), path, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
+
+ return nil
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// simdgen is an experiment in generating Go <-> asm SIMD mappings.
+//
+// Usage: simdgen [-xedPath=path] [-q=query] input.yaml...
+//
+// If -xedPath is provided, one of the inputs is a sum of op-code definitions
+// generated from the Intel XED data at path.
+//
+// If input YAML files are provided, each file is read as an input value. See
+// [unify.Closure.UnmarshalYAML] or "go doc unify.Closure.UnmarshalYAML" for the
+// format of these files.
+//
+// TODO: Example definitions and values.
+//
+// The command unifies across all of the inputs and prints all possible results
+// of this unification.
+//
+// If the -q flag is provided, its string value is parsed as a value and treated
+// as another input to unification. This is intended as a way to "query" the
+// result, typically by narrowing it down to a small subset of results.
+//
+// Typical usage:
+//
+// go run . -xedPath $XEDPATH *.yaml
+//
+// To see just the definitions generated from XED, run:
+//
+// go run . -xedPath $XEDPATH
+//
+// (This works because if there's only one input, there's nothing to unify it
+// with, so the result is simply itself.)
+//
+// To see just the definitions for VPADDQ:
+//
+// go run . -xedPath $XEDPATH -q '{asm: VPADDQ}'
+//
+// simdgen can also generate Go definitions of SIMD mappings:
+// To generate go files to the go root, run:
+//
+// go run . -xedPath $XEDPATH -o godefs -goroot $PATH/TO/go go.yaml categories.yaml types.yaml
+//
+// types.yaml is already written, it specifies the shapes of vectors.
+// categories.yaml and go.yaml contains definitions that unifies with types.yaml and XED
+// data, you can find an example in ops/AddSub/.
+//
+// When generating Go definitions, simdgen do 3 "magic"s:
+// - It splits masked operations(with op's [Masked] field set) to const and non const:
+// - One is a normal masked operation, the original
+// - The other has its mask operand's [Const] fields set to "K0".
+// - This way the user does not need to provide a separate "K0"-masked operation def.
+//
+// - It deduplicates intrinsic names that have duplicates:
+// - If there are two operations that shares the same signature, one is AVX512 the other
+// is before AVX512, the other will be selected.
+// - This happens often when some operations are defined both before AVX512 and after.
+// This way the user does not need to provide a separate "K0" operation for the
+// AVX512 counterpart.
+//
+// - It copies the op's [ConstImm] field to its immediate operand's [Const] field.
+// - This way the user does not need to provide verbose op definition while only
+// the const immediate field is different. This is useful to reduce verbosity of
+// compares with imm control predicates.
+//
+// These 3 magics could be disabled by enabling -nosplitmask, -nodedup or
+// -noconstimmporting flags.
+//
+// simdgen right now only supports amd64, -arch=$OTHERARCH will trigger a fatal error.
+package main
+
+// Big TODOs:
+//
+// - This can produce duplicates, which can also lead to less efficient
+// environment merging. Add hashing and use it for deduplication. Be careful
+// about how this shows up in debug traces, since it could make things
+// confusing if we don't show it happening.
+//
+// - Do I need Closure, Value, and Domain? It feels like I should only need two
+// types.
+
+import (
+ "cmp"
+ "flag"
+ "fmt"
+ "log"
+ "maps"
+ "os"
+ "path/filepath"
+ "runtime/pprof"
+ "slices"
+ "strings"
+
+ "gopkg.in/yaml.v3"
+ "simd/_gen/unify"
+)
+
+var (
+ xedPath = flag.String("xedPath", "", "load XED datafiles from `path`")
+ flagQ = flag.String("q", "", "query: read `def` as another input (skips final validation)")
+ flagO = flag.String("o", "yaml", "output type: yaml, godefs (generate definitions into a Go source tree")
+ flagGoDefRoot = flag.String("goroot", ".", "the path to the Go dev directory that will receive the generated files")
+ FlagNoDedup = flag.Bool("nodedup", false, "disable deduplicating godefs of 2 qualifying operations from different extensions")
+ FlagNoConstImmPorting = flag.Bool("noconstimmporting", false, "disable const immediate porting from op to imm operand")
+ FlagArch = flag.String("arch", "amd64", "the target architecture")
+
+ Verbose = flag.Bool("v", false, "verbose")
+
+ flagDebugXED = flag.Bool("debug-xed", false, "show XED instructions")
+ flagDebugUnify = flag.Bool("debug-unify", false, "print unification trace")
+ flagDebugHTML = flag.String("debug-html", "", "write unification trace to `file.html`")
+ FlagReportDup = flag.Bool("reportdup", false, "report the duplicate godefs")
+
+ flagCPUProfile = flag.String("cpuprofile", "", "write CPU profile to `file`")
+ flagMemProfile = flag.String("memprofile", "", "write memory profile to `file`")
+)
+
+const simdPackage = "simd"
+
+func main() {
+ flag.Parse()
+
+ if *flagCPUProfile != "" {
+ f, err := os.Create(*flagCPUProfile)
+ if err != nil {
+ log.Fatalf("-cpuprofile: %s", err)
+ }
+ defer f.Close()
+ pprof.StartCPUProfile(f)
+ defer pprof.StopCPUProfile()
+ }
+ if *flagMemProfile != "" {
+ f, err := os.Create(*flagMemProfile)
+ if err != nil {
+ log.Fatalf("-memprofile: %s", err)
+ }
+ defer func() {
+ pprof.WriteHeapProfile(f)
+ f.Close()
+ }()
+ }
+
+ var inputs []unify.Closure
+
+ if *FlagArch != "amd64" {
+ log.Fatalf("simdgen only supports amd64")
+ }
+
+ // Load XED into a defs set.
+ if *xedPath != "" {
+ xedDefs := loadXED(*xedPath)
+ inputs = append(inputs, unify.NewSum(xedDefs...))
+ }
+
+ // Load query.
+ if *flagQ != "" {
+ r := strings.NewReader(*flagQ)
+ def, err := unify.Read(r, "<query>", unify.ReadOpts{})
+ if err != nil {
+ log.Fatalf("parsing -q: %s", err)
+ }
+ inputs = append(inputs, def)
+ }
+
+ // Load defs files.
+ must := make(map[*unify.Value]struct{})
+ for _, path := range flag.Args() {
+ defs, err := unify.ReadFile(path, unify.ReadOpts{})
+ if err != nil {
+ log.Fatal(err)
+ }
+ inputs = append(inputs, defs)
+
+ if filepath.Base(path) == "go.yaml" {
+ // These must all be used in the final result
+ for def := range defs.Summands() {
+ must[def] = struct{}{}
+ }
+ }
+ }
+
+ // Prepare for unification
+ if *flagDebugUnify {
+ unify.Debug.UnifyLog = os.Stderr
+ }
+ if *flagDebugHTML != "" {
+ f, err := os.Create(*flagDebugHTML)
+ if err != nil {
+ log.Fatal(err)
+ }
+ unify.Debug.HTML = f
+ defer f.Close()
+ }
+
+ // Unify!
+ unified, err := unify.Unify(inputs...)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Print results.
+ switch *flagO {
+ case "yaml":
+ // Produce a result that looks like encoding a slice, but stream it.
+ fmt.Println("!sum")
+ var val1 [1]*unify.Value
+ for val := range unified.All() {
+ val1[0] = val
+ // We have to make a new encoder each time or it'll print a document
+ // separator between each object.
+ enc := yaml.NewEncoder(os.Stdout)
+ if err := enc.Encode(val1); err != nil {
+ log.Fatal(err)
+ }
+ enc.Close()
+ }
+ case "godefs":
+ if err := writeGoDefs(*flagGoDefRoot, unified); err != nil {
+ log.Fatalf("Failed writing godefs: %+v", err)
+ }
+ }
+
+ if !*Verbose && *xedPath != "" {
+ if operandRemarks == 0 {
+ fmt.Fprintf(os.Stderr, "XED decoding generated no errors, which is unusual.\n")
+ } else {
+ fmt.Fprintf(os.Stderr, "XED decoding generated %d \"errors\" which is not cause for alarm, use -v for details.\n", operandRemarks)
+ }
+ }
+
+ // Validate results.
+ //
+ // Don't validate if this is a command-line query because that tends to
+ // eliminate lots of required defs and is used in cases where maybe defs
+ // aren't enumerable anyway.
+ if *flagQ == "" && len(must) > 0 {
+ validate(unified, must)
+ }
+}
+
+func validate(cl unify.Closure, required map[*unify.Value]struct{}) {
+ // Validate that:
+ // 1. All final defs are exact
+ // 2. All required defs are used
+ for def := range cl.All() {
+ if _, ok := def.Domain.(unify.Def); !ok {
+ fmt.Fprintf(os.Stderr, "%s: expected Def, got %T\n", def.PosString(), def.Domain)
+ continue
+ }
+
+ if !def.Exact() {
+ fmt.Fprintf(os.Stderr, "%s: def not reduced to an exact value, why is %s:\n", def.PosString(), def.WhyNotExact())
+ fmt.Fprintf(os.Stderr, "\t%s\n", strings.ReplaceAll(def.String(), "\n", "\n\t"))
+ }
+
+ for root := range def.Provenance() {
+ delete(required, root)
+ }
+ }
+ // Report unused defs
+ unused := slices.SortedFunc(maps.Keys(required),
+ func(a, b *unify.Value) int {
+ return cmp.Or(
+ cmp.Compare(a.Pos().Path, b.Pos().Path),
+ cmp.Compare(a.Pos().Line, b.Pos().Line),
+ )
+ })
+ for _, def := range unused {
+ // TODO: Can we say anything more actionable? This is always a problem
+ // with unification: if it fails, it's very hard to point a finger at
+ // any particular reason. We could go back and try unifying this again
+ // with each subset of the inputs (starting with individual inputs) to
+ // at least say "it doesn't unify with anything in x.yaml". That's a lot
+ // of work, but if we have trouble debugging unification failure it may
+ // be worth it.
+ fmt.Fprintf(os.Stderr, "%s: def required, but did not unify (%v)\n",
+ def.PosString(), def)
+ }
+}
--- /dev/null
+!sum
+- go: Add
+ commutative: true
+ documentation: !string |-
+ // NAME adds corresponding elements of two vectors.
+- go: AddSaturated
+ commutative: true
+ documentation: !string |-
+ // NAME adds corresponding elements of two vectors with saturation.
+- go: Sub
+ commutative: false
+ documentation: !string |-
+ // NAME subtracts corresponding elements of two vectors.
+- go: SubSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME subtracts corresponding elements of two vectors with saturation.
+- go: AddPairs
+ commutative: false
+ documentation: !string |-
+ // NAME horizontally adds adjacent pairs of elements.
+ // For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
+- go: SubPairs
+ commutative: false
+ documentation: !string |-
+ // NAME horizontally subtracts adjacent pairs of elements.
+ // For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].
+- go: AddPairsSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME horizontally adds adjacent pairs of elements with saturation.
+ // For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0+y1, y2+y3, ..., x0+x1, x2+x3, ...].
+- go: SubPairsSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME horizontally subtracts adjacent pairs of elements with saturation.
+ // For x = [x0, x1, x2, x3, ...] and y = [y0, y1, y2, y3, ...], the result is [y0-y1, y2-y3, ..., x0-x1, x2-x3, ...].
--- /dev/null
+!sum
+# Add
+- go: Add
+ asm: "VPADD[BWDQ]|VADDP[SD]"
+ in:
+ - &any
+ go: $t
+ - *any
+ out:
+ - *any
+# Add Saturated
+- go: AddSaturated
+ asm: "VPADDS[BWDQ]"
+ in:
+ - &int
+ go: $t
+ base: int
+ - *int
+ out:
+ - *int
+- go: AddSaturated
+ asm: "VPADDUS[BWDQ]"
+ in:
+ - &uint
+ go: $t
+ base: uint
+ - *uint
+ out:
+ - *uint
+
+# Sub
+- go: Sub
+ asm: "VPSUB[BWDQ]|VSUBP[SD]"
+ in: &2any
+ - *any
+ - *any
+ out: &1any
+ - *any
+# Sub Saturated
+- go: SubSaturated
+ asm: "VPSUBS[BWDQ]"
+ in: &2int
+ - *int
+ - *int
+ out: &1int
+ - *int
+- go: SubSaturated
+ asm: "VPSUBUS[BWDQ]"
+ in:
+ - *uint
+ - *uint
+ out:
+ - *uint
+- go: AddPairs
+ asm: "VPHADD[DW]"
+ in: *2any
+ out: *1any
+- go: SubPairs
+ asm: "VPHSUB[DW]"
+ in: *2any
+ out: *1any
+- go: AddPairs
+ asm: "VHADDP[SD]" # floats
+ in: *2any
+ out: *1any
+- go: SubPairs
+ asm: "VHSUBP[SD]" # floats
+ in: *2any
+ out: *1any
+- go: AddPairsSaturated
+ asm: "VPHADDS[DW]"
+ in: *2int
+ out: *1int
+- go: SubPairsSaturated
+ asm: "VPHSUBS[DW]"
+ in: *2int
+ out: *1int
--- /dev/null
+!sum
+- go: And
+ commutative: true
+ documentation: !string |-
+ // NAME performs a bitwise AND operation between two vectors.
+- go: Or
+ commutative: true
+ documentation: !string |-
+ // NAME performs a bitwise OR operation between two vectors.
+- go: AndNot
+ commutative: false
+ documentation: !string |-
+ // NAME performs a bitwise x &^ y.
+- go: Xor
+ commutative: true
+ documentation: !string |-
+ // NAME performs a bitwise XOR operation between two vectors.
+
+# We also have PTEST and VPTERNLOG, those should be hidden from the users
+# and only appear in rewrite rules.
--- /dev/null
+!sum
+# In the XED data, *all* floating point bitwise logic operation has their
+# operand type marked as uint. We are not trying to understand why Intel
+# decided that they want FP bit-wise logic operations, but this irregularity
+# has to be dealed with in separate rules with some overwrites.
+
+# For many bit-wise operations, we have the following non-orthogonal
+# choices:
+#
+# - Non-masked AVX operations have no element width (because it
+# doesn't matter), but only cover 128 and 256 bit vectors.
+#
+# - Masked AVX-512 operations have an element width (because it needs
+# to know how to interpret the mask), and cover 128, 256, and 512 bit
+# vectors. These only cover 32- and 64-bit element widths.
+#
+# - Non-masked AVX-512 operations still have an element width (because
+# they're just the masked operations with an implicit K0 mask) but it
+# doesn't matter! This is the only option for non-masked 512 bit
+# operations, and we can pick any of the element widths.
+#
+# We unify with ALL of these operations and the compiler generator
+# picks when there are multiple options.
+
+# TODO: We don't currently generate unmasked bit-wise operations on 512 bit
+# vectors of 8- or 16-bit elements. AVX-512 only has *masked* bit-wise
+# operations for 32- and 64-bit elements; while the element width doesn't matter
+# for unmasked operations, right now we don't realize that we can just use the
+# 32- or 64-bit version for the unmasked form. Maybe in the XED decoder we
+# should recognize bit-wise operations when generating unmasked versions and
+# omit the element width.
+
+# For binary operations, we constrain their two inputs and one output to the
+# same Go type using a variable.
+
+- go: And
+ asm: "VPAND[DQ]?"
+ in:
+ - &any
+ go: $t
+ - *any
+ out:
+ - *any
+
+- go: And
+ asm: "VPANDD" # Fill in the gap, And is missing for Uint8x64 and Int8x64
+ inVariant: []
+ in: &twoI8x64
+ - &i8x64
+ go: $t
+ overwriteElementBits: 8
+ - *i8x64
+ out: &oneI8x64
+ - *i8x64
+
+- go: And
+ asm: "VPANDD" # Fill in the gap, And is missing for Uint16x32 and Int16x32
+ inVariant: []
+ in: &twoI16x32
+ - &i16x32
+ go: $t
+ overwriteElementBits: 16
+ - *i16x32
+ out: &oneI16x32
+ - *i16x32
+
+- go: AndNot
+ asm: "VPANDN[DQ]?"
+ operandOrder: "21" # switch the arg order
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+
+- go: AndNot
+ asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint8x64 and Int8x64
+ operandOrder: "21" # switch the arg order
+ inVariant: []
+ in: *twoI8x64
+ out: *oneI8x64
+
+- go: AndNot
+ asm: "VPANDND" # Fill in the gap, AndNot is missing for Uint16x32 and Int16x32
+ operandOrder: "21" # switch the arg order
+ inVariant: []
+ in: *twoI16x32
+ out: *oneI16x32
+
+- go: Or
+ asm: "VPOR[DQ]?"
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+
+- go: Or
+ asm: "VPORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
+ inVariant: []
+ in: *twoI8x64
+ out: *oneI8x64
+
+- go: Or
+ asm: "VPORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
+ inVariant: []
+ in: *twoI16x32
+ out: *oneI16x32
+
+- go: Xor
+ asm: "VPXOR[DQ]?"
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+
+- go: Xor
+ asm: "VPXORD" # Fill in the gap, Or is missing for Uint8x64 and Int8x64
+ inVariant: []
+ in: *twoI8x64
+ out: *oneI8x64
+
+- go: Xor
+ asm: "VPXORD" # Fill in the gap, Or is missing for Uint16x32 and Int16x32
+ inVariant: []
+ in: *twoI16x32
+ out: *oneI16x32
\ No newline at end of file
--- /dev/null
+!sum
+# const imm predicate(holds for both float and int|uint):
+# 0: Equal
+# 1: Less
+# 2: LessEqual
+# 4: NotEqual
+# 5: GreaterEqual
+# 6: Greater
+- go: Equal
+ constImm: 0
+ commutative: true
+ documentation: !string |-
+ // NAME compares for equality.
+- go: Less
+ constImm: 1
+ commutative: false
+ documentation: !string |-
+ // NAME compares for less than.
+- go: LessEqual
+ constImm: 2
+ commutative: false
+ documentation: !string |-
+ // NAME compares for less than or equal.
+- go: IsNan # For float only.
+ constImm: 3
+ commutative: true
+ documentation: !string |-
+ // NAME checks if elements are NaN. Use as x.IsNan(x).
+- go: NotEqual
+ constImm: 4
+ commutative: true
+ documentation: !string |-
+ // NAME compares for inequality.
+- go: GreaterEqual
+ constImm: 13
+ commutative: false
+ documentation: !string |-
+ // NAME compares for greater than or equal.
+- go: Greater
+ constImm: 14
+ commutative: false
+ documentation: !string |-
+ // NAME compares for greater than.
--- /dev/null
+!sum
+# Ints
+- go: Equal
+ asm: "V?PCMPEQ[BWDQ]"
+ in:
+ - &any
+ go: $t
+ - *any
+ out:
+ - &anyvregToMask
+ go: $t
+ overwriteBase: int
+ overwriteClass: mask
+- go: Greater
+ asm: "V?PCMPGT[BWDQ]"
+ in:
+ - &int
+ go: $t
+ base: int
+ - *int
+ out:
+ - *anyvregToMask
+# 256-bit VCMPGTQ's output elemBits is marked 32-bit in the XED data, we
+# believe this is an error, so add this definition to overwrite.
+- go: Greater
+ asm: "VPCMPGTQ"
+ in:
+ - &int64
+ go: $t
+ base: int
+ elemBits: 64
+ - *int64
+ out:
+ - base: int
+ elemBits: 32
+ overwriteElementBits: 64
+ overwriteClass: mask
+ overwriteBase: int
+
+# TODO these are redundant with VPCMP operations.
+# AVX-512 compares produce masks.
+- go: Equal
+ asm: "V?PCMPEQ[BWDQ]"
+ in:
+ - *any
+ - *any
+ out:
+ - class: mask
+- go: Greater
+ asm: "V?PCMPGT[BWDQ]"
+ in:
+ - *int
+ - *int
+ out:
+ - class: mask
+
+# MASKED signed comparisons for X/Y registers
+# unmasked would clash with emulations on AVX2
+- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
+ asm: "VPCMP[BWDQ]"
+ in:
+ - &int
+ bits: (128|256)
+ go: $t
+ base: int
+ - *int
+ - class: immediate
+ const: 0 # Just a placeholder, will be overwritten by const imm porting.
+ inVariant:
+ - class: mask
+ out:
+ - class: mask
+
+# MASKED unsigned comparisons for X/Y registers
+# unmasked would clash with emulations on AVX2
+- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
+ asm: "VPCMPU[BWDQ]"
+ in:
+ - &uint
+ bits: (128|256)
+ go: $t
+ base: uint
+ - *uint
+ - class: immediate
+ const: 0
+ inVariant:
+ - class: mask
+ out:
+ - class: mask
+
+# masked/unmasked signed comparisons for Z registers
+- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
+ asm: "VPCMP[BWDQ]"
+ in:
+ - &int
+ bits: 512
+ go: $t
+ base: int
+ - *int
+ - class: immediate
+ const: 0 # Just a placeholder, will be overwritten by const imm porting.
+ out:
+ - class: mask
+
+# masked/unmasked unsigned comparisons for Z registers
+- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual)
+ asm: "VPCMPU[BWDQ]"
+ in:
+ - &uint
+ bits: 512
+ go: $t
+ base: uint
+ - *uint
+ - class: immediate
+ const: 0
+ out:
+ - class: mask
+
+# Floats
+- go: Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan
+ asm: "VCMPP[SD]"
+ in:
+ - &float
+ go: $t
+ base: float
+ - *float
+ - class: immediate
+ const: 0
+ out:
+ - go: $t
+ overwriteBase: int
+ overwriteClass: mask
+- go: (Equal|Greater|Less|LessEqual|GreaterEqual|NotEqual|IsNan)
+ asm: "VCMPP[SD]"
+ in:
+ - *float
+ - *float
+ - class: immediate
+ const: 0
+ out:
+ - class: mask
\ No newline at end of file
--- /dev/null
+!sum
+- go: ConvertToInt32
+ commutative: false
+ documentation: !string |-
+ // ConvertToInt32 converts element values to int32.
+
+- go: ConvertToUint32
+ commutative: false
+ documentation: !string |-
+ // ConvertToUint32Masked converts element values to uint32.
--- /dev/null
+!sum
+- go: ConvertToInt32
+ asm: "VCVTTPS2DQ"
+ in:
+ - &fp
+ go: $t
+ base: float
+ out:
+ - &i32
+ go: $u
+ base: int
+ elemBits: 32
+- go: ConvertToUint32
+ asm: "VCVTPS2UDQ"
+ in:
+ - *fp
+ out:
+ - &u32
+ go: $u
+ base: uint
+ elemBits: 32
--- /dev/null
+!sum
+- go: Div
+ commutative: false
+ documentation: !string |-
+ // NAME divides elements of two vectors.
+- go: Sqrt
+ commutative: false
+ documentation: !string |-
+ // NAME computes the square root of each element.
+- go: Reciprocal
+ commutative: false
+ documentation: !string |-
+ // NAME computes an approximate reciprocal of each element.
+- go: ReciprocalSqrt
+ commutative: false
+ documentation: !string |-
+ // NAME computes an approximate reciprocal of the square root of each element.
+- go: Scale
+ commutative: false
+ documentation: !string |-
+ // NAME multiplies elements by a power of 2.
+- go: RoundToEven
+ commutative: false
+ constImm: 0
+ documentation: !string |-
+ // NAME rounds elements to the nearest integer.
+- go: RoundToEvenScaled
+ commutative: false
+ constImm: 0
+ documentation: !string |-
+ // NAME rounds elements with specified precision.
+- go: RoundToEvenScaledResidue
+ commutative: false
+ constImm: 0
+ documentation: !string |-
+ // NAME computes the difference after rounding with specified precision.
+- go: Floor
+ commutative: false
+ constImm: 1
+ documentation: !string |-
+ // NAME rounds elements down to the nearest integer.
+- go: FloorScaled
+ commutative: false
+ constImm: 1
+ documentation: !string |-
+ // NAME rounds elements down with specified precision.
+- go: FloorScaledResidue
+ commutative: false
+ constImm: 1
+ documentation: !string |-
+ // NAME computes the difference after flooring with specified precision.
+- go: Ceil
+ commutative: false
+ constImm: 2
+ documentation: !string |-
+ // NAME rounds elements up to the nearest integer.
+- go: CeilScaled
+ commutative: false
+ constImm: 2
+ documentation: !string |-
+ // NAME rounds elements up with specified precision.
+- go: CeilScaledResidue
+ commutative: false
+ constImm: 2
+ documentation: !string |-
+ // NAME computes the difference after ceiling with specified precision.
+- go: Trunc
+ commutative: false
+ constImm: 3
+ documentation: !string |-
+ // NAME truncates elements towards zero.
+- go: TruncScaled
+ commutative: false
+ constImm: 3
+ documentation: !string |-
+ // NAME truncates elements with specified precision.
+- go: TruncScaledResidue
+ commutative: false
+ constImm: 3
+ documentation: !string |-
+ // NAME computes the difference after truncating with specified precision.
+- go: AddSub
+ commutative: false
+ documentation: !string |-
+ // NAME subtracts even elements and adds odd elements of two vectors.
--- /dev/null
+!sum
+- go: Div
+ asm: "V?DIVP[SD]"
+ in: &2fp
+ - &fp
+ go: $t
+ base: float
+ - *fp
+ out: &1fp
+ - *fp
+- go: Sqrt
+ asm: "V?SQRTP[SD]"
+ in: *1fp
+ out: *1fp
+# TODO: Provide separate methods for 12-bit precision and 14-bit precision?
+- go: Reciprocal
+ asm: "VRCP(14)?P[SD]"
+ in: *1fp
+ out: *1fp
+- go: ReciprocalSqrt
+ asm: "V?RSQRT(14)?P[SD]"
+ in: *1fp
+ out: *1fp
+- go: Scale
+ asm: "VSCALEFP[SD]"
+ in: *2fp
+ out: *1fp
+
+- go: "RoundToEven|Ceil|Floor|Trunc"
+ asm: "VROUNDP[SD]"
+ in:
+ - *fp
+ - class: immediate
+ const: 0 # place holder
+ out: *1fp
+
+- go: "(RoundToEven|Ceil|Floor|Trunc)Scaled"
+ asm: "VRNDSCALEP[SD]"
+ in:
+ - *fp
+ - class: immediate
+ const: 0 # place holder
+ immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
+ name: prec
+ out: *1fp
+- go: "(RoundToEven|Ceil|Floor|Trunc)ScaledResidue"
+ asm: "VREDUCEP[SD]"
+ in:
+ - *fp
+ - class: immediate
+ const: 0 # place holder
+ immOffset: 4 # "M", round to numbers with M digits after dot(by means of binary number).
+ name: prec
+ out: *1fp
+
+- go: "AddSub"
+ asm: "VADDSUBP[SD]"
+ in:
+ - *fp
+ - *fp
+ out:
+ - *fp
--- /dev/null
+!sum
+- go: GaloisFieldAffineTransform
+ commutative: false
+ documentation: !string |-
+ // NAME computes an affine transformation in GF(2^8):
+ // x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
+ // b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
+ // corresponding to a group of 8 elements in x.
+- go: GaloisFieldAffineTransformInverse
+ commutative: false
+ documentation: !string |-
+ // NAME computes an affine transformation in GF(2^8),
+ // with x inverted with respect to reduction polynomial x^8 + x^4 + x^3 + x + 1:
+ // x is a vector of 8-bit vectors, with each adjacent 8 as a group; y is a vector of 8x8 1-bit matrixes;
+ // b is an 8-bit vector. The affine transformation is y * x + b, with each element of y
+ // corresponding to a group of 8 elements in x.
+- go: GaloisFieldMul
+ commutative: false
+ documentation: !string |-
+ // NAME computes element-wise GF(2^8) multiplication with
+ // reduction polynomial x^8 + x^4 + x^3 + x + 1.
--- /dev/null
+!sum
+- go: GaloisFieldAffineTransform
+ asm: VGF2P8AFFINEQB
+ operandOrder: 2I # 2nd operand, then immediate
+ in: &AffineArgs
+ - &uint8
+ go: $t
+ base: uint
+ - &uint8x8
+ go: $t2
+ base: uint
+ - &pureImmVar
+ class: immediate
+ immOffset: 0
+ name: b
+ out:
+ - *uint8
+
+- go: GaloisFieldAffineTransformInverse
+ asm: VGF2P8AFFINEINVQB
+ operandOrder: 2I # 2nd operand, then immediate
+ in: *AffineArgs
+ out:
+ - *uint8
+
+- go: GaloisFieldMul
+ asm: VGF2P8MULB
+ in:
+ - *uint8
+ - *uint8
+ out:
+ - *uint8
--- /dev/null
+!sum
+- go: Average
+ commutative: true
+ documentation: !string |-
+ // NAME computes the rounded average of corresponding elements.
+- go: Abs
+ commutative: false
+ # Unary operation, not commutative
+ documentation: !string |-
+ // NAME computes the absolute value of each element.
+- go: CopySign
+ # Applies sign of second operand to first: sign(val, sign_src)
+ commutative: false
+ documentation: !string |-
+ // NAME returns the product of the first operand with -1, 0, or 1,
+ // whichever constant is nearest to the value of the second operand.
+ # Sign does not have masked version
+- go: OnesCount
+ commutative: false
+ documentation: !string |-
+ // NAME counts the number of set bits in each element.
--- /dev/null
+!sum
+# Average (unsigned byte, unsigned word)
+# Instructions: VPAVGB, VPAVGW
+- go: Average
+ asm: "VPAVG[BW]" # Matches VPAVGB (byte) and VPAVGW (word)
+ in:
+ - &uint_t # $t will be Uint8xN for VPAVGB, Uint16xN for VPAVGW
+ go: $t
+ base: uint
+ - *uint_t
+ out:
+ - *uint_t
+
+# Absolute Value (signed byte, word, dword, qword)
+# Instructions: VPABSB, VPABSW, VPABSD, VPABSQ
+- go: Abs
+ asm: "VPABS[BWDQ]" # Matches VPABSB, VPABSW, VPABSD, VPABSQ
+ in:
+ - &int_t # $t will be Int8xN, Int16xN, Int32xN, Int64xN
+ go: $t
+ base: int
+ out:
+ - *int_t # Output is magnitude, fits in the same signed type
+
+# Sign Operation (signed byte, word, dword)
+# Applies sign of second operand to the first.
+# Instructions: VPSIGNB, VPSIGNW, VPSIGND
+- go: CopySign
+ asm: "VPSIGN[BWD]" # Matches VPSIGNB, VPSIGNW, VPSIGND
+ in:
+ - *int_t # value to apply sign to
+ - *int_t # value from which to take the sign
+ out:
+ - *int_t
+
+# Population Count (count set bits in each element)
+# Instructions: VPOPCNTB, VPOPCNTW (AVX512_BITALG)
+# VPOPCNTD, VPOPCNTQ (AVX512_VPOPCNTDQ)
+- go: OnesCount
+ asm: "VPOPCNT[BWDQ]"
+ in:
+ - &any
+ go: $t
+ out:
+ - *any
--- /dev/null
+!sum
+- go: DotProdPairs
+ commutative: false
+ documentation: !string |-
+ // NAME multiplies the elements and add the pairs together,
+ // yielding a vector of half as many elements with twice the input element size.
+# TODO: maybe simplify this name within the receiver-type + method-naming scheme we use.
+- go: DotProdPairsSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME multiplies the elements and add the pairs together with saturation,
+ // yielding a vector of half as many elements with twice the input element size.
+# QuadDotProd, i.e. VPDPBUSD(S) are operations with src/dst on the same register, we are not supporting this as of now.
+# - go: DotProdBroadcast
+# commutative: true
+# # documentation: !string |-
+# // NAME multiplies all elements and broadcasts the sum.
+- go: AddDotProdQuadruple
+ commutative: false
+ documentation: !string |-
+ // NAME performs dot products on groups of 4 elements of x and y and then adds z.
+- go: AddDotProdQuadrupleSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME multiplies performs dot products on groups of 4 elements of x and y and then adds z.
+- go: AddDotProdPairs
+ commutative: false
+ noTypes: "true"
+ noGenericOps: "true"
+ documentation: !string |-
+ // NAME performs dot products on pairs of elements of y and z and then adds x.
+- go: AddDotProdPairsSaturated
+ commutative: false
+ documentation: !string |-
+ // NAME performs dot products on pairs of elements of y and z and then adds x.
+- go: MulAdd
+ commutative: false
+ documentation: !string |-
+ // NAME performs a fused (x * y) + z.
+- go: MulAddSub
+ commutative: false
+ documentation: !string |-
+ // NAME performs a fused (x * y) - z for odd-indexed elements, and (x * y) + z for even-indexed elements.
+- go: MulSubAdd
+ commutative: false
+ documentation: !string |-
+ // NAME performs a fused (x * y) + z for odd-indexed elements, and (x * y) - z for even-indexed elements.
--- /dev/null
+!sum
+- go: DotProdPairs
+ asm: VPMADDWD
+ in:
+ - &int
+ go: $t
+ base: int
+ - *int
+ out:
+ - &int2 # The elemBits are different
+ go: $t2
+ base: int
+- go: DotProdPairsSaturated
+ asm: VPMADDUBSW
+ in:
+ - &uint
+ go: $t
+ base: uint
+ overwriteElementBits: 8
+ - &int3
+ go: $t3
+ base: int
+ overwriteElementBits: 8
+ out:
+ - *int2
+# - go: DotProdBroadcast
+# asm: VDPP[SD]
+# in:
+# - &dpb_src
+# go: $t
+# - *dpb_src
+# - class: immediate
+# const: 127
+# out:
+# - *dpb_src
+- go: AddDotProdQuadruple
+ asm: "VPDPBUSD"
+ operandOrder: "31" # switch operand 3 and 1
+ in:
+ - &qdpa_acc
+ go: $t_acc
+ base: int
+ elemBits: 32
+ - &qdpa_src1
+ go: $t_src1
+ base: uint
+ overwriteElementBits: 8
+ - &qdpa_src2
+ go: $t_src2
+ base: int
+ overwriteElementBits: 8
+ out:
+ - *qdpa_acc
+- go: AddDotProdQuadrupleSaturated
+ asm: "VPDPBUSDS"
+ operandOrder: "31" # switch operand 3 and 1
+ in:
+ - *qdpa_acc
+ - *qdpa_src1
+ - *qdpa_src2
+ out:
+ - *qdpa_acc
+- go: AddDotProdPairs
+ asm: "VPDPWSSD"
+ in:
+ - &pdpa_acc
+ go: $t_acc
+ base: int
+ elemBits: 32
+ - &pdpa_src1
+ go: $t_src1
+ base: int
+ overwriteElementBits: 16
+ - &pdpa_src2
+ go: $t_src2
+ base: int
+ overwriteElementBits: 16
+ out:
+ - *pdpa_acc
+- go: AddDotProdPairsSaturated
+ asm: "VPDPWSSDS"
+ in:
+ - *pdpa_acc
+ - *pdpa_src1
+ - *pdpa_src2
+ out:
+ - *pdpa_acc
+- go: MulAdd
+ asm: "VFMADD213PS|VFMADD213PD"
+ in:
+ - &fma_op
+ go: $t
+ base: float
+ - *fma_op
+ - *fma_op
+ out:
+ - *fma_op
+- go: MulAddSub
+ asm: "VFMADDSUB213PS|VFMADDSUB213PD"
+ in:
+ - *fma_op
+ - *fma_op
+ - *fma_op
+ out:
+ - *fma_op
+- go: MulSubAdd
+ asm: "VFMSUBADD213PS|VFMSUBADD213PD"
+ in:
+ - *fma_op
+ - *fma_op
+ - *fma_op
+ out:
+ - *fma_op
\ No newline at end of file
--- /dev/null
+!sum
+- go: Max
+ commutative: true
+ documentation: !string |-
+ // NAME computes the maximum of corresponding elements.
+- go: Min
+ commutative: true
+ documentation: !string |-
+ // NAME computes the minimum of corresponding elements.
--- /dev/null
+!sum
+- go: Max
+ asm: "V?PMAXS[BWDQ]"
+ in: &2int
+ - &int
+ go: $t
+ base: int
+ - *int
+ out: &1int
+ - *int
+- go: Max
+ asm: "V?PMAXU[BWDQ]"
+ in: &2uint
+ - &uint
+ go: $t
+ base: uint
+ - *uint
+ out: &1uint
+ - *uint
+
+- go: Min
+ asm: "V?PMINS[BWDQ]"
+ in: *2int
+ out: *1int
+- go: Min
+ asm: "V?PMINU[BWDQ]"
+ in: *2uint
+ out: *1uint
+
+- go: Max
+ asm: "V?MAXP[SD]"
+ in: &2float
+ - &float
+ go: $t
+ base: float
+ - *float
+ out: &1float
+ - *float
+- go: Min
+ asm: "V?MINP[SD]"
+ in: *2float
+ out: *1float
--- /dev/null
+!sum
+- go: SetElem
+ commutative: false
+ documentation: !string |-
+ // NAME sets a single constant-indexed element's value.
+- go: GetElem
+ commutative: false
+ documentation: !string |-
+ // NAME retrieves a single constant-indexed element's value.
+- go: SetLo
+ commutative: false
+ constImm: 0
+ documentation: !string |-
+ // NAME returns x with its lower half set to y.
+- go: GetLo
+ commutative: false
+ constImm: 0
+ documentation: !string |-
+ // NAME returns the lower half of x.
+- go: SetHi
+ commutative: false
+ constImm: 1
+ documentation: !string |-
+ // NAME returns x with its upper half set to y.
+- go: GetHi
+ commutative: false
+ constImm: 1
+ documentation: !string |-
+ // NAME returns the upper half of x.
+- go: Permute
+ commutative: false
+ documentation: !string |-
+ // NAME performs a full permutation of vector x using indices:
+ // result := {x[indices[0]], x[indices[1]], ..., x[indices[n]]}
+ // Only the needed bits to represent x's index are used in indices' elements.
+- go: Permute2 # Permute2 is only available on or after AVX512
+ commutative: false
+ documentation: !string |-
+ // NAME performs a full permutation of vector x, y using indices:
+ // result := {xy[indices[0]], xy[indices[1]], ..., xy[indices[n]]}
+ // where xy is x appending y.
+ // Only the needed bits to represent xy's index are used in indices' elements.
+- go: Compress
+ commutative: false
+ documentation: !string |-
+ // NAME performs a compression on vector x using mask by
+ // selecting elements as indicated by mask, and pack them to lower indexed elements.
+- go: blend
+ commutative: false
+ documentation: !string |-
+ // NAME blends two vectors based on mask values, choosing either
+ // the first or the second based on whether the third is false or true
+- go: Expand
+ commutative: false
+ documentation: !string |-
+ // NAME performs an expansion on a vector x whose elements are packed to lower parts.
+ // The expansion is to distribute elements as indexed by mask, from lower mask elements to upper in order.
+- go: Broadcast128
+ commutative: false
+ documentation: !string |-
+ // NAME copies element zero of its (128-bit) input to all elements of
+ // the 128-bit output vector.
+- go: Broadcast256
+ commutative: false
+ documentation: !string |-
+ // NAME copies element zero of its (128-bit) input to all elements of
+ // the 256-bit output vector.
+- go: Broadcast512
+ commutative: false
+ documentation: !string |-
+ // NAME copies element zero of its (128-bit) input to all elements of
+ // the 512-bit output vector.
--- /dev/null
+!sum
+- go: SetElem
+ asm: "VPINSR[BWDQ]"
+ in:
+ - &t
+ class: vreg
+ base: $b
+ - class: greg
+ base: $b
+ lanes: 1 # Scalar, darn it!
+ - &imm
+ class: immediate
+ immOffset: 0
+ name: index
+ out:
+ - *t
+
+- go: SetElem
+ asm: "VPINSR[DQ]"
+ in:
+ - &t
+ class: vreg
+ base: int
+ OverwriteBase: float
+ - class: greg
+ base: int
+ OverwriteBase: float
+ lanes: 1 # Scalar, darn it!
+ - &imm
+ class: immediate
+ immOffset: 0
+ name: index
+ out:
+ - *t
+
+- go: GetElem
+ asm: "VPEXTR[BWDQ]"
+ in:
+ - class: vreg
+ base: $b
+ elemBits: $e
+ - *imm
+ out:
+ - class: greg
+ base: $b
+ bits: $e
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTI128|VINSERTI64X4"
+ inVariant: []
+ in:
+ - &i8x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 8
+ - &i8xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 8
+ - &imm01 # This immediate should be only 0 or 1
+ class: immediate
+ const: 0 # place holder
+ name: index
+ out:
+ - *i8x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTI128|VEXTRACTI64X4"
+ inVariant: []
+ in:
+ - *i8x2N
+ - *imm01
+ out:
+ - *i8xN
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTI128|VINSERTI64X4"
+ inVariant: []
+ in:
+ - &i16x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 16
+ - &i16xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 16
+ - *imm01
+ out:
+ - *i16x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTI128|VEXTRACTI64X4"
+ inVariant: []
+ in:
+ - *i16x2N
+ - *imm01
+ out:
+ - *i16xN
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTI128|VINSERTI64X4"
+ inVariant: []
+ in:
+ - &i32x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 32
+ - &i32xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 32
+ - *imm01
+ out:
+ - *i32x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTI128|VEXTRACTI64X4"
+ inVariant: []
+ in:
+ - *i32x2N
+ - *imm01
+ out:
+ - *i32xN
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTI128|VINSERTI64X4"
+ inVariant: []
+ in:
+ - &i64x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 64
+ - &i64xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 64
+ - *imm01
+ out:
+ - *i64x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTI128|VEXTRACTI64X4"
+ inVariant: []
+ in:
+ - *i64x2N
+ - *imm01
+ out:
+ - *i64xN
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTF128|VINSERTF64X4"
+ inVariant: []
+ in:
+ - &f32x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 32
+ - &f32xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 32
+ - *imm01
+ out:
+ - *f32x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTF128|VEXTRACTF64X4"
+ inVariant: []
+ in:
+ - *f32x2N
+ - *imm01
+ out:
+ - *f32xN
+
+- go: "SetHi|SetLo"
+ asm: "VINSERTF128|VINSERTF64X4"
+ inVariant: []
+ in:
+ - &f64x2N
+ class: vreg
+ base: $t
+ OverwriteElementBits: 64
+ - &f64xN
+ class: vreg
+ base: $t
+ OverwriteElementBits: 64
+ - *imm01
+ out:
+ - *f64x2N
+
+- go: "GetHi|GetLo"
+ asm: "VEXTRACTF128|VEXTRACTF64X4"
+ inVariant: []
+ in:
+ - *f64x2N
+ - *imm01
+ out:
+ - *f64xN
+
+- go: Permute
+ asm: "VPERM[BWDQ]|VPERMP[SD]"
+ operandOrder: "21Type1"
+ in:
+ - &anyindices
+ go: $t
+ name: indices
+ overwriteBase: uint
+ - &any
+ go: $t
+ out:
+ - *any
+
+- go: Permute2
+ asm: "VPERMI2[BWDQ]|VPERMI2P[SD]"
+ # Because we are overwriting the receiver's type, we
+ # have to move the receiver to be a parameter so that
+ # we can have no duplication.
+ operandOrder: "231Type1"
+ in:
+ - *anyindices # result in arg 0
+ - *any
+ - *any
+ out:
+ - *any
+
+- go: Compress
+ asm: "VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]"
+ in:
+ # The mask in Compress is a control mask rather than a write mask, so it's not optional.
+ - class: mask
+ - *any
+ out:
+ - *any
+
+# For now a non-public method because
+# (1) [OverwriteClass] must be set together with [OverwriteBase]
+# (2) "simdgen does not support [OverwriteClass] in inputs".
+# That means the signature is wrong.
+- go: blend
+ asm: VPBLENDVB
+ in:
+ - &v
+ go: $t
+ class: vreg
+ base: int
+ - *v
+ -
+ class: vreg
+ base: int
+ name: mask
+ out:
+ - *v
+
+# For AVX512
+- go: blend
+ asm: VPBLENDM[BWDQ]
+ in:
+ - &v
+ go: $t
+ bits: 512
+ class: vreg
+ base: int
+ - *v
+ inVariant:
+ -
+ class: mask
+ out:
+ - *v
+
+- go: Expand
+ asm: "VPEXPAND[BWDQ]|VEXPANDP[SD]"
+ in:
+ # The mask in Expand is a control mask rather than a write mask, so it's not optional.
+ - class: mask
+ - *any
+ out:
+ - *any
+
+- go: Broadcast128
+ asm: VPBROADCAST[BWDQ]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+
+# weirdly, this one case on AVX2 is memory-operand-only
+- go: Broadcast128
+ asm: VPBROADCASTQ
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: 64
+ base: int
+ OverwriteBase: float
+ out:
+ - class: vreg
+ bits: 128
+ elemBits: 64
+ base: int
+ OverwriteBase: float
+
+- go: Broadcast256
+ asm: VPBROADCAST[BWDQ]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 256
+ elemBits: $e
+ base: $b
+
+- go: Broadcast512
+ asm: VPBROADCAST[BWDQ]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 512
+ elemBits: $e
+ base: $b
+
+- go: Broadcast128
+ asm: VBROADCASTS[SD]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+
+- go: Broadcast256
+ asm: VBROADCASTS[SD]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 256
+ elemBits: $e
+ base: $b
+
+- go: Broadcast512
+ asm: VBROADCASTS[SD]
+ in:
+ - class: vreg
+ bits: 128
+ elemBits: $e
+ base: $b
+ out:
+ - class: vreg
+ bits: 512
+ elemBits: $e
+ base: $b
--- /dev/null
+!sum
+- go: Mul
+ commutative: true
+ documentation: !string |-
+ // NAME multiplies corresponding elements of two vectors.
+- go: MulEvenWiden
+ commutative: true
+ documentation: !string |-
+ // NAME multiplies even-indexed elements, widening the result.
+ // Result[i] = v1.Even[i] * v2.Even[i].
+- go: MulHigh
+ commutative: true
+ documentation: !string |-
+ // NAME multiplies elements and stores the high part of the result.
--- /dev/null
+!sum
+# "Normal" multiplication is only available for floats.
+# This only covers the single and double precision.
+- go: Mul
+ asm: "VMULP[SD]"
+ in:
+ - &fp
+ go: $t
+ base: float
+ - *fp
+ out:
+ - *fp
+
+# Integer multiplications.
+
+# MulEvenWiden
+# Dword only.
+- go: MulEvenWiden
+ asm: "VPMULDQ"
+ in:
+ - &intNot64
+ go: $t
+ elemBits: 8|16|32
+ base: int
+ - *intNot64
+ out:
+ - &int2
+ go: $t2
+ base: int
+- go: MulEvenWiden
+ asm: "VPMULUDQ"
+ in:
+ - &uintNot64
+ go: $t
+ elemBits: 8|16|32
+ base: uint
+ - *uintNot64
+ out:
+ - &uint2
+ go: $t2
+ base: uint
+
+# MulHigh
+# Word only.
+- go: MulHigh
+ asm: "VPMULHW"
+ in:
+ - &int
+ go: $t
+ base: int
+ - *int
+ out:
+ - *int
+- go: MulHigh
+ asm: "VPMULHUW"
+ in:
+ - &uint
+ go: $t
+ base: uint
+ - *uint
+ out:
+ - *uint
+
+# MulLow
+# signed and unsigned are the same for lower bits.
+- go: Mul
+ asm: "VPMULL[WDQ]"
+ in:
+ - &any
+ go: $t
+ - *any
+ out:
+ - *any
--- /dev/null
+!sum
+- go: ShiftAllLeft
+ nameAndSizeCheck: true
+ specialLower: sftimm
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element to the left by the specified number of bits. Emptied lower bits are zeroed.
+- go: ShiftAllRight
+ signed: false
+ nameAndSizeCheck: true
+ specialLower: sftimm
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element to the right by the specified number of bits. Emptied upper bits are zeroed.
+- go: ShiftAllRight
+ signed: true
+ specialLower: sftimm
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element to the right by the specified number of bits. Emptied upper bits are filled with the sign bit.
+- go: shiftAllLeftConst # no APIs, only ssa ops.
+ noTypes: "true"
+ noGenericOps: "true"
+ SSAVariant: "const" # to avoid its name colliding with reg version of this instruction, amend this to its ssa op name.
+ nameAndSizeCheck: true
+ commutative: false
+- go: shiftAllRightConst # no APIs, only ssa ops.
+ noTypes: "true"
+ noGenericOps: "true"
+ SSAVariant: "const"
+ signed: false
+ nameAndSizeCheck: true
+ commutative: false
+- go: shiftAllRightConst # no APIs, only ssa ops.
+ noTypes: "true"
+ noGenericOps: "true"
+ SSAVariant: "const"
+ signed: true
+ nameAndSizeCheck: true
+ commutative: false
+
+- go: ShiftLeft
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element in x to the left by the number of bits specified in y's corresponding elements. Emptied lower bits are zeroed.
+- go: ShiftRight
+ signed: false
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are zeroed.
+- go: ShiftRight
+ signed: true
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element in x to the right by the number of bits specified in y's corresponding elements. Emptied upper bits are filled with the sign bit.
+- go: RotateAllLeft
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME rotates each element to the left by the number of bits specified by the immediate.
+- go: RotateLeft
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME rotates each element in x to the left by the number of bits specified by y's corresponding elements.
+- go: RotateAllRight
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME rotates each element to the right by the number of bits specified by the immediate.
+- go: RotateRight
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME rotates each element in x to the right by the number of bits specified by y's corresponding elements.
+- go: ShiftAllLeftConcat
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element of x to the left by the number of bits specified by the
+ // immediate(only the lower 5 bits are used), and then copies the upper bits of y to the emptied lower bits of the shifted x.
+- go: ShiftAllRightConcat
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element of x to the right by the number of bits specified by the
+ // immediate(only the lower 5 bits are used), and then copies the lower bits of y to the emptied upper bits of the shifted x.
+- go: ShiftLeftConcat
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element of x to the left by the number of bits specified by the
+ // corresponding elements in y(only the lower 5 bits are used), and then copies the upper bits of z to the emptied lower bits of the shifted x.
+- go: ShiftRightConcat
+ nameAndSizeCheck: true
+ commutative: false
+ documentation: !string |-
+ // NAME shifts each element of x to the right by the number of bits specified by the
+ // corresponding elements in y(only the lower 5 bits are used), and then copies the lower bits of z to the emptied upper bits of the shifted x.
--- /dev/null
+!sum
+# Integers
+# ShiftAll*
+- go: ShiftAllLeft
+ asm: "VPSLL[WDQ]"
+ in:
+ - &any
+ go: $t
+ - &vecAsScalar64
+ go: "Uint.*"
+ treatLikeAScalarOfSize: 64
+ out:
+ - *any
+- go: ShiftAllRight
+ signed: false
+ asm: "VPSRL[WDQ]"
+ in:
+ - &uint
+ go: $t
+ base: uint
+ - *vecAsScalar64
+ out:
+ - *uint
+- go: ShiftAllRight
+ signed: true
+ asm: "VPSRA[WDQ]"
+ in:
+ - &int
+ go: $t
+ base: int
+ - *vecAsScalar64
+ out:
+ - *int
+
+- go: shiftAllLeftConst
+ asm: "VPSLL[WDQ]"
+ in:
+ - *any
+ - &imm
+ class: immediate
+ immOffset: 0
+ out:
+ - *any
+- go: shiftAllRightConst
+ asm: "VPSRL[WDQ]"
+ in:
+ - *int
+ - *imm
+ out:
+ - *int
+- go: shiftAllRightConst
+ asm: "VPSRA[WDQ]"
+ in:
+ - *uint
+ - *imm
+ out:
+ - *uint
+
+# Shift* (variable)
+- go: ShiftLeft
+ asm: "VPSLLV[WD]"
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+# XED data of VPSLLVQ marks the element bits 32 which is off to the actual semantic, we need to overwrite
+# it to 64.
+- go: ShiftLeft
+ asm: "VPSLLVQ"
+ in:
+ - &anyOverwriteElemBits
+ go: $t
+ overwriteElementBits: 64
+ - *anyOverwriteElemBits
+ out:
+ - *anyOverwriteElemBits
+- go: ShiftRight
+ signed: false
+ asm: "VPSRLV[WD]"
+ in:
+ - *uint
+ - *uint
+ out:
+ - *uint
+# XED data of VPSRLVQ needs the same overwrite as VPSLLVQ.
+- go: ShiftRight
+ signed: false
+ asm: "VPSRLVQ"
+ in:
+ - &uintOverwriteElemBits
+ go: $t
+ base: uint
+ overwriteElementBits: 64
+ - *uintOverwriteElemBits
+ out:
+ - *uintOverwriteElemBits
+- go: ShiftRight
+ signed: true
+ asm: "VPSRAV[WDQ]"
+ in:
+ - *int
+ - *int
+ out:
+ - *int
+
+# Rotate
+- go: RotateAllLeft
+ asm: "VPROL[DQ]"
+ in:
+ - *any
+ - &pureImm
+ class: immediate
+ immOffset: 0
+ name: shift
+ out:
+ - *any
+- go: RotateAllRight
+ asm: "VPROR[DQ]"
+ in:
+ - *any
+ - *pureImm
+ out:
+ - *any
+- go: RotateLeft
+ asm: "VPROLV[DQ]"
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+- go: RotateRight
+ asm: "VPRORV[DQ]"
+ in:
+ - *any
+ - *any
+ out:
+ - *any
+
+# Bizzare shifts.
+- go: ShiftAllLeftConcat
+ asm: "VPSHLD[WDQ]"
+ in:
+ - *any
+ - *any
+ - *pureImm
+ out:
+ - *any
+- go: ShiftAllRightConcat
+ asm: "VPSHRD[WDQ]"
+ in:
+ - *any
+ - *any
+ - *pureImm
+ out:
+ - *any
+- go: ShiftLeftConcat
+ asm: "VPSHLDV[WDQ]"
+ in:
+ - *any
+ - *any
+ - *any
+ out:
+ - *any
+- go: ShiftRightConcat
+ asm: "VPSHRDV[WDQ]"
+ in:
+ - *any
+ - *any
+ - *any
+ out:
+ - *any
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+)
+
+func pprints(v any) string {
+ var pp pprinter
+ pp.val(reflect.ValueOf(v), 0)
+ return string(pp.buf)
+}
+
+type pprinter struct {
+ buf []byte
+}
+
+func (p *pprinter) indent(by int) {
+ for range by {
+ p.buf = append(p.buf, '\t')
+ }
+}
+
+func (p *pprinter) val(v reflect.Value, indent int) {
+ switch v.Kind() {
+ default:
+ p.buf = fmt.Appendf(p.buf, "unsupported kind %v", v.Kind())
+
+ case reflect.Bool:
+ p.buf = strconv.AppendBool(p.buf, v.Bool())
+
+ case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.buf = strconv.AppendInt(p.buf, v.Int(), 10)
+
+ case reflect.String:
+ p.buf = strconv.AppendQuote(p.buf, v.String())
+
+ case reflect.Pointer:
+ if v.IsNil() {
+ p.buf = append(p.buf, "nil"...)
+ } else {
+ p.buf = append(p.buf, "&"...)
+ p.val(v.Elem(), indent)
+ }
+
+ case reflect.Slice, reflect.Array:
+ p.buf = append(p.buf, "[\n"...)
+ for i := range v.Len() {
+ p.indent(indent + 1)
+ p.val(v.Index(i), indent+1)
+ p.buf = append(p.buf, ",\n"...)
+ }
+ p.indent(indent)
+ p.buf = append(p.buf, ']')
+
+ case reflect.Struct:
+ vt := v.Type()
+ p.buf = append(append(p.buf, vt.String()...), "{\n"...)
+ for f := range v.NumField() {
+ p.indent(indent + 1)
+ p.buf = append(append(p.buf, vt.Field(f).Name...), ": "...)
+ p.val(v.Field(f), indent+1)
+ p.buf = append(p.buf, ",\n"...)
+ }
+ p.indent(indent)
+ p.buf = append(p.buf, '}')
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "testing"
+
+func TestSort(t *testing.T) {
+ testCases := []struct {
+ s1, s2 string
+ want int
+ }{
+ {"a1", "a2", -1},
+ {"a11a", "a11b", -1},
+ {"a01a1", "a1a01", -1},
+ {"a2", "a1", 1},
+ {"a10", "a2", 1},
+ {"a1", "a10", -1},
+ {"z11", "z2", 1},
+ {"z2", "z11", -1},
+ {"abc", "abd", -1},
+ {"123", "45", 1},
+ {"file1", "file1", 0},
+ {"file", "file1", -1},
+ {"file1", "file", 1},
+ {"a01", "a1", -1},
+ {"a1a", "a1b", -1},
+ }
+
+ for _, tc := range testCases {
+ got := compareNatural(tc.s1, tc.s2)
+ result := "✅"
+ if got != tc.want {
+ result = "❌"
+ t.Errorf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
+ } else {
+ t.Logf("%s CompareNatural(\"%s\", \"%s\") -> got %2d, want %2d\n", result, tc.s1, tc.s2, got, tc.want)
+ }
+ }
+}
--- /dev/null
+# This file defines the possible types of each operand and result.
+#
+# In general, we're able to narrow this down on some attributes directly from
+# the machine instruction descriptions, but the Go mappings need to further
+# constrain them and how they relate. For example, on x86 we can't distinguish
+# int and uint, though we can distinguish these from float.
+
+in: !repeat
+- !sum &types
+ - {class: vreg, go: Int8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
+ - {class: vreg, go: Uint8x16, base: "uint", elemBits: 8, bits: 128, lanes: 16}
+ - {class: vreg, go: Int16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
+ - {class: vreg, go: Uint16x8, base: "uint", elemBits: 16, bits: 128, lanes: 8}
+ - {class: vreg, go: Int32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
+ - {class: vreg, go: Uint32x4, base: "uint", elemBits: 32, bits: 128, lanes: 4}
+ - {class: vreg, go: Int64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
+ - {class: vreg, go: Uint64x2, base: "uint", elemBits: 64, bits: 128, lanes: 2}
+ - {class: vreg, go: Float32x4, base: "float", elemBits: 32, bits: 128, lanes: 4}
+ - {class: vreg, go: Float64x2, base: "float", elemBits: 64, bits: 128, lanes: 2}
+ - {class: vreg, go: Int8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
+ - {class: vreg, go: Uint8x32, base: "uint", elemBits: 8, bits: 256, lanes: 32}
+ - {class: vreg, go: Int16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
+ - {class: vreg, go: Uint16x16, base: "uint", elemBits: 16, bits: 256, lanes: 16}
+ - {class: vreg, go: Int32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
+ - {class: vreg, go: Uint32x8, base: "uint", elemBits: 32, bits: 256, lanes: 8}
+ - {class: vreg, go: Int64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
+ - {class: vreg, go: Uint64x4, base: "uint", elemBits: 64, bits: 256, lanes: 4}
+ - {class: vreg, go: Float32x8, base: "float", elemBits: 32, bits: 256, lanes: 8}
+ - {class: vreg, go: Float64x4, base: "float", elemBits: 64, bits: 256, lanes: 4}
+ - {class: vreg, go: Int8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
+ - {class: vreg, go: Uint8x64, base: "uint", elemBits: 8, bits: 512, lanes: 64}
+ - {class: vreg, go: Int16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
+ - {class: vreg, go: Uint16x32, base: "uint", elemBits: 16, bits: 512, lanes: 32}
+ - {class: vreg, go: Int32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
+ - {class: vreg, go: Uint32x16, base: "uint", elemBits: 32, bits: 512, lanes: 16}
+ - {class: vreg, go: Int64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
+ - {class: vreg, go: Uint64x8, base: "uint", elemBits: 64, bits: 512, lanes: 8}
+ - {class: vreg, go: Float32x16, base: "float", elemBits: 32, bits: 512, lanes: 16}
+ - {class: vreg, go: Float64x8, base: "float", elemBits: 64, bits: 512, lanes: 8}
+
+ - {class: mask, go: Mask8x16, base: "int", elemBits: 8, bits: 128, lanes: 16}
+ - {class: mask, go: Mask16x8, base: "int", elemBits: 16, bits: 128, lanes: 8}
+ - {class: mask, go: Mask32x4, base: "int", elemBits: 32, bits: 128, lanes: 4}
+ - {class: mask, go: Mask64x2, base: "int", elemBits: 64, bits: 128, lanes: 2}
+ - {class: mask, go: Mask8x32, base: "int", elemBits: 8, bits: 256, lanes: 32}
+ - {class: mask, go: Mask16x16, base: "int", elemBits: 16, bits: 256, lanes: 16}
+ - {class: mask, go: Mask32x8, base: "int", elemBits: 32, bits: 256, lanes: 8}
+ - {class: mask, go: Mask64x4, base: "int", elemBits: 64, bits: 256, lanes: 4}
+ - {class: mask, go: Mask8x64, base: "int", elemBits: 8, bits: 512, lanes: 64}
+ - {class: mask, go: Mask16x32, base: "int", elemBits: 16, bits: 512, lanes: 32}
+ - {class: mask, go: Mask32x16, base: "int", elemBits: 32, bits: 512, lanes: 16}
+ - {class: mask, go: Mask64x8, base: "int", elemBits: 64, bits: 512, lanes: 8}
+
+
+ - {class: greg, go: float64, base: "float", bits: 64, lanes: 1}
+ - {class: greg, go: float32, base: "float", bits: 32, lanes: 1}
+ - {class: greg, go: int64, base: "int", bits: 64, lanes: 1}
+ - {class: greg, go: int32, base: "int", bits: 32, lanes: 1}
+ - {class: greg, go: int16, base: "int", bits: 16, lanes: 1}
+ - {class: greg, go: int8, base: "int", bits: 8, lanes: 1}
+ - {class: greg, go: uint64, base: "uint", bits: 64, lanes: 1}
+ - {class: greg, go: uint32, base: "uint", bits: 32, lanes: 1}
+ - {class: greg, go: uint16, base: "uint", bits: 16, lanes: 1}
+ - {class: greg, go: uint8, base: "uint", bits: 8, lanes: 1}
+
+# Special shapes just to make INSERT[IF]128 work.
+# The elemBits field of these shapes are wrong, it would be overwritten by overwriteElemBits.
+ - {class: vreg, go: Int8x16, base: "int", elemBits: 128, bits: 128, lanes: 16}
+ - {class: vreg, go: Uint8x16, base: "uint", elemBits: 128, bits: 128, lanes: 16}
+ - {class: vreg, go: Int16x8, base: "int", elemBits: 128, bits: 128, lanes: 8}
+ - {class: vreg, go: Uint16x8, base: "uint", elemBits: 128, bits: 128, lanes: 8}
+ - {class: vreg, go: Int32x4, base: "int", elemBits: 128, bits: 128, lanes: 4}
+ - {class: vreg, go: Uint32x4, base: "uint", elemBits: 128, bits: 128, lanes: 4}
+ - {class: vreg, go: Int64x2, base: "int", elemBits: 128, bits: 128, lanes: 2}
+ - {class: vreg, go: Uint64x2, base: "uint", elemBits: 128, bits: 128, lanes: 2}
+
+ - {class: vreg, go: Int8x32, base: "int", elemBits: 128, bits: 256, lanes: 32}
+ - {class: vreg, go: Uint8x32, base: "uint", elemBits: 128, bits: 256, lanes: 32}
+ - {class: vreg, go: Int16x16, base: "int", elemBits: 128, bits: 256, lanes: 16}
+ - {class: vreg, go: Uint16x16, base: "uint", elemBits: 128, bits: 256, lanes: 16}
+ - {class: vreg, go: Int32x8, base: "int", elemBits: 128, bits: 256, lanes: 8}
+ - {class: vreg, go: Uint32x8, base: "uint", elemBits: 128, bits: 256, lanes: 8}
+ - {class: vreg, go: Int64x4, base: "int", elemBits: 128, bits: 256, lanes: 4}
+ - {class: vreg, go: Uint64x4, base: "uint", elemBits: 128, bits: 256, lanes: 4}
+
+ - {class: immediate, go: Immediate} # TODO: we only support imms that are not used as value -- usually as instruction semantic predicate like VPCMP as of now.
+inVariant: !repeat
+- *types
+out: !repeat
+- *types
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "cmp"
+ "fmt"
+ "log"
+ "maps"
+ "regexp"
+ "slices"
+ "strconv"
+ "strings"
+
+ "golang.org/x/arch/x86/xeddata"
+ "gopkg.in/yaml.v3"
+ "simd/_gen/unify"
+)
+
+const (
+ NOT_REG_CLASS = 0 // not a register
+ VREG_CLASS = 1 // classify as a vector register; see
+ GREG_CLASS = 2 // classify as a general register
+)
+
+// instVariant is a bitmap indicating a variant of an instruction that has
+// optional parameters.
+type instVariant uint8
+
+const (
+ instVariantNone instVariant = 0
+
+ // instVariantMasked indicates that this is the masked variant of an
+ // optionally-masked instruction.
+ instVariantMasked instVariant = 1 << iota
+)
+
+var operandRemarks int
+
+// TODO: Doc. Returns Values with Def domains.
+func loadXED(xedPath string) []*unify.Value {
+ // TODO: Obviously a bunch more to do here.
+
+ db, err := xeddata.NewDatabase(xedPath)
+ if err != nil {
+ log.Fatalf("open database: %v", err)
+ }
+
+ var defs []*unify.Value
+ err = xeddata.WalkInsts(xedPath, func(inst *xeddata.Inst) {
+ inst.Pattern = xeddata.ExpandStates(db, inst.Pattern)
+
+ switch {
+ case inst.RealOpcode == "N":
+ return // Skip unstable instructions
+ case !strings.HasPrefix(inst.Extension, "AVX"):
+ // We're only interested in AVX instructions.
+ return
+ }
+
+ if *flagDebugXED {
+ fmt.Printf("%s:\n%+v\n", inst.Pos, inst)
+ }
+
+ ops, err := decodeOperands(db, strings.Fields(inst.Operands))
+ if err != nil {
+ operandRemarks++
+ if *Verbose {
+ log.Printf("%s: [%s] %s", inst.Pos, inst.Opcode(), err)
+ }
+ return
+ }
+
+ applyQuirks(inst, ops)
+
+ defsPos := len(defs)
+ defs = append(defs, instToUVal(inst, ops)...)
+
+ if *flagDebugXED {
+ for i := defsPos; i < len(defs); i++ {
+ y, _ := yaml.Marshal(defs[i])
+ fmt.Printf("==>\n%s\n", y)
+ }
+ }
+ })
+ if err != nil {
+ log.Fatalf("walk insts: %v", err)
+ }
+
+ if len(unknownFeatures) > 0 {
+ if !*Verbose {
+ nInst := 0
+ for _, insts := range unknownFeatures {
+ nInst += len(insts)
+ }
+ log.Printf("%d unhandled CPU features for %d instructions (use -v for details)", len(unknownFeatures), nInst)
+ } else {
+ keys := slices.SortedFunc(maps.Keys(unknownFeatures), func(a, b cpuFeatureKey) int {
+ return cmp.Or(cmp.Compare(a.Extension, b.Extension),
+ cmp.Compare(a.ISASet, b.ISASet))
+ })
+ for _, key := range keys {
+ if key.ISASet == "" || key.ISASet == key.Extension {
+ log.Printf("unhandled Extension %s", key.Extension)
+ } else {
+ log.Printf("unhandled Extension %s and ISASet %s", key.Extension, key.ISASet)
+ }
+ log.Printf(" opcodes: %s", slices.Sorted(maps.Keys(unknownFeatures[key])))
+ }
+ }
+ }
+
+ return defs
+}
+
+var (
+ maskRequiredRe = regexp.MustCompile(`VPCOMPRESS[BWDQ]|VCOMPRESSP[SD]|VPEXPAND[BWDQ]|VEXPANDP[SD]`)
+ maskOptionalRe = regexp.MustCompile(`VPCMP(EQ|GT|U)?[BWDQ]|VCMPP[SD]`)
+)
+
+func applyQuirks(inst *xeddata.Inst, ops []operand) {
+ opc := inst.Opcode()
+ switch {
+ case maskRequiredRe.MatchString(opc):
+ // The mask on these instructions is marked optional, but the
+ // instruction is pointless without the mask.
+ for i, op := range ops {
+ if op, ok := op.(operandMask); ok {
+ op.optional = false
+ ops[i] = op
+ }
+ }
+
+ case maskOptionalRe.MatchString(opc):
+ // Conversely, these masks should be marked optional and aren't.
+ for i, op := range ops {
+ if op, ok := op.(operandMask); ok && op.action.r {
+ op.optional = true
+ ops[i] = op
+ }
+ }
+ }
+}
+
+type operandCommon struct {
+ action operandAction
+}
+
+// operandAction defines whether this operand is read and/or written.
+//
+// TODO: Should this live in [xeddata.Operand]?
+type operandAction struct {
+ r bool // Read
+ w bool // Written
+ cr bool // Read is conditional (implies r==true)
+ cw bool // Write is conditional (implies w==true)
+}
+
+type operandMem struct {
+ operandCommon
+ // TODO
+}
+
+type vecShape struct {
+ elemBits int // Element size in bits
+ bits int // Register width in bits (total vector bits)
+}
+
+type operandVReg struct { // Vector register
+ operandCommon
+ vecShape
+ elemBaseType scalarBaseType
+}
+
+type operandGReg struct { // Vector register
+ operandCommon
+ vecShape
+ elemBaseType scalarBaseType
+}
+
+// operandMask is a vector mask.
+//
+// Regardless of the actual mask representation, the [vecShape] of this operand
+// corresponds to the "bit for bit" type of mask. That is, elemBits gives the
+// element width covered by each mask element, and bits/elemBits gives the total
+// number of mask elements. (bits gives the total number of bits as if this were
+// a bit-for-bit mask, which may be meaningless on its own.)
+type operandMask struct {
+ operandCommon
+ vecShape
+ // Bits in the mask is w/bits.
+
+ allMasks bool // If set, size cannot be inferred because all operands are masks.
+
+ // Mask can be omitted, in which case it defaults to K0/"no mask"
+ optional bool
+}
+
+type operandImm struct {
+ operandCommon
+ bits int // Immediate size in bits
+}
+
+type operand interface {
+ common() operandCommon
+ addToDef(b *unify.DefBuilder)
+}
+
+func strVal(s any) *unify.Value {
+ return unify.NewValue(unify.NewStringExact(fmt.Sprint(s)))
+}
+
+func (o operandCommon) common() operandCommon {
+ return o
+}
+
+func (o operandMem) addToDef(b *unify.DefBuilder) {
+ // TODO: w, base
+ b.Add("class", strVal("memory"))
+}
+
+func (o operandVReg) addToDef(b *unify.DefBuilder) {
+ baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
+ if err != nil {
+ panic("parsing baseRe: " + err.Error())
+ }
+ b.Add("class", strVal("vreg"))
+ b.Add("bits", strVal(o.bits))
+ b.Add("base", unify.NewValue(baseDomain))
+ // If elemBits == bits, then the vector can be ANY shape. This happens with,
+ // for example, logical ops.
+ if o.elemBits != o.bits {
+ b.Add("elemBits", strVal(o.elemBits))
+ }
+}
+
+func (o operandGReg) addToDef(b *unify.DefBuilder) {
+ baseDomain, err := unify.NewStringRegex(o.elemBaseType.regex())
+ if err != nil {
+ panic("parsing baseRe: " + err.Error())
+ }
+ b.Add("class", strVal("greg"))
+ b.Add("bits", strVal(o.bits))
+ b.Add("base", unify.NewValue(baseDomain))
+ if o.elemBits != o.bits {
+ b.Add("elemBits", strVal(o.elemBits))
+ }
+}
+
+func (o operandMask) addToDef(b *unify.DefBuilder) {
+ b.Add("class", strVal("mask"))
+ if o.allMasks {
+ // If all operands are masks, omit sizes and let unification determine mask sizes.
+ return
+ }
+ b.Add("elemBits", strVal(o.elemBits))
+ b.Add("bits", strVal(o.bits))
+}
+
+func (o operandImm) addToDef(b *unify.DefBuilder) {
+ b.Add("class", strVal("immediate"))
+ b.Add("bits", strVal(o.bits))
+}
+
+var actionEncoding = map[string]operandAction{
+ "r": {r: true},
+ "cr": {r: true, cr: true},
+ "w": {w: true},
+ "cw": {w: true, cw: true},
+ "rw": {r: true, w: true},
+ "crw": {r: true, w: true, cr: true},
+ "rcw": {r: true, w: true, cw: true},
+}
+
+func decodeOperand(db *xeddata.Database, operand string) (operand, error) {
+ op, err := xeddata.NewOperand(db, operand)
+ if err != nil {
+ log.Fatalf("parsing operand %q: %v", operand, err)
+ }
+ if *flagDebugXED {
+ fmt.Printf(" %+v\n", op)
+ }
+
+ if strings.HasPrefix(op.Name, "EMX_BROADCAST") {
+ // This refers to a set of macros defined in all-state.txt that set a
+ // BCAST operand to various fixed values. But the BCAST operand is
+ // itself suppressed and "internal", so I think we can just ignore this
+ // operand.
+ return nil, nil
+ }
+
+ // TODO: See xed_decoded_inst_operand_action. This might need to be more
+ // complicated.
+ action, ok := actionEncoding[op.Action]
+ if !ok {
+ return nil, fmt.Errorf("unknown action %q", op.Action)
+ }
+ common := operandCommon{action: action}
+
+ lhs := op.NameLHS()
+ if strings.HasPrefix(lhs, "MEM") {
+ // TODO: Width, base type
+ return operandMem{
+ operandCommon: common,
+ }, nil
+ } else if strings.HasPrefix(lhs, "REG") {
+ if op.Width == "mskw" {
+ // The mask operand doesn't specify a width. We have to infer it.
+ //
+ // XED uses the marker ZEROSTR to indicate that a mask operand is
+ // optional and, if omitted, implies K0, aka "no mask".
+ return operandMask{
+ operandCommon: common,
+ optional: op.Attributes["TXT=ZEROSTR"],
+ }, nil
+ } else {
+ class, regBits := decodeReg(op)
+ if class == NOT_REG_CLASS {
+ return nil, fmt.Errorf("failed to decode register %q", operand)
+ }
+ baseType, elemBits, ok := decodeType(op)
+ if !ok {
+ return nil, fmt.Errorf("failed to decode register width %q", operand)
+ }
+ shape := vecShape{elemBits: elemBits, bits: regBits}
+ if class == VREG_CLASS {
+ return operandVReg{
+ operandCommon: common,
+ vecShape: shape,
+ elemBaseType: baseType,
+ }, nil
+ }
+ // general register
+ m := min(shape.bits, shape.elemBits)
+ shape.bits, shape.elemBits = m, m
+ return operandGReg{
+ operandCommon: common,
+ vecShape: shape,
+ elemBaseType: baseType,
+ }, nil
+
+ }
+ } else if strings.HasPrefix(lhs, "IMM") {
+ _, bits, ok := decodeType(op)
+ if !ok {
+ return nil, fmt.Errorf("failed to decode register width %q", operand)
+ }
+ return operandImm{
+ operandCommon: common,
+ bits: bits,
+ }, nil
+ }
+
+ // TODO: BASE and SEG
+ return nil, fmt.Errorf("unknown operand LHS %q in %q", lhs, operand)
+}
+
+func decodeOperands(db *xeddata.Database, operands []string) (ops []operand, err error) {
+ // Decode the XED operand descriptions.
+ for _, o := range operands {
+ op, err := decodeOperand(db, o)
+ if err != nil {
+ return nil, err
+ }
+ if op != nil {
+ ops = append(ops, op)
+ }
+ }
+
+ // XED doesn't encode the size of mask operands. If there are mask operands,
+ // try to infer their sizes from other operands.
+ if err := inferMaskSizes(ops); err != nil {
+ return nil, fmt.Errorf("%w in operands %+v", err, operands)
+ }
+
+ return ops, nil
+}
+
+func inferMaskSizes(ops []operand) error {
+ // This is a heuristic and it falls apart in some cases:
+ //
+ // - Mask operations like KAND[BWDQ] have *nothing* in the XED to indicate
+ // mask size.
+ //
+ // - VINSERT*, VPSLL*, VPSRA*, and VPSRL* and some others naturally have
+ // mixed input sizes and the XED doesn't indicate which operands the mask
+ // applies to.
+ //
+ // - VPDP* and VP4DP* have really complex mixed operand patterns.
+ //
+ // I think for these we may just have to hand-write a table of which
+ // operands each mask applies to.
+ inferMask := func(r, w bool) error {
+ var masks []int
+ var rSizes, wSizes, sizes []vecShape
+ allMasks := true
+ hasWMask := false
+ for i, op := range ops {
+ action := op.common().action
+ if _, ok := op.(operandMask); ok {
+ if action.r && action.w {
+ return fmt.Errorf("unexpected rw mask")
+ }
+ if action.r == r || action.w == w {
+ masks = append(masks, i)
+ }
+ if action.w {
+ hasWMask = true
+ }
+ } else {
+ allMasks = false
+ if reg, ok := op.(operandVReg); ok {
+ if action.r {
+ rSizes = append(rSizes, reg.vecShape)
+ }
+ if action.w {
+ wSizes = append(wSizes, reg.vecShape)
+ }
+ }
+ }
+ }
+ if len(masks) == 0 {
+ return nil
+ }
+
+ if r {
+ sizes = rSizes
+ if len(sizes) == 0 {
+ sizes = wSizes
+ }
+ }
+ if w {
+ sizes = wSizes
+ if len(sizes) == 0 {
+ sizes = rSizes
+ }
+ }
+
+ if len(sizes) == 0 {
+ // If all operands are masks, leave the mask inferrence to the users.
+ if allMasks {
+ for _, i := range masks {
+ m := ops[i].(operandMask)
+ m.allMasks = true
+ ops[i] = m
+ }
+ return nil
+ }
+ return fmt.Errorf("cannot infer mask size: no register operands")
+ }
+ shape, ok := singular(sizes)
+ if !ok {
+ if !hasWMask && len(wSizes) == 1 && len(masks) == 1 {
+ // This pattern looks like predicate mask, so its shape should align with the
+ // output. TODO: verify this is a safe assumption.
+ shape = wSizes[0]
+ } else {
+ return fmt.Errorf("cannot infer mask size: multiple register sizes %v", sizes)
+ }
+ }
+ for _, i := range masks {
+ m := ops[i].(operandMask)
+ m.vecShape = shape
+ ops[i] = m
+ }
+ return nil
+ }
+ if err := inferMask(true, false); err != nil {
+ return err
+ }
+ if err := inferMask(false, true); err != nil {
+ return err
+ }
+ return nil
+}
+
+// addOperandstoDef adds "in", "inVariant", and "out" to an instruction Def.
+//
+// Optional mask input operands are added to the inVariant field if
+// variant&instVariantMasked, and omitted otherwise.
+func addOperandsToDef(ops []operand, instDB *unify.DefBuilder, variant instVariant) {
+ var inVals, inVar, outVals []*unify.Value
+ asmPos := 0
+ for _, op := range ops {
+ var db unify.DefBuilder
+ op.addToDef(&db)
+ db.Add("asmPos", unify.NewValue(unify.NewStringExact(fmt.Sprint(asmPos))))
+
+ action := op.common().action
+ asmCount := 1 // # of assembly operands; 0 or 1
+ if action.r {
+ inVal := unify.NewValue(db.Build())
+ // If this is an optional mask, put it in the input variant tuple.
+ if mask, ok := op.(operandMask); ok && mask.optional {
+ if variant&instVariantMasked != 0 {
+ inVar = append(inVar, inVal)
+ } else {
+ // This operand doesn't appear in the assembly at all.
+ asmCount = 0
+ }
+ } else {
+ // Just a regular input operand.
+ inVals = append(inVals, inVal)
+ }
+ }
+ if action.w {
+ outVal := unify.NewValue(db.Build())
+ outVals = append(outVals, outVal)
+ }
+
+ asmPos += asmCount
+ }
+
+ instDB.Add("in", unify.NewValue(unify.NewTuple(inVals...)))
+ instDB.Add("inVariant", unify.NewValue(unify.NewTuple(inVar...)))
+ instDB.Add("out", unify.NewValue(unify.NewTuple(outVals...)))
+}
+
+func instToUVal(inst *xeddata.Inst, ops []operand) []*unify.Value {
+ feature, ok := decodeCPUFeature(inst)
+ if !ok {
+ return nil
+ }
+
+ var vals []*unify.Value
+ vals = append(vals, instToUVal1(inst, ops, feature, instVariantNone))
+ if hasOptionalMask(ops) {
+ vals = append(vals, instToUVal1(inst, ops, feature, instVariantMasked))
+ }
+ return vals
+}
+
+func instToUVal1(inst *xeddata.Inst, ops []operand, feature string, variant instVariant) *unify.Value {
+ var db unify.DefBuilder
+ db.Add("goarch", unify.NewValue(unify.NewStringExact("amd64")))
+ db.Add("asm", unify.NewValue(unify.NewStringExact(inst.Opcode())))
+ addOperandsToDef(ops, &db, variant)
+ db.Add("cpuFeature", unify.NewValue(unify.NewStringExact(feature)))
+
+ if strings.Contains(inst.Pattern, "ZEROING=0") {
+ // This is an EVEX instruction, but the ".Z" (zero-merging)
+ // instruction flag is NOT valid. EVEX.z must be zero.
+ //
+ // This can mean a few things:
+ //
+ // - The output of an instruction is a mask, so merging modes don't
+ // make any sense. E.g., VCMPPS.
+ //
+ // - There are no masks involved anywhere. (Maybe MASK=0 is also set
+ // in this case?) E.g., VINSERTPS.
+ //
+ // - The operation inherently performs merging. E.g., VCOMPRESSPS
+ // with a mem operand.
+ //
+ // There may be other reasons.
+ db.Add("zeroing", unify.NewValue(unify.NewStringExact("false")))
+ }
+ pos := unify.Pos{Path: inst.Pos.Path, Line: inst.Pos.Line}
+ return unify.NewValuePos(db.Build(), pos)
+}
+
+// decodeCPUFeature returns the CPU feature name required by inst. These match
+// the names of the "Has*" feature checks in the simd package.
+func decodeCPUFeature(inst *xeddata.Inst) (string, bool) {
+ key := cpuFeatureKey{
+ Extension: inst.Extension,
+ ISASet: isaSetStrip.ReplaceAllLiteralString(inst.ISASet, ""),
+ }
+ feat, ok := cpuFeatureMap[key]
+ if !ok {
+ imap := unknownFeatures[key]
+ if imap == nil {
+ imap = make(map[string]struct{})
+ unknownFeatures[key] = imap
+ }
+ imap[inst.Opcode()] = struct{}{}
+ return "", false
+ }
+ if feat == "ignore" {
+ return "", false
+ }
+ return feat, true
+}
+
+var isaSetStrip = regexp.MustCompile("_(128N?|256N?|512)$")
+
+type cpuFeatureKey struct {
+ Extension, ISASet string
+}
+
+// cpuFeatureMap maps from XED's "EXTENSION" and "ISA_SET" to a CPU feature name
+// that can be used in the SIMD API.
+var cpuFeatureMap = map[cpuFeatureKey]string{
+ {"AVX", ""}: "AVX",
+ {"AVX_VNNI", "AVX_VNNI"}: "AVXVNNI",
+ {"AVX2", ""}: "AVX2",
+
+ // AVX-512 foundational features. We combine all of these into one "AVX512" feature.
+ {"AVX512EVEX", "AVX512F"}: "AVX512",
+ {"AVX512EVEX", "AVX512CD"}: "AVX512",
+ {"AVX512EVEX", "AVX512BW"}: "AVX512",
+ {"AVX512EVEX", "AVX512DQ"}: "AVX512",
+ // AVX512VL doesn't appear explicitly in the ISASet. I guess it's implied by
+ // the vector length suffix.
+
+ // AVX-512 extension features
+ {"AVX512EVEX", "AVX512_BITALG"}: "AVX512BITALG",
+ {"AVX512EVEX", "AVX512_GFNI"}: "AVX512GFNI",
+ {"AVX512EVEX", "AVX512_VBMI2"}: "AVX512VBMI2",
+ {"AVX512EVEX", "AVX512_VBMI"}: "AVX512VBMI",
+ {"AVX512EVEX", "AVX512_VNNI"}: "AVX512VNNI",
+ {"AVX512EVEX", "AVX512_VPOPCNTDQ"}: "AVX512VPOPCNTDQ",
+
+ // AVX 10.2 (not yet supported)
+ {"AVX512EVEX", "AVX10_2_RC"}: "ignore",
+}
+
+var unknownFeatures = map[cpuFeatureKey]map[string]struct{}{}
+
+// hasOptionalMask returns whether there is an optional mask operand in ops.
+func hasOptionalMask(ops []operand) bool {
+ for _, op := range ops {
+ if op, ok := op.(operandMask); ok && op.optional {
+ return true
+ }
+ }
+ return false
+}
+
+func singular[T comparable](xs []T) (T, bool) {
+ if len(xs) == 0 {
+ return *new(T), false
+ }
+ for _, x := range xs[1:] {
+ if x != xs[0] {
+ return *new(T), false
+ }
+ }
+ return xs[0], true
+}
+
+// decodeReg returns class (NOT_REG_CLASS, VREG_CLASS, GREG_CLASS),
+// and width in bits. If the operand cannot be decided as a register,
+// then the clas is NOT_REG_CLASS.
+func decodeReg(op *xeddata.Operand) (class, width int) {
+ // op.Width tells us the total width, e.g.,:
+ //
+ // dq => 128 bits (XMM)
+ // qq => 256 bits (YMM)
+ // mskw => K
+ // z[iuf?](8|16|32|...) => 512 bits (ZMM)
+ //
+ // But the encoding is really weird and it's not clear if these *always*
+ // mean XMM/YMM/ZMM or if other irregular things can use these large widths.
+ // Hence, we dig into the register sets themselves.
+
+ if !strings.HasPrefix(op.NameLHS(), "REG") {
+ return NOT_REG_CLASS, 0
+ }
+ // TODO: We shouldn't be relying on the macro naming conventions. We should
+ // use all-dec-patterns.txt, but xeddata doesn't support that table right now.
+ rhs := op.NameRHS()
+ if !strings.HasSuffix(rhs, "()") {
+ return NOT_REG_CLASS, 0
+ }
+ switch {
+ case strings.HasPrefix(rhs, "XMM_"):
+ return VREG_CLASS, 128
+ case strings.HasPrefix(rhs, "YMM_"):
+ return VREG_CLASS, 256
+ case strings.HasPrefix(rhs, "ZMM_"):
+ return VREG_CLASS, 512
+ case strings.HasPrefix(rhs, "GPR64_"), strings.HasPrefix(rhs, "VGPR64_"):
+ return GREG_CLASS, 64
+ case strings.HasPrefix(rhs, "GPR32_"), strings.HasPrefix(rhs, "VGPR32_"):
+ return GREG_CLASS, 32
+ }
+ return NOT_REG_CLASS, 0
+}
+
+var xtypeRe = regexp.MustCompile(`^([iuf])([0-9]+)$`)
+
+// scalarBaseType describes the base type of a scalar element. This is a Go
+// type, but without the bit width suffix (with the exception of
+// scalarBaseIntOrUint).
+type scalarBaseType int
+
+const (
+ scalarBaseInt scalarBaseType = iota
+ scalarBaseUint
+ scalarBaseIntOrUint // Signed or unsigned is unspecified
+ scalarBaseFloat
+ scalarBaseComplex
+ scalarBaseBFloat
+ scalarBaseHFloat
+)
+
+func (s scalarBaseType) regex() string {
+ switch s {
+ case scalarBaseInt:
+ return "int"
+ case scalarBaseUint:
+ return "uint"
+ case scalarBaseIntOrUint:
+ return "int|uint"
+ case scalarBaseFloat:
+ return "float"
+ case scalarBaseComplex:
+ return "complex"
+ case scalarBaseBFloat:
+ return "BFloat"
+ case scalarBaseHFloat:
+ return "HFloat"
+ }
+ panic(fmt.Sprintf("unknown scalar base type %d", s))
+}
+
+func decodeType(op *xeddata.Operand) (base scalarBaseType, bits int, ok bool) {
+ // The xtype tells you the element type. i8, i16, i32, i64, f32, etc.
+ //
+ // TODO: Things like AVX2 VPAND have an xtype of u256 because they're
+ // element-width agnostic. Do I map that to all widths, or just omit the
+ // element width and let unification flesh it out? There's no u512
+ // (presumably those are all masked, so elem width matters). These are all
+ // Category: LOGICAL, so maybe we could use that info?
+
+ // Handle some weird ones.
+ switch op.Xtype {
+ // 8-bit float formats as defined by Open Compute Project "OCP 8-bit
+ // Floating Point Specification (OFP8)".
+ case "bf8": // E5M2 float
+ return scalarBaseBFloat, 8, true
+ case "hf8": // E4M3 float
+ return scalarBaseHFloat, 8, true
+ case "bf16": // bfloat16 float
+ return scalarBaseBFloat, 16, true
+ case "2f16":
+ // Complex consisting of 2 float16s. Doesn't exist in Go, but we can say
+ // what it would be.
+ return scalarBaseComplex, 32, true
+ case "2i8", "2I8":
+ // These just use the lower INT8 in each 16 bit field.
+ // As far as I can tell, "2I8" is a typo.
+ return scalarBaseInt, 8, true
+ case "2u16", "2U16":
+ // some VPDP* has it
+ // TODO: does "z" means it has zeroing?
+ return scalarBaseUint, 16, true
+ case "2i16", "2I16":
+ // some VPDP* has it
+ return scalarBaseInt, 16, true
+ case "4u8", "4U8":
+ // some VPDP* has it
+ return scalarBaseUint, 8, true
+ case "4i8", "4I8":
+ // some VPDP* has it
+ return scalarBaseInt, 8, true
+ }
+
+ // The rest follow a simple pattern.
+ m := xtypeRe.FindStringSubmatch(op.Xtype)
+ if m == nil {
+ // TODO: Report unrecognized xtype
+ return 0, 0, false
+ }
+ bits, _ = strconv.Atoi(m[2])
+ switch m[1] {
+ case "i", "u":
+ // XED is rather inconsistent about what's signed, unsigned, or doesn't
+ // matter, so merge them together and let the Go definitions narrow as
+ // appropriate. Maybe there's a better way to do this.
+ return scalarBaseIntOrUint, bits, true
+ case "f":
+ return scalarBaseFloat, bits, true
+ default:
+ panic("unreachable")
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "iter"
+ "maps"
+ "slices"
+)
+
+type Closure struct {
+ val *Value
+ env envSet
+}
+
+func NewSum(vs ...*Value) Closure {
+ id := &ident{name: "sum"}
+ return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)}
+}
+
+// IsBottom returns whether c consists of no values.
+func (c Closure) IsBottom() bool {
+ return c.val.Domain == nil
+}
+
+// Summands returns the top-level Values of c. This assumes the top-level of c
+// was constructed as a sum, and is mostly useful for debugging.
+func (c Closure) Summands() iter.Seq[*Value] {
+ return func(yield func(*Value) bool) {
+ var rec func(v *Value, env envSet) bool
+ rec = func(v *Value, env envSet) bool {
+ switch d := v.Domain.(type) {
+ case Var:
+ parts := env.partitionBy(d.id)
+ for _, part := range parts {
+ // It may be a sum of sums. Walk into this value.
+ if !rec(part.value, part.env) {
+ return false
+ }
+ }
+ return true
+ default:
+ return yield(v)
+ }
+ }
+ rec(c.val, c.env)
+ }
+}
+
+// All enumerates all possible concrete values of c by substituting variables
+// from the environment.
+//
+// E.g., enumerating this Value
+//
+// a: !sum [1, 2]
+// b: !sum [3, 4]
+//
+// results in
+//
+// - {a: 1, b: 3}
+// - {a: 1, b: 4}
+// - {a: 2, b: 3}
+// - {a: 2, b: 4}
+func (c Closure) All() iter.Seq[*Value] {
+ // In order to enumerate all concrete values under all possible variable
+ // bindings, we use a "non-deterministic continuation passing style" to
+ // implement this. We use CPS to traverse the Value tree, threading the
+ // (possibly narrowing) environment through that CPS following an Euler
+ // tour. Where the environment permits multiple choices, we invoke the same
+ // continuation for each choice. Similar to a yield function, the
+ // continuation can return false to stop the non-deterministic walk.
+ return func(yield func(*Value) bool) {
+ c.val.all1(c.env, func(v *Value, e envSet) bool {
+ return yield(v)
+ })
+ }
+}
+
+func (v *Value) all1(e envSet, cont func(*Value, envSet) bool) bool {
+ switch d := v.Domain.(type) {
+ default:
+ panic(fmt.Sprintf("unknown domain type %T", d))
+
+ case nil:
+ return true
+
+ case Top, String:
+ return cont(v, e)
+
+ case Def:
+ fields := d.keys()
+ // We can reuse this parts slice because we're doing a DFS through the
+ // state space. (Otherwise, we'd have to do some messy threading of an
+ // immutable slice-like value through allElt.)
+ parts := make(map[string]*Value, len(fields))
+
+ // TODO: If there are no Vars or Sums under this Def, then nothing can
+ // change the Value or env, so we could just cont(v, e).
+ var allElt func(elt int, e envSet) bool
+ allElt = func(elt int, e envSet) bool {
+ if elt == len(fields) {
+ // Build a new Def from the concrete parts. Clone parts because
+ // we may reuse it on other non-deterministic branches.
+ nVal := newValueFrom(Def{maps.Clone(parts)}, v)
+ return cont(nVal, e)
+ }
+
+ return d.fields[fields[elt]].all1(e, func(v *Value, e envSet) bool {
+ parts[fields[elt]] = v
+ return allElt(elt+1, e)
+ })
+ }
+ return allElt(0, e)
+
+ case Tuple:
+ // Essentially the same as Def.
+ if d.repeat != nil {
+ // There's nothing we can do with this.
+ return cont(v, e)
+ }
+ parts := make([]*Value, len(d.vs))
+ var allElt func(elt int, e envSet) bool
+ allElt = func(elt int, e envSet) bool {
+ if elt == len(d.vs) {
+ // Build a new tuple from the concrete parts. Clone parts because
+ // we may reuse it on other non-deterministic branches.
+ nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v)
+ return cont(nVal, e)
+ }
+
+ return d.vs[elt].all1(e, func(v *Value, e envSet) bool {
+ parts[elt] = v
+ return allElt(elt+1, e)
+ })
+ }
+ return allElt(0, e)
+
+ case Var:
+ // Go each way this variable can be bound.
+ for _, ePart := range e.partitionBy(d.id) {
+ // d.id is no longer bound in this environment partition. We'll may
+ // need it later in the Euler tour, so bind it back to this single
+ // value.
+ env := ePart.env.bind(d.id, ePart.value)
+ if !ePart.value.all1(env, cont) {
+ return false
+ }
+ }
+ return true
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "iter"
+ "maps"
+ "reflect"
+ "regexp"
+ "slices"
+ "strconv"
+ "strings"
+)
+
+// A Domain is a non-empty set of values, all of the same kind.
+//
+// Domain may be a scalar:
+//
+// - [String] - Represents string-typed values.
+//
+// Or a composite:
+//
+// - [Def] - A mapping from fixed keys to [Domain]s.
+//
+// - [Tuple] - A fixed-length sequence of [Domain]s or
+// all possible lengths repeating a [Domain].
+//
+// Or top or bottom:
+//
+// - [Top] - Represents all possible values of all kinds.
+//
+// - nil - Represents no values.
+//
+// Or a variable:
+//
+// - [Var] - A value captured in the environment.
+type Domain interface {
+ Exact() bool
+ WhyNotExact() string
+
+ // decode stores this value in a Go value. If this value is not exact, this
+ // returns a potentially wrapped *inexactError.
+ decode(reflect.Value) error
+}
+
+type inexactError struct {
+ valueType string
+ goType string
+}
+
+func (e *inexactError) Error() string {
+ return fmt.Sprintf("cannot store inexact %s value in %s", e.valueType, e.goType)
+}
+
+type decodeError struct {
+ path string
+ err error
+}
+
+func newDecodeError(path string, err error) *decodeError {
+ if err, ok := err.(*decodeError); ok {
+ return &decodeError{path: path + "." + err.path, err: err.err}
+ }
+ return &decodeError{path: path, err: err}
+}
+
+func (e *decodeError) Unwrap() error {
+ return e.err
+}
+
+func (e *decodeError) Error() string {
+ return fmt.Sprintf("%s: %s", e.path, e.err)
+}
+
+// Top represents all possible values of all possible types.
+type Top struct{}
+
+func (t Top) Exact() bool { return false }
+func (t Top) WhyNotExact() string { return "is top" }
+
+func (t Top) decode(rv reflect.Value) error {
+ // We can decode Top into a pointer-typed value as nil.
+ if rv.Kind() != reflect.Pointer {
+ return &inexactError{"top", rv.Type().String()}
+ }
+ rv.SetZero()
+ return nil
+}
+
+// A Def is a mapping from field names to [Value]s. Any fields not explicitly
+// listed have [Value] [Top].
+type Def struct {
+ fields map[string]*Value
+}
+
+// A DefBuilder builds a [Def] one field at a time. The zero value is an empty
+// [Def].
+type DefBuilder struct {
+ fields map[string]*Value
+}
+
+func (b *DefBuilder) Add(name string, v *Value) {
+ if b.fields == nil {
+ b.fields = make(map[string]*Value)
+ }
+ if _, ok := b.fields[name]; ok {
+ panic(fmt.Sprintf("duplicate field %q", name))
+ }
+ b.fields[name] = v
+}
+
+// Build constructs a [Def] from the fields added to this builder.
+func (b *DefBuilder) Build() Def {
+ return Def{maps.Clone(b.fields)}
+}
+
+// Exact returns true if all field Values are exact.
+func (d Def) Exact() bool {
+ for _, v := range d.fields {
+ if !v.Exact() {
+ return false
+ }
+ }
+ return true
+}
+
+// WhyNotExact returns why the value is not exact
+func (d Def) WhyNotExact() string {
+ for s, v := range d.fields {
+ if !v.Exact() {
+ w := v.WhyNotExact()
+ return "field " + s + ": " + w
+ }
+ }
+ return ""
+}
+
+func (d Def) decode(rv reflect.Value) error {
+ if rv.Kind() != reflect.Struct {
+ return fmt.Errorf("cannot decode Def into %s", rv.Type())
+ }
+
+ var lowered map[string]string // Lower case -> canonical for d.fields.
+ rt := rv.Type()
+ for fi := range rv.NumField() {
+ fType := rt.Field(fi)
+ if fType.PkgPath != "" {
+ continue
+ }
+ v := d.fields[fType.Name]
+ if v == nil {
+ v = topValue
+
+ // Try a case-insensitive match
+ canon, ok := d.fields[strings.ToLower(fType.Name)]
+ if ok {
+ v = canon
+ } else {
+ if lowered == nil {
+ lowered = make(map[string]string, len(d.fields))
+ for k := range d.fields {
+ l := strings.ToLower(k)
+ if k != l {
+ lowered[l] = k
+ }
+ }
+ }
+ canon, ok := lowered[strings.ToLower(fType.Name)]
+ if ok {
+ v = d.fields[canon]
+ }
+ }
+ }
+ if err := decodeReflect(v, rv.Field(fi)); err != nil {
+ return newDecodeError(fType.Name, err)
+ }
+ }
+ return nil
+}
+
+func (d Def) keys() []string {
+ return slices.Sorted(maps.Keys(d.fields))
+}
+
+func (d Def) All() iter.Seq2[string, *Value] {
+ // TODO: We call All fairly often. It's probably bad to sort this every
+ // time.
+ keys := slices.Sorted(maps.Keys(d.fields))
+ return func(yield func(string, *Value) bool) {
+ for _, k := range keys {
+ if !yield(k, d.fields[k]) {
+ return
+ }
+ }
+ }
+}
+
+// A Tuple is a sequence of Values in one of two forms: 1. a fixed-length tuple,
+// where each Value can be different or 2. a "repeated tuple", which is a Value
+// repeated 0 or more times.
+type Tuple struct {
+ vs []*Value
+
+ // repeat, if non-nil, means this Tuple consists of an element repeated 0 or
+ // more times. If repeat is non-nil, vs must be nil. This is a generator
+ // function because we don't necessarily want *exactly* the same Value
+ // repeated. For example, in YAML encoding, a !sum in a repeated tuple needs
+ // a fresh variable in each instance.
+ repeat []func(envSet) (*Value, envSet)
+}
+
+func NewTuple(vs ...*Value) Tuple {
+ return Tuple{vs: vs}
+}
+
+func NewRepeat(gens ...func(envSet) (*Value, envSet)) Tuple {
+ return Tuple{repeat: gens}
+}
+
+func (d Tuple) Exact() bool {
+ if d.repeat != nil {
+ return false
+ }
+ for _, v := range d.vs {
+ if !v.Exact() {
+ return false
+ }
+ }
+ return true
+}
+
+func (d Tuple) WhyNotExact() string {
+ if d.repeat != nil {
+ return "d.repeat is not nil"
+ }
+ for i, v := range d.vs {
+ if !v.Exact() {
+ w := v.WhyNotExact()
+ return "index " + strconv.FormatInt(int64(i), 10) + ": " + w
+ }
+ }
+ return ""
+}
+
+func (d Tuple) decode(rv reflect.Value) error {
+ if d.repeat != nil {
+ return &inexactError{"repeated tuple", rv.Type().String()}
+ }
+ // TODO: We could also do arrays.
+ if rv.Kind() != reflect.Slice {
+ return fmt.Errorf("cannot decode Tuple into %s", rv.Type())
+ }
+ if rv.IsNil() || rv.Cap() < len(d.vs) {
+ rv.Set(reflect.MakeSlice(rv.Type(), len(d.vs), len(d.vs)))
+ } else {
+ rv.SetLen(len(d.vs))
+ }
+ for i, v := range d.vs {
+ if err := decodeReflect(v, rv.Index(i)); err != nil {
+ return newDecodeError(fmt.Sprintf("%d", i), err)
+ }
+ }
+ return nil
+}
+
+// A String represents a set of strings. It can represent the intersection of a
+// set of regexps, or a single exact string. In general, the domain of a String
+// is non-empty, but we do not attempt to prove emptiness of a regexp value.
+type String struct {
+ kind stringKind
+ re []*regexp.Regexp // Intersection of regexps
+ exact string
+}
+
+type stringKind int
+
+const (
+ stringRegex stringKind = iota
+ stringExact
+)
+
+func NewStringRegex(exprs ...string) (String, error) {
+ if len(exprs) == 0 {
+ exprs = []string{""}
+ }
+ v := String{kind: -1}
+ for _, expr := range exprs {
+ if expr == "" {
+ // Skip constructing the regexp. It won't have a "literal prefix"
+ // and so we wind up thinking this is a regexp instead of an exact
+ // (empty) string.
+ v = String{kind: stringExact, exact: ""}
+ continue
+ }
+
+ re, err := regexp.Compile(`\A(?:` + expr + `)\z`)
+ if err != nil {
+ return String{}, fmt.Errorf("parsing value: %s", err)
+ }
+
+ // An exact value narrows the whole domain to exact, so we're done, but
+ // should keep parsing.
+ if v.kind == stringExact {
+ continue
+ }
+
+ if exact, complete := re.LiteralPrefix(); complete {
+ v = String{kind: stringExact, exact: exact}
+ } else {
+ v.kind = stringRegex
+ v.re = append(v.re, re)
+ }
+ }
+ return v, nil
+}
+
+func NewStringExact(s string) String {
+ return String{kind: stringExact, exact: s}
+}
+
+// Exact returns whether this Value is known to consist of a single string.
+func (d String) Exact() bool {
+ return d.kind == stringExact
+}
+
+func (d String) WhyNotExact() string {
+ if d.kind == stringExact {
+ return ""
+ }
+ return "string is not exact"
+}
+
+func (d String) decode(rv reflect.Value) error {
+ if d.kind != stringExact {
+ return &inexactError{"regex", rv.Type().String()}
+ }
+ switch rv.Kind() {
+ default:
+ return fmt.Errorf("cannot decode String into %s", rv.Type())
+ case reflect.String:
+ rv.SetString(d.exact)
+ case reflect.Int:
+ i, err := strconv.Atoi(d.exact)
+ if err != nil {
+ return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
+ }
+ rv.SetInt(int64(i))
+ case reflect.Bool:
+ b, err := strconv.ParseBool(d.exact)
+ if err != nil {
+ return fmt.Errorf("cannot decode String into %s: %s", rv.Type(), err)
+ }
+ rv.SetBool(b)
+ }
+ return nil
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "bytes"
+ "fmt"
+ "html"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+)
+
+const maxNodes = 30
+
+type dotEncoder struct {
+ w *bytes.Buffer
+
+ idGen int // Node name generation
+ valLimit int // Limit the number of Values in a subgraph
+
+ idp identPrinter
+}
+
+func newDotEncoder() *dotEncoder {
+ return &dotEncoder{
+ w: new(bytes.Buffer),
+ }
+}
+
+func (enc *dotEncoder) clear() {
+ enc.w.Reset()
+ enc.idGen = 0
+}
+
+func (enc *dotEncoder) writeTo(w io.Writer) {
+ fmt.Fprintln(w, "digraph {")
+ // Use the "new" ranking algorithm, which lets us put nodes from different
+ // clusters in the same rank.
+ fmt.Fprintln(w, "newrank=true;")
+ fmt.Fprintln(w, "node [shape=box, ordering=out];")
+
+ w.Write(enc.w.Bytes())
+ fmt.Fprintln(w, "}")
+}
+
+func (enc *dotEncoder) writeSvg(w io.Writer) error {
+ cmd := exec.Command("dot", "-Tsvg")
+ in, err := cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ var out bytes.Buffer
+ cmd.Stdout = &out
+ cmd.Stderr = os.Stderr
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+ enc.writeTo(in)
+ in.Close()
+ if err := cmd.Wait(); err != nil {
+ return err
+ }
+ // Trim SVG header so the result can be embedded
+ //
+ // TODO: In Graphviz 10.0.1, we could use -Tsvg_inline.
+ svg := out.Bytes()
+ if i := bytes.Index(svg, []byte("<svg ")); i >= 0 {
+ svg = svg[i:]
+ }
+ _, err = w.Write(svg)
+ return err
+}
+
+func (enc *dotEncoder) newID(f string) string {
+ id := fmt.Sprintf(f, enc.idGen)
+ enc.idGen++
+ return id
+}
+
+func (enc *dotEncoder) node(label, sublabel string) string {
+ id := enc.newID("n%d")
+ l := html.EscapeString(label)
+ if sublabel != "" {
+ l += fmt.Sprintf("<BR ALIGN=\"CENTER\"/><FONT POINT-SIZE=\"10\">%s</FONT>", html.EscapeString(sublabel))
+ }
+ fmt.Fprintf(enc.w, "%s [label=<%s>];\n", id, l)
+ return id
+}
+
+func (enc *dotEncoder) edge(from, to string, label string, args ...any) {
+ l := fmt.Sprintf(label, args...)
+ fmt.Fprintf(enc.w, "%s -> %s [label=%q];\n", from, to, l)
+}
+
+func (enc *dotEncoder) valueSubgraph(v *Value) {
+ enc.valLimit = maxNodes
+ cID := enc.newID("cluster_%d")
+ fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
+ fmt.Fprintf(enc.w, "style=invis;")
+ vID := enc.value(v)
+ fmt.Fprintf(enc.w, "}\n")
+ // We don't need the IDs right now.
+ _, _ = cID, vID
+}
+
+func (enc *dotEncoder) value(v *Value) string {
+ if enc.valLimit <= 0 {
+ id := enc.newID("n%d")
+ fmt.Fprintf(enc.w, "%s [label=\"...\", shape=triangle];\n", id)
+ return id
+ }
+ enc.valLimit--
+
+ switch vd := v.Domain.(type) {
+ default:
+ panic(fmt.Sprintf("unknown domain type %T", vd))
+
+ case nil:
+ return enc.node("_|_", "")
+
+ case Top:
+ return enc.node("_", "")
+
+ // TODO: Like in YAML, figure out if this is just a sum. In dot, we
+ // could say any unentangled variable is a sum, and if it has more than
+ // one reference just share the node.
+
+ // case Sum:
+ // node := enc.node("Sum", "")
+ // for i, elt := range vd.vs {
+ // enc.edge(node, enc.value(elt), "%d", i)
+ // if enc.valLimit <= 0 {
+ // break
+ // }
+ // }
+ // return node
+
+ case Def:
+ node := enc.node("Def", "")
+ for k, v := range vd.All() {
+ enc.edge(node, enc.value(v), "%s", k)
+ if enc.valLimit <= 0 {
+ break
+ }
+ }
+ return node
+
+ case Tuple:
+ if vd.repeat == nil {
+ label := "Tuple"
+ node := enc.node(label, "")
+ for i, elt := range vd.vs {
+ enc.edge(node, enc.value(elt), "%d", i)
+ if enc.valLimit <= 0 {
+ break
+ }
+ }
+ return node
+ } else {
+ // TODO
+ return enc.node("TODO: Repeat", "")
+ }
+
+ case String:
+ switch vd.kind {
+ case stringExact:
+ return enc.node(fmt.Sprintf("%q", vd.exact), "")
+ case stringRegex:
+ var parts []string
+ for _, re := range vd.re {
+ parts = append(parts, fmt.Sprintf("%q", re))
+ }
+ return enc.node(strings.Join(parts, "&"), "")
+ }
+ panic("bad String kind")
+
+ case Var:
+ return enc.node(fmt.Sprintf("Var %s", enc.idp.unique(vd.id)), "")
+ }
+}
+
+func (enc *dotEncoder) envSubgraph(e envSet) {
+ enc.valLimit = maxNodes
+ cID := enc.newID("cluster_%d")
+ fmt.Fprintf(enc.w, "subgraph %s {\n", cID)
+ fmt.Fprintf(enc.w, "style=invis;")
+ vID := enc.env(e.root)
+ fmt.Fprintf(enc.w, "}\n")
+ _, _ = cID, vID
+}
+
+func (enc *dotEncoder) env(e *envExpr) string {
+ switch e.kind {
+ default:
+ panic("bad kind")
+ case envZero:
+ return enc.node("0", "")
+ case envUnit:
+ return enc.node("1", "")
+ case envBinding:
+ node := enc.node(fmt.Sprintf("%q :", enc.idp.unique(e.id)), "")
+ enc.edge(node, enc.value(e.val), "")
+ return node
+ case envProduct:
+ node := enc.node("⨯", "")
+ for _, op := range e.operands {
+ enc.edge(node, enc.env(op), "")
+ }
+ return node
+ case envSum:
+ node := enc.node("+", "")
+ for _, op := range e.operands {
+ enc.edge(node, enc.env(op), "")
+ }
+ return node
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "iter"
+ "reflect"
+ "strings"
+)
+
+// An envSet is an immutable set of environments, where each environment is a
+// mapping from [ident]s to [Value]s.
+//
+// To keep this compact, we use an algebraic representation similar to
+// relational algebra. The atoms are zero, unit, or a singular binding:
+//
+// - A singular binding is an environment set consisting of a single environment
+// that binds a single ident to a single value.
+//
+// - Zero is the empty set.
+//
+// - Unit is an environment set consisting of a single, empty environment (no
+// bindings).
+//
+// From these, we build up more complex sets of environments using sums and
+// cross products:
+//
+// - A sum is simply the union of the two environment sets.
+//
+// - A cross product is the Cartesian product of the two environment sets,
+// followed by combining each pair of environments. Combining simply merges the
+// two mappings, but fails if the mappings overlap.
+//
+// For example, to represent {{x: 1, y: 1}, {x: 2, y: 2}}, we build the two
+// environments and sum them:
+//
+// ({x: 1} ⨯ {y: 1}) + ({x: 2} ⨯ {y: 2})
+//
+// If we add a third variable z that can be 1 or 2, independent of x and y, we
+// get four logical environments:
+//
+// {x: 1, y: 1, z: 1}
+// {x: 2, y: 2, z: 1}
+// {x: 1, y: 1, z: 2}
+// {x: 2, y: 2, z: 2}
+//
+// This could be represented as a sum of all four environments, but because z is
+// independent, we can use a more compact representation:
+//
+// (({x: 1} ⨯ {y: 1}) + ({x: 2} ⨯ {y: 2})) ⨯ ({z: 1} + {z: 2})
+//
+// Environment sets obey commutative algebra rules:
+//
+// e + 0 = e
+// e ⨯ 0 = 0
+// e ⨯ 1 = e
+// e + f = f + e
+// e ⨯ f = f ⨯ e
+type envSet struct {
+ root *envExpr
+}
+
+type envExpr struct {
+ // TODO: A tree-based data structure for this may not be ideal, since it
+ // involves a lot of walking to find things and we often have to do deep
+ // rewrites anyway for partitioning. Would some flattened array-style
+ // representation be better, possibly combined with an index of ident uses?
+ // We could even combine that with an immutable array abstraction (ala
+ // Clojure) that could enable more efficient construction operations.
+
+ kind envExprKind
+
+ // For envBinding
+ id *ident
+ val *Value
+
+ // For sum or product. Len must be >= 2 and none of the elements can have
+ // the same kind as this node.
+ operands []*envExpr
+}
+
+type envExprKind byte
+
+const (
+ envZero envExprKind = iota
+ envUnit
+ envProduct
+ envSum
+ envBinding
+)
+
+var (
+ // topEnv is the unit value (multiplicative identity) of a [envSet].
+ topEnv = envSet{envExprUnit}
+ // bottomEnv is the zero value (additive identity) of a [envSet].
+ bottomEnv = envSet{envExprZero}
+
+ envExprZero = &envExpr{kind: envZero}
+ envExprUnit = &envExpr{kind: envUnit}
+)
+
+// bind binds id to each of vals in e.
+//
+// Its panics if id is already bound in e.
+//
+// Environments are typically initially constructed by starting with [topEnv]
+// and calling bind one or more times.
+func (e envSet) bind(id *ident, vals ...*Value) envSet {
+ if e.isEmpty() {
+ return bottomEnv
+ }
+
+ // TODO: If any of vals are _, should we just drop that val? We're kind of
+ // inconsistent about whether an id missing from e means id is invalid or
+ // means id is _.
+
+ // Check that id isn't present in e.
+ for range e.root.bindings(id) {
+ panic("id " + id.name + " already present in environment")
+ }
+
+ // Create a sum of all the values.
+ bindings := make([]*envExpr, 0, 1)
+ for _, val := range vals {
+ bindings = append(bindings, &envExpr{kind: envBinding, id: id, val: val})
+ }
+
+ // Multiply it in.
+ return envSet{newEnvExprProduct(e.root, newEnvExprSum(bindings...))}
+}
+
+func (e envSet) isEmpty() bool {
+ return e.root.kind == envZero
+}
+
+// bindings yields all [envBinding] nodes in e with the given id. If id is nil,
+// it yields all binding nodes.
+func (e *envExpr) bindings(id *ident) iter.Seq[*envExpr] {
+ // This is just a pre-order walk and it happens this is the only thing we
+ // need a pre-order walk for.
+ return func(yield func(*envExpr) bool) {
+ var rec func(e *envExpr) bool
+ rec = func(e *envExpr) bool {
+ if e.kind == envBinding && (id == nil || e.id == id) {
+ if !yield(e) {
+ return false
+ }
+ }
+ for _, o := range e.operands {
+ if !rec(o) {
+ return false
+ }
+ }
+ return true
+ }
+ rec(e)
+ }
+}
+
+// newEnvExprProduct constructs a product node from exprs, performing
+// simplifications. It does NOT check that bindings are disjoint.
+func newEnvExprProduct(exprs ...*envExpr) *envExpr {
+ factors := make([]*envExpr, 0, 2)
+ for _, expr := range exprs {
+ switch expr.kind {
+ case envZero:
+ return envExprZero
+ case envUnit:
+ // No effect on product
+ case envProduct:
+ factors = append(factors, expr.operands...)
+ default:
+ factors = append(factors, expr)
+ }
+ }
+
+ if len(factors) == 0 {
+ return envExprUnit
+ } else if len(factors) == 1 {
+ return factors[0]
+ }
+ return &envExpr{kind: envProduct, operands: factors}
+}
+
+// newEnvExprSum constructs a sum node from exprs, performing simplifications.
+func newEnvExprSum(exprs ...*envExpr) *envExpr {
+ // TODO: If all of envs are products (or bindings), factor any common terms.
+ // E.g., x * y + x * z ==> x * (y + z). This is easy to do for binding
+ // terms, but harder to do for more general terms.
+
+ var have smallSet[*envExpr]
+ terms := make([]*envExpr, 0, 2)
+ for _, expr := range exprs {
+ switch expr.kind {
+ case envZero:
+ // No effect on sum
+ case envSum:
+ for _, expr1 := range expr.operands {
+ if have.Add(expr1) {
+ terms = append(terms, expr1)
+ }
+ }
+ default:
+ if have.Add(expr) {
+ terms = append(terms, expr)
+ }
+ }
+ }
+
+ if len(terms) == 0 {
+ return envExprZero
+ } else if len(terms) == 1 {
+ return terms[0]
+ }
+ return &envExpr{kind: envSum, operands: terms}
+}
+
+func crossEnvs(env1, env2 envSet) envSet {
+ // Confirm that envs have disjoint idents.
+ var ids1 smallSet[*ident]
+ for e := range env1.root.bindings(nil) {
+ ids1.Add(e.id)
+ }
+ for e := range env2.root.bindings(nil) {
+ if ids1.Has(e.id) {
+ panic(fmt.Sprintf("%s bound on both sides of cross-product", e.id.name))
+ }
+ }
+
+ return envSet{newEnvExprProduct(env1.root, env2.root)}
+}
+
+func unionEnvs(envs ...envSet) envSet {
+ exprs := make([]*envExpr, len(envs))
+ for i := range envs {
+ exprs[i] = envs[i].root
+ }
+ return envSet{newEnvExprSum(exprs...)}
+}
+
+// envPartition is a subset of an env where id is bound to value in all
+// deterministic environments.
+type envPartition struct {
+ id *ident
+ value *Value
+ env envSet
+}
+
+// partitionBy splits e by distinct bindings of id and removes id from each
+// partition.
+//
+// If there are environments in e where id is not bound, they will not be
+// reflected in any partition.
+//
+// It panics if e is bottom, since attempting to partition an empty environment
+// set almost certainly indicates a bug.
+func (e envSet) partitionBy(id *ident) []envPartition {
+ if e.isEmpty() {
+ // We could return zero partitions, but getting here at all almost
+ // certainly indicates a bug.
+ panic("cannot partition empty environment set")
+ }
+
+ // Emit a partition for each value of id.
+ var seen smallSet[*Value]
+ var parts []envPartition
+ for n := range e.root.bindings(id) {
+ if !seen.Add(n.val) {
+ // Already emitted a partition for this value.
+ continue
+ }
+
+ parts = append(parts, envPartition{
+ id: id,
+ value: n.val,
+ env: envSet{e.root.substitute(id, n.val)},
+ })
+ }
+
+ return parts
+}
+
+// substitute replaces bindings of id to val with 1 and bindings of id to any
+// other value with 0 and simplifies the result.
+func (e *envExpr) substitute(id *ident, val *Value) *envExpr {
+ switch e.kind {
+ default:
+ panic("bad kind")
+
+ case envZero, envUnit:
+ return e
+
+ case envBinding:
+ if e.id != id {
+ return e
+ } else if e.val != val {
+ return envExprZero
+ } else {
+ return envExprUnit
+ }
+
+ case envProduct, envSum:
+ // Substitute each operand. Sometimes, this won't change anything, so we
+ // build the new operands list lazily.
+ var nOperands []*envExpr
+ for i, op := range e.operands {
+ nOp := op.substitute(id, val)
+ if nOperands == nil && op != nOp {
+ // Operand diverged; initialize nOperands.
+ nOperands = make([]*envExpr, 0, len(e.operands))
+ nOperands = append(nOperands, e.operands[:i]...)
+ }
+ if nOperands != nil {
+ nOperands = append(nOperands, nOp)
+ }
+ }
+ if nOperands == nil {
+ // Nothing changed.
+ return e
+ }
+ if e.kind == envProduct {
+ return newEnvExprProduct(nOperands...)
+ } else {
+ return newEnvExprSum(nOperands...)
+ }
+ }
+}
+
+// A smallSet is a set optimized for stack allocation when small.
+type smallSet[T comparable] struct {
+ array [32]T
+ n int
+
+ m map[T]struct{}
+}
+
+// Has returns whether val is in set.
+func (s *smallSet[T]) Has(val T) bool {
+ arr := s.array[:s.n]
+ for i := range arr {
+ if arr[i] == val {
+ return true
+ }
+ }
+ _, ok := s.m[val]
+ return ok
+}
+
+// Add adds val to the set and returns true if it was added (not already
+// present).
+func (s *smallSet[T]) Add(val T) bool {
+ // Test for presence.
+ if s.Has(val) {
+ return false
+ }
+
+ // Add it
+ if s.n < len(s.array) {
+ s.array[s.n] = val
+ s.n++
+ } else {
+ if s.m == nil {
+ s.m = make(map[T]struct{})
+ }
+ s.m[val] = struct{}{}
+ }
+ return true
+}
+
+type ident struct {
+ _ [0]func() // Not comparable (only compare *ident)
+ name string
+}
+
+type Var struct {
+ id *ident
+}
+
+func (d Var) Exact() bool {
+ // These can't appear in concrete Values.
+ panic("Exact called on non-concrete Value")
+}
+
+func (d Var) WhyNotExact() string {
+ // These can't appear in concrete Values.
+ return "WhyNotExact called on non-concrete Value"
+}
+
+func (d Var) decode(rv reflect.Value) error {
+ return &inexactError{"var", rv.Type().String()}
+}
+
+func (d Var) unify(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
+ // TODO: Vars from !sums in the input can have a huge number of values.
+ // Unifying these could be way more efficient with some indexes over any
+ // exact values we can pull out, like Def fields that are exact Strings.
+ // Maybe we try to produce an array of yes/no/maybe matches and then we only
+ // have to do deeper evaluation of the maybes. We could probably cache this
+ // on an envTerm. It may also help to special-case Var/Var unification to
+ // pick which one to index versus enumerate.
+
+ if vd, ok := w.Domain.(Var); ok && d.id == vd.id {
+ // Unifying $x with $x results in $x. If we descend into this we'll have
+ // problems because we strip $x out of the environment to keep ourselves
+ // honest and then can't find it on the other side.
+ //
+ // TODO: I'm not positive this is the right fix.
+ return vd, e, nil
+ }
+
+ // We need to unify w with the value of d in each possible environment. We
+ // can save some work by grouping environments by the value of d, since
+ // there will be a lot of redundancy here.
+ var nEnvs []envSet
+ envParts := e.partitionBy(d.id)
+ for i, envPart := range envParts {
+ exit := uf.enterVar(d.id, i)
+ // Each branch logically gets its own copy of the initial environment
+ // (narrowed down to just this binding of the variable), and each branch
+ // may result in different changes to that starting environment.
+ res, e2, err := w.unify(envPart.value, envPart.env, swap, uf)
+ exit.exit()
+ if err != nil {
+ return nil, envSet{}, err
+ }
+ if res.Domain == nil {
+ // This branch entirely failed to unify, so it's gone.
+ continue
+ }
+ nEnv := e2.bind(d.id, res)
+ nEnvs = append(nEnvs, nEnv)
+ }
+
+ if len(nEnvs) == 0 {
+ // All branches failed
+ return nil, bottomEnv, nil
+ }
+
+ // The effect of this is entirely captured in the environment. We can return
+ // back the same Bind node.
+ return d, unionEnvs(nEnvs...), nil
+}
+
+// An identPrinter maps [ident]s to unique string names.
+type identPrinter struct {
+ ids map[*ident]string
+ idGen map[string]int
+}
+
+func (p *identPrinter) unique(id *ident) string {
+ if p.ids == nil {
+ p.ids = make(map[*ident]string)
+ p.idGen = make(map[string]int)
+ }
+
+ name, ok := p.ids[id]
+ if !ok {
+ gen := p.idGen[id.name]
+ p.idGen[id.name]++
+ if gen == 0 {
+ name = id.name
+ } else {
+ name = fmt.Sprintf("%s#%d", id.name, gen)
+ }
+ p.ids[id] = name
+ }
+
+ return name
+}
+
+func (p *identPrinter) slice(ids []*ident) string {
+ var strs []string
+ for _, id := range ids {
+ strs = append(strs, p.unique(id))
+ }
+ return fmt.Sprintf("[%s]", strings.Join(strs, ", "))
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "html"
+ "io"
+ "strings"
+)
+
+func (t *tracer) writeHTML(w io.Writer) {
+ if !t.saveTree {
+ panic("writeHTML called without tracer.saveTree")
+ }
+
+ fmt.Fprintf(w, "<html><head><style>%s</style></head>", htmlCSS)
+ for _, root := range t.trees {
+ dot := newDotEncoder()
+ html := htmlTracer{w: w, dot: dot}
+ html.writeTree(root)
+ }
+ fmt.Fprintf(w, "</html>\n")
+}
+
+const htmlCSS = `
+.unify {
+ display: grid;
+ grid-auto-columns: min-content;
+ text-align: center;
+}
+
+.header {
+ grid-row: 1;
+ font-weight: bold;
+ padding: 0.25em;
+ position: sticky;
+ top: 0;
+ background: white;
+}
+
+.envFactor {
+ display: grid;
+ grid-auto-rows: min-content;
+ grid-template-columns: subgrid;
+ text-align: center;
+}
+`
+
+type htmlTracer struct {
+ w io.Writer
+ dot *dotEncoder
+ svgs map[any]string
+}
+
+func (t *htmlTracer) writeTree(node *traceTree) {
+ // TODO: This could be really nice.
+ //
+ // - Put nodes that were unified on the same rank with {rank=same; a; b}
+ //
+ // - On hover, highlight nodes that node was unified with and the result. If
+ // it's a variable, highlight it in the environment, too.
+ //
+ // - On click, show the details of unifying that node.
+ //
+ // This could be the only way to navigate, without necessarily needing the
+ // whole nest of <detail> nodes.
+
+ // TODO: It might be possible to write this out on the fly.
+
+ t.emit([]*Value{node.v, node.w}, []string{"v", "w"}, node.envIn)
+
+ // Render children.
+ for i, child := range node.children {
+ if i >= 10 {
+ fmt.Fprintf(t.w, `<div style="margin-left: 4em">...</div>`)
+ break
+ }
+ fmt.Fprintf(t.w, `<details style="margin-left: 4em"><summary>%s</summary>`, html.EscapeString(child.label))
+ t.writeTree(child)
+ fmt.Fprintf(t.w, "</details>\n")
+ }
+
+ // Render result.
+ if node.err != nil {
+ fmt.Fprintf(t.w, "Error: %s\n", html.EscapeString(node.err.Error()))
+ } else {
+ t.emit([]*Value{node.res}, []string{"res"}, node.env)
+ }
+}
+
+func htmlSVG[Key comparable](t *htmlTracer, f func(Key), arg Key) string {
+ if s, ok := t.svgs[arg]; ok {
+ return s
+ }
+ var buf strings.Builder
+ f(arg)
+ t.dot.writeSvg(&buf)
+ t.dot.clear()
+ svg := buf.String()
+ if t.svgs == nil {
+ t.svgs = make(map[any]string)
+ }
+ t.svgs[arg] = svg
+ buf.Reset()
+ return svg
+}
+
+func (t *htmlTracer) emit(vs []*Value, labels []string, env envSet) {
+ fmt.Fprintf(t.w, `<div class="unify">`)
+ for i, v := range vs {
+ fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">%s</div>`, i+1, html.EscapeString(labels[i]))
+ fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, i+1, htmlSVG(t, t.dot.valueSubgraph, v))
+ }
+ col := len(vs)
+
+ fmt.Fprintf(t.w, `<div class="header" style="grid-column: %d">in</div>`, col+1)
+ fmt.Fprintf(t.w, `<div style="grid-area: 2 / %d">%s</div>`, col+1, htmlSVG(t, t.dot.envSubgraph, env))
+
+ fmt.Fprintf(t.w, `</div>`)
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+)
+
+type Pos struct {
+ Path string
+ Line int
+}
+
+func (p Pos) String() string {
+ var b []byte
+ b, _ = p.AppendText(b)
+ return string(b)
+}
+
+func (p Pos) AppendText(b []byte) ([]byte, error) {
+ if p.Line == 0 {
+ if p.Path == "" {
+ return append(b, "?:?"...), nil
+ } else {
+ return append(b, p.Path...), nil
+ }
+ } else if p.Path == "" {
+ return fmt.Appendf(b, "?:%d", p.Line), nil
+ }
+ return fmt.Appendf(b, "%s:%d", p.Path, p.Line), nil
+}
--- /dev/null
+# In the original representation of environments, this caused an exponential
+# blowup in time and allocation. With that representation, this took about 20
+# seconds on my laptop and had a max RSS of ~12 GB. Big enough to be really
+# noticeable, but not so big it's likely to crash a developer machine. With the
+# better environment representation, it runs almost instantly and has an RSS of
+# ~90 MB.
+unify:
+- !sum
+ - !sum [1, 2]
+ - !sum [3, 4]
+ - !sum [5, 6]
+ - !sum [7, 8]
+ - !sum [9, 10]
+ - !sum [11, 12]
+ - !sum [13, 14]
+ - !sum [15, 16]
+ - !sum [17, 18]
+ - !sum [19, 20]
+ - !sum [21, 22]
+- !sum
+ - !sum [1, 2]
+ - !sum [3, 4]
+ - !sum [5, 6]
+ - !sum [7, 8]
+ - !sum [9, 10]
+ - !sum [11, 12]
+ - !sum [13, 14]
+ - !sum [15, 16]
+ - !sum [17, 18]
+ - !sum [19, 20]
+ - !sum [21, 22]
+all:
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
--- /dev/null
+# Basic tests of unification
+
+#
+# Terminals
+#
+
+unify:
+- _
+- _
+want:
+ _
+---
+unify:
+- _
+- test
+want:
+ test
+---
+unify:
+- test
+- t?est
+want:
+ test
+---
+unify:
+- 1
+- 1
+want:
+ 1
+---
+unify:
+- test
+- foo
+want:
+ _|_
+
+#
+# Tuple
+#
+
+---
+unify:
+- [a, b]
+- [a, b]
+want:
+ [a, b]
+---
+unify:
+- [a, _]
+- [_, b]
+want:
+ [a, b]
+---
+unify:
+- ["ab?c", "de?f"]
+- [ac, def]
+want:
+ [ac, def]
+
+#
+# Repeats
+#
+
+---
+unify:
+- !repeat [a]
+- [_]
+want:
+ [a]
+---
+unify:
+- !repeat [a]
+- [_, _]
+want:
+ [a, a]
+---
+unify:
+- !repeat [a]
+- [b]
+want:
+ _|_
+---
+unify:
+- !repeat [xy*]
+- [x, xy, xyy]
+want:
+ [x, xy, xyy]
+---
+unify:
+- !repeat [xy*]
+- !repeat ["xz?y*"]
+- [x, xy, xyy]
+want:
+ [x, xy, xyy]
+---
+unify:
+- !repeat [!sum [a, b]]
+- [a, b, a]
+all:
+- [a, b, a]
+---
+unify:
+- !repeat [!sum [a, b]]
+- !repeat [!sum [b, c]]
+- [b, b, b]
+all:
+- [b, b, b]
+---
+unify:
+- !repeat [!sum [a, b]]
+- !repeat [!sum [b, c]]
+- [a]
+all: []
+
+#
+# Def
+#
+
+---
+unify:
+- {a: a, b: b}
+- {a: a, b: b}
+want:
+ {a: a, b: b}
+---
+unify:
+- {a: a}
+- {b: b}
+want:
+ {a: a, b: b}
+
+#
+# Sum
+#
+
+---
+unify:
+- !sum [1, 2]
+- !sum [2, 3]
+all:
+- 2
+---
+unify:
+- !sum [{label: a, value: abc}, {label: b, value: def}]
+- !sum [{value: "ab?c", extra: d}, {value: "def?", extra: g}]
+all:
+- {extra: d, label: a, value: abc}
+- {extra: g, label: b, value: def}
+---
+# A sum of repeats must deal with different dynamically-created variables in
+# each branch.
+unify:
+- !sum [!repeat [a], !repeat [b]]
+- [a, a, a]
+all:
+- [a, a, a]
+---
+unify:
+- !sum [!repeat [a], !repeat [b]]
+- [a, a, b]
+all: []
+---
+# Exercise sumEnvs with more than one result
+unify:
+- !sum
+ - [a|b, c|d]
+ - [e, g]
+- [!sum [a, b, e, f], !sum [c, d, g, h]]
+all:
+- [a, c]
+- [a, d]
+- [b, c]
+- [b, d]
+- [e, g]
--- /dev/null
+#
+# Basic tests
+#
+
+name: "basic string"
+unify:
+- $x
+- test
+all:
+- test
+---
+name: "basic tuple"
+unify:
+- [$x, $x]
+- [test, test]
+all:
+- [test, test]
+---
+name: "three tuples"
+unify:
+- [$x, $x]
+- [test, _]
+- [_, test]
+all:
+- [test, test]
+---
+name: "basic def"
+unify:
+- {a: $x, b: $x}
+- {a: test, b: test}
+all:
+- {a: test, b: test}
+---
+name: "three defs"
+unify:
+- {a: $x, b: $x}
+- {a: test}
+- {b: test}
+all:
+- {a: test, b: test}
+
+#
+# Bottom tests
+#
+
+---
+name: "basic bottom"
+unify:
+- [$x, $x]
+- [test, foo]
+all: []
+---
+name: "three-way bottom"
+unify:
+- [$x, $x]
+- [test, _]
+- [_, foo]
+all: []
+
+#
+# Basic sum tests
+#
+
+---
+name: "basic sum"
+unify:
+- $x
+- !sum [a, b]
+all:
+- a
+- b
+---
+name: "sum of tuples"
+unify:
+- [$x]
+- !sum [[a], [b]]
+all:
+- [a]
+- [b]
+---
+name: "acausal sum"
+unify:
+- [_, !sum [a, b]]
+- [$x, $x]
+all:
+- [a, a]
+- [b, b]
+
+#
+# Transitivity tests
+#
+
+---
+name: "transitivity"
+unify:
+- [_, _, _, test]
+- [$x, $x, _, _]
+- [ _, $x, $x, _]
+- [ _, _, $x, $x]
+all:
+- [test, test, test, test]
+
+#
+# Multiple vars
+#
+
+---
+name: "basic uncorrelated vars"
+unify:
+- - !sum [1, 2]
+ - !sum [3, 4]
+- - $a
+ - $b
+all:
+- [1, 3]
+- [1, 4]
+- [2, 3]
+- [2, 4]
+---
+name: "uncorrelated vars"
+unify:
+- - !sum [1, 2]
+ - !sum [3, 4]
+ - !sum [1, 2]
+- - $a
+ - $b
+ - $a
+all:
+- [1, 3, 1]
+- [1, 4, 1]
+- [2, 3, 2]
+- [2, 4, 2]
+---
+name: "entangled vars"
+unify:
+- - !sum [[1,2],[3,4]]
+ - !sum [[2,1],[3,4],[4,3]]
+- - [$a, $b]
+ - [$b, $a]
+all:
+- - [1, 2]
+ - [2, 1]
+- - [3, 4]
+ - [4, 3]
+
+#
+# End-to-end examples
+#
+
+---
+name: "end-to-end"
+unify:
+- go: Add
+ in:
+ - go: $t
+ - go: $t
+- in: !repeat
+ - !sum
+ - go: Int32x4
+ base: int
+ - go: Uint32x4
+ base: uint
+all:
+- go: Add
+ in:
+ - base: int
+ go: Int32x4
+ - base: int
+ go: Int32x4
+- go: Add
+ in:
+ - base: uint
+ go: Uint32x4
+ - base: uint
+ go: Uint32x4
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "io"
+ "strings"
+
+ "gopkg.in/yaml.v3"
+)
+
+// debugDotInHTML, if true, includes dot code for all graphs in the HTML. Useful
+// for debugging the dot output itself.
+const debugDotInHTML = false
+
+var Debug struct {
+ // UnifyLog, if non-nil, receives a streaming text trace of unification.
+ UnifyLog io.Writer
+
+ // HTML, if non-nil, writes an HTML trace of unification to HTML.
+ HTML io.Writer
+}
+
+type tracer struct {
+ logw io.Writer
+
+ enc yamlEncoder // Print consistent idents throughout
+
+ saveTree bool // if set, record tree; required for HTML output
+
+ path []string
+
+ node *traceTree
+ trees []*traceTree
+}
+
+type traceTree struct {
+ label string // Identifies this node as a child of parent
+ v, w *Value // Unification inputs
+ envIn envSet
+ res *Value // Unification result
+ env envSet
+ err error // or error
+
+ parent *traceTree
+ children []*traceTree
+}
+
+type tracerExit struct {
+ t *tracer
+ len int
+ node *traceTree
+}
+
+func (t *tracer) enter(pat string, vals ...any) tracerExit {
+ if t == nil {
+ return tracerExit{}
+ }
+
+ label := fmt.Sprintf(pat, vals...)
+
+ var p *traceTree
+ if t.saveTree {
+ p = t.node
+ if p != nil {
+ t.node = &traceTree{label: label, parent: p}
+ p.children = append(p.children, t.node)
+ }
+ }
+
+ t.path = append(t.path, label)
+ return tracerExit{t, len(t.path) - 1, p}
+}
+
+func (t *tracer) enterVar(id *ident, branch int) tracerExit {
+ if t == nil {
+ return tracerExit{}
+ }
+
+ // Use the tracer's ident printer
+ return t.enter("Var %s br %d", t.enc.idp.unique(id), branch)
+}
+
+func (te tracerExit) exit() {
+ if te.t == nil {
+ return
+ }
+ te.t.path = te.t.path[:te.len]
+ te.t.node = te.node
+}
+
+func indentf(prefix string, pat string, vals ...any) string {
+ s := fmt.Sprintf(pat, vals...)
+ if len(prefix) == 0 {
+ return s
+ }
+ if !strings.Contains(s, "\n") {
+ return prefix + s
+ }
+
+ indent := prefix
+ if strings.TrimLeft(prefix, " ") != "" {
+ // Prefix has non-space characters in it. Construct an all space-indent.
+ indent = strings.Repeat(" ", len(prefix))
+ }
+ return prefix + strings.ReplaceAll(s, "\n", "\n"+indent)
+}
+
+func yamlf(prefix string, node *yaml.Node) string {
+ b, err := yaml.Marshal(node)
+ if err != nil {
+ return fmt.Sprintf("<marshal failed: %s>", err)
+ }
+ return strings.TrimRight(indentf(prefix, "%s", b), " \n")
+}
+
+func (t *tracer) logf(pat string, vals ...any) {
+ if t == nil || t.logw == nil {
+ return
+ }
+ prefix := fmt.Sprintf("[%s] ", strings.Join(t.path, "/"))
+ s := indentf(prefix, pat, vals...)
+ s = strings.TrimRight(s, " \n")
+ fmt.Fprintf(t.logw, "%s\n", s)
+}
+
+func (t *tracer) traceUnify(v, w *Value, e envSet) {
+ if t == nil {
+ return
+ }
+
+ t.logf("Unify\n%s\nwith\n%s\nin\n%s",
+ yamlf(" ", t.enc.value(v)),
+ yamlf(" ", t.enc.value(w)),
+ yamlf(" ", t.enc.env(e)))
+
+ if t.saveTree {
+ if t.node == nil {
+ t.node = &traceTree{}
+ t.trees = append(t.trees, t.node)
+ }
+ t.node.v, t.node.w, t.node.envIn = v, w, e
+ }
+}
+
+func (t *tracer) traceDone(res *Value, e envSet, err error) {
+ if t == nil {
+ return
+ }
+
+ if err != nil {
+ t.logf("==> %s", err)
+ } else {
+ t.logf("==>\n%s", yamlf(" ", t.enc.closure(Closure{res, e})))
+ }
+
+ if t.saveTree {
+ node := t.node
+ if node == nil {
+ panic("popped top of trace stack")
+ }
+ node.res, node.err = res, err
+ node.env = e
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package unify implements unification of structured values.
+//
+// A [Value] represents a possibly infinite set of concrete values, where a
+// value is either a string ([String]), a tuple of values ([Tuple]), or a
+// string-keyed map of values called a "def" ([Def]). These sets can be further
+// constrained by variables ([Var]). A [Value] combined with bindings of
+// variables is a [Closure].
+//
+// [Unify] finds a [Closure] that satisfies two or more other [Closure]s. This
+// can be thought of as intersecting the sets represented by these Closures'
+// values, or as the greatest lower bound/infimum of these Closures. If no such
+// Closure exists, the result of unification is "bottom", or the empty set.
+//
+// # Examples
+//
+// The regular expression "a*" is the infinite set of strings of zero or more
+// "a"s. "a*" can be unified with "a" or "aa" or "aaa", and the result is just
+// "a", "aa", or "aaa", respectively. However, unifying "a*" with "b" fails
+// because there are no values that satisfy both.
+//
+// Sums express sets directly. For example, !sum [a, b] is the set consisting of
+// "a" and "b". Unifying this with !sum [b, c] results in just "b". This also
+// makes it easy to demonstrate that unification isn't necessarily a single
+// concrete value. For example, unifying !sum [a, b, c] with !sum [b, c, d]
+// results in two concrete values: "b" and "c".
+//
+// The special value _ or "top" represents all possible values. Unifying _ with
+// any value x results in x.
+//
+// Unifying composite values—tuples and defs—unifies their elements.
+//
+// The value [a*, aa] is an infinite set of tuples. If we unify that with the
+// value [aaa, a*], the only possible value that satisfies both is [aaa, aa].
+// Likewise, this is the intersection of the sets described by these two values.
+//
+// Defs are similar to tuples, but they are indexed by strings and don't have a
+// fixed length. For example, {x: a, y: b} is a def with two fields. Any field
+// not mentioned in a def is implicitly top. Thus, unifying this with {y: b, z:
+// c} results in {x: a, y: b, z: c}.
+//
+// Variables constrain values. For example, the value [$x, $x] represents all
+// tuples whose first and second values are the same, but doesn't otherwise
+// constrain that value. Thus, this set includes [a, a] as well as [[b, c, d],
+// [b, c, d]], but it doesn't include [a, b].
+//
+// Sums are internally implemented as fresh variables that are simultaneously
+// bound to all values of the sum. That is !sum [a, b] is actually $var (where
+// var is some fresh name), closed under the environment $var=a | $var=b.
+package unify
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+)
+
+// Unify computes a Closure that satisfies each input Closure. If no such
+// Closure exists, it returns bottom.
+func Unify(closures ...Closure) (Closure, error) {
+ if len(closures) == 0 {
+ return Closure{topValue, topEnv}, nil
+ }
+
+ var trace *tracer
+ if Debug.UnifyLog != nil || Debug.HTML != nil {
+ trace = &tracer{
+ logw: Debug.UnifyLog,
+ saveTree: Debug.HTML != nil,
+ }
+ }
+
+ unified := closures[0]
+ for _, c := range closures[1:] {
+ var err error
+ uf := newUnifier()
+ uf.tracer = trace
+ e := crossEnvs(unified.env, c.env)
+ unified.val, unified.env, err = unified.val.unify(c.val, e, false, uf)
+ if Debug.HTML != nil {
+ uf.writeHTML(Debug.HTML)
+ }
+ if err != nil {
+ return Closure{}, err
+ }
+ }
+
+ return unified, nil
+}
+
+type unifier struct {
+ *tracer
+}
+
+func newUnifier() *unifier {
+ return &unifier{}
+}
+
+// errDomains is a sentinel error used between unify and unify1 to indicate that
+// unify1 could not unify the domains of the two values.
+var errDomains = errors.New("cannot unify domains")
+
+func (v *Value) unify(w *Value, e envSet, swap bool, uf *unifier) (*Value, envSet, error) {
+ if swap {
+ // Put the values in order. This just happens to be a handy choke-point
+ // to do this at.
+ v, w = w, v
+ }
+
+ uf.traceUnify(v, w, e)
+
+ d, e2, err := v.unify1(w, e, false, uf)
+ if err == errDomains {
+ // Try the other order.
+ d, e2, err = w.unify1(v, e, true, uf)
+ if err == errDomains {
+ // Okay, we really can't unify these.
+ err = fmt.Errorf("cannot unify %T (%s) and %T (%s): kind mismatch", v.Domain, v.PosString(), w.Domain, w.PosString())
+ }
+ }
+ if err != nil {
+ uf.traceDone(nil, envSet{}, err)
+ return nil, envSet{}, err
+ }
+ res := unified(d, v, w)
+ uf.traceDone(res, e2, nil)
+ if d == nil {
+ // Double check that a bottom Value also has a bottom env.
+ if !e2.isEmpty() {
+ panic("bottom Value has non-bottom environment")
+ }
+ }
+
+ return res, e2, nil
+}
+
+func (v *Value) unify1(w *Value, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
+ // TODO: If there's an error, attach position information to it.
+
+ vd, wd := v.Domain, w.Domain
+
+ // Bottom returns bottom, and eliminates all possible environments.
+ if vd == nil || wd == nil {
+ return nil, bottomEnv, nil
+ }
+
+ // Top always returns the other.
+ if _, ok := vd.(Top); ok {
+ return wd, e, nil
+ }
+
+ // Variables
+ if vd, ok := vd.(Var); ok {
+ return vd.unify(w, e, swap, uf)
+ }
+
+ // Composite values
+ if vd, ok := vd.(Def); ok {
+ if wd, ok := wd.(Def); ok {
+ return vd.unify(wd, e, swap, uf)
+ }
+ }
+ if vd, ok := vd.(Tuple); ok {
+ if wd, ok := wd.(Tuple); ok {
+ return vd.unify(wd, e, swap, uf)
+ }
+ }
+
+ // Scalar values
+ if vd, ok := vd.(String); ok {
+ if wd, ok := wd.(String); ok {
+ res := vd.unify(wd)
+ if res == nil {
+ e = bottomEnv
+ }
+ return res, e, nil
+ }
+ }
+
+ return nil, envSet{}, errDomains
+}
+
+func (d Def) unify(o Def, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
+ out := Def{fields: make(map[string]*Value)}
+
+ // Check keys of d against o.
+ for key, dv := range d.All() {
+ ov, ok := o.fields[key]
+ if !ok {
+ // ov is implicitly Top. Bypass unification.
+ out.fields[key] = dv
+ continue
+ }
+ exit := uf.enter("%s", key)
+ res, e2, err := dv.unify(ov, e, swap, uf)
+ exit.exit()
+ if err != nil {
+ return nil, envSet{}, err
+ } else if res.Domain == nil {
+ // No match.
+ return nil, bottomEnv, nil
+ }
+ out.fields[key] = res
+ e = e2
+ }
+ // Check keys of o that we didn't already check. These all implicitly match
+ // because we know the corresponding fields in d are all Top.
+ for key, dv := range o.All() {
+ if _, ok := d.fields[key]; !ok {
+ out.fields[key] = dv
+ }
+ }
+ return out, e, nil
+}
+
+func (v Tuple) unify(w Tuple, e envSet, swap bool, uf *unifier) (Domain, envSet, error) {
+ if v.repeat != nil && w.repeat != nil {
+ // Since we generate the content of these lazily, there's not much we
+ // can do but just stick them on a list to unify later.
+ return Tuple{repeat: concat(v.repeat, w.repeat)}, e, nil
+ }
+
+ // Expand any repeated tuples.
+ tuples := make([]Tuple, 0, 2)
+ if v.repeat == nil {
+ tuples = append(tuples, v)
+ } else {
+ v2, e2 := v.doRepeat(e, len(w.vs))
+ tuples = append(tuples, v2...)
+ e = e2
+ }
+ if w.repeat == nil {
+ tuples = append(tuples, w)
+ } else {
+ w2, e2 := w.doRepeat(e, len(v.vs))
+ tuples = append(tuples, w2...)
+ e = e2
+ }
+
+ // Now unify all of the tuples (usually this will be just 2 tuples)
+ out := tuples[0]
+ for _, t := range tuples[1:] {
+ if len(out.vs) != len(t.vs) {
+ uf.logf("tuple length mismatch")
+ return nil, bottomEnv, nil
+ }
+ zs := make([]*Value, len(out.vs))
+ for i, v1 := range out.vs {
+ exit := uf.enter("%d", i)
+ z, e2, err := v1.unify(t.vs[i], e, swap, uf)
+ exit.exit()
+ if err != nil {
+ return nil, envSet{}, err
+ } else if z.Domain == nil {
+ return nil, bottomEnv, nil
+ }
+ zs[i] = z
+ e = e2
+ }
+ out = Tuple{vs: zs}
+ }
+
+ return out, e, nil
+}
+
+// doRepeat creates a fixed-length tuple from a repeated tuple. The caller is
+// expected to unify the returned tuples.
+func (v Tuple) doRepeat(e envSet, n int) ([]Tuple, envSet) {
+ res := make([]Tuple, len(v.repeat))
+ for i, gen := range v.repeat {
+ res[i].vs = make([]*Value, n)
+ for j := range n {
+ res[i].vs[j], e = gen(e)
+ }
+ }
+ return res, e
+}
+
+// unify intersects the domains of two [String]s. If it can prove that this
+// domain is empty, it returns nil (bottom).
+//
+// TODO: Consider splitting literals and regexps into two domains.
+func (v String) unify(w String) Domain {
+ // Unification is symmetric, so put them in order of string kind so we only
+ // have to deal with half the cases.
+ if v.kind > w.kind {
+ v, w = w, v
+ }
+
+ switch v.kind {
+ case stringRegex:
+ switch w.kind {
+ case stringRegex:
+ // Construct a match against all of the regexps
+ return String{kind: stringRegex, re: slices.Concat(v.re, w.re)}
+ case stringExact:
+ for _, re := range v.re {
+ if !re.MatchString(w.exact) {
+ return nil
+ }
+ }
+ return w
+ }
+ case stringExact:
+ if v.exact != w.exact {
+ return nil
+ }
+ return v
+ }
+ panic("bad string kind")
+}
+
+func concat[T any](s1, s2 []T) []T {
+ // Reuse s1 or s2 if possible.
+ if len(s1) == 0 {
+ return s2
+ }
+ return append(s1[:len(s1):len(s1)], s2...)
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "testing"
+
+ "gopkg.in/yaml.v3"
+)
+
+func TestUnify(t *testing.T) {
+ paths, err := filepath.Glob("testdata/*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(paths) == 0 {
+ t.Fatal("no testdata found")
+ }
+ for _, path := range paths {
+ // Skip paths starting with _ so experimental files can be added.
+ base := filepath.Base(path)
+ if base[0] == '_' {
+ continue
+ }
+ if !strings.HasSuffix(base, ".yaml") {
+ t.Errorf("non-.yaml file in testdata: %s", base)
+ continue
+ }
+ base = strings.TrimSuffix(base, ".yaml")
+
+ t.Run(base, func(t *testing.T) {
+ testUnify(t, path)
+ })
+ }
+}
+
+func testUnify(t *testing.T, path string) {
+ f, err := os.Open(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer f.Close()
+
+ type testCase struct {
+ Skip bool
+ Name string
+ Unify []Closure
+ Want yaml.Node
+ All yaml.Node
+ }
+ dec := yaml.NewDecoder(f)
+
+ for i := 0; ; i++ {
+ var tc testCase
+ err := dec.Decode(&tc)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ name := tc.Name
+ if name == "" {
+ name = fmt.Sprint(i)
+ }
+
+ t.Run(name, func(t *testing.T) {
+ if tc.Skip {
+ t.Skip("skip: true set in test case")
+ }
+
+ defer func() {
+ p := recover()
+ if p != nil || t.Failed() {
+ // Redo with a trace
+ //
+ // TODO: Use t.Output() in Go 1.25.
+ var buf bytes.Buffer
+ Debug.UnifyLog = &buf
+ func() {
+ defer func() {
+ // If the original unify panicked, the second one
+ // probably will, too. Ignore it and let the first panic
+ // bubble.
+ recover()
+ }()
+ Unify(tc.Unify...)
+ }()
+ Debug.UnifyLog = nil
+ t.Logf("Trace:\n%s", buf.String())
+ }
+ if p != nil {
+ panic(p)
+ }
+ }()
+
+ // Unify the test cases
+ //
+ // TODO: Try reordering the inputs also
+ c, err := Unify(tc.Unify...)
+ if err != nil {
+ // TODO: Tests of errors
+ t.Fatal(err)
+ }
+
+ // Encode the result back to YAML so we can check if it's structurally
+ // equal.
+ clean := func(val any) *yaml.Node {
+ var node yaml.Node
+ node.Encode(val)
+ for n := range allYamlNodes(&node) {
+ // Canonicalize the style. There may be other style flags we need to
+ // muck with.
+ n.Style &^= yaml.FlowStyle
+ n.HeadComment = ""
+ n.LineComment = ""
+ n.FootComment = ""
+ }
+ return &node
+ }
+ check := func(gotVal any, wantNode *yaml.Node) {
+ got, err := yaml.Marshal(clean(gotVal))
+ if err != nil {
+ t.Fatalf("Encoding Value back to yaml failed: %s", err)
+ }
+ want, err := yaml.Marshal(clean(wantNode))
+ if err != nil {
+ t.Fatalf("Encoding Want back to yaml failed: %s", err)
+ }
+
+ if !bytes.Equal(got, want) {
+ t.Errorf("%s:%d:\nwant:\n%sgot\n%s", f.Name(), wantNode.Line, want, got)
+ }
+ }
+ if tc.Want.Kind != 0 {
+ check(c.val, &tc.Want)
+ }
+ if tc.All.Kind != 0 {
+ fVal := slices.Collect(c.All())
+ check(fVal, &tc.All)
+ }
+ })
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "fmt"
+ "iter"
+ "reflect"
+)
+
+// A Value represents a structured, non-deterministic value consisting of
+// strings, tuples of Values, and string-keyed maps of Values. A
+// non-deterministic Value will also contain variables, which are resolved via
+// an environment as part of a [Closure].
+//
+// For debugging, a Value can also track the source position it was read from in
+// an input file, and its provenance from other Values.
+type Value struct {
+ Domain Domain
+
+ // A Value has either a pos or parents (or neither).
+ pos *Pos
+ parents *[2]*Value
+}
+
+var (
+ topValue = &Value{Domain: Top{}}
+ bottomValue = &Value{Domain: nil}
+)
+
+// NewValue returns a new [Value] with the given domain and no position
+// information.
+func NewValue(d Domain) *Value {
+ return &Value{Domain: d}
+}
+
+// NewValuePos returns a new [Value] with the given domain at position p.
+func NewValuePos(d Domain, p Pos) *Value {
+ return &Value{Domain: d, pos: &p}
+}
+
+// newValueFrom returns a new [Value] with the given domain that copies the
+// position information of p.
+func newValueFrom(d Domain, p *Value) *Value {
+ return &Value{Domain: d, pos: p.pos, parents: p.parents}
+}
+
+func unified(d Domain, p1, p2 *Value) *Value {
+ return &Value{Domain: d, parents: &[2]*Value{p1, p2}}
+}
+
+func (v *Value) Pos() Pos {
+ if v.pos == nil {
+ return Pos{}
+ }
+ return *v.pos
+}
+
+func (v *Value) PosString() string {
+ var b []byte
+ for root := range v.Provenance() {
+ if len(b) > 0 {
+ b = append(b, ' ')
+ }
+ b, _ = root.pos.AppendText(b)
+ }
+ return string(b)
+}
+
+func (v *Value) WhyNotExact() string {
+ if v.Domain == nil {
+ return "v.Domain is nil"
+ }
+ return v.Domain.WhyNotExact()
+}
+
+func (v *Value) Exact() bool {
+ if v.Domain == nil {
+ return false
+ }
+ return v.Domain.Exact()
+}
+
+// Decode decodes v into a Go value.
+//
+// v must be exact, except that it can include Top. into must be a pointer.
+// [Def]s are decoded into structs. [Tuple]s are decoded into slices. [String]s
+// are decoded into strings or ints. Any field can itself be a pointer to one of
+// these types. Top can be decoded into a pointer-typed field and will set the
+// field to nil. Anything else will allocate a value if necessary.
+//
+// Any type may implement [Decoder], in which case its DecodeUnified method will
+// be called instead of using the default decoding scheme.
+func (v *Value) Decode(into any) error {
+ rv := reflect.ValueOf(into)
+ if rv.Kind() != reflect.Pointer {
+ return fmt.Errorf("cannot decode into non-pointer %T", into)
+ }
+ return decodeReflect(v, rv.Elem())
+}
+
+func decodeReflect(v *Value, rv reflect.Value) error {
+ var ptr reflect.Value
+ if rv.Kind() == reflect.Pointer {
+ if rv.IsNil() {
+ // Transparently allocate through pointers, *except* for Top, which
+ // wants to set the pointer to nil.
+ //
+ // TODO: Drop this condition if I switch to an explicit Optional[T]
+ // or move the Top logic into Def.
+ if _, ok := v.Domain.(Top); !ok {
+ // Allocate the value to fill in, but don't actually store it in
+ // the pointer until we successfully decode.
+ ptr = rv
+ rv = reflect.New(rv.Type().Elem()).Elem()
+ }
+ } else {
+ rv = rv.Elem()
+ }
+ }
+
+ var err error
+ if reflect.PointerTo(rv.Type()).Implements(decoderType) {
+ // Use the custom decoder.
+ err = rv.Addr().Interface().(Decoder).DecodeUnified(v)
+ } else {
+ err = v.Domain.decode(rv)
+ }
+ if err == nil && ptr.IsValid() {
+ ptr.Set(rv.Addr())
+ }
+ return err
+}
+
+// Decoder can be implemented by types as a custom implementation of [Decode]
+// for that type.
+type Decoder interface {
+ DecodeUnified(v *Value) error
+}
+
+var decoderType = reflect.TypeOf((*Decoder)(nil)).Elem()
+
+// Provenance iterates over all of the source Values that have contributed to
+// this Value.
+func (v *Value) Provenance() iter.Seq[*Value] {
+ return func(yield func(*Value) bool) {
+ var rec func(d *Value) bool
+ rec = func(d *Value) bool {
+ if d.pos != nil {
+ if !yield(d) {
+ return false
+ }
+ }
+ if d.parents != nil {
+ for _, p := range d.parents {
+ if !rec(p) {
+ return false
+ }
+ }
+ }
+ return true
+ }
+ rec(v)
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "reflect"
+ "slices"
+ "testing"
+)
+
+func ExampleClosure_All_tuple() {
+ v := mustParse(`
+- !sum [1, 2]
+- !sum [3, 4]
+`)
+ printYaml(slices.Collect(v.All()))
+
+ // Output:
+ // - [1, 3]
+ // - [1, 4]
+ // - [2, 3]
+ // - [2, 4]
+}
+
+func ExampleClosure_All_def() {
+ v := mustParse(`
+a: !sum [1, 2]
+b: !sum [3, 4]
+c: 5
+`)
+ printYaml(slices.Collect(v.All()))
+
+ // Output:
+ // - {a: 1, b: 3, c: 5}
+ // - {a: 1, b: 4, c: 5}
+ // - {a: 2, b: 3, c: 5}
+ // - {a: 2, b: 4, c: 5}
+}
+
+func checkDecode[T any](t *testing.T, got *Value, want T) {
+ var gotT T
+ if err := got.Decode(&gotT); err != nil {
+ t.Fatalf("Decode failed: %v", err)
+ }
+ if !reflect.DeepEqual(&gotT, &want) {
+ t.Fatalf("got:\n%s\nwant:\n%s", prettyYaml(gotT), prettyYaml(want))
+ }
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "gopkg.in/yaml.v3"
+)
+
+// ReadOpts provides options to [Read] and related functions. The zero value is
+// the default options.
+type ReadOpts struct {
+ // FS, if non-nil, is the file system from which to resolve !import file
+ // names.
+ FS fs.FS
+}
+
+// Read reads a [Closure] in YAML format from r, using path for error messages.
+//
+// It maps YAML nodes into terminal Values as follows:
+//
+// - "_" or !top _ is the top value ([Top]).
+//
+// - "_|_" or !bottom _ is the bottom value. This is an error during
+// unmarshaling, but can appear in marshaled values.
+//
+// - "$<name>" or !var <name> is a variable ([Var]). Everywhere the same name
+// appears within a single unmarshal operation, it is mapped to the same
+// variable. Different unmarshal operations get different variables, even if
+// they have the same string name.
+//
+// - !regex "x" is a regular expression ([String]), as is any string that
+// doesn't match "_", "_|_", or "$...". Regular expressions are implicitly
+// anchored at the beginning and end. If the string doesn't contain any
+// meta-characters (that is, it's a "literal" regular expression), then it's
+// treated as an exact string.
+//
+// - !string "x", or any int, float, bool, or binary value is an exact string
+// ([String]).
+//
+// - !regex [x, y, ...] is an intersection of regular expressions ([String]).
+//
+// It maps YAML nodes into non-terminal Values as follows:
+//
+// - Sequence nodes like [x, y, z] are tuples ([Tuple]).
+//
+// - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of
+// x. There must be exactly one element in the list.
+//
+// - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are
+// implicitly top.
+//
+// - !sum [x, y, z] is a sum of its children. This can be thought of as a union
+// of the values x, y, and z, or as a non-deterministic choice between x, y, and
+// z. If a variable appears both inside the sum and outside of it, only the
+// non-deterministic choice view really works. The unifier does not directly
+// implement sums; instead, this is decoded as a fresh variable that's
+// simultaneously bound to x, y, and z.
+//
+// - !import glob is like a !sum, but its children are read from all files
+// matching the given glob pattern, which is interpreted relative to the current
+// file path. Each file gets its own variable scope.
+func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
+ dec := yamlDecoder{opts: opts, path: path, env: topEnv}
+ v, err := dec.read(r)
+ if err != nil {
+ return Closure{}, err
+ }
+ return dec.close(v), nil
+}
+
+// ReadFile reads a [Closure] in YAML format from a file.
+//
+// The file must consist of a single YAML document.
+//
+// If opts.FS is not set, this sets it to a FS rooted at path's directory.
+//
+// See [Read] for details.
+func ReadFile(path string, opts ReadOpts) (Closure, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return Closure{}, err
+ }
+ defer f.Close()
+
+ if opts.FS == nil {
+ opts.FS = os.DirFS(filepath.Dir(path))
+ }
+
+ return Read(f, path, opts)
+}
+
+// UnmarshalYAML implements [yaml.Unmarshaler].
+//
+// Since there is no way to pass [ReadOpts] to this function, it assumes default
+// options.
+func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
+ dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
+ v, err := dec.root(node)
+ if err != nil {
+ return err
+ }
+ *c = dec.close(v)
+ return nil
+}
+
+type yamlDecoder struct {
+ opts ReadOpts
+ path string
+
+ vars map[string]*ident
+ nSums int
+
+ env envSet
+}
+
+func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
+ n, err := readOneNode(r)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", dec.path, err)
+ }
+
+ // Decode YAML node to a Value
+ v, err := dec.root(n)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", dec.path, err)
+ }
+
+ return v, nil
+}
+
+// readOneNode reads a single YAML document from r and returns an error if there
+// are more documents in r.
+func readOneNode(r io.Reader) (*yaml.Node, error) {
+ yd := yaml.NewDecoder(r)
+
+ // Decode as a YAML node
+ var node yaml.Node
+ if err := yd.Decode(&node); err != nil {
+ return nil, err
+ }
+ np := &node
+ if np.Kind == yaml.DocumentNode {
+ np = node.Content[0]
+ }
+
+ // Ensure there are no more YAML docs in this file
+ if err := yd.Decode(nil); err == nil {
+ return nil, fmt.Errorf("must not contain multiple documents")
+ } else if err != io.EOF {
+ return nil, err
+ }
+
+ return np, nil
+}
+
+// root parses the root of a file.
+func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
+ // Prepare for variable name resolution in this file. This may be a nested
+ // root, so restore the current values when we're done.
+ oldVars, oldNSums := dec.vars, dec.nSums
+ defer func() {
+ dec.vars, dec.nSums = oldVars, oldNSums
+ }()
+ dec.vars = make(map[string]*ident, 0)
+ dec.nSums = 0
+
+ return dec.value(node)
+}
+
+// close wraps a decoded [Value] into a [Closure].
+func (dec *yamlDecoder) close(v *Value) Closure {
+ return Closure{v, dec.env}
+}
+
+func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
+ pos := &Pos{Path: dec.path, Line: node.Line}
+
+ // Resolve alias nodes.
+ if node.Kind == yaml.AliasNode {
+ node = node.Alias
+ }
+
+ mk := func(d Domain) (*Value, error) {
+ v := &Value{Domain: d, pos: pos}
+ return v, nil
+ }
+ mk2 := func(d Domain, err error) (*Value, error) {
+ if err != nil {
+ return nil, err
+ }
+ return mk(d)
+ }
+
+ // is tests the kind and long tag of node.
+ is := func(kind yaml.Kind, tag string) bool {
+ return node.Kind == kind && node.LongTag() == tag
+ }
+ isExact := func() bool {
+ if node.Kind != yaml.ScalarNode {
+ return false
+ }
+ // We treat any string-ish YAML node as a string.
+ switch node.LongTag() {
+ case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
+ return true
+ }
+ return false
+ }
+
+ // !!str nodes provide a short-hand syntax for several leaf domains that are
+ // also available under explicit tags. To simplify checking below, we set
+ // strVal to non-"" only for !!str nodes.
+ strVal := ""
+ isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
+ if isStr {
+ strVal = node.Value
+ }
+
+ switch {
+ case is(yaml.ScalarNode, "!var"):
+ strVal = "$" + node.Value
+ fallthrough
+ case strings.HasPrefix(strVal, "$"):
+ id, ok := dec.vars[strVal]
+ if !ok {
+ // We encode different idents with the same string name by adding a
+ // #N suffix. Strip that off so it doesn't accumulate. This isn't
+ // meant to be used in user-written input, though nothing stops that.
+ name, _, _ := strings.Cut(strVal, "#")
+ id = &ident{name: name}
+ dec.vars[strVal] = id
+ dec.env = dec.env.bind(id, topValue)
+ }
+ return mk(Var{id: id})
+
+ case strVal == "_" || is(yaml.ScalarNode, "!top"):
+ return mk(Top{})
+
+ case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
+ return nil, errors.New("found bottom")
+
+ case isExact():
+ val := node.Value
+ return mk(NewStringExact(val))
+
+ case isStr || is(yaml.ScalarNode, "!regex"):
+ // Any other string we treat as a regex. This will produce an exact
+ // string anyway if the regex is literal.
+ val := node.Value
+ return mk2(NewStringRegex(val))
+
+ case is(yaml.SequenceNode, "!regex"):
+ var vals []string
+ if err := node.Decode(&vals); err != nil {
+ return nil, err
+ }
+ return mk2(NewStringRegex(vals...))
+
+ case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
+ var db DefBuilder
+ for i := 0; i < len(node.Content); i += 2 {
+ key := node.Content[i]
+ if key.Kind != yaml.ScalarNode {
+ return nil, fmt.Errorf("non-scalar key %q", key.Value)
+ }
+ val, err := dec.value(node.Content[i+1])
+ if err != nil {
+ return nil, err
+ }
+ db.Add(key.Value, val)
+ }
+ return mk(db.Build())
+
+ case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
+ elts := node.Content
+ vs := make([]*Value, 0, len(elts))
+ for _, elt := range elts {
+ v, err := dec.value(elt)
+ if err != nil {
+ return nil, err
+ }
+ vs = append(vs, v)
+ }
+ return mk(NewTuple(vs...))
+
+ case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
+ // !repeat must have one child. !repeat-unify is used internally for
+ // delayed unification, and is the same, it's just allowed to have more
+ // than one child.
+ if node.LongTag() == "!repeat" && len(node.Content) != 1 {
+ return nil, fmt.Errorf("!repeat must have exactly one child")
+ }
+
+ // Decode the children to make sure they're well-formed, but otherwise
+ // discard that decoding and do it again every time we need a new
+ // element.
+ var gen []func(e envSet) (*Value, envSet)
+ origEnv := dec.env
+ elts := node.Content
+ for i, elt := range elts {
+ _, err := dec.value(elt)
+ if err != nil {
+ return nil, err
+ }
+ // Undo any effects on the environment. We *do* keep any named
+ // variables that were added to the vars map in case they were
+ // introduced within the element.
+ dec.env = origEnv
+ // Add a generator function
+ gen = append(gen, func(e envSet) (*Value, envSet) {
+ dec.env = e
+ // TODO: If this is in a sum, this tends to generate a ton of
+ // fresh variables that are different on each branch of the
+ // parent sum. Does it make sense to hold on to the i'th value
+ // of the tuple after we've generated it?
+ v, err := dec.value(elts[i])
+ if err != nil {
+ // It worked the first time, so this really shouldn't hapen.
+ panic("decoding repeat element failed")
+ }
+ return v, dec.env
+ })
+ }
+ return mk(NewRepeat(gen...))
+
+ case is(yaml.SequenceNode, "!sum"):
+ vs := make([]*Value, 0, len(node.Content))
+ for _, elt := range node.Content {
+ v, err := dec.value(elt)
+ if err != nil {
+ return nil, err
+ }
+ vs = append(vs, v)
+ }
+ if len(vs) == 1 {
+ return vs[0], nil
+ }
+
+ // A sum is implemented as a fresh variable that's simultaneously bound
+ // to each of the descendants.
+ id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
+ dec.nSums++
+ dec.env = dec.env.bind(id, vs...)
+ return mk(Var{id: id})
+
+ case is(yaml.ScalarNode, "!import"):
+ if dec.opts.FS == nil {
+ return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
+ }
+ pat := node.Value
+
+ if !fs.ValidPath(pat) {
+ // This will result in Glob returning no results. Give a more useful
+ // error message for this case.
+ return nil, fmt.Errorf("!import path must not contain '.' or '..'")
+ }
+
+ ms, err := fs.Glob(dec.opts.FS, pat)
+ if err != nil {
+ return nil, fmt.Errorf("resolving !import: %w", err)
+ }
+ if len(ms) == 0 {
+ return nil, fmt.Errorf("!import did not match any files")
+ }
+
+ // Parse each file
+ vs := make([]*Value, 0, len(ms))
+ for _, m := range ms {
+ v, err := dec.import1(m)
+ if err != nil {
+ return nil, err
+ }
+ vs = append(vs, v)
+ }
+
+ // Create a sum.
+ if len(vs) == 1 {
+ return vs[0], nil
+ }
+ id := &ident{name: "import"}
+ dec.env = dec.env.bind(id, vs...)
+ return mk(Var{id: id})
+ }
+
+ return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
+}
+
+func (dec *yamlDecoder) import1(path string) (*Value, error) {
+ // Make sure we can open the path first.
+ f, err := dec.opts.FS.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("!import failed: %w", err)
+ }
+ defer f.Close()
+
+ // Prepare the enter path.
+ oldFS, oldPath := dec.opts.FS, dec.path
+ defer func() {
+ dec.opts.FS, dec.path = oldFS, oldPath
+ }()
+
+ // Enter path, which is relative to the current path's directory.
+ newPath := filepath.Join(filepath.Dir(dec.path), path)
+ subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
+ if err != nil {
+ return nil, err
+ }
+ dec.opts.FS, dec.path = subFS, newPath
+
+ // Parse the file.
+ return dec.read(f)
+}
+
+type yamlEncoder struct {
+ idp identPrinter
+ e envSet // We track the environment for !repeat nodes.
+}
+
+// TODO: Switch some Value marshaling to Closure?
+
+func (c Closure) MarshalYAML() (any, error) {
+ // TODO: If the environment is trivial, just marshal the value.
+ enc := &yamlEncoder{}
+ return enc.closure(c), nil
+}
+
+func (c Closure) String() string {
+ b, err := yaml.Marshal(c)
+ if err != nil {
+ return fmt.Sprintf("marshal failed: %s", err)
+ }
+ return string(b)
+}
+
+func (v *Value) MarshalYAML() (any, error) {
+ enc := &yamlEncoder{}
+ return enc.value(v), nil
+}
+
+func (v *Value) String() string {
+ b, err := yaml.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("marshal failed: %s", err)
+ }
+ return string(b)
+}
+
+func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
+ enc.e = c.env
+ var n yaml.Node
+ n.Kind = yaml.MappingNode
+ n.Tag = "!closure"
+ n.Content = make([]*yaml.Node, 4)
+ n.Content[0] = new(yaml.Node)
+ n.Content[0].SetString("env")
+ n.Content[2] = new(yaml.Node)
+ n.Content[2].SetString("in")
+ n.Content[3] = enc.value(c.val)
+ // Fill in the env after we've written the value in case value encoding
+ // affects the env.
+ n.Content[1] = enc.env(enc.e)
+ enc.e = envSet{} // Allow GC'ing the env
+ return &n
+}
+
+func (enc *yamlEncoder) env(e envSet) *yaml.Node {
+ var encode func(e *envExpr) *yaml.Node
+ encode = func(e *envExpr) *yaml.Node {
+ var n yaml.Node
+ switch e.kind {
+ default:
+ panic("bad kind")
+ case envZero:
+ n.SetString("0")
+ case envUnit:
+ n.SetString("1")
+ case envBinding:
+ var id yaml.Node
+ id.SetString(enc.idp.unique(e.id))
+ n.Kind = yaml.MappingNode
+ n.Content = []*yaml.Node{&id, enc.value(e.val)}
+ case envProduct, envSum:
+ n.Kind = yaml.SequenceNode
+ if e.kind == envProduct {
+ n.Tag = "!product"
+ } else {
+ n.Tag = "!sum"
+ }
+ for _, e2 := range e.operands {
+ n.Content = append(n.Content, encode(e2))
+ }
+ }
+ return &n
+ }
+ return encode(e.root)
+}
+
+var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
+
+func (enc *yamlEncoder) value(v *Value) *yaml.Node {
+ var n yaml.Node
+ switch d := v.Domain.(type) {
+ case nil:
+ // Not allowed by unmarshaler, but useful for understanding when
+ // something goes horribly wrong.
+ //
+ // TODO: We might be able to track useful provenance for this, which
+ // would really help with debugging unexpected bottoms.
+ n.SetString("_|_")
+ return &n
+
+ case Top:
+ n.SetString("_")
+ return &n
+
+ case Def:
+ n.Kind = yaml.MappingNode
+ for k, elt := range d.All() {
+ var kn yaml.Node
+ kn.SetString(k)
+ n.Content = append(n.Content, &kn, enc.value(elt))
+ }
+ n.HeadComment = v.PosString()
+ return &n
+
+ case Tuple:
+ n.Kind = yaml.SequenceNode
+ if d.repeat == nil {
+ for _, elt := range d.vs {
+ n.Content = append(n.Content, enc.value(elt))
+ }
+ } else {
+ if len(d.repeat) == 1 {
+ n.Tag = "!repeat"
+ } else {
+ n.Tag = "!repeat-unify"
+ }
+ // TODO: I'm not positive this will round-trip everything correctly.
+ for _, gen := range d.repeat {
+ v, e := gen(enc.e)
+ enc.e = e
+ n.Content = append(n.Content, enc.value(v))
+ }
+ }
+ return &n
+
+ case String:
+ switch d.kind {
+ case stringExact:
+ n.SetString(d.exact)
+ switch {
+ // Make this into a "nice" !!int node if I can.
+ case yamlIntRe.MatchString(d.exact):
+ n.Tag = "tag:yaml.org,2002:int"
+
+ // Or a "nice" !!bool node.
+ case d.exact == "false" || d.exact == "true":
+ n.Tag = "tag:yaml.org,2002:bool"
+
+ // If this doesn't require escaping, leave it as a str node to avoid
+ // the annoying YAML tags. Otherwise, mark it as an exact string.
+ // Alternatively, we could always emit a str node with regexp
+ // quoting.
+ case d.exact != regexp.QuoteMeta(d.exact):
+ n.Tag = "!string"
+ }
+ return &n
+ case stringRegex:
+ o := make([]string, 0, 1)
+ for _, re := range d.re {
+ s := re.String()
+ s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
+ o = append(o, s)
+ }
+ if len(o) == 1 {
+ n.SetString(o[0])
+ return &n
+ }
+ n.Encode(o)
+ n.Tag = "!regex"
+ return &n
+ }
+ panic("bad String kind")
+
+ case Var:
+ // TODO: If Var only appears once in the whole Value and is independent
+ // in the environment (part of a term that is only over Var), then emit
+ // this as a !sum instead.
+ if false {
+ var vs []*Value // TODO: Get values of this var.
+ if len(vs) == 1 {
+ return enc.value(vs[0])
+ }
+ n.Kind = yaml.SequenceNode
+ n.Tag = "!sum"
+ for _, elt := range vs {
+ n.Content = append(n.Content, enc.value(elt))
+ }
+ return &n
+ }
+ n.SetString(enc.idp.unique(d.id))
+ if !strings.HasPrefix(d.id.name, "$") {
+ n.Tag = "!var"
+ }
+ return &n
+ }
+ panic(fmt.Sprintf("unknown domain type %T", v.Domain))
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unify
+
+import (
+ "bytes"
+ "fmt"
+ "iter"
+ "log"
+ "strings"
+ "testing"
+ "testing/fstest"
+
+ "gopkg.in/yaml.v3"
+)
+
+func mustParse(expr string) Closure {
+ var c Closure
+ if err := yaml.Unmarshal([]byte(expr), &c); err != nil {
+ panic(err)
+ }
+ return c
+}
+
+func oneValue(t *testing.T, c Closure) *Value {
+ t.Helper()
+ var v *Value
+ var i int
+ for v = range c.All() {
+ i++
+ }
+ if i != 1 {
+ t.Fatalf("expected 1 value, got %d", i)
+ }
+ return v
+}
+
+func printYaml(val any) {
+ fmt.Println(prettyYaml(val))
+}
+
+func prettyYaml(val any) string {
+ b, err := yaml.Marshal(val)
+ if err != nil {
+ panic(err)
+ }
+ var node yaml.Node
+ if err := yaml.Unmarshal(b, &node); err != nil {
+ panic(err)
+ }
+
+ // Map lines to start offsets. We'll use this to figure out when nodes are
+ // "small" and should use inline style.
+ lines := []int{-1, 0}
+ for pos := 0; pos < len(b); {
+ next := bytes.IndexByte(b[pos:], '\n')
+ if next == -1 {
+ break
+ }
+ pos += next + 1
+ lines = append(lines, pos)
+ }
+ lines = append(lines, len(b))
+
+ // Strip comments and switch small nodes to inline style
+ cleanYaml(&node, lines, len(b))
+
+ b, err = yaml.Marshal(&node)
+ if err != nil {
+ panic(err)
+ }
+ return string(b)
+}
+
+func cleanYaml(node *yaml.Node, lines []int, endPos int) {
+ node.HeadComment = ""
+ node.FootComment = ""
+ node.LineComment = ""
+
+ for i, n2 := range node.Content {
+ end2 := endPos
+ if i < len(node.Content)-1 {
+ end2 = lines[node.Content[i+1].Line]
+ }
+ cleanYaml(n2, lines, end2)
+ }
+
+ // Use inline style?
+ switch node.Kind {
+ case yaml.MappingNode, yaml.SequenceNode:
+ if endPos-lines[node.Line] < 40 {
+ node.Style = yaml.FlowStyle
+ }
+ }
+}
+
+func allYamlNodes(n *yaml.Node) iter.Seq[*yaml.Node] {
+ return func(yield func(*yaml.Node) bool) {
+ if !yield(n) {
+ return
+ }
+ for _, n2 := range n.Content {
+ for n3 := range allYamlNodes(n2) {
+ if !yield(n3) {
+ return
+ }
+ }
+ }
+ }
+}
+
+func TestRoundTripString(t *testing.T) {
+ // Check that we can round-trip a string with regexp meta-characters in it.
+ const y = `!string test*`
+ t.Logf("input:\n%s", y)
+
+ v1 := oneValue(t, mustParse(y))
+ var buf1 strings.Builder
+ enc := yaml.NewEncoder(&buf1)
+ if err := enc.Encode(v1); err != nil {
+ log.Fatal(err)
+ }
+ enc.Close()
+ t.Logf("after parse 1:\n%s", buf1.String())
+
+ v2 := oneValue(t, mustParse(buf1.String()))
+ var buf2 strings.Builder
+ enc = yaml.NewEncoder(&buf2)
+ if err := enc.Encode(v2); err != nil {
+ log.Fatal(err)
+ }
+ enc.Close()
+ t.Logf("after parse 2:\n%s", buf2.String())
+
+ if buf1.String() != buf2.String() {
+ t.Fatal("parse 1 and parse 2 differ")
+ }
+}
+
+func TestEmptyString(t *testing.T) {
+ // Regression test. Make sure an empty string is parsed as an exact string,
+ // not a regexp.
+ const y = `""`
+ t.Logf("input:\n%s", y)
+
+ v1 := oneValue(t, mustParse(y))
+ if !v1.Exact() {
+ t.Fatal("expected exact string")
+ }
+}
+
+func TestImport(t *testing.T) {
+ // Test a basic import
+ main := strings.NewReader("!import x/y.yaml")
+ fs := fstest.MapFS{
+ // Test a glob import with a relative path
+ "x/y.yaml": {Data: []byte("!import y/*.yaml")},
+ "x/y/z.yaml": {Data: []byte("42")},
+ }
+ cl, err := Read(main, "x.yaml", ReadOpts{FS: fs})
+ if err != nil {
+ t.Fatal(err)
+ }
+ x := 42
+ checkDecode(t, oneValue(t, cl), &x)
+}
+
+func TestImportEscape(t *testing.T) {
+ // Make sure an import can't escape its subdirectory.
+ main := strings.NewReader("!import x/y.yaml")
+ fs := fstest.MapFS{
+ "x/y.yaml": {Data: []byte("!import ../y/*.yaml")},
+ "y/z.yaml": {Data: []byte("42")},
+ }
+ _, err := Read(main, "x.yaml", ReadOpts{FS: fs})
+ if err == nil {
+ t.Fatal("relative !import should have failed")
+ }
+ if !strings.Contains(err.Error(), "must not contain") {
+ t.Fatalf("unexpected error %v", err)
+ }
+}
+
+func TestImportScope(t *testing.T) {
+ // Test that imports have different variable scopes.
+ main := strings.NewReader("[!import y.yaml, !import y.yaml]")
+ fs := fstest.MapFS{
+ "y.yaml": {Data: []byte("$v")},
+ }
+ cl1, err := Read(main, "x.yaml", ReadOpts{FS: fs})
+ if err != nil {
+ t.Fatal(err)
+ }
+ cl2 := mustParse("[1, 2]")
+ res, err := Unify(cl1, cl2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ checkDecode(t, oneValue(t, res), []int{1, 2})
+}