]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.typeparams] cmd/compile/internal/types2: implement type terms
authorRobert Griesemer <gri@golang.org>
Wed, 28 Jul 2021 02:14:30 +0000 (19:14 -0700)
committerRobert Griesemer <gri@golang.org>
Thu, 29 Jul 2021 19:33:27 +0000 (19:33 +0000)
Type terms will be used to represent a type set as a list
of type terms. Eventually, a type term may also include
a method set.

Groundwork for the implementation of lazily computed
type sets for union expressions.

Change-Id: Ic88750af21f697ce0b52a2259eff40bee115964c
Reviewed-on: https://go-review.googlesource.com/c/go/+/338049
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
src/cmd/compile/internal/types2/typeterm.go [new file with mode: 0644]
src/cmd/compile/internal/types2/typeterm_test.go [new file with mode: 0644]

diff --git a/src/cmd/compile/internal/types2/typeterm.go b/src/cmd/compile/internal/types2/typeterm.go
new file mode 100644 (file)
index 0000000..59a89cb
--- /dev/null
@@ -0,0 +1,166 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package types2
+
+// TODO(gri) use a different symbol instead of ⊤ for the set of all types
+//           (⊤ is hard to distinguish from T in some fonts)
+
+// A term describes elementary type sets:
+//
+//   ∅:  (*term)(nil)     == ∅                      // set of no types (empty set)
+//   ⊤:  &term{}          == ⊤                      // set of all types
+//   T:  &term{false, T}  == {T}                    // set of type T
+//  ~t:  &term{true, t}   == {t' | under(t') == t}  // set of types with underlying type t
+//
+type term struct {
+       tilde bool // valid if typ != nil
+       typ   Type
+}
+
+func (x *term) String() string {
+       switch {
+       case x == nil:
+               return "∅"
+       case x.typ == nil:
+               return "⊤"
+       case x.tilde:
+               return "~" + x.typ.String()
+       default:
+               return x.typ.String()
+       }
+}
+
+// equal reports whether x and y represent the same type set.
+func (x *term) equal(y *term) bool {
+       // easy cases
+       switch {
+       case x == nil || y == nil:
+               return x == y
+       case x.typ == nil || y.typ == nil:
+               return x.typ == y.typ
+       }
+       // ∅ ⊂ x, y ⊂ ⊤
+
+       return x.tilde == y.tilde && Identical(x.typ, y.typ)
+}
+
+// union returns the union x ∪ y: zero, one, or two non-nil terms.
+func (x *term) union(y *term) (_, _ *term) {
+       // easy cases
+       switch {
+       case x == nil && y == nil:
+               return nil, nil // ∅ ∪ ∅ == ∅
+       case x == nil:
+               return y, nil // ∅ ∪ y == y
+       case y == nil:
+               return x, nil // x ∪ ∅ == x
+       case x.typ == nil:
+               return x, nil // ⊤ ∪ y == ⊤
+       case y.typ == nil:
+               return y, nil // x ∪ ⊤ == ⊤
+       }
+       // ∅ ⊂ x, y ⊂ ⊤
+
+       if x.disjoint(y) {
+               return x, y // x ∪ y == (x, y) if x ∩ y == ∅
+       }
+       // x.typ == y.typ
+
+       // ~t ∪ ~t == ~t
+       // ~t ∪  T == ~t
+       //  T ∪ ~t == ~t
+       //  T ∪  T ==  T
+       if x.tilde || !y.tilde {
+               return x, nil
+       }
+       return y, nil
+}
+
+// intersect returns the intersection x ∩ y.
+func (x *term) intersect(y *term) *term {
+       // easy cases
+       switch {
+       case x == nil || y == nil:
+               return nil // ∅ ∩ y == ∅ and ∩ ∅ == ∅
+       case x.typ == nil:
+               return y // ⊤ ∩ y == y
+       case y.typ == nil:
+               return x // x ∩ ⊤ == x
+       }
+       // ∅ ⊂ x, y ⊂ ⊤
+
+       if x.disjoint(y) {
+               return nil // x ∩ y == ∅ if x ∩ y == ∅
+       }
+       // x.typ == y.typ
+
+       // ~t ∩ ~t == ~t
+       // ~t ∩  T ==  T
+       //  T ∩ ~t ==  T
+       //  T ∩  T ==  T
+       if !x.tilde || y.tilde {
+               return x
+       }
+       return y
+}
+
+// includes reports whether t ∈ x.
+func (x *term) includes(t Type) bool {
+       // easy cases
+       switch {
+       case x == nil:
+               return false // t ∈ ∅ == false
+       case x.typ == nil:
+               return true // t ∈ ⊤ == true
+       }
+       // ∅ ⊂ x ⊂ ⊤
+
+       u := t
+       if x.tilde {
+               u = under(u)
+       }
+       return Identical(x.typ, u)
+}
+
+// subsetOf reports whether x ⊆ y.
+func (x *term) subsetOf(y *term) bool {
+       // easy cases
+       switch {
+       case x == nil:
+               return true // ∅ ⊆ y == true
+       case y == nil:
+               return false // x ⊆ ∅ == false since x != ∅
+       case y.typ == nil:
+               return true // x ⊆ ⊤ == true
+       case x.typ == nil:
+               return false // ⊤ ⊆ y == false since y != ⊤
+       }
+       // ∅ ⊂ x, y ⊂ ⊤
+
+       if x.disjoint(y) {
+               return false // x ⊆ y == false if x ∩ y == ∅
+       }
+       // x.typ == y.typ
+
+       // ~t ⊆ ~t == true
+       // ~t ⊆ T == false
+       //  T ⊆ ~t == true
+       //  T ⊆  T == true
+       return !x.tilde || y.tilde
+}
+
+// disjoint reports whether x ∩ y == ∅.
+// x.typ and y.typ must not be nil.
+func (x *term) disjoint(y *term) bool {
+       ux := x.typ
+       if y.tilde {
+               ux = under(ux)
+       }
+       uy := y.typ
+       if x.tilde {
+               uy = under(uy)
+       }
+       return !Identical(ux, uy)
+}
diff --git a/src/cmd/compile/internal/types2/typeterm_test.go b/src/cmd/compile/internal/types2/typeterm_test.go
new file mode 100644 (file)
index 0000000..4676fb0
--- /dev/null
@@ -0,0 +1,205 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package types2
+
+import (
+       "strings"
+       "testing"
+)
+
+var testTerms = map[string]*term{
+       "∅":       nil,
+       "⊤":       &term{},
+       "int":     &term{false, Typ[Int]},
+       "~int":    &term{true, Typ[Int]},
+       "string":  &term{false, Typ[String]},
+       "~string": &term{true, Typ[String]},
+       // TODO(gri) add a defined type
+}
+
+func TestTermString(t *testing.T) {
+       for want, x := range testTerms {
+               if got := x.String(); got != want {
+                       t.Errorf("%v.String() == %v; want %v", x, got, want)
+               }
+       }
+}
+
+func split(s string, n int) []string {
+       r := strings.Split(s, " ")
+       if len(r) != n {
+               panic("invalid test case: " + s)
+       }
+       return r
+}
+
+func testTerm(name string) *term {
+       r, ok := testTerms[name]
+       if !ok {
+               panic("invalid test argument: " + name)
+       }
+       return r
+}
+
+func TestTermEqual(t *testing.T) {
+       for _, test := range []string{
+               "∅ ∅ T",
+               "⊤ ⊤ T",
+               "int int T",
+               "~int ~int T",
+               "∅ ⊤ F",
+               "∅ int F",
+               "∅ ~int F",
+               "⊤ int F",
+               "⊤ ~int F",
+               "int ~int F",
+       } {
+               args := split(test, 3)
+               x := testTerm(args[0])
+               y := testTerm(args[1])
+               want := args[2] == "T"
+               if got := x.equal(y); got != want {
+                       t.Errorf("%v.equal(%v) = %v; want %v", x, y, got, want)
+               }
+               // equal is symmetric
+               x, y = y, x
+               if got := x.equal(y); got != want {
+                       t.Errorf("%v.equal(%v) = %v; want %v", x, y, got, want)
+               }
+       }
+}
+
+func TestTermUnion(t *testing.T) {
+       for _, test := range []string{
+               "∅ ∅ ∅ ∅",
+               "∅ ⊤ ⊤ ∅",
+               "∅ int int ∅",
+               "∅ ~int ~int ∅",
+               "⊤ ⊤ ⊤ ∅",
+               "⊤ int ⊤ ∅",
+               "⊤ ~int ⊤ ∅",
+               "int int int ∅",
+               "int ~int ~int ∅",
+               "int string int string",
+               "int ~string int ~string",
+               "~int ~string ~int ~string",
+
+               // union is symmetric, but the result order isn't - repeat symmetric cases explictly
+               "⊤ ∅ ⊤ ∅",
+               "int ∅ int ∅",
+               "~int ∅ ~int ∅",
+               "int ⊤ ⊤ ∅",
+               "~int ⊤ ⊤ ∅",
+               "~int int ~int ∅",
+               "string int string int",
+               "~string int ~string int",
+               "~string ~int ~string ~int",
+       } {
+               args := split(test, 4)
+               x := testTerm(args[0])
+               y := testTerm(args[1])
+               want1 := testTerm(args[2])
+               want2 := testTerm(args[3])
+               if got1, got2 := x.union(y); !got1.equal(want1) || !got2.equal(want2) {
+                       t.Errorf("%v.union(%v) = %v, %v; want %v, %v", x, y, got1, got2, want1, want2)
+               }
+       }
+}
+
+func TestTermIntersection(t *testing.T) {
+       for _, test := range []string{
+               "∅ ∅ ∅",
+               "∅ ⊤ ∅",
+               "∅ int ∅",
+               "∅ ~int ∅",
+               "⊤ ⊤ ⊤",
+               "⊤ int int",
+               "⊤ ~int ~int",
+               "int int int",
+               "int ~int int",
+               "int string ∅",
+               "int ~string ∅",
+               "~int ~string ∅",
+       } {
+               args := split(test, 3)
+               x := testTerm(args[0])
+               y := testTerm(args[1])
+               want := testTerm(args[2])
+               if got := x.intersect(y); !got.equal(want) {
+                       t.Errorf("%v.intersect(%v) = %v; want %v", x, y, got, want)
+               }
+               // intersect is symmetric
+               x, y = y, x
+               if got := x.intersect(y); !got.equal(want) {
+                       t.Errorf("%v.intersect(%v) = %v; want %v", x, y, got, want)
+               }
+       }
+}
+
+func TestTermIncludes(t *testing.T) {
+       for _, test := range []string{
+               "∅ int F",
+               "⊤ int T",
+               "int int T",
+               "~int int T",
+               "string int F",
+               "~string int F",
+       } {
+               args := split(test, 3)
+               x := testTerm(args[0])
+               y := testTerm(args[1]).typ
+               want := args[2] == "T"
+               if got := x.includes(y); got != want {
+                       t.Errorf("%v.includes(%v) = %v; want %v", x, y, got, want)
+               }
+       }
+}
+
+func TestTermSubsetOf(t *testing.T) {
+       for _, test := range []string{
+               "∅ ∅ T",
+               "⊤ ⊤ T",
+               "int int T",
+               "~int ~int T",
+               "∅ ⊤ T",
+               "∅ int T",
+               "∅ ~int T",
+               "⊤ int F",
+               "⊤ ~int F",
+               "int ~int T",
+       } {
+               args := split(test, 3)
+               x := testTerm(args[0])
+               y := testTerm(args[1])
+               want := args[2] == "T"
+               if got := x.subsetOf(y); got != want {
+                       t.Errorf("%v.subsetOf(%v) = %v; want %v", x, y, got, want)
+               }
+       }
+}
+
+func TestTermDisjoint(t *testing.T) {
+       for _, test := range []string{
+               "int int F",
+               "~int ~int F",
+               "int ~int F",
+               "int string T",
+               "int ~string T",
+               "~int ~string T",
+       } {
+               args := split(test, 3)
+               x := testTerm(args[0])
+               y := testTerm(args[1])
+               want := args[2] == "T"
+               if got := x.disjoint(y); got != want {
+                       t.Errorf("%v.disjoint(%v) = %v; want %v", x, y, got, want)
+               }
+               // disjoint is symmetric
+               x, y = y, x
+               if got := x.disjoint(y); got != want {
+                       t.Errorf("%v.disjoint(%v) = %v; want %v", x, y, got, want)
+               }
+       }
+}