]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile/internal/syntax: print type parameters and type lists
authorRobert Griesemer <gri@golang.org>
Thu, 4 Mar 2021 02:33:45 +0000 (18:33 -0800)
committerRobert Griesemer <gri@golang.org>
Thu, 4 Mar 2021 22:20:29 +0000 (22:20 +0000)
types2 uses the syntax printer to print expressions (for tracing
or error messages), so we need to (at least) print type lists in
interfaces.

While at it, also implement the printing of type parameter lists.

Fixes #44766.

Change-Id: I36a4a7152d9bef7251af264b5c7890aca88d8dc3
Reviewed-on: https://go-review.googlesource.com/c/go/+/298549
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
src/cmd/compile/internal/syntax/printer.go
src/cmd/compile/internal/syntax/printer_test.go

index 9109ce2363c38bdacc0aa562dcd470df1605e947..e557f5d9247b521a67984f5b60d8963174b45a1d 100644 (file)
@@ -481,10 +481,10 @@ func (p *printer) printRawNode(n Node) {
                if len(n.FieldList) > 0 {
                        if p.linebreaks {
                                p.print(newline, indent)
-                               p.printFieldList(n.FieldList, n.TagList)
+                               p.printFieldList(n.FieldList, n.TagList, _Semi)
                                p.print(outdent, newline)
                        } else {
-                               p.printFieldList(n.FieldList, n.TagList)
+                               p.printFieldList(n.FieldList, n.TagList, _Semi)
                        }
                }
                p.print(_Rbrace)
@@ -494,20 +494,40 @@ func (p *printer) printRawNode(n Node) {
                p.printSignature(n)
 
        case *InterfaceType:
+               // separate type list and method list
+               var types []Expr
+               var methods []*Field
+               for _, f := range n.MethodList {
+                       if f.Name != nil && f.Name.Value == "type" {
+                               types = append(types, f.Type)
+                       } else {
+                               // method or embedded interface
+                               methods = append(methods, f)
+                       }
+               }
+
+               multiLine := len(n.MethodList) > 0 && p.linebreaks
                p.print(_Interface)
-               if len(n.MethodList) > 0 && p.linebreaks {
+               if multiLine {
                        p.print(blank)
                }
                p.print(_Lbrace)
-               if len(n.MethodList) > 0 {
-                       if p.linebreaks {
-                               p.print(newline, indent)
-                               p.printMethodList(n.MethodList)
-                               p.print(outdent, newline)
-                       } else {
-                               p.printMethodList(n.MethodList)
+               if multiLine {
+                       p.print(newline, indent)
+               }
+               if len(types) > 0 {
+                       p.print(_Type, blank)
+                       p.printExprList(types)
+                       if len(methods) > 0 {
+                               p.print(_Semi, blank)
                        }
                }
+               if len(methods) > 0 {
+                       p.printMethodList(methods)
+               }
+               if multiLine {
+                       p.print(outdent, newline)
+               }
                p.print(_Rbrace)
 
        case *MapType:
@@ -667,7 +687,13 @@ func (p *printer) printRawNode(n Node) {
                if n.Group == nil {
                        p.print(_Type, blank)
                }
-               p.print(n.Name, blank)
+               p.print(n.Name)
+               if n.TParamList != nil {
+                       p.print(_Lbrack)
+                       p.printFieldList(n.TParamList, nil, _Comma)
+                       p.print(_Rbrack)
+               }
+               p.print(blank)
                if n.Alias {
                        p.print(_Assign, blank)
                }
@@ -696,6 +722,11 @@ func (p *printer) printRawNode(n Node) {
                        p.print(_Rparen, blank)
                }
                p.print(n.Name)
+               if n.TParamList != nil {
+                       p.print(_Lbrack)
+                       p.printFieldList(n.TParamList, nil, _Comma)
+                       p.print(_Rbrack)
+               }
                p.printSignature(n.Type)
                if n.Body != nil {
                        p.print(blank, n.Body)
@@ -746,14 +777,14 @@ func (p *printer) printFields(fields []*Field, tags []*BasicLit, i, j int) {
        }
 }
 
-func (p *printer) printFieldList(fields []*Field, tags []*BasicLit) {
+func (p *printer) printFieldList(fields []*Field, tags []*BasicLit, sep token) {
        i0 := 0
        var typ Expr
        for i, f := range fields {
                if f.Name == nil || f.Type != typ {
                        if i0 < i {
                                p.printFields(fields, tags, i0, i)
-                               p.print(_Semi, newline)
+                               p.print(sep, newline)
                                i0 = i
                        }
                        typ = f.Type
index bcae815a4680aeede496d72842bf6bf40b590f0a..4890327595d9031e2f936fb7b90872f595f9c497 100644 (file)
@@ -61,6 +61,21 @@ var stringTests = []string{
        "package p",
        "package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )",
 
+       // generic type declarations
+       "package p; type _[T any] struct{}",
+       "package p; type _[A, B, C interface{m()}] struct{}",
+       "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}] struct{}",
+
+       // generic function declarations
+       "package p; func _[T any]()",
+       "package p; func _[A, B, C interface{m()}]()",
+       "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{type int}]()",
+
+       // methods with generic receiver types
+       "package p; func (R[T]) _()",
+       "package p; func (*R[A, B, C]) _()",
+       "package p; func (_ *R[A, B, C]) _()",
+
        // channels
        "package p; type _ chan chan int",
        "package p; type _ chan (<-chan int)",
@@ -79,7 +94,7 @@ var stringTests = []string{
 
 func TestPrintString(t *testing.T) {
        for _, want := range stringTests {
-               ast, err := Parse(nil, strings.NewReader(want), nil, nil, 0)
+               ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics)
                if err != nil {
                        t.Error(err)
                        continue
@@ -116,6 +131,24 @@ var exprTests = [][2]string{
        {"func(x int) complex128 { return 0 }", "func(x int) complex128 {…}"},
        {"[]int{1, 2, 3}", "[]int{…}"},
 
+       // type expressions
+       dup("[1 << 10]byte"),
+       dup("[]int"),
+       dup("*int"),
+       dup("struct{x int}"),
+       dup("func()"),
+       dup("func(int, float32) string"),
+       dup("interface{m()}"),
+       dup("interface{m() string; n(x int)}"),
+       dup("interface{type int}"),
+       dup("interface{type int, float64, string}"),
+       dup("interface{type int; m()}"),
+       dup("interface{type int, float64, string; m() string; n(x int)}"),
+       dup("map[string]int"),
+       dup("chan E"),
+       dup("<-chan E"),
+       dup("chan<- E"),
+
        // non-type expressions
        dup("(x)"),
        dup("x.f"),
@@ -172,7 +205,7 @@ var exprTests = [][2]string{
 func TestShortString(t *testing.T) {
        for _, test := range exprTests {
                src := "package p; var _ = " + test[0]
-               ast, err := Parse(nil, strings.NewReader(src), nil, nil, 0)
+               ast, err := Parse(nil, strings.NewReader(src), nil, nil, AllowGenerics)
                if err != nil {
                        t.Errorf("%s: %s", test[0], err)
                        continue