From 33a1a93a92804205eca89e2bb113ca68c1de5a4f Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Wed, 8 Dec 2021 20:32:29 -0800 Subject: [PATCH] cmd/compile/internal/syntax: fix parsing of type parameter lists The parser cannot distinguish a type parameter list of the form [P *T ] or [P (T)] where T is not a type literal from an array length specification P*T (product) or P(T) (constant-valued function call) and thus interprets these forms as the start of array types. This ambiguity must be resolved explicitly by placing *T inside an interface, adding a trailing comma, or by leaving parentheses away where possible. This CL adjusts the parser such that these forms are interpreted as (the beginning) of type parameter lists if the token after P*T or P(T) is a comma, or if T is a type literal. This CL also adjusts the printer to print a comma if necessary to avoid this ambiguity, and adds additional printer tests. Fixes #49482 Change-Id: I36328e2a7d9439c39ba0349837c445542549e84e Reviewed-on: https://go-review.googlesource.com/c/go/+/370774 Trust: Robert Griesemer Run-TryBot: Robert Griesemer Reviewed-by: Robert Findley TryBot-Result: Gopher Robot --- src/cmd/compile/internal/syntax/parser.go | 149 ++++++++++++++---- src/cmd/compile/internal/syntax/printer.go | 55 ++++--- .../compile/internal/syntax/printer_test.go | 77 +++++---- .../internal/syntax/testdata/issue49482.go2 | 31 ++++ 4 files changed, 227 insertions(+), 85 deletions(-) create mode 100644 src/cmd/compile/internal/syntax/testdata/issue49482.go2 diff --git a/src/cmd/compile/internal/syntax/parser.go b/src/cmd/compile/internal/syntax/parser.go index 770175fe54..40c5eca408 100644 --- a/src/cmd/compile/internal/syntax/parser.go +++ b/src/cmd/compile/internal/syntax/parser.go @@ -588,44 +588,81 @@ func (p *parser) typeDecl(group *Group) Decl { d.Name = p.name() if p.allowGenerics() && p.tok == _Lbrack { // d.Name "[" ... - // array/slice or type parameter list + // array/slice type or type parameter list pos := p.pos() p.next() switch p.tok { case _Name: - // d.Name "[" name ... - // array or type parameter list - name := p.name() - // Index or slice expressions are never constant and thus invalid - // array length expressions. Thus, if we see a "[" following name - // we can safely assume that "[" name starts a type parameter list. - var x Expr // x != nil means x is the array length expression + // We may have an array type or a type parameter list. + // In either case we expect an expression x (which may + // just be a name, or a more complex expression) which + // we can analyze further. + // + // A type parameter list may have a type bound starting + // with a "[" as in: P []E. In that case, simply parsing + // an expression would lead to an error: P[] is invalid. + // But since index or slice expressions are never constant + // and thus invalid array length expressions, if we see a + // "[" following a name it must be the start of an array + // or slice constraint. Only if we don't see a "[" do we + // need to parse a full expression. + var x Expr = p.name() if p.tok != _Lbrack { - // d.Name "[" name ... - // If we reach here, the next token is not a "[", and we need to - // parse the expression starting with name. If that expression is - // just that name, not followed by a "]" (in which case we might - // have the array length "[" name "]"), we can also safely assume - // a type parameter list. + // To parse the expression starting with name, expand + // the call sequence we would get by passing in name + // to parser.expr, and pass in name to parser.pexpr. p.xnest++ - // To parse the expression starting with name, expand the call - // sequence we would get by passing in name to parser.expr, and - // pass in name to parser.pexpr. - x = p.binaryExpr(p.pexpr(name, false), 0) + x = p.binaryExpr(p.pexpr(x, false), 0) p.xnest-- - if x == name && p.tok != _Rbrack { - x = nil + } + + // analyze the cases + var pname *Name // pname != nil means pname is the type parameter name + var ptype Expr // ptype != nil means ptype is the type parameter type; pname != nil in this case + switch t := x.(type) { + case *Name: + // Unless we see a "]", we are at the start of a type parameter list. + if p.tok != _Rbrack { + // d.Name "[" name ... + pname = t + // no ptype + } + case *Operation: + // If we have an expression of the form name*T, and T is a (possibly + // parenthesized) type literal or the next token is a comma, we are + // at the start of a type parameter list. + if name, _ := t.X.(*Name); name != nil { + if t.Op == Mul && (isTypeLit(t.Y) || p.tok == _Comma) { + // d.Name "[" name "*" t.Y + // d.Name "[" name "*" t.Y "," + t.X, t.Y = t.Y, nil // convert t into unary *t.Y + pname = name + ptype = t + } + } + case *CallExpr: + // If we have an expression of the form name(T), and T is a (possibly + // parenthesized) type literal or the next token is a comma, we are + // at the start of a type parameter list. + if name, _ := t.Fun.(*Name); name != nil { + if len(t.ArgList) == 1 && !t.HasDots && (isTypeLit(t.ArgList[0]) || p.tok == _Comma) { + // d.Name "[" name "(" t.ArgList[0] ")" + // d.Name "[" name "(" t.ArgList[0] ")" "," + pname = name + ptype = t.ArgList[0] + } } } - if x == nil { - // d.Name "[" name ... - // type parameter list - d.TParamList = p.paramList(name, _Rbrack, true) + + if pname != nil { + // d.Name "[" pname ... + // d.Name "[" pname ptype ... + // d.Name "[" pname ptype "," ... + d.TParamList = p.paramList(pname, ptype, _Rbrack, true) d.Alias = p.gotAssign() d.Type = p.typeOrNil() } else { - // d.Name "[" x "]" ... - // x is the array length expression + // d.Name "[" x ... d.Type = p.arrayType(pos, x) } case _Rbrack: @@ -650,6 +687,21 @@ func (p *parser) typeDecl(group *Group) Decl { return d } +// isTypeLit reports whether x is a (possibly parenthesized) type literal. +func isTypeLit(x Expr) bool { + switch x := x.(type) { + case *ArrayType, *StructType, *FuncType, *InterfaceType, *SliceType, *MapType, *ChanType: + return true + case *Operation: + // *T may be a pointer dereferenciation. + // Only consider *T as type literal if T is a type literal. + return x.Op == Mul && x.Y == nil && isTypeLit(x.X) + case *ParenExpr: + return isTypeLit(x.X) + } + return false +} + // VarSpec = IdentifierList ( Type [ "=" ExpressionList ] | "=" ExpressionList ) . func (p *parser) varDecl(group *Group) Decl { if trace { @@ -689,7 +741,7 @@ func (p *parser) funcDeclOrNil() *FuncDecl { f.Pragma = p.takePragma() if p.got(_Lparen) { - rcvr := p.paramList(nil, _Rparen, false) + rcvr := p.paramList(nil, nil, _Rparen, false) switch len(rcvr) { case 0: p.error("method has no receiver") @@ -1369,12 +1421,12 @@ func (p *parser) funcType(context string) ([]*Field, *FuncType) { p.syntaxError("empty type parameter list") p.next() } else { - tparamList = p.paramList(nil, _Rbrack, true) + tparamList = p.paramList(nil, nil, _Rbrack, true) } } p.want(_Lparen) - typ.ParamList = p.paramList(nil, _Rparen, false) + typ.ParamList = p.paramList(nil, nil, _Rparen, false) typ.ResultList = p.funcResult() return tparamList, typ @@ -1392,6 +1444,13 @@ func (p *parser) arrayType(pos Pos, len Expr) Expr { len = p.expr() p.xnest-- } + if p.tok == _Comma { + // Trailing commas are accepted in type parameter + // lists but not in array type declarations. + // Accept for better error handling but complain. + p.syntaxError("unexpected comma; expecting ]") + p.next() + } p.want(_Rbrack) t := new(ArrayType) t.pos = pos @@ -1516,7 +1575,7 @@ func (p *parser) funcResult() []*Field { } if p.got(_Lparen) { - return p.paramList(nil, _Rparen, false) + return p.paramList(nil, nil, _Rparen, false) } pos := p.pos() @@ -1742,7 +1801,7 @@ func (p *parser) methodDecl() *Field { // A type argument list looks like a parameter list with only // types. Parse a parameter list and decide afterwards. - list := p.paramList(nil, _Rbrack, false) + list := p.paramList(nil, nil, _Rbrack, false) if len(list) == 0 { // The type parameter list is not [] but we got nothing // due to other errors (reported by paramList). Treat @@ -1948,17 +2007,41 @@ func (p *parser) paramDeclOrNil(name *Name, follow token) *Field { // ParameterList = ParameterDecl { "," ParameterDecl } . // "(" or "[" has already been consumed. // If name != nil, it is the first name after "(" or "[". +// If typ != nil, name must be != nil, and (name, typ) is the first field in the list. // In the result list, either all fields have a name, or no field has a name. -func (p *parser) paramList(name *Name, close token, requireNames bool) (list []*Field) { +func (p *parser) paramList(name *Name, typ Expr, close token, requireNames bool) (list []*Field) { if trace { defer p.trace("paramList")() } + // p.list won't invoke its function argument if we're at the end of the + // parameter list. If we have a complete field, handle this case here. + if name != nil && typ != nil && p.tok == close { + p.next() + par := new(Field) + par.pos = name.pos + par.Name = name + par.Type = typ + return []*Field{par} + } + var named int // number of parameters that have an explicit name and type var typed int // number of parameters that have an explicit type end := p.list(_Comma, close, func() bool { - par := p.paramDeclOrNil(name, close) + var par *Field + if typ != nil { + if debug && name == nil { + panic("initial type provided without name") + } + par = new(Field) + par.pos = name.pos + par.Name = name + par.Type = typ + } else { + par = p.paramDeclOrNil(name, close) + } name = nil // 1st name was consumed if present + typ = nil // 1st type was consumed if present if par != nil { if debug && par.Name == nil && par.Type == nil { panic("parameter without name or type") diff --git a/src/cmd/compile/internal/syntax/printer.go b/src/cmd/compile/internal/syntax/printer.go index c8d31799af..11190ab287 100644 --- a/src/cmd/compile/internal/syntax/printer.go +++ b/src/cmd/compile/internal/syntax/printer.go @@ -666,9 +666,7 @@ func (p *printer) printRawNode(n Node) { } p.print(n.Name) if n.TParamList != nil { - p.print(_Lbrack) - p.printFieldList(n.TParamList, nil, _Comma) - p.print(_Rbrack) + p.printParameterList(n.TParamList, true) } p.print(blank) if n.Alias { @@ -700,9 +698,7 @@ func (p *printer) printRawNode(n Node) { } p.print(n.Name) if n.TParamList != nil { - p.print(_Lbrack) - p.printFieldList(n.TParamList, nil, _Comma) - p.print(_Rbrack) + p.printParameterList(n.TParamList, true) } p.printSignature(n.Type) if n.Body != nil { @@ -887,38 +883,47 @@ func (p *printer) printDeclList(list []Decl) { } func (p *printer) printSignature(sig *FuncType) { - p.printParameterList(sig.ParamList) + p.printParameterList(sig.ParamList, false) if list := sig.ResultList; list != nil { p.print(blank) if len(list) == 1 && list[0].Name == nil { p.printNode(list[0].Type) } else { - p.printParameterList(list) + p.printParameterList(list, false) } } } -func (p *printer) printParameterList(list []*Field) { - p.print(_Lparen) - if len(list) > 0 { - for i, f := range list { - if i > 0 { - p.print(_Comma, blank) - } - if f.Name != nil { - p.printNode(f.Name) - if i+1 < len(list) { - f1 := list[i+1] - if f1.Name != nil && f1.Type == f.Type { - continue // no need to print type - } +func (p *printer) printParameterList(list []*Field, types bool) { + open, close := _Lparen, _Rparen + if types { + open, close = _Lbrack, _Rbrack + } + p.print(open) + for i, f := range list { + if i > 0 { + p.print(_Comma, blank) + } + if f.Name != nil { + p.printNode(f.Name) + if i+1 < len(list) { + f1 := list[i+1] + if f1.Name != nil && f1.Type == f.Type { + continue // no need to print type } - p.print(blank) } - p.printNode(f.Type) + p.print(blank) + } + p.printNode(unparen(f.Type)) // no need for (extra) parentheses around parameter types + } + // A type parameter list [P *T] where T is not a type literal requires a comma as in [P *T,] + // so that it's not parsed as [P*T]. + if types && len(list) == 1 { + if t, _ := list[0].Type.(*Operation); t != nil && t.Op == Mul && t.Y == nil && !isTypeLit(t.X) { + p.print(_Comma) } } - p.print(_Rparen) + p.print(close) } func (p *printer) printStmtList(list []Stmt, braces bool) { diff --git a/src/cmd/compile/internal/syntax/printer_test.go b/src/cmd/compile/internal/syntax/printer_test.go index 604f1fc1ca..941af0aeb4 100644 --- a/src/cmd/compile/internal/syntax/printer_test.go +++ b/src/cmd/compile/internal/syntax/printer_test.go @@ -53,54 +53,77 @@ func TestPrintError(t *testing.T) { } } -var stringTests = []string{ - "package p", - "package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )", +var stringTests = [][2]string{ + dup("package p"), + dup("package p; type _ int; type T1 = struct{}; type ( _ *struct{}; T2 = float32 )"), // generic type declarations - "package p; type _[T any] struct{}", - "package p; type _[A, B, C interface{m()}] struct{}", - "package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}] struct{}", + dup("package p; type _[T any] struct{}"), + dup("package p; type _[A, B, C interface{m()}] struct{}"), + dup("package p; type _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}] struct{}"), + + dup("package p; type _[P *T,] struct{}"), + dup("package p; type _[P *T, _ any] struct{}"), + {"package p; type _[P (*T),] struct{}", "package p; type _[P *T,] struct{}"}, + {"package p; type _[P (*T), _ any] struct{}", "package p; type _[P *T, _ any] struct{}"}, + {"package p; type _[P (T),] struct{}", "package p; type _[P T] struct{}"}, + {"package p; type _[P (T), _ any] struct{}", "package p; type _[P T, _ any] struct{}"}, + + dup("package p; type _[P *struct{}] struct{}"), + {"package p; type _[P (*struct{})] struct{}", "package p; type _[P *struct{}] struct{}"}, + {"package p; type _[P ([]int)] struct{}", "package p; type _[P []int] struct{}"}, + + dup("package p; type _ [P(T)]struct{}"), + dup("package p; type _ [P((T))]struct{}"), + dup("package p; type _ [P * *T]struct{}"), + dup("package p; type _ [P * T]struct{}"), + dup("package p; type _ [P(*T)]struct{}"), + dup("package p; type _ [P(**T)]struct{}"), + dup("package p; type _ [P * T - T]struct{}"), + + // array type declarations + dup("package p; type _ [P * T]struct{}"), + dup("package p; type _ [P * T - T]struct{}"), // generic function declarations - "package p; func _[T any]()", - "package p; func _[A, B, C interface{m()}]()", - "package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}]()", + dup("package p; func _[T any]()"), + dup("package p; func _[A, B, C interface{m()}]()"), + dup("package p; func _[T any, A, B, C interface{m()}, X, Y, Z interface{~int}]()"), // methods with generic receiver types - "package p; func (R[T]) _()", - "package p; func (*R[A, B, C]) _()", - "package p; func (_ *R[A, B, C]) _()", + dup("package p; func (R[T]) _()"), + dup("package p; func (*R[A, B, C]) _()"), + dup("package p; func (_ *R[A, B, C]) _()"), // type constraint literals with elided interfaces - "package p; func _[P ~int, Q int | string]() {}", - "package p; func _[P struct{f int}, Q *P]() {}", + dup("package p; func _[P ~int, Q int | string]() {}"), + dup("package p; func _[P struct{f int}, Q *P]() {}"), // channels - "package p; type _ chan chan int", - "package p; type _ chan (<-chan int)", - "package p; type _ chan chan<- int", + dup("package p; type _ chan chan int"), + dup("package p; type _ chan (<-chan int)"), + dup("package p; type _ chan chan<- int"), - "package p; type _ <-chan chan int", - "package p; type _ <-chan <-chan int", - "package p; type _ <-chan chan<- int", + dup("package p; type _ <-chan chan int"), + dup("package p; type _ <-chan <-chan int"), + dup("package p; type _ <-chan chan<- int"), - "package p; type _ chan<- chan int", - "package p; type _ chan<- <-chan int", - "package p; type _ chan<- chan<- int", + dup("package p; type _ chan<- chan int"), + dup("package p; type _ chan<- <-chan int"), + dup("package p; type _ chan<- chan<- int"), // TODO(gri) expand } func TestPrintString(t *testing.T) { - for _, want := range stringTests { - ast, err := Parse(nil, strings.NewReader(want), nil, nil, AllowGenerics) + for _, test := range stringTests { + ast, err := Parse(nil, strings.NewReader(test[0]), nil, nil, AllowGenerics) if err != nil { t.Error(err) continue } - if got := String(ast); got != want { - t.Errorf("%q: got %q", want, got) + if got := String(ast); got != test[1] { + t.Errorf("%q: got %q", test[1], got) } } } diff --git a/src/cmd/compile/internal/syntax/testdata/issue49482.go2 b/src/cmd/compile/internal/syntax/testdata/issue49482.go2 new file mode 100644 index 0000000000..1fc303d169 --- /dev/null +++ b/src/cmd/compile/internal/syntax/testdata/issue49482.go2 @@ -0,0 +1,31 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +type ( + // these need a comma to disambiguate + _[P *T,] struct{} + _[P *T, _ any] struct{} + _[P (*T),] struct{} + _[P (*T), _ any] struct{} + _[P (T),] struct{} + _[P (T), _ any] struct{} + + // these parse as name followed by type + _[P *struct{}] struct{} + _[P (*struct{})] struct{} + _[P ([]int)] struct{} + + // array declarations + _ [P(T)]struct{} + _ [P((T))]struct{} + _ [P * *T] struct{} // this could be a name followed by a type but it makes the rules more complicated + _ [P * T]struct{} + _ [P(*T)]struct{} + _ [P(**T)]struct{} + _ [P * T - T]struct{} + _ [P*T-T /* ERROR unexpected comma */ ,]struct{} + _ [10 /* ERROR unexpected comma */ ,]struct{} +) -- 2.50.0