]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: prevent untyped types from reaching walk
authorMatthew Dempsky <mdempsky@google.com>
Fri, 2 Mar 2018 23:20:49 +0000 (15:20 -0800)
committerMatthew Dempsky <mdempsky@google.com>
Wed, 7 Mar 2018 18:14:22 +0000 (18:14 +0000)
We already require expressions to have already been typechecked before
reaching walk. Moreover, all untyped expressions should have been
converted to their default type by walk.

However, in practice, we've been somewhat sloppy and inconsistent
about ensuring this. In particular, a lot of AST rewrites ended up
leaving untyped bool expressions scattered around. These likely aren't
harmful in practice, but it seems worth cleaning up.

The two most common cases addressed by this CL are:

1) When generating OIF and OFOR nodes, we would often typecheck the
conditional expression, but not apply defaultlit to force it to the
expression's default type.

2) When rewriting string comparisons into more fundamental primitives,
we were simply overwriting r.Type with the desired type, which didn't
propagate the type to nested subexpressions. These are fixed by
utilizing finishcompare, which correctly handles this (and is already
used by other comparison lowering rewrites).

Lastly, walkexpr is extended to assert that it's not called on untyped
expressions.

Fixes #23834.

Change-Id: Icbd29648a293555e4015d3b06a95a24ccbd3f790
Reviewed-on: https://go-review.googlesource.com/98337
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/compile/internal/gc/range.go
src/cmd/compile/internal/gc/select.go
src/cmd/compile/internal/gc/swt.go
src/cmd/compile/internal/gc/walk.go

index db852e83a2b4c75a2f91e65da098ded7b77d04d8..91f0cd363e3f7d6338e821e45b010fc129cd69c0 100644 (file)
@@ -434,6 +434,7 @@ func walkrange(n *Node) *Node {
        typecheckslice(n.Left.Ninit.Slice(), Etop)
 
        n.Left = typecheck(n.Left, Erv)
+       n.Left = defaultlit(n.Left, nil)
        n.Right = typecheck(n.Right, Etop)
        typecheckslice(body, Etop)
        n.Nbody.Prepend(body...)
@@ -529,6 +530,7 @@ func memclrrange(n, v1, v2, a *Node) bool {
        n.Nbody.Append(v1)
 
        n.Left = typecheck(n.Left, Erv)
+       n.Left = defaultlit(n.Left, nil)
        typecheckslice(n.Nbody.Slice(), Etop)
        n = walkstmt(n)
        return true
index 38eaaccfd27a64f249fe5bf7b28092125b39f213..a74677d5600a24b6aae843d287c20899d29e07e0 100644 (file)
@@ -308,6 +308,7 @@ func walkselectcases(cases *Nodes) []*Node {
 
                cond := nod(OEQ, chosen, nodintconst(int64(i)))
                cond = typecheck(cond, Erv)
+               cond = defaultlit(cond, nil)
 
                r = nod(OIF, cond, nil)
                r.Nbody.AppendNodes(&cas.Nbody)
index 725268ba5c8dd256d4fbc3817dd4dd89eaef46f0..c9fb67e916a4e155adcac45c85597c2eb3ddd039 100644 (file)
@@ -217,6 +217,7 @@ func walkswitch(sw *Node) {
        if sw.Left == nil {
                sw.Left = nodbool(true)
                sw.Left = typecheck(sw.Left, Erv)
+               sw.Left = defaultlit(sw.Left, nil)
        }
 
        if sw.Left.Op == OTYPESW {
@@ -314,21 +315,16 @@ func (s *exprSwitch) walkCases(cc []caseClause) *Node {
                                low := nod(OGE, s.exprname, rng[0])
                                high := nod(OLE, s.exprname, rng[1])
                                a.Left = nod(OANDAND, low, high)
-                               a.Left = typecheck(a.Left, Erv)
-                               a.Left = defaultlit(a.Left, nil)
-                               a.Left = walkexpr(a.Left, nil) // give walk the opportunity to optimize the range check
                        } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE {
                                a.Left = nod(OEQ, s.exprname, n.Left) // if name == val
-                               a.Left = typecheck(a.Left, Erv)
-                               a.Left = defaultlit(a.Left, nil)
                        } else if s.kind == switchKindTrue {
                                a.Left = n.Left // if val
                        } else {
                                // s.kind == switchKindFalse
                                a.Left = nod(ONOT, n.Left, nil) // if !val
-                               a.Left = typecheck(a.Left, Erv)
-                               a.Left = defaultlit(a.Left, nil)
                        }
+                       a.Left = typecheck(a.Left, Erv)
+                       a.Left = defaultlit(a.Left, nil)
                        a.Nbody.Set1(n.Right) // goto l
 
                        cas = append(cas, a)
@@ -750,6 +746,7 @@ func (s *typeSwitch) walk(sw *Node) {
                def = blk
        }
        i.Left = typecheck(i.Left, Erv)
+       i.Left = defaultlit(i.Left, nil)
        cas = append(cas, i)
 
        // Load hash from type or itab.
@@ -869,6 +866,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node {
                        a := nod(OIF, nil, nil)
                        a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash)))
                        a.Left = typecheck(a.Left, Erv)
+                       a.Left = defaultlit(a.Left, nil)
                        a.Nbody.Set1(n.Right)
                        cas = append(cas, a)
                }
@@ -880,6 +878,7 @@ func (s *typeSwitch) walkCases(cc []caseClause) *Node {
        a := nod(OIF, nil, nil)
        a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash)))
        a.Left = typecheck(a.Left, Erv)
+       a.Left = defaultlit(a.Left, nil)
        a.Nbody.Set1(s.walkCases(cc[:half]))
        a.Rlist.Set1(s.walkCases(cc[half:]))
        return a
index bdfda78061b3753a50d0f55dc53f9c053aba5289..a2dfdb5abc3d4812361be3a3e9e1cd27a8a18b37 100644 (file)
@@ -476,6 +476,10 @@ func walkexpr(n *Node, init *Nodes) *Node {
                Fatalf("missed typecheck: %+v", n)
        }
 
+       if n.Type.IsUntyped() {
+               Fatalf("expression has untyped type: %+v", n)
+       }
+
        if n.Op == ONAME && n.Class() == PAUTOHEAP {
                nn := nod(OIND, n.Name.Param.Heapaddr, nil)
                nn = typecheck(nn, Erv)
@@ -1234,10 +1238,7 @@ opswitch:
                if (Op(n.Etype) == OEQ || Op(n.Etype) == ONE) && Isconst(n.Right, CTSTR) && n.Left.Op == OADDSTR && n.Left.List.Len() == 2 && Isconst(n.Left.List.Second(), CTSTR) && strlit(n.Right) == strlit(n.Left.List.Second()) {
                        // TODO(marvin): Fix Node.EType type union.
                        r := nod(Op(n.Etype), nod(OLEN, n.Left.List.First(), nil), nodintconst(0))
-                       r = typecheck(r, Erv)
-                       r = walkexpr(r, init)
-                       r.Type = n.Type
-                       n = r
+                       n = finishcompare(n, r, init)
                        break
                }
 
@@ -1337,10 +1338,7 @@ opswitch:
                                        remains -= step
                                        i += step
                                }
-                               r = typecheck(r, Erv)
-                               r = walkexpr(r, init)
-                               r.Type = n.Type
-                               n = r
+                               n = finishcompare(n, r, init)
                                break
                        }
                }
@@ -1374,9 +1372,6 @@ opswitch:
                                r = nod(ONOT, r, nil)
                                r = nod(OOROR, nod(ONE, llen, rlen), r)
                        }
-
-                       r = typecheck(r, Erv)
-                       r = walkexpr(r, nil)
                } else {
                        // sys_cmpstring(s1, s2) :: 0
                        r = mkcall("cmpstring", types.Types[TINT], init, conv(n.Left, types.Types[TSTRING]), conv(n.Right, types.Types[TSTRING]))
@@ -1384,12 +1379,7 @@ opswitch:
                        r = nod(Op(n.Etype), r, nodintconst(0))
                }
 
-               r = typecheck(r, Erv)
-               if !n.Type.IsBoolean() {
-                       Fatalf("cmp %v", n.Type)
-               }
-               r.Type = n.Type
-               n = r
+               n = finishcompare(n, r, init)
 
        case OADDSTR:
                n = addstr(n, init)