]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssa: (#2 of 5): core utilities
authorAlan Donovan <adonovan@google.com>
Mon, 28 Jan 2013 23:06:14 +0000 (18:06 -0500)
committerAlan Donovan <adonovan@google.com>
Mon, 28 Jan 2013 23:06:14 +0000 (18:06 -0500)
This CL includes the implementation of Literal, all the
Value.String and Instruction.String methods, the sanity
checker, and other misc utilities.

R=gri, iant, iant
CC=golang-dev
https://golang.org/cl/7199052

src/pkg/exp/ssa/literal.go [new file with mode: 0644]
src/pkg/exp/ssa/print.go [new file with mode: 0644]
src/pkg/exp/ssa/sanity.go [new file with mode: 0644]
src/pkg/exp/ssa/ssa.go
src/pkg/exp/ssa/util.go [new file with mode: 0644]

diff --git a/src/pkg/exp/ssa/literal.go b/src/pkg/exp/ssa/literal.go
new file mode 100644 (file)
index 0000000..fa26c47
--- /dev/null
@@ -0,0 +1,137 @@
+package ssa
+
+// This file defines the Literal SSA value type.
+
+import (
+       "fmt"
+       "go/types"
+       "math/big"
+       "strconv"
+)
+
+// newLiteral returns a new literal of the specified value and type.
+// val must be valid according to the specification of Literal.Value.
+//
+func newLiteral(val interface{}, typ types.Type) *Literal {
+       // This constructor exists to provide a single place to
+       // insert logging/assertions during debugging.
+       return &Literal{typ, val}
+}
+
+// intLiteral returns an untyped integer literal that evaluates to i.
+func intLiteral(i int64) *Literal {
+       return newLiteral(i, types.Typ[types.UntypedInt])
+}
+
+// nilLiteral returns a nil literal of the specified (reference) type.
+func nilLiteral(typ types.Type) *Literal {
+       return newLiteral(types.NilType{}, typ)
+}
+
+func (l *Literal) Name() string {
+       var s string
+       switch x := l.Value.(type) {
+       case bool:
+               s = fmt.Sprintf("%v", l.Value)
+       case int64:
+               s = fmt.Sprintf("%d", l.Value)
+       case *big.Int:
+               s = x.String()
+       case *big.Rat:
+               s = x.FloatString(20)
+       case string:
+               if len(x) > 20 {
+                       x = x[:17] + "..." // abbreviate
+               }
+               s = strconv.Quote(x)
+       case types.Complex:
+               r := x.Re.FloatString(20)
+               i := x.Im.FloatString(20)
+               s = fmt.Sprintf("%s+%si", r, i)
+       case types.NilType:
+               s = "nil"
+       default:
+               panic(fmt.Sprintf("unexpected literal value: %T", x))
+       }
+       return s + ":" + l.Type_.String()
+}
+
+func (l *Literal) Type() types.Type {
+       return l.Type_
+}
+
+// IsNil returns true if this literal represents a typed or untyped nil value.
+func (l *Literal) IsNil() bool {
+       _, ok := l.Value.(types.NilType)
+       return ok
+}
+
+// Int64 returns the numeric value of this literal truncated to fit
+// a signed 64-bit integer.
+//
+func (l *Literal) Int64() int64 {
+       switch x := l.Value.(type) {
+       case int64:
+               return x
+       case *big.Int:
+               return x.Int64()
+       case *big.Rat:
+               // TODO(adonovan): fix: is this the right rounding mode?
+               var q big.Int
+               return q.Quo(x.Num(), x.Denom()).Int64()
+       }
+       panic(fmt.Sprintf("unexpected literal value: %T", l.Value))
+}
+
+// Uint64 returns the numeric value of this literal truncated to fit
+// an unsigned 64-bit integer.
+//
+func (l *Literal) Uint64() uint64 {
+       switch x := l.Value.(type) {
+       case int64:
+               if x < 0 {
+                       return 0
+               }
+               return uint64(x)
+       case *big.Int:
+               return x.Uint64()
+       case *big.Rat:
+               // TODO(adonovan): fix: is this right?
+               var q big.Int
+               return q.Quo(x.Num(), x.Denom()).Uint64()
+       }
+       panic(fmt.Sprintf("unexpected literal value: %T", l.Value))
+}
+
+// Float64 returns the numeric value of this literal truncated to fit
+// a float64.
+//
+func (l *Literal) Float64() float64 {
+       switch x := l.Value.(type) {
+       case int64:
+               return float64(x)
+       case *big.Int:
+               var r big.Rat
+               f, _ := r.SetInt(x).Float64()
+               return f
+       case *big.Rat:
+               f, _ := x.Float64()
+               return f
+       }
+       panic(fmt.Sprintf("unexpected literal value: %T", l.Value))
+}
+
+// Complex128 returns the complex value of this literal truncated to
+// fit a complex128.
+//
+func (l *Literal) Complex128() complex128 {
+       switch x := l.Value.(type) {
+       case int64, *big.Int, *big.Rat:
+               return complex(l.Float64(), 0)
+       case types.Complex:
+               re64, _ := x.Re.Float64()
+               im64, _ := x.Im.Float64()
+               return complex(re64, im64)
+       }
+       panic(fmt.Sprintf("unexpected literal value: %T", l.Value))
+}
diff --git a/src/pkg/exp/ssa/print.go b/src/pkg/exp/ssa/print.go
new file mode 100644 (file)
index 0000000..b8708b6
--- /dev/null
@@ -0,0 +1,383 @@
+package ssa
+
+// This file implements the String() methods for all Value and
+// Instruction types.
+
+import (
+       "bytes"
+       "fmt"
+       "go/ast"
+       "go/types"
+)
+
+func (id Id) String() string {
+       if id.Pkg == nil {
+               return id.Name
+       }
+       return fmt.Sprintf("%s/%s", id.Pkg.Path, id.Name)
+}
+
+// relName returns the name of v relative to i.
+// In most cases, this is identical to v.Name(), but for cross-package
+// references to Functions (including methods) and Globals, the
+// package-qualified FullName is used instead.
+//
+func relName(v Value, i Instruction) string {
+       switch v := v.(type) {
+       case *Global:
+               if v.Pkg == i.Block().Func.Pkg {
+                       return v.Name()
+               }
+               return v.FullName()
+       case *Function:
+               if v.Pkg == nil || v.Pkg == i.Block().Func.Pkg {
+                       return v.Name()
+               }
+               return v.FullName()
+       }
+       return v.Name()
+}
+
+// Value.String()
+//
+// This method is provided only for debugging.
+// It never appears in disassembly, which uses Value.Name().
+
+func (v *Literal) String() string {
+       return fmt.Sprintf("literal %s rep=%T", v.Name(), v.Value)
+}
+
+func (v *Parameter) String() string {
+       return fmt.Sprintf("parameter %s : %s", v.Name(), v.Type())
+}
+
+func (v *Capture) String() string {
+       return fmt.Sprintf("capture %s : %s", v.Name(), v.Type())
+}
+
+func (v *Global) String() string {
+       return fmt.Sprintf("global %s : %s", v.Name(), v.Type())
+}
+
+func (v *Builtin) String() string {
+       return fmt.Sprintf("builtin %s : %s", v.Name(), v.Type())
+}
+
+func (r *Function) String() string {
+       return fmt.Sprintf("function %s : %s", r.Name(), r.Type())
+}
+
+// FullName returns the name of this function qualified by the
+// package name, unless it is anonymous or synthetic.
+//
+// TODO(adonovan): move to func.go when it's submitted.
+//
+func (f *Function) FullName() string {
+       if f.Enclosing != nil || f.Pkg == nil {
+               return f.Name_ // anonymous or synthetic
+       }
+       return fmt.Sprintf("%s.%s", f.Pkg.ImportPath, f.Name_)
+}
+
+// FullName returns g's package-qualified name.
+func (g *Global) FullName() string {
+       return fmt.Sprintf("%s.%s", g.Pkg.ImportPath, g.Name_)
+}
+
+// Instruction.String()
+
+func (v *Alloc) String() string {
+       op := "local"
+       if v.Heap {
+               op = "new"
+       }
+       return fmt.Sprintf("%s %s", op, indirectType(v.Type()))
+}
+
+func (v *Phi) String() string {
+       var b bytes.Buffer
+       b.WriteString("phi [")
+       for i, edge := range v.Edges {
+               if i > 0 {
+                       b.WriteString(", ")
+               }
+               // Be robust against malformed CFG.
+               blockname := "?"
+               if v.Block_ != nil && i < len(v.Block_.Preds) {
+                       blockname = v.Block_.Preds[i].Name
+               }
+               b.WriteString(blockname)
+               b.WriteString(": ")
+               b.WriteString(relName(edge, v))
+       }
+       b.WriteString("]")
+       return b.String()
+}
+
+func printCall(v *CallCommon, prefix string, instr Instruction) string {
+       var b bytes.Buffer
+       b.WriteString(prefix)
+       if v.Func != nil {
+               b.WriteString(relName(v.Func, instr))
+       } else {
+               name := underlyingType(v.Recv.Type()).(*types.Interface).Methods[v.Method].Name
+               fmt.Fprintf(&b, "invoke %s.%s [#%d]", relName(v.Recv, instr), name, v.Method)
+       }
+       b.WriteString("(")
+       for i, arg := range v.Args {
+               if i > 0 {
+                       b.WriteString(", ")
+               }
+               b.WriteString(relName(arg, instr))
+       }
+       if v.HasEllipsis {
+               b.WriteString("...")
+       }
+       b.WriteString(")")
+       return b.String()
+}
+
+func (v *Call) String() string {
+       return printCall(&v.CallCommon, "", v)
+}
+
+func (v *BinOp) String() string {
+       return fmt.Sprintf("%s %s %s", relName(v.X, v), v.Op.String(), relName(v.Y, v))
+}
+
+func (v *UnOp) String() string {
+       return fmt.Sprintf("%s%s%s", v.Op, relName(v.X, v), commaOk(v.CommaOk))
+}
+
+func (v *Conv) String() string {
+       return fmt.Sprintf("convert %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v))
+}
+
+func (v *ChangeInterface) String() string {
+       return fmt.Sprintf("change interface %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v))
+}
+
+func (v *MakeInterface) String() string {
+       return fmt.Sprintf("make interface %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v))
+}
+
+func (v *MakeClosure) String() string {
+       var b bytes.Buffer
+       fmt.Fprintf(&b, "make closure %s", relName(v.Fn, v))
+       if v.Bindings != nil {
+               b.WriteString(" [")
+               for i, c := range v.Bindings {
+                       if i > 0 {
+                               b.WriteString(", ")
+                       }
+                       b.WriteString(relName(c, v))
+               }
+               b.WriteString("]")
+       }
+       return b.String()
+}
+
+func (v *MakeSlice) String() string {
+       var b bytes.Buffer
+       b.WriteString("make slice ")
+       b.WriteString(v.Type().String())
+       b.WriteString(" ")
+       b.WriteString(relName(v.Len, v))
+       b.WriteString(" ")
+       b.WriteString(relName(v.Cap, v))
+       return b.String()
+}
+
+func (v *Slice) String() string {
+       var b bytes.Buffer
+       b.WriteString("slice ")
+       b.WriteString(relName(v.X, v))
+       b.WriteString("[")
+       if v.Low != nil {
+               b.WriteString(relName(v.Low, v))
+       }
+       b.WriteString(":")
+       if v.High != nil {
+               b.WriteString(relName(v.High, v))
+       }
+       b.WriteString("]")
+       return b.String()
+}
+
+func (v *MakeMap) String() string {
+       res := ""
+       if v.Reserve != nil {
+               res = relName(v.Reserve, v)
+       }
+       return fmt.Sprintf("make %s %s", v.Type(), res)
+}
+
+func (v *MakeChan) String() string {
+       return fmt.Sprintf("make %s %s", v.Type(), relName(v.Size, v))
+}
+
+func (v *FieldAddr) String() string {
+       fields := underlyingType(indirectType(v.X.Type())).(*types.Struct).Fields
+       // Be robust against a bad index.
+       name := "?"
+       if v.Field >= 0 && v.Field < len(fields) {
+               name = fields[v.Field].Name
+       }
+       return fmt.Sprintf("&%s.%s [#%d]", relName(v.X, v), name, v.Field)
+}
+
+func (v *Field) String() string {
+       fields := underlyingType(v.X.Type()).(*types.Struct).Fields
+       // Be robust against a bad index.
+       name := "?"
+       if v.Field >= 0 && v.Field < len(fields) {
+               name = fields[v.Field].Name
+       }
+       return fmt.Sprintf("%s.%s [#%d]", relName(v.X, v), name, v.Field)
+}
+
+func (v *IndexAddr) String() string {
+       return fmt.Sprintf("&%s[%s]", relName(v.X, v), relName(v.Index, v))
+}
+
+func (v *Index) String() string {
+       return fmt.Sprintf("%s[%s]", relName(v.X, v), relName(v.Index, v))
+}
+
+func (v *Lookup) String() string {
+       return fmt.Sprintf("%s[%s]%s", relName(v.X, v), relName(v.Index, v), commaOk(v.CommaOk))
+}
+
+func (v *Range) String() string {
+       return "range " + relName(v.X, v)
+}
+
+func (v *Next) String() string {
+       return "next " + relName(v.Iter, v)
+}
+
+func (v *TypeAssert) String() string {
+       return fmt.Sprintf("typeassert%s %s.(%s)", commaOk(v.CommaOk), relName(v.X, v), v.AssertedType)
+}
+
+func (v *Extract) String() string {
+       return fmt.Sprintf("extract %s #%d", relName(v.Tuple, v), v.Index)
+}
+
+func (s *Jump) String() string {
+       // Be robust against malformed CFG.
+       blockname := "?"
+       if s.Block_ != nil && len(s.Block_.Succs) == 1 {
+               blockname = s.Block_.Succs[0].Name
+       }
+       return fmt.Sprintf("jump %s", blockname)
+}
+
+func (s *If) String() string {
+       // Be robust against malformed CFG.
+       tblockname, fblockname := "?", "?"
+       if s.Block_ != nil && len(s.Block_.Succs) == 2 {
+               tblockname = s.Block_.Succs[0].Name
+               fblockname = s.Block_.Succs[1].Name
+       }
+       return fmt.Sprintf("if %s goto %s else %s", relName(s.Cond, s), tblockname, fblockname)
+}
+
+func (s *Go) String() string {
+       return printCall(&s.CallCommon, "go ", s)
+}
+
+func (s *Ret) String() string {
+       var b bytes.Buffer
+       b.WriteString("ret")
+       for i, r := range s.Results {
+               if i == 0 {
+                       b.WriteString(" ")
+               } else {
+                       b.WriteString(", ")
+               }
+               b.WriteString(relName(r, s))
+       }
+       return b.String()
+}
+
+func (s *Send) String() string {
+       return fmt.Sprintf("send %s <- %s", relName(s.Chan, s), relName(s.X, s))
+}
+
+func (s *Defer) String() string {
+       return printCall(&s.CallCommon, "defer ", s)
+}
+
+func (s *Select) String() string {
+       var b bytes.Buffer
+       for i, st := range s.States {
+               if i > 0 {
+                       b.WriteString(", ")
+               }
+               if st.Dir == ast.RECV {
+                       b.WriteString("<-")
+                       b.WriteString(relName(st.Chan, s))
+               } else {
+                       b.WriteString(relName(st.Chan, s))
+                       b.WriteString("<-")
+                       b.WriteString(relName(st.Send, s))
+               }
+       }
+       non := ""
+       if !s.Blocking {
+               non = "non"
+       }
+       return fmt.Sprintf("select %sblocking [%s]", non, b.String())
+}
+
+func (s *Store) String() string {
+       return fmt.Sprintf("*%s = %s", relName(s.Addr, s), relName(s.Val, s))
+}
+
+func (s *MapUpdate) String() string {
+       return fmt.Sprintf("%s[%s] = %s", relName(s.Map, s), relName(s.Key, s), relName(s.Value, s))
+}
+
+func (p *Package) String() string {
+       // TODO(adonovan): prettify output.
+       var b bytes.Buffer
+       fmt.Fprintf(&b, "Package %s at %s:\n", p.ImportPath, p.Prog.Files.File(p.Pos).Name())
+
+       // TODO(adonovan): make order deterministic.
+       maxname := 0
+       for name := range p.Members {
+               if l := len(name); l > maxname {
+                       maxname = l
+               }
+       }
+
+       for name, mem := range p.Members {
+               switch mem := mem.(type) {
+               case *Literal:
+                       fmt.Fprintf(&b, " const %-*s %s\n", maxname, name, mem.Name())
+
+               case *Function:
+                       fmt.Fprintf(&b, " func  %-*s %s\n", maxname, name, mem.Type())
+
+               case *Type:
+                       fmt.Fprintf(&b, " type  %-*s %s\n", maxname, name, mem.NamedType.Underlying)
+                       // TODO(adonovan): make order deterministic.
+                       for name, method := range mem.Methods {
+                               fmt.Fprintf(&b, "       method %s %s\n", name, method.Signature)
+                       }
+
+               case *Global:
+                       fmt.Fprintf(&b, " var   %-*s %s\n", maxname, name, mem.Type())
+
+               }
+       }
+       return b.String()
+}
+
+func commaOk(x bool) string {
+       if x {
+               return ",ok"
+       }
+       return ""
+}
diff --git a/src/pkg/exp/ssa/sanity.go b/src/pkg/exp/ssa/sanity.go
new file mode 100644 (file)
index 0000000..bbb30cf
--- /dev/null
@@ -0,0 +1,263 @@
+package ssa
+
+// An optional pass for sanity checking invariants of the SSA representation.
+// Currently it checks CFG invariants but little at the instruction level.
+
+import (
+       "bytes"
+       "fmt"
+       "io"
+       "os"
+)
+
+type sanity struct {
+       reporter io.Writer
+       fn       *Function
+       block    *BasicBlock
+       insane   bool
+}
+
+// SanityCheck performs integrity checking of the SSA representation
+// of the function fn and returns true if it was valid.  Diagnostics
+// are written to reporter if non-nil, os.Stderr otherwise.  Some
+// diagnostics are only warnings and do not imply a negative result.
+//
+// Sanity checking is intended to facilitate the debugging of code
+// transformation passes.
+//
+func SanityCheck(fn *Function, reporter io.Writer) bool {
+       if reporter == nil {
+               reporter = os.Stderr
+       }
+       return (&sanity{reporter: reporter}).checkFunction(fn)
+}
+
+// MustSanityCheck is like SanityCheck but panics instead of returning
+// a negative result.
+//
+func MustSanityCheck(fn *Function, reporter io.Writer) {
+       if !SanityCheck(fn, reporter) {
+               panic("SanityCheck failed")
+       }
+}
+
+// blockNames returns the names of the specified blocks as a
+// human-readable string.
+//
+func blockNames(blocks []*BasicBlock) string {
+       var buf bytes.Buffer
+       for i, b := range blocks {
+               if i > 0 {
+                       io.WriteString(&buf, ", ")
+               }
+               io.WriteString(&buf, b.Name)
+       }
+       return buf.String()
+}
+
+func (s *sanity) diagnostic(prefix, format string, args ...interface{}) {
+       fmt.Fprintf(s.reporter, "%s: function %s", prefix, s.fn.FullName())
+       if s.block != nil {
+               fmt.Fprintf(s.reporter, ", block %s", s.block.Name)
+       }
+       io.WriteString(s.reporter, ": ")
+       fmt.Fprintf(s.reporter, format, args...)
+       io.WriteString(s.reporter, "\n")
+}
+
+func (s *sanity) errorf(format string, args ...interface{}) {
+       s.insane = true
+       s.diagnostic("Error", format, args...)
+}
+
+func (s *sanity) warnf(format string, args ...interface{}) {
+       s.diagnostic("Warning", format, args...)
+}
+
+// findDuplicate returns an arbitrary basic block that appeared more
+// than once in blocks, or nil if all were unique.
+func findDuplicate(blocks []*BasicBlock) *BasicBlock {
+       if len(blocks) < 2 {
+               return nil
+       }
+       if blocks[0] == blocks[1] {
+               return blocks[0]
+       }
+       // Slow path:
+       m := make(map[*BasicBlock]bool)
+       for _, b := range blocks {
+               if m[b] {
+                       return b
+               }
+               m[b] = true
+       }
+       return nil
+}
+
+func (s *sanity) checkInstr(idx int, instr Instruction) {
+       switch instr := instr.(type) {
+       case *If, *Jump, *Ret:
+               s.errorf("control flow instruction not at end of block")
+       case *Phi:
+               if idx == 0 {
+                       // It suffices to apply this check to just the first phi node.
+                       if dup := findDuplicate(s.block.Preds); dup != nil {
+                               s.errorf("phi node in block with duplicate predecessor %s", dup.Name)
+                       }
+               } else {
+                       prev := s.block.Instrs[idx-1]
+                       if _, ok := prev.(*Phi); !ok {
+                               s.errorf("Phi instruction follows a non-Phi: %T", prev)
+                       }
+               }
+               if ne, np := len(instr.Edges), len(s.block.Preds); ne != np {
+                       s.errorf("phi node has %d edges but %d predecessors", ne, np)
+               }
+
+       case *Alloc:
+       case *Call:
+       case *BinOp:
+       case *UnOp:
+       case *MakeClosure:
+       case *MakeChan:
+       case *MakeMap:
+       case *MakeSlice:
+       case *Slice:
+       case *Field:
+       case *FieldAddr:
+       case *IndexAddr:
+       case *Index:
+       case *Select:
+       case *Range:
+       case *TypeAssert:
+       case *Extract:
+       case *Go:
+       case *Defer:
+       case *Send:
+       case *Store:
+       case *MapUpdate:
+       case *Next:
+       case *Lookup:
+       case *Conv:
+       case *ChangeInterface:
+       case *MakeInterface:
+               // TODO(adonovan): implement checks.
+       default:
+               panic(fmt.Sprintf("Unknown instruction type: %T", instr))
+       }
+}
+
+func (s *sanity) checkFinalInstr(idx int, instr Instruction) {
+       switch instr.(type) {
+       case *If:
+               if nsuccs := len(s.block.Succs); nsuccs != 2 {
+                       s.errorf("If-terminated block has %d successors; expected 2", nsuccs)
+                       return
+               }
+               if s.block.Succs[0] == s.block.Succs[1] {
+                       s.errorf("If-instruction has same True, False target blocks: %s", s.block.Succs[0].Name)
+                       return
+               }
+
+       case *Jump:
+               if nsuccs := len(s.block.Succs); nsuccs != 1 {
+                       s.errorf("Jump-terminated block has %d successors; expected 1", nsuccs)
+                       return
+               }
+
+       case *Ret:
+               if nsuccs := len(s.block.Succs); nsuccs != 0 {
+                       s.errorf("Ret-terminated block has %d successors; expected none", nsuccs)
+                       return
+               }
+               // TODO(adonovan): check number and types of results
+
+       default:
+               s.errorf("non-control flow instruction at end of block")
+       }
+}
+
+func (s *sanity) checkBlock(b *BasicBlock, isEntry bool) {
+       s.block = b
+
+       // Check all blocks are reachable.
+       // (The entry block is always implicitly reachable.)
+       if !isEntry && len(b.Preds) == 0 {
+               s.warnf("unreachable block")
+               if b.Instrs == nil {
+                       // Since this block is about to be pruned,
+                       // tolerating transient problems in it
+                       // simplifies other optimisations.
+                       return
+               }
+       }
+
+       // Check predecessor and successor relations are dual.
+       for _, a := range b.Preds {
+               found := false
+               for _, bb := range a.Succs {
+                       if bb == b {
+                               found = true
+                               break
+                       }
+               }
+               if !found {
+                       s.errorf("expected successor edge in predecessor %s; found only: %s", a.Name, blockNames(a.Succs))
+               }
+       }
+       for _, c := range b.Succs {
+               found := false
+               for _, bb := range c.Preds {
+                       if bb == b {
+                               found = true
+                               break
+                       }
+               }
+               if !found {
+                       s.errorf("expected predecessor edge in successor %s; found only: %s", c.Name, blockNames(c.Preds))
+               }
+       }
+
+       // Check each instruction is sane.
+       n := len(b.Instrs)
+       if n == 0 {
+               s.errorf("basic block contains no instructions")
+       }
+       for j, instr := range b.Instrs {
+               if b2 := instr.Block(); b2 == nil {
+                       s.errorf("nil Block() for instruction at index %d", j)
+                       continue
+               } else if b2 != b {
+                       s.errorf("wrong Block() (%s) for instruction at index %d ", b2.Name, j)
+                       continue
+               }
+               if j < n-1 {
+                       s.checkInstr(j, instr)
+               } else {
+                       s.checkFinalInstr(j, instr)
+               }
+       }
+}
+
+func (s *sanity) checkFunction(fn *Function) bool {
+       // TODO(adonovan): check Function invariants:
+       // - check owning Package (if any) contains this function.
+       // - check params match signature
+       // - check locals are all !Heap
+       // - check transient fields are nil
+       // - check block labels are unique (warning)
+       s.fn = fn
+       if fn.Prog == nil {
+               s.errorf("nil Prog")
+       }
+       for i, b := range fn.Blocks {
+               if b == nil {
+                       s.warnf("nil *BasicBlock at f.Blocks[%d]", i)
+                       continue
+               }
+               s.checkBlock(b, i == 0)
+       }
+       s.block = nil
+       s.fn = nil
+       return !s.insane
+}
index 904280ae2015dc101453d8526db3e173355473a7..8e503dc35b8ea7782f1018e7e2fa8dea5b9aca4c 100644 (file)
@@ -222,12 +222,12 @@ type Function struct {
 
        // The following fields are set transiently during building,
        // then cleared.
-       currentBlock *BasicBlock            // where to emit code
-       objects      map[types.Object]Value // addresses of local variables
-       results      []*Alloc               // tuple of named results
-       // syntax    *funcSyntax             // abstract syntax trees for Go source functions
-       // targets   *targets                // linked stack of branch targets
-       // lblocks   map[*ast.Object]*lblock // labelled blocks
+       currentBlock *BasicBlock             // where to emit code
+       objects      map[types.Object]Value  // addresses of local variables
+       results      []*Alloc                // tuple of named results
+       syntax       *funcSyntax             // abstract syntax trees for Go source functions
+       targets      *targets                // linked stack of branch targets
+       lblocks      map[*ast.Object]*lblock // labelled blocks
 }
 
 // An SSA basic block.
@@ -984,18 +984,9 @@ func (v *Capture) Name() string     { return v.Outer.Name() }
 
 func (v *Global) Type() types.Type { return v.Type_ }
 func (v *Global) Name() string     { return v.Name_ }
-func (v *Global) String() string   { return v.Name_ } // placeholder
 
 func (v *Function) Name() string     { return v.Name_ }
 func (v *Function) Type() types.Type { return v.Signature }
-func (v *Function) String() string   { return v.Name_ } // placeholder
-
-// FullName returns v's package-qualified name.
-func (v *Global) FullName() string { return fmt.Sprintf("%s.%s", v.Pkg.ImportPath, v.Name_) }
-
-func (v *Literal) Name() string     { return "Literal" } // placeholder
-func (v *Literal) String() string   { return "Literal" } // placeholder
-func (v *Literal) Type() types.Type { return v.Type_ }   // placeholder
 
 func (v *Parameter) Type() types.Type { return v.Type_ }
 func (v *Parameter) Name() string     { return v.Name_ }
diff --git a/src/pkg/exp/ssa/util.go b/src/pkg/exp/ssa/util.go
new file mode 100644 (file)
index 0000000..0d2ebde
--- /dev/null
@@ -0,0 +1,172 @@
+package ssa
+
+// This file defines a number of miscellaneous utility functions.
+
+import (
+       "fmt"
+       "go/ast"
+       "go/types"
+)
+
+func unreachable() {
+       panic("unreachable")
+}
+
+//// AST utilities
+
+// noparens returns e with any enclosing parentheses stripped.
+func noparens(e ast.Expr) ast.Expr {
+       for {
+               p, ok := e.(*ast.ParenExpr)
+               if !ok {
+                       break
+               }
+               e = p.X
+       }
+       return e
+}
+
+// isBlankIdent returns true iff e is an Ident with name "_".
+// They have no associated types.Object, and thus no type.
+//
+// TODO(gri): consider making typechecker not treat them differently.
+// It's one less thing for clients like us to worry about.
+//
+func isBlankIdent(e ast.Expr) bool {
+       id, ok := e.(*ast.Ident)
+       return ok && id.Name == "_"
+}
+
+//// Type utilities.  Some of these belong in go/types.
+
+// underlyingType returns the underlying type of typ.
+// TODO(gri): this is a copy of go/types.underlying; export that function.
+//
+func underlyingType(typ types.Type) types.Type {
+       if typ, ok := typ.(*types.NamedType); ok {
+               return typ.Underlying // underlying types are never NamedTypes
+       }
+       if typ == nil {
+               panic("underlyingType(nil)")
+       }
+       return typ
+}
+
+// isPointer returns true for types whose underlying type is a pointer.
+func isPointer(typ types.Type) bool {
+       if nt, ok := typ.(*types.NamedType); ok {
+               typ = nt.Underlying
+       }
+       _, ok := typ.(*types.Pointer)
+       return ok
+}
+
+// pointer(typ) returns the type that is a pointer to typ.
+func pointer(typ types.Type) *types.Pointer {
+       return &types.Pointer{Base: typ}
+}
+
+// indirect(typ) assumes that typ is a pointer type,
+// or named alias thereof, and returns its base type.
+// Panic ensures if it is not a pointer.
+//
+func indirectType(ptr types.Type) types.Type {
+       if v, ok := underlyingType(ptr).(*types.Pointer); ok {
+               return v.Base
+       }
+       // When debugging it is convenient to comment out this line
+       // and let it continue to print the (illegal) SSA form.
+       panic("indirect() of non-pointer type: " + ptr.String())
+       return nil
+}
+
+// deref returns a pointer's base type; otherwise it returns typ.
+func deref(typ types.Type) types.Type {
+       if typ, ok := underlyingType(typ).(*types.Pointer); ok {
+               return typ.Base
+       }
+       return typ
+}
+
+// methodIndex returns the method (and its index) named id within the
+// method table methods of named or interface type typ.  If not found,
+// panic ensues.
+//
+func methodIndex(typ types.Type, methods []*types.Method, id Id) (i int, m *types.Method) {
+       for i, m = range methods {
+               if IdFromQualifiedName(m.QualifiedName) == id {
+                       return
+               }
+       }
+       panic(fmt.Sprint("method not found: ", id, " in interface ", typ))
+}
+
+// objKind returns the syntactic category of the named entity denoted by obj.
+func objKind(obj types.Object) ast.ObjKind {
+       switch obj.(type) {
+       case *types.Package:
+               return ast.Pkg
+       case *types.TypeName:
+               return ast.Typ
+       case *types.Const:
+               return ast.Con
+       case *types.Var:
+               return ast.Var
+       case *types.Func:
+               return ast.Fun
+       }
+       panic(fmt.Sprintf("unexpected Object type: %T", obj))
+}
+
+// DefaultType returns the default "typed" type for an "untyped" type;
+// it returns the incoming type for all other types. If there is no
+// corresponding untyped type, the result is types.Typ[types.Invalid].
+//
+// Exported to exp/ssa/interp.
+//
+// TODO(gri): this is a copy of go/types.defaultType; export that function.
+//
+func DefaultType(typ types.Type) types.Type {
+       if t, ok := typ.(*types.Basic); ok {
+               k := types.Invalid
+               switch t.Kind {
+               // case UntypedNil:
+               //      There is no default type for nil. For a good error message,
+               //      catch this case before calling this function.
+               case types.UntypedBool:
+                       k = types.Bool
+               case types.UntypedInt:
+                       k = types.Int
+               case types.UntypedRune:
+                       k = types.Rune
+               case types.UntypedFloat:
+                       k = types.Float64
+               case types.UntypedComplex:
+                       k = types.Complex128
+               case types.UntypedString:
+                       k = types.String
+               }
+               typ = types.Typ[k]
+       }
+       return typ
+}
+
+// makeId returns the Id (name, pkg) if the name is exported or
+// (name, nil) otherwise.
+//
+func makeId(name string, pkg *types.Package) (id Id) {
+       id.Name = name
+       if !ast.IsExported(name) {
+               id.Pkg = pkg
+       }
+       return
+}
+
+// IdFromQualifiedName returns the Id (qn.Name, qn.Pkg) if qn is an
+// exported name or (qn.Name, nil) otherwise.
+//
+// Exported to exp/ssa/interp.
+//
+func IdFromQualifiedName(qn types.QualifiedName) Id {
+       return makeId(qn.Name, qn.Pkg)
+}