]> Cypherpunks repositories - gostls13.git/commitdiff
go/printer: avoid allocating for every position printed
authorDaniel Martí <mvdan@mvdan.cc>
Wed, 15 Jun 2022 21:55:58 +0000 (22:55 +0100)
committerRobert Griesemer <gri@google.com>
Thu, 18 Aug 2022 16:35:13 +0000 (16:35 +0000)
printer.print is an overloaded method for multiple purposes.
When fed a position, it updates the current position.
When fed a string, it prints the string.
When fed a token, it prints the token. And so on.

However, this overloading comes at a significant cost.
Because the parameters are a list of the `any` interface type,
any type which is not of pointer or interface kind will allocate when
passed as an argument, as interfaces can only contain pointers.

A large portion of the arguments passed to the print method are of type
token.Pos, whose underlying type is int - so it allocates.
Removing those allocations has a significant benefit,
at the cost of some verbosity in the code:

name      old time/op    new time/op    delta
Print-16    6.10ms ± 2%    5.39ms ± 2%  -11.72%  (p=0.000 n=10+10)

name      old speed      new speed      delta
Print-16  8.50MB/s ± 2%  9.63MB/s ± 2%  +13.28%  (p=0.000 n=10+10)

name      old alloc/op   new alloc/op   delta
Print-16     443kB ± 0%     332kB ± 0%  -25.10%  (p=0.000 n=10+9)

name      old allocs/op  new allocs/op  delta
Print-16     17.3k ± 0%      3.5k ± 0%  -80.10%  (p=0.000 n=10+10)

There should be more significant speed-ups left, particularly for the
token.Token, string, and whiteSpace types fed to the same method.
They are left for a future CL, in case this kind of optimization is not
a path we want to take.

Change-Id: I3ff8387242c5a935bb003e60e0813b7b9c65402e
Reviewed-on: https://go-review.googlesource.com/c/go/+/412557
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
Reviewed-by: hopehook <hopehook@qq.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Robert Griesemer <gri@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/go/printer/nodes.go
src/go/printer/printer.go

index c167b5f137a6b5eb24287822b701c624bd1ea48d..f16d76acc7e87588855d204657aef089e84e7c89 100644 (file)
@@ -152,7 +152,8 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp
                        if i > 0 {
                                // use position of expression following the comma as
                                // comma position for correct comment placement
-                               p.print(x.Pos(), token.COMMA, blank)
+                               p.printPos(x.Pos())
+                               p.print(token.COMMA, blank)
                        }
                        p.expr0(x, depth)
                }
@@ -243,7 +244,7 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp
                        // comma position for correct comment placement, but
                        // only if the expression is on the same line.
                        if !needsLinebreak {
-                               p.print(x.Pos())
+                               p.printPos(x.Pos())
                        }
                        p.print(token.COMMA)
                        needsBlank := true
@@ -278,7 +279,8 @@ func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exp
                        // can align if possible.
                        // (needsLinebreak is set if we started a new line before)
                        p.expr(pair.Key)
-                       p.print(pair.Colon, token.COLON, vtab)
+                       p.printPos(pair.Colon)
+                       p.print(token.COLON, vtab)
                        p.expr(pair.Value)
                } else {
                        p.expr0(x, depth)
@@ -331,7 +333,8 @@ func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
        if mode != funcParam {
                openTok, closeTok = token.LBRACK, token.RBRACK
        }
-       p.print(fields.Opening, openTok)
+       p.printPos(fields.Opening)
+       p.print(openTok)
        if len(fields.List) > 0 {
                prevLine := p.lineFor(fields.Opening)
                ws := indent
@@ -348,7 +351,7 @@ func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
                                // comma position for correct comma placement, but
                                // only if the next parameter is on the same line
                                if !needsLinebreak {
-                                       p.print(par.Pos())
+                                       p.printPos(par.Pos())
                                }
                                p.print(token.COMMA)
                        }
@@ -394,7 +397,8 @@ func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
                }
        }
 
-       p.print(fields.Closing, closeTok)
+       p.printPos(fields.Closing)
+       p.print(closeTok)
 }
 
 // combinesWithName reports whether a name followed by the expression x
@@ -502,12 +506,16 @@ func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool)
                // possibly a one-line struct/interface
                if len(list) == 0 {
                        // no blank between keyword and {} in this case
-                       p.print(lbrace, token.LBRACE, rbrace, token.RBRACE)
+                       p.printPos(lbrace)
+                       p.print(token.LBRACE)
+                       p.printPos(rbrace)
+                       p.print(token.RBRACE)
                        return
                } else if p.isOneLineFieldList(list) {
                        // small enough - print on one line
                        // (don't use identList and ignore source line breaks)
-                       p.print(lbrace, token.LBRACE, blank)
+                       p.printPos(lbrace)
+                       p.print(token.LBRACE, blank)
                        f := list[0]
                        if isStruct {
                                for i, x := range f.Names {
@@ -531,13 +539,17 @@ func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool)
                                        p.expr(f.Type)
                                }
                        }
-                       p.print(blank, rbrace, token.RBRACE)
+                       p.print(blank)
+                       p.printPos(rbrace)
+                       p.print(token.RBRACE)
                        return
                }
        }
        // hasComments || !srcIsOneLine
 
-       p.print(blank, lbrace, token.LBRACE, indent)
+       p.print(blank)
+       p.printPos(lbrace)
+       p.print(token.LBRACE, indent)
        if hasComments || len(list) > 0 {
                p.print(formfeed)
        }
@@ -632,7 +644,9 @@ func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool)
                }
 
        }
-       p.print(unindent, formfeed, rbrace, token.RBRACE)
+       p.print(unindent, formfeed)
+       p.printPos(rbrace)
+       p.print(token.RBRACE)
 }
 
 // ----------------------------------------------------------------------------
@@ -783,7 +797,8 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) {
        }
        xline := p.pos.Line // before the operator (it may be on the next line!)
        yline := p.lineFor(x.Y.Pos())
-       p.print(x.OpPos, x.Op)
+       p.printPos(x.OpPos)
+       p.print(x.Op)
        if xline != yline && xline > 0 && yline > 0 {
                // at least one line break, but respect an extra empty line
                // in the source
@@ -807,7 +822,7 @@ func isBinary(expr ast.Expr) bool {
 }
 
 func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
-       p.print(expr.Pos())
+       p.printPos(expr.Pos())
 
        switch x := expr.(type) {
        case *ast.BadExpr:
@@ -825,7 +840,8 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
 
        case *ast.KeyValueExpr:
                p.expr(x.Key)
-               p.print(x.Colon, token.COLON, blank)
+               p.printPos(x.Colon)
+               p.print(token.COLON, blank)
                p.expr(x.Value)
 
        case *ast.StarExpr:
@@ -866,7 +882,8 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                p.print(x)
 
        case *ast.FuncLit:
-               p.print(x.Type.Pos(), token.FUNC)
+               p.printPos(x.Type.Pos())
+               p.print(token.FUNC)
                // See the comment in funcDecl about how the header size is computed.
                startCol := p.out.Column - len("func")
                p.signature(x.Type)
@@ -880,7 +897,8 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                } else {
                        p.print(token.LPAREN)
                        p.expr0(x.X, reduceDepth(depth)) // parentheses undo one level of depth
-                       p.print(x.Rparen, token.RPAREN)
+                       p.printPos(x.Rparen)
+                       p.print(token.RPAREN)
                }
 
        case *ast.SelectorExpr:
@@ -888,33 +906,41 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
 
        case *ast.TypeAssertExpr:
                p.expr1(x.X, token.HighestPrec, depth)
-               p.print(token.PERIOD, x.Lparen, token.LPAREN)
+               p.print(token.PERIOD)
+               p.printPos(x.Lparen)
+               p.print(token.LPAREN)
                if x.Type != nil {
                        p.expr(x.Type)
                } else {
                        p.print(token.TYPE)
                }
-               p.print(x.Rparen, token.RPAREN)
+               p.printPos(x.Rparen)
+               p.print(token.RPAREN)
 
        case *ast.IndexExpr:
                // TODO(gri): should treat[] like parentheses and undo one level of depth
                p.expr1(x.X, token.HighestPrec, 1)
-               p.print(x.Lbrack, token.LBRACK)
+               p.printPos(x.Lbrack)
+               p.print(token.LBRACK)
                p.expr0(x.Index, depth+1)
-               p.print(x.Rbrack, token.RBRACK)
+               p.printPos(x.Rbrack)
+               p.print(token.RBRACK)
 
        case *ast.IndexListExpr:
                // TODO(gri): as for IndexExpr, should treat [] like parentheses and undo
                // one level of depth
                p.expr1(x.X, token.HighestPrec, 1)
-               p.print(x.Lbrack, token.LBRACK)
+               p.printPos(x.Lbrack)
+               p.print(token.LBRACK)
                p.exprList(x.Lbrack, x.Indices, depth+1, commaTerm, x.Rbrack, false)
-               p.print(x.Rbrack, token.RBRACK)
+               p.printPos(x.Rbrack)
+               p.print(token.RBRACK)
 
        case *ast.SliceExpr:
                // TODO(gri): should treat[] like parentheses and undo one level of depth
                p.expr1(x.X, token.HighestPrec, 1)
-               p.print(x.Lbrack, token.LBRACK)
+               p.printPos(x.Lbrack)
+               p.print(token.LBRACK)
                indices := []ast.Expr{x.Low, x.High}
                if x.Max != nil {
                        indices = append(indices, x.Max)
@@ -950,7 +976,8 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                                p.expr0(x, depth+1)
                        }
                }
-               p.print(x.Rbrack, token.RBRACK)
+               p.printPos(x.Rbrack)
+               p.print(token.RBRACK)
 
        case *ast.CallExpr:
                if len(x.Args) > 1 {
@@ -965,17 +992,20 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                } else {
                        wasIndented = p.possibleSelectorExpr(x.Fun, token.HighestPrec, depth)
                }
-               p.print(x.Lparen, token.LPAREN)
+               p.printPos(x.Lparen)
+               p.print(token.LPAREN)
                if x.Ellipsis.IsValid() {
                        p.exprList(x.Lparen, x.Args, depth, 0, x.Ellipsis, false)
-                       p.print(x.Ellipsis, token.ELLIPSIS)
+                       p.printPos(x.Ellipsis)
+                       p.print(token.ELLIPSIS)
                        if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
                                p.print(token.COMMA, formfeed)
                        }
                } else {
                        p.exprList(x.Lparen, x.Args, depth, commaTerm, x.Rparen, false)
                }
-               p.print(x.Rparen, token.RPAREN)
+               p.printPos(x.Rparen)
+               p.print(token.RPAREN)
                if wasIndented {
                        p.print(unindent)
                }
@@ -986,7 +1016,8 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                        p.expr1(x.Type, token.HighestPrec, depth)
                }
                p.level++
-               p.print(x.Lbrace, token.LBRACE)
+               p.printPos(x.Lbrace)
+               p.print(token.LBRACE)
                p.exprList(x.Lbrace, x.Elts, 1, commaTerm, x.Rbrace, x.Incomplete)
                // do not insert extra line break following a /*-style comment
                // before the closing '}' as it might break the code if there
@@ -999,7 +1030,9 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                }
                // need the initial indent to print lone comments with
                // the proper level of indentation
-               p.print(indent, unindent, mode, x.Rbrace, token.RBRACE, mode)
+               p.print(indent, unindent, mode)
+               p.printPos(x.Rbrace)
+               p.print(token.RBRACE, mode)
                p.level--
 
        case *ast.Ellipsis:
@@ -1041,7 +1074,9 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
                case ast.RECV:
                        p.print(token.ARROW, token.CHAN) // x.Arrow and x.Pos() are the same
                case ast.SEND:
-                       p.print(token.CHAN, x.Arrow, token.ARROW)
+                       p.print(token.CHAN)
+                       p.printPos(x.Arrow)
+                       p.print(token.ARROW)
                }
                p.print(blank)
                p.expr(x.Value)
@@ -1125,13 +1160,16 @@ func (p *printer) selectorExpr(x *ast.SelectorExpr, depth int, isMethod bool) bo
        p.expr1(x.X, token.HighestPrec, depth)
        p.print(token.PERIOD)
        if line := p.lineFor(x.Sel.Pos()); p.pos.IsValid() && p.pos.Line < line {
-               p.print(indent, newline, x.Sel.Pos(), x.Sel)
+               p.print(indent, newline)
+               p.printPos(x.Sel.Pos())
+               p.print(x.Sel)
                if !isMethod {
                        p.print(unindent)
                }
                return true
        }
-       p.print(x.Sel.Pos(), x.Sel)
+       p.printPos(x.Sel.Pos())
+       p.print(x.Sel)
        return false
 }
 
@@ -1189,10 +1227,12 @@ func (p *printer) stmtList(list []ast.Stmt, nindent int, nextIsRBrace bool) {
 
 // block prints an *ast.BlockStmt; it always spans at least two lines.
 func (p *printer) block(b *ast.BlockStmt, nindent int) {
-       p.print(b.Lbrace, token.LBRACE)
+       p.printPos(b.Lbrace)
+       p.print(token.LBRACE)
        p.stmtList(b.List, nindent, true)
        p.linebreak(p.lineFor(b.Rbrace), 1, ignore, true)
-       p.print(b.Rbrace, token.RBRACE)
+       p.printPos(b.Rbrace)
+       p.print(token.RBRACE)
 }
 
 func isTypeName(x ast.Expr) bool {
@@ -1307,7 +1347,7 @@ func (p *printer) indentList(list []ast.Expr) bool {
 }
 
 func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
-       p.print(stmt.Pos())
+       p.printPos(stmt.Pos())
 
        switch s := stmt.(type) {
        case *ast.BadStmt:
@@ -1325,10 +1365,13 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                // between (see writeWhitespace)
                p.print(unindent)
                p.expr(s.Label)
-               p.print(s.Colon, token.COLON, indent)
+               p.printPos(s.Colon)
+               p.print(token.COLON, indent)
                if e, isEmpty := s.Stmt.(*ast.EmptyStmt); isEmpty {
                        if !nextIsRBrace {
-                               p.print(newline, e.Pos(), token.SEMICOLON)
+                               p.print(newline)
+                               p.printPos(e.Pos())
+                               p.print(token.SEMICOLON)
                                break
                        }
                } else {
@@ -1343,13 +1386,16 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
        case *ast.SendStmt:
                const depth = 1
                p.expr0(s.Chan, depth)
-               p.print(blank, s.Arrow, token.ARROW, blank)
+               p.print(blank)
+               p.printPos(s.Arrow)
+               p.print(token.ARROW, blank)
                p.expr0(s.Value, depth)
 
        case *ast.IncDecStmt:
                const depth = 1
                p.expr0(s.X, depth+1)
-               p.print(s.TokPos, s.Tok)
+               p.printPos(s.TokPos)
+               p.print(s.Tok)
 
        case *ast.AssignStmt:
                var depth = 1
@@ -1357,7 +1403,9 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                        depth++
                }
                p.exprList(s.Pos(), s.Lhs, depth, 0, s.TokPos, false)
-               p.print(blank, s.TokPos, s.Tok, blank)
+               p.print(blank)
+               p.printPos(s.TokPos)
+               p.print(s.Tok, blank)
                p.exprList(s.TokPos, s.Rhs, depth, 0, token.NoPos, false)
 
        case *ast.GoStmt:
@@ -1424,7 +1472,8 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                } else {
                        p.print(token.DEFAULT)
                }
-               p.print(s.Colon, token.COLON)
+               p.printPos(s.Colon)
+               p.print(token.COLON)
                p.stmtList(s.Body, 1, nextIsRBrace)
 
        case *ast.SwitchStmt:
@@ -1451,7 +1500,8 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                } else {
                        p.print(token.DEFAULT)
                }
-               p.print(s.Colon, token.COLON)
+               p.printPos(s.Colon)
+               p.print(token.COLON)
                p.stmtList(s.Body, 1, nextIsRBrace)
 
        case *ast.SelectStmt:
@@ -1459,7 +1509,10 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                body := s.Body
                if len(body.List) == 0 && !p.commentBefore(p.posFor(body.Rbrace)) {
                        // print empty select statement w/o comments on one line
-                       p.print(body.Lbrace, token.LBRACE, body.Rbrace, token.RBRACE)
+                       p.printPos(body.Lbrace)
+                       p.print(token.LBRACE)
+                       p.printPos(body.Rbrace)
+                       p.print(token.RBRACE)
                } else {
                        p.block(body, 0)
                }
@@ -1476,10 +1529,13 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
                        if s.Value != nil {
                                // use position of value following the comma as
                                // comma position for correct comment placement
-                               p.print(s.Value.Pos(), token.COMMA, blank)
+                               p.printPos(s.Value.Pos())
+                               p.print(token.COMMA, blank)
                                p.expr(s.Value)
                        }
-                       p.print(blank, s.TokPos, s.Tok, blank)
+                       p.print(blank)
+                       p.printPos(s.TokPos)
+                       p.print(s.Tok, blank)
                }
                p.print(token.RANGE, blank)
                p.expr(stripParens(s.X))
@@ -1638,7 +1694,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) {
                }
                p.expr(sanitizeImportPath(s.Path))
                p.setComment(s.Comment)
-               p.print(s.EndPos)
+               p.printPos(s.EndPos)
 
        case *ast.ValueSpec:
                if n != 1 {
@@ -1680,11 +1736,13 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) {
 
 func (p *printer) genDecl(d *ast.GenDecl) {
        p.setComment(d.Doc)
-       p.print(d.Pos(), d.Tok, blank)
+       p.printPos(d.Pos())
+       p.print(d.Tok, blank)
 
        if d.Lparen.IsValid() || len(d.Specs) > 1 {
                // group of parenthesized declarations
-               p.print(d.Lparen, token.LPAREN)
+               p.printPos(d.Lparen)
+               p.print(token.LPAREN)
                if n := len(d.Specs); n > 0 {
                        p.print(indent, formfeed)
                        if n > 1 && (d.Tok == token.CONST || d.Tok == token.VAR) {
@@ -1711,7 +1769,8 @@ func (p *printer) genDecl(d *ast.GenDecl) {
                        }
                        p.print(unindent, formfeed)
                }
-               p.print(d.Rparen, token.RPAREN)
+               p.printPos(d.Rparen)
+               p.print(token.RPAREN)
 
        } else if len(d.Specs) > 0 {
                // single declaration
@@ -1825,7 +1884,9 @@ func (p *printer) funcBody(headerSize int, sep whiteSpace, b *ast.BlockStmt) {
 
        const maxSize = 100
        if headerSize+p.bodySize(b, maxSize) <= maxSize {
-               p.print(sep, b.Lbrace, token.LBRACE)
+               p.print(sep)
+               p.printPos(b.Lbrace)
+               p.print(token.LBRACE)
                if len(b.List) > 0 {
                        p.print(blank)
                        for i, s := range b.List {
@@ -1836,7 +1897,9 @@ func (p *printer) funcBody(headerSize int, sep whiteSpace, b *ast.BlockStmt) {
                        }
                        p.print(blank)
                }
-               p.print(noExtraLinebreak, b.Rbrace, token.RBRACE, noExtraLinebreak)
+               p.print(noExtraLinebreak)
+               p.printPos(b.Rbrace)
+               p.print(token.RBRACE, noExtraLinebreak)
                return
        }
 
@@ -1858,7 +1921,8 @@ func (p *printer) distanceFrom(startPos token.Pos, startOutCol int) int {
 
 func (p *printer) funcDecl(d *ast.FuncDecl) {
        p.setComment(d.Doc)
-       p.print(d.Pos(), token.FUNC, blank)
+       p.printPos(d.Pos())
+       p.print(token.FUNC, blank)
        // We have to save startCol only after emitting FUNC; otherwise it can be on a
        // different line (all whitespace preceding the FUNC is emitted only when the
        // FUNC is emitted).
@@ -1875,7 +1939,8 @@ func (p *printer) funcDecl(d *ast.FuncDecl) {
 func (p *printer) decl(decl ast.Decl) {
        switch d := decl.(type) {
        case *ast.BadDecl:
-               p.print(d.Pos(), "BadDecl")
+               p.printPos(d.Pos())
+               p.print("BadDecl")
        case *ast.GenDecl:
                p.genDecl(d)
        case *ast.FuncDecl:
@@ -1928,7 +1993,8 @@ func (p *printer) declList(list []ast.Decl) {
 
 func (p *printer) file(src *ast.File) {
        p.setComment(src.Doc)
-       p.print(src.Pos(), token.PACKAGE, blank)
+       p.printPos(src.Pos())
+       p.print(token.PACKAGE, blank)
        p.expr(src.Name)
        p.declList(src.Decls)
        p.print(newline)
index 244a19b2a7787ffc5370e81b5859c6000eb80b84..ec73eec34d8c19c3bd02de32a80c037fc0cc1560 100644 (file)
@@ -886,6 +886,12 @@ func mayCombine(prev token.Token, next byte) (b bool) {
        return
 }
 
+func (p *printer) printPos(pos token.Pos) {
+       if pos.IsValid() {
+               p.pos = p.posFor(pos) // accurate position of next item
+       }
+}
+
 // print prints a list of "items" (roughly corresponding to syntactic
 // tokens, but also including whitespace and formatting information).
 // It is the only print function that should be called directly from
@@ -982,12 +988,6 @@ func (p *printer) print(args ...any) {
                        }
                        p.lastTok = x
 
-               case token.Pos:
-                       if x.IsValid() {
-                               p.pos = p.posFor(x) // accurate position of next item
-                       }
-                       continue
-
                case string:
                        // incorrect AST - print error message
                        data = x