}
ir.EditChildren(m, edit)
- if x.Typecheck() == 3 {
- // These are nodes whose transforms were delayed until
- // their instantiated type was known.
- m.SetTypecheck(1)
- if typecheck.IsCmp(x.Op()) {
- transformCompare(m.(*ir.BinaryExpr))
- } else {
- switch x.Op() {
- case ir.OSLICE, ir.OSLICE3:
- transformSlice(m.(*ir.SliceExpr))
-
- case ir.OADD:
- m = transformAdd(m.(*ir.BinaryExpr))
-
- case ir.OINDEX:
- transformIndex(m.(*ir.IndexExpr))
-
- case ir.OAS2:
- as2 := m.(*ir.AssignListStmt)
- transformAssign(as2, as2.Lhs, as2.Rhs)
-
- case ir.OAS:
- as := m.(*ir.AssignStmt)
+ m.SetTypecheck(1)
+ if typecheck.IsCmp(x.Op()) {
+ transformCompare(m.(*ir.BinaryExpr))
+ } else {
+ switch x.Op() {
+ case ir.OSLICE, ir.OSLICE3:
+ transformSlice(m.(*ir.SliceExpr))
+
+ case ir.OADD:
+ m = transformAdd(m.(*ir.BinaryExpr))
+
+ case ir.OINDEX:
+ transformIndex(m.(*ir.IndexExpr))
+
+ case ir.OAS2:
+ as2 := m.(*ir.AssignListStmt)
+ transformAssign(as2, as2.Lhs, as2.Rhs)
+
+ case ir.OAS:
+ as := m.(*ir.AssignStmt)
+ if as.Y != nil {
+ // transformAssign doesn't handle the case
+ // of zeroing assignment of a dcl (rhs[0] is nil).
lhs, rhs := []ir.Node{as.X}, []ir.Node{as.Y}
transformAssign(as, lhs, rhs)
+ }
- case ir.OASOP:
- as := m.(*ir.AssignOpStmt)
- transformCheckAssign(as, as.X)
+ case ir.OASOP:
+ as := m.(*ir.AssignOpStmt)
+ transformCheckAssign(as, as.X)
- case ir.ORETURN:
- transformReturn(m.(*ir.ReturnStmt))
+ case ir.ORETURN:
+ transformReturn(m.(*ir.ReturnStmt))
- case ir.OSEND:
- transformSend(m.(*ir.SendStmt))
+ case ir.OSEND:
+ transformSend(m.(*ir.SendStmt))
- default:
- base.Fatalf("Unexpected node with Typecheck() == 3")
- }
}
}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package a
+
+// SliceEqual reports whether two slices are equal: the same length and all
+// elements equal. All floating point NaNs are considered equal.
+func SliceEqual[Elem comparable](s1, s2 []Elem) bool {
+ if len(s1) != len(s2) {
+ return false
+ }
+ for i, v1 := range s1 {
+ v2 := s2[i]
+ if v1 != v2 {
+ isNaN := func(f Elem) bool { return f != f }
+ if !isNaN(v1) || !isNaN(v2) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// Keys returns the keys of the map m.
+// The keys will be an indeterminate order.
+func Keys[K comparable, V any](m map[K]V) []K {
+ r := make([]K, 0, len(m))
+ for k := range m {
+ r = append(r, k)
+ }
+ return r
+}
+
+// Values returns the values of the map m.
+// The values will be in an indeterminate order.
+func Values[K comparable, V any](m map[K]V) []V {
+ r := make([]V, 0, len(m))
+ for _, v := range m {
+ r = append(r, v)
+ }
+ return r
+}
+
+// Equal reports whether two maps contain the same key/value pairs.
+// Values are compared using ==.
+func Equal[K, V comparable](m1, m2 map[K]V) bool {
+ if len(m1) != len(m2) {
+ return false
+ }
+ for k, v1 := range m1 {
+ if v2, ok := m2[k]; !ok || v1 != v2 {
+ return false
+ }
+ }
+ return true
+}
+
+// Copy returns a copy of m.
+func Copy[K comparable, V any](m map[K]V) map[K]V {
+ r := make(map[K]V, len(m))
+ for k, v := range m {
+ r[k] = v
+ }
+ return r
+}
+
+// Add adds all key/value pairs in m2 to m1. Keys in m2 that are already
+// present in m1 will be overwritten with the value in m2.
+func Add[K comparable, V any](m1, m2 map[K]V) {
+ for k, v := range m2 {
+ m1[k] = v
+ }
+}
+
+// Sub removes all keys in m2 from m1. Keys in m2 that are not present
+// in m1 are ignored. The values in m2 are ignored.
+func Sub[K comparable, V any](m1, m2 map[K]V) {
+ for k := range m2 {
+ delete(m1, k)
+ }
+}
+
+// Intersect removes all keys from m1 that are not present in m2.
+// Keys in m2 that are not in m1 are ignored. The values in m2 are ignored.
+func Intersect[K comparable, V any](m1, m2 map[K]V) {
+ for k := range m1 {
+ if _, ok := m2[k]; !ok {
+ delete(m1, k)
+ }
+ }
+}
+
+// Filter deletes any key/value pairs from m for which f returns false.
+func Filter[K comparable, V any](m map[K]V, f func(K, V) bool) {
+ for k, v := range m {
+ if !f(k, v) {
+ delete(m, k)
+ }
+ }
+}
+
+// TransformValues applies f to each value in m. The keys remain unchanged.
+func TransformValues[K comparable, V any](m map[K]V, f func(V) V) {
+ for k, v := range m {
+ m[k] = f(v)
+ }
+}
--- /dev/null
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "a"
+ "fmt"
+ "math"
+ "sort"
+)
+
+var m1 = map[int]int{1: 2, 2: 4, 4: 8, 8: 16}
+var m2 = map[int]string{1: "2", 2: "4", 4: "8", 8: "16"}
+
+func TestKeys() {
+ want := []int{1, 2, 4, 8}
+
+ got1 := a.Keys(m1)
+ sort.Ints(got1)
+ if !a.SliceEqual(got1, want) {
+ panic(fmt.Sprintf("a.Keys(%v) = %v, want %v", m1, got1, want))
+ }
+
+ got2 := a.Keys(m2)
+ sort.Ints(got2)
+ if !a.SliceEqual(got2, want) {
+ panic(fmt.Sprintf("a.Keys(%v) = %v, want %v", m2, got2, want))
+ }
+}
+
+func TestValues() {
+ got1 := a.Values(m1)
+ want1 := []int{2, 4, 8, 16}
+ sort.Ints(got1)
+ if !a.SliceEqual(got1, want1) {
+ panic(fmt.Sprintf("a.Values(%v) = %v, want %v", m1, got1, want1))
+ }
+
+ got2 := a.Values(m2)
+ want2 := []string{"16", "2", "4", "8"}
+ sort.Strings(got2)
+ if !a.SliceEqual(got2, want2) {
+ panic(fmt.Sprintf("a.Values(%v) = %v, want %v", m2, got2, want2))
+ }
+}
+
+func TestEqual() {
+ if !a.Equal(m1, m1) {
+ panic(fmt.Sprintf("a.Equal(%v, %v) = false, want true", m1, m1))
+ }
+ if a.Equal(m1, nil) {
+ panic(fmt.Sprintf("a.Equal(%v, nil) = true, want false", m1))
+ }
+ if a.Equal(nil, m1) {
+ panic(fmt.Sprintf("a.Equal(nil, %v) = true, want false", m1))
+ }
+ if !a.Equal[int, int](nil, nil) {
+ panic("a.Equal(nil, nil) = false, want true")
+ }
+ if ms := map[int]int{1: 2}; a.Equal(m1, ms) {
+ panic(fmt.Sprintf("a.Equal(%v, %v) = true, want false", m1, ms))
+ }
+
+ // Comparing NaN for equality is expected to fail.
+ mf := map[int]float64{1: 0, 2: math.NaN()}
+ if a.Equal(mf, mf) {
+ panic(fmt.Sprintf("a.Equal(%v, %v) = true, want false", mf, mf))
+ }
+}
+
+func TestCopy() {
+ m2 := a.Copy(m1)
+ if !a.Equal(m1, m2) {
+ panic(fmt.Sprintf("a.Copy(%v) = %v, want %v", m1, m2, m1))
+ }
+ m2[16] = 32
+ if a.Equal(m1, m2) {
+ panic(fmt.Sprintf("a.Equal(%v, %v) = true, want false", m1, m2))
+ }
+}
+
+func TestAdd() {
+ mc := a.Copy(m1)
+ a.Add(mc, mc)
+ if !a.Equal(mc, m1) {
+ panic(fmt.Sprintf("a.Add(%v, %v) = %v, want %v", m1, m1, mc, m1))
+ }
+ a.Add(mc, map[int]int{16: 32})
+ want := map[int]int{1: 2, 2: 4, 4: 8, 8: 16, 16: 32}
+ if !a.Equal(mc, want) {
+ panic(fmt.Sprintf("a.Add result = %v, want %v", mc, want))
+ }
+}
+
+func TestSub() {
+ mc := a.Copy(m1)
+ a.Sub(mc, mc)
+ if len(mc) > 0 {
+ panic(fmt.Sprintf("a.Sub(%v, %v) = %v, want empty map", m1, m1, mc))
+ }
+ mc = a.Copy(m1)
+ a.Sub(mc, map[int]int{1: 0})
+ want := map[int]int{2: 4, 4: 8, 8: 16}
+ if !a.Equal(mc, want) {
+ panic(fmt.Sprintf("a.Sub result = %v, want %v", mc, want))
+ }
+}
+
+func TestIntersect() {
+ mc := a.Copy(m1)
+ a.Intersect(mc, mc)
+ if !a.Equal(mc, m1) {
+ panic(fmt.Sprintf("a.Intersect(%v, %v) = %v, want %v", m1, m1, mc, m1))
+ }
+ a.Intersect(mc, map[int]int{1: 0, 2: 0})
+ want := map[int]int{1: 2, 2: 4}
+ if !a.Equal(mc, want) {
+ panic(fmt.Sprintf("a.Intersect result = %v, want %v", mc, want))
+ }
+}
+
+func TestFilter() {
+ mc := a.Copy(m1)
+ a.Filter(mc, func(int, int) bool { return true })
+ if !a.Equal(mc, m1) {
+ panic(fmt.Sprintf("a.Filter(%v, true) = %v, want %v", m1, mc, m1))
+ }
+ a.Filter(mc, func(k, v int) bool { return k < 3 })
+ want := map[int]int{1: 2, 2: 4}
+ if !a.Equal(mc, want) {
+ panic(fmt.Sprintf("a.Filter result = %v, want %v", mc, want))
+ }
+}
+
+func TestTransformValues() {
+ mc := a.Copy(m1)
+ a.TransformValues(mc, func(i int) int { return i / 2 })
+ want := map[int]int{1: 1, 2: 2, 4: 4, 8: 8}
+ if !a.Equal(mc, want) {
+ panic(fmt.Sprintf("a.TransformValues result = %v, want %v", mc, want))
+ }
+}
+
+func main() {
+ TestKeys()
+ TestValues()
+ TestEqual()
+ TestCopy()
+ TestAdd()
+ TestSub()
+ TestIntersect()
+ TestFilter()
+ TestTransformValues()
+}