]> Cypherpunks repositories - gostls13.git/commitdiff
go/parser, go/printer: fix parsing of ambiguous type parameter lists
authorRobert Findley <rfindley@google.com>
Mon, 14 Feb 2022 03:48:39 +0000 (22:48 -0500)
committerRobert Findley <rfindley@google.com>
Tue, 15 Feb 2022 01:01:15 +0000 (01:01 +0000)
This is a port of CL 370774 to go/parser and go/printer. It is adjusted
for the slightly different factoring of parameter list parsing and
printing in go/parser and go/printer.

For #49482

Change-Id: I1c5b1facddbfcb7f7b2be356c817fc7e608223f1
Reviewed-on: https://go-review.googlesource.com/c/go/+/385575
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
src/go/parser/parser.go
src/go/parser/short_test.go
src/go/parser/testdata/issue49482.go2 [new file with mode: 0644]
src/go/parser/testdata/typeparams.src
src/go/printer/nodes.go
src/go/printer/testdata/generics.golden
src/go/printer/testdata/generics.input

index 4479adb73273059dad8206e50e280bd59ccd6af8..51a3c3e67f9d8dbe9733a449c96f6475715530fe 100644 (file)
@@ -543,6 +543,13 @@ func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType {
                }
                p.exprLev--
        }
+       if p.tok == token.COMMA {
+               // Trailing commas are accepted in type parameter
+               // lists but not in array type declarations.
+               // Accept for better error handling but complain.
+               p.error(p.pos, "unexpected comma; expecting ]")
+               p.next()
+       }
        p.expect(token.RBRACK)
        elt := p.parseType()
        return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt}
@@ -797,7 +804,7 @@ func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) {
        return
 }
 
-func (p *parser) parseParameterList(name0 *ast.Ident, closing token.Token) (params []*ast.Field) {
+func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token) (params []*ast.Field) {
        if p.trace {
                defer un(trace(p, "ParameterList"))
        }
@@ -816,8 +823,17 @@ func (p *parser) parseParameterList(name0 *ast.Ident, closing token.Token) (para
        var named int // number of parameters that have an explicit name and type
 
        for name0 != nil || p.tok != closing && p.tok != token.EOF {
-               par := p.parseParamDecl(name0, typeSetsOK)
+               var par field
+               if typ0 != nil {
+                       if typeSetsOK {
+                               typ0 = p.embeddedElem(typ0)
+                       }
+                       par = field{name0, typ0}
+               } else {
+                       par = p.parseParamDecl(name0, typeSetsOK)
+               }
                name0 = nil // 1st name was consumed if present
+               typ0 = nil  // 1st typ was consumed if present
                if par.name != nil || par.typ != nil {
                        list = append(list, par)
                        if par.name != nil && par.typ != nil {
@@ -926,7 +942,7 @@ func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.Field
                opening := p.pos
                p.next()
                // [T any](params) syntax
-               list := p.parseParameterList(nil, token.RBRACK)
+               list := p.parseParameterList(nil, nil, token.RBRACK)
                rbrack := p.expect(token.RBRACK)
                tparams = &ast.FieldList{Opening: opening, List: list, Closing: rbrack}
                // Type parameter lists must not be empty.
@@ -940,7 +956,7 @@ func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.Field
 
        var fields []*ast.Field
        if p.tok != token.RPAREN {
-               fields = p.parseParameterList(nil, token.RPAREN)
+               fields = p.parseParameterList(nil, nil, token.RPAREN)
        }
 
        rparen := p.expect(token.RPAREN)
@@ -1007,7 +1023,7 @@ func (p *parser) parseMethodSpec() *ast.Field {
                                //
                                // Interface methods do not have type parameters. We parse them for a
                                // better error message and improved error recovery.
-                               _ = p.parseParameterList(name0, token.RBRACK)
+                               _ = p.parseParameterList(name0, nil, token.RBRACK)
                                _ = p.expect(token.RBRACK)
                                p.error(lbrack, "interface method must have no type parameters")
 
@@ -1784,7 +1800,12 @@ func (p *parser) tokPrec() (token.Token, int) {
        return tok, tok.Precedence()
 }
 
-func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
+// parseBinaryExpr parses a (possibly) binary expression.
+// If x is non-nil, it is used as the left operand.
+// If check is true, operands are checked to be valid expressions.
+//
+// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring.
+func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int, check bool) ast.Expr {
        if p.trace {
                defer un(trace(p, "BinaryExpr"))
        }
@@ -1798,11 +1819,32 @@ func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
                        return x
                }
                pos := p.expect(op)
-               y := p.parseBinaryExpr(nil, oprec+1)
-               x = &ast.BinaryExpr{X: p.checkExpr(x), OpPos: pos, Op: op, Y: p.checkExpr(y)}
+               y := p.parseBinaryExpr(nil, oprec+1, check)
+               if check {
+                       x = p.checkExpr(x)
+                       y = p.checkExpr(y)
+               }
+               x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y}
        }
 }
 
+// checkBinaryExpr checks binary expressions that were not already checked by
+// parseBinaryExpr, because the latter was called with check=false.
+func (p *parser) checkBinaryExpr(x ast.Expr) {
+       bx, ok := x.(*ast.BinaryExpr)
+       if !ok {
+               return
+       }
+
+       bx.X = p.checkExpr(bx.X)
+       bx.Y = p.checkExpr(bx.Y)
+
+       // parseBinaryExpr checks x and y for each binary expr in a tree, so we
+       // traverse the tree of binary exprs starting from x.
+       p.checkBinaryExpr(bx.X)
+       p.checkBinaryExpr(bx.Y)
+}
+
 // The result may be a type or even a raw type ([...]int). Callers must
 // check the result (using checkExpr or checkExprOrType), depending on
 // context.
@@ -1811,7 +1853,7 @@ func (p *parser) parseExpr() ast.Expr {
                defer un(trace(p, "Expression"))
        }
 
-       return p.parseBinaryExpr(nil, token.LowestPrec+1)
+       return p.parseBinaryExpr(nil, token.LowestPrec+1, true)
 }
 
 func (p *parser) parseRhs() ast.Expr {
@@ -2534,12 +2576,12 @@ func (p *parser) parseValueSpec(doc *ast.CommentGroup, _ token.Pos, keyword toke
        return spec
 }
 
-func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident) {
+func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) {
        if p.trace {
                defer un(trace(p, "parseGenericType"))
        }
 
-       list := p.parseParameterList(name0, token.RBRACK)
+       list := p.parseParameterList(name0, typ0, token.RBRACK)
        closePos := p.expect(token.RBRACK)
        spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos}
        // Let the type checker decide whether to accept type parameters on aliases:
@@ -2564,31 +2606,85 @@ func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Pos, _ token.Token
                lbrack := p.pos
                p.next()
                if p.tok == token.IDENT {
-                       // array type or generic type: [name0...
-                       name0 := p.parseIdent()
+                       // We may have an array type or a type parameter list.
+                       // In either case we expect an expression x (which may
+                       // just be a name, or a more complex expression) which
+                       // we can analyze further.
+                       //
+                       // A type parameter list may have a type bound starting
+                       // with a "[" as in: P []E. In that case, simply parsing
+                       // an expression would lead to an error: P[] is invalid.
+                       // But since index or slice expressions are never constant
+                       // and thus invalid array length expressions, if we see a
+                       // "[" following a name it must be the start of an array
+                       // or slice constraint. Only if we don't see a "[" do we
+                       // need to parse a full expression.
 
                        // Index or slice expressions are never constant and thus invalid
                        // array length expressions. Thus, if we see a "[" following name
                        // we can safely assume that "[" name starts a type parameter list.
-                       var x ast.Expr // x != nil means x is the array length expression
+                       var x ast.Expr = p.parseIdent()
                        if p.tok != token.LBRACK {
-                               // We may still have either an array type or generic type -- check if
-                               // name0 is the entire expr.
+                               // To parse the expression starting with name, expand
+                               // the call sequence we would get by passing in name
+                               // to parser.expr, and pass in name to parsePrimaryExpr.
                                p.exprLev++
-                               lhs := p.parsePrimaryExpr(name0)
-                               x = p.parseBinaryExpr(lhs, token.LowestPrec+1)
+                               lhs := p.parsePrimaryExpr(x)
+                               x = p.parseBinaryExpr(lhs, token.LowestPrec+1, false)
                                p.exprLev--
-                               if x == name0 && p.tok != token.RBRACK {
-                                       x = nil
+                       }
+
+                       // analyze the cases
+                       var pname *ast.Ident // pname != nil means pname is the type parameter name
+                       var ptype ast.Expr   // ptype != nil means ptype is the type parameter type; pname != nil in this case
+
+                       switch t := x.(type) {
+                       case *ast.Ident:
+                               // Unless we see a "]", we are at the start of a type parameter list.
+                               if p.tok != token.RBRACK {
+                                       // d.Name "[" name ...
+                                       pname = t
+                                       // no ptype
+                               }
+                       case *ast.BinaryExpr:
+                               // If we have an expression of the form name*T, and T is a (possibly
+                               // parenthesized) type literal or the next token is a comma, we are
+                               // at the start of a type parameter list.
+                               if name, _ := t.X.(*ast.Ident); name != nil {
+                                       if t.Op == token.MUL && (isTypeLit(t.Y) || p.tok == token.COMMA) {
+                                               // d.Name "[" name "*" t.Y
+                                               // d.Name "[" name "*" t.Y ","
+                                               // convert t into unary *t.Y
+                                               pname = name
+                                               ptype = &ast.StarExpr{Star: t.OpPos, X: t.Y}
+                                       }
+                               }
+                               if pname == nil {
+                                       // A normal binary expression. Since we passed check=false, we must
+                                       // now check its operands.
+                                       p.checkBinaryExpr(t)
+                               }
+                       case *ast.CallExpr:
+                               // If we have an expression of the form name(T), and T is a (possibly
+                               // parenthesized) type literal or the next token is a comma, we are
+                               // at the start of a type parameter list.
+                               if name, _ := t.Fun.(*ast.Ident); name != nil {
+                                       if len(t.Args) == 1 && !t.Ellipsis.IsValid() && (isTypeLit(t.Args[0]) || p.tok == token.COMMA) {
+                                               // d.Name "[" name "(" t.ArgList[0] ")"
+                                               // d.Name "[" name "(" t.ArgList[0] ")" ","
+                                               pname = name
+                                               ptype = t.Args[0]
+                                       }
                                }
                        }
 
-                       if x == nil {
-                               // generic type [T any];
-                               p.parseGenericType(spec, lbrack, name0)
+                       if pname != nil {
+                               // d.Name "[" pname ...
+                               // d.Name "[" pname ptype ...
+                               // d.Name "[" pname ptype "," ...
+                               p.parseGenericType(spec, lbrack, pname, ptype)
                        } else {
-                               // array type
-                               // TODO(rfindley) should resolve all identifiers in x.
+                               // d.Name "[" x ...
                                spec.Type = p.parseArrayType(lbrack, x)
                        }
                } else {
@@ -2611,6 +2707,21 @@ func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Pos, _ token.Token
        return spec
 }
 
+// isTypeLit reports whether x is a (possibly parenthesized) type literal.
+func isTypeLit(x ast.Expr) bool {
+       switch x := x.(type) {
+       case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
+               return true
+       case *ast.StarExpr:
+               // *T may be a pointer dereferenciation.
+               // Only consider *T as type literal if T is a type literal.
+               return isTypeLit(x.X)
+       case *ast.ParenExpr:
+               return isTypeLit(x.X)
+       }
+       return false
+}
+
 func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl {
        if p.trace {
                defer un(trace(p, "GenDecl("+keyword.String()+")"))
index cf4fa0a902e864da69d8001dc1c56de33349d72d..d117f0d3815aad7913b0eb71068e6276f84572a0 100644 (file)
@@ -74,7 +74,7 @@ var validWithTParamsOnly = []string{
        `package p; type T[P any /* ERROR "expected ']', found any" */ ] struct { P }`,
        `package p; type T[P comparable /* ERROR "expected ']', found comparable" */ ] struct { P }`,
        `package p; type T[P comparable /* ERROR "expected ']', found comparable" */ [P]] struct { P }`,
-       `package p; type T[P1, /* ERROR "expected ']', found ','" */ P2 any] struct { P1; f []P2 }`,
+       `package p; type T[P1, /* ERROR "unexpected comma" */ P2 any] struct { P1; f []P2 }`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ T any]()()`,
        `package p; func _(T (P))`,
        `package p; func f[ /* ERROR "expected '\(', found '\['" */ A, B any](); func _() { _ = f[int, int] }`,
@@ -83,8 +83,8 @@ var validWithTParamsOnly = []string{
        `package p; func _(p.T[ /* ERROR "missing ',' in parameter list" */ Q])`,
        `package p; type _[A interface /* ERROR "expected ']', found 'interface'" */ {},] struct{}`,
        `package p; type _[A interface /* ERROR "expected ']', found 'interface'" */ {}] struct{}`,
-       `package p; type _[A, /* ERROR "expected ']', found ','" */  B any,] struct{}`,
-       `package p; type _[A, /* ERROR "expected ']', found ','" */ B any] struct{}`,
+       `package p; type _[A, /* ERROR "unexpected comma" */  B any,] struct{}`,
+       `package p; type _[A, /* ERROR "unexpected comma" */ B any] struct{}`,
        `package p; type _[A any /* ERROR "expected ']', found any" */,] struct{}`,
        `package p; type _[A any /* ERROR "expected ']', found any" */ ]struct{}`,
        `package p; type _[A any /* ERROR "expected ']', found any" */ ] struct{ A }`,
@@ -95,8 +95,8 @@ var validWithTParamsOnly = []string{
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ A, B C](a A) B`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ A, B C[A, B]](a A) B`,
 
-       `package p; type _[A, /* ERROR "expected ']', found ','" */ B any] interface { _(a A) B }`,
-       `package p; type _[A, /* ERROR "expected ']', found ','" */ B C[A, B]] interface { _(a A) B }`,
+       `package p; type _[A, /* ERROR "unexpected comma" */ B any] interface { _(a A) B }`,
+       `package p; type _[A, /* ERROR "unexpected comma" */ B C[A, B]] interface { _(a A) B }`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ T1, T2 interface{}](x T1) T2`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ T1 interface{ m() }, T2, T3 interface{}](x T1, y T3) T2`,
        `package p; var _ = [ /* ERROR "expected expression" */ ]T[int]{}`,
@@ -193,7 +193,7 @@ var invalids = []string{
        `package p; func f() { go func() { func() { f(x func /* ERROR "missing ','" */ (){}) } } }`,
        `package p; func _() (type /* ERROR "found 'type'" */ T)(T)`,
        `package p; func (type /* ERROR "found 'type'" */ T)(T) _()`,
-       `package p; type _[A+B, /* ERROR "expected ']'" */ ] int`,
+       `package p; type _[A+B, /* ERROR "unexpected comma" */ ] int`,
 
        // TODO(rfindley): this error should be positioned on the ':'
        `package p; var a = a[[]int:[ /* ERROR "expected expression" */ ]int];`,
@@ -231,7 +231,7 @@ var invalidNoTParamErrs = []string{
        `package p; type T[P any /* ERROR "expected ']', found any" */ ] = T0`,
        `package p; var _ func[ /* ERROR "expected '\(', found '\['" */ T any](T)`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ ]()`,
-       `package p; type _[A, /* ERROR "expected ']', found ','" */] struct{ A }`,
+       `package p; type _[A, /* ERROR "unexpected comma" */] struct{ A }`,
        `package p; func _[ /* ERROR "expected '\(', found '\['" */ type P, *Q interface{}]()`,
 
        `package p; func (T) _[ /* ERROR "expected '\(', found '\['" */ A, B any](a A) B`,
diff --git a/src/go/parser/testdata/issue49482.go2 b/src/go/parser/testdata/issue49482.go2
new file mode 100644 (file)
index 0000000..50de651
--- /dev/null
@@ -0,0 +1,35 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+type (
+        // these need a comma to disambiguate
+        _[P *T,] struct{}
+        _[P *T, _ any] struct{}
+        _[P (*T),] struct{}
+        _[P (*T), _ any] struct{}
+        _[P (T),] struct{}
+        _[P (T), _ any] struct{}
+
+        // these parse as name followed by type
+        _[P *struct{}] struct{}
+        _[P (*struct{})] struct{}
+        _[P ([]int)] struct{}
+
+        // array declarations
+        _ [P(T)]struct{}
+        _ [P((T))]struct{}
+        _ [P * *T]struct{}
+        _ [P * T]struct{}
+        _ [P(*T)]struct{}
+        _ [P(**T)]struct{}
+        _ [P * T - T]struct{}
+        _ [P*T-T, /* ERROR "unexpected comma" */ ]struct{}
+        _ [10, /* ERROR "unexpected comma" */ ]struct{}
+
+        // These should be parsed as generic type declarations.
+        _[P *struct /* ERROR "expected expression" */ {}|int] struct{}
+        _[P *struct /* ERROR "expected expression" */ {}|int|string] struct{}
+)
index 1fea23f51aa9d45f99b217042d954a6a14b69c29..479cb968714a250c18cfe8edbaadfaf7537a2602 100644 (file)
@@ -9,7 +9,7 @@ package p
 
 type List[E any /* ERROR "expected ']', found any" */ ] []E
 
-type Pair[L, /* ERROR "expected ']', found ','" */ R any] struct {
+type Pair[L, /* ERROR "unexpected comma" */ R any] struct {
        Left L
        Right R
 }
index 19d4ab6663d3d0f067a33cef99a375995d2dcfe9..f2170dbc4f31c1ecd57660dd93867729001ab1c8 100644 (file)
@@ -367,20 +367,48 @@ func (p *printer) parameters(fields *ast.FieldList, isTypeParam bool) {
                        p.expr(stripParensAlways(par.Type))
                        prevLine = parLineEnd
                }
+
                // if the closing ")" is on a separate line from the last parameter,
                // print an additional "," and line break
                if closing := p.lineFor(fields.Closing); 0 < prevLine && prevLine < closing {
                        p.print(token.COMMA)
                        p.linebreak(closing, 0, ignore, true)
+               } else if isTypeParam && fields.NumFields() == 1 {
+                       // Otherwise, if we are in a type parameter list that could be confused
+                       // with the constant array length expression [P*C], print a comma so that
+                       // parsing is unambiguous.
+                       //
+                       // Note that while ParenExprs can also be ambiguous (issue #49482), the
+                       // printed type is never parenthesized (stripParensAlways is used above).
+                       if t, _ := fields.List[0].Type.(*ast.StarExpr); t != nil && !isTypeLit(t.X) {
+                               p.print(token.COMMA)
+                       }
                }
+
                // unindent if we indented
                if ws == ignore {
                        p.print(unindent)
                }
        }
+
        p.print(fields.Closing, closeTok)
 }
 
+// isTypeLit reports whether x is a (possibly parenthesized) type literal.
+func isTypeLit(x ast.Expr) bool {
+       switch x := x.(type) {
+       case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
+               return true
+       case *ast.StarExpr:
+               // *T may be a pointer dereferenciation.
+               // Only consider *T as type literal if T is a type literal.
+               return isTypeLit(x.X)
+       case *ast.ParenExpr:
+               return isTypeLit(x.X)
+       }
+       return false
+}
+
 func (p *printer) signature(sig *ast.FuncType) {
        if sig.TypeParams != nil {
                p.parameters(sig.TypeParams, true)
index 3d95eda5b2ac4a56ce79d7837f0766e42d3fe6e0..4fac2c9c580b159d0a61c8e01859962e9c5d2a7e 100644 (file)
@@ -38,3 +38,29 @@ func _() {
 // type constraint literals with elided interfaces
 func _[P ~int, Q int | string]()       {}
 func _[P struct{ f int }, Q *P]()      {}
+
+// various potentially ambiguous type parameter lists (issue #49482)
+type _[P *T,] struct{}
+type _[P *T, _ any] struct{}
+type _[P *T,] struct{}
+type _[P *T, _ any] struct{}
+type _[P T] struct{}
+type _[P T, _ any] struct{}
+
+type _[P *struct{}] struct{}
+type _[P *struct{}] struct{}
+type _[P []int] struct{}
+
+// array type declarations
+type _ [P(T)]struct{}
+type _ [P((T))]struct{}
+type _ [P * *T]struct{}
+type _ [P * T]struct{}
+type _ [P(*T)]struct{}
+type _ [P(**T)]struct{}
+type _ [P * T]struct{}
+type _ [P*T - T]struct{}
+
+type _[
+       P *T,
+] struct{}
index 746dfdd2358ba987efc0847035ad6f759151e07f..fde9d32ef04844f1b46bfb229776932f29e21f7d 100644 (file)
@@ -35,3 +35,29 @@ func _() {
 // type constraint literals with elided interfaces
 func _[P ~int, Q int | string]() {}
 func _[P struct{f int}, Q *P]() {}
+
+// various potentially ambiguous type parameter lists (issue #49482)
+type _[P *T,] struct{}
+type _[P *T, _ any] struct{}
+type _[P (*T),] struct{}
+type _[P (*T), _ any] struct{}
+type _[P (T),] struct{}
+type _[P (T), _ any] struct{}
+
+type _[P *struct{}] struct{}
+type _[P (*struct{})] struct{}
+type _[P ([]int)] struct{}
+
+// array type declarations
+type _ [P(T)]struct{}
+type _ [P((T))]struct{}
+type _ [P * *T]struct{}
+type _ [P * T]struct{}
+type _ [P(*T)]struct{}
+type _ [P(**T)]struct{}
+type _ [P * T]struct{}
+type _ [P * T - T]struct{}
+
+type _[
+       P *T,
+] struct{}