--- /dev/null
+package ssa
+
+// Simple block optimisations to simplify the control flow graph.
+
+// TODO(adonovan): instead of creating several "unreachable" blocks
+// per function in the Builder, reuse a single one (e.g. at Blocks[1])
+// to reduce garbage.
+//
+// TODO(adonovan): in the absence of multiway branch instructions,
+// each BasicBlock has 0, 1, or 2 successors. We should preallocate
+// the backing array for the Succs slice inline in BasicBlock.
+
+import (
+ "fmt"
+ "os"
+)
+
+// If true, perform sanity checking and show progress at each
+// successive iteration of optimizeBlocks. Very verbose.
+const debugBlockOpt = false
+
+func hasPhi(b *BasicBlock) bool {
+ _, ok := b.Instrs[0].(*Phi)
+ return ok
+}
+
+// prune attempts to prune block b if it is unreachable (i.e. has no
+// predecessors other than itself), disconnecting it from the CFG.
+// The result is true if the optimisation was applied. i is the block
+// index within the function.
+//
+func prune(f *Function, i int, b *BasicBlock) bool {
+ if i == 0 {
+ return false // don't prune entry block
+ }
+ if len(b.Preds) == 0 || len(b.Preds) == 1 && b.Preds[0] == b {
+ // Disconnect it from its successors.
+ for _, c := range b.Succs {
+ c.removePred(b)
+ }
+ if debugBlockOpt {
+ fmt.Fprintln(os.Stderr, "prune", b.Name)
+ }
+
+ // Delete b.
+ f.Blocks[i] = nil
+ return true
+ }
+ return false
+}
+
+// jumpThreading attempts to apply simple jump-threading to block b,
+// in which a->b->c become a->c if b is just a Jump.
+// The result is true if the optimisation was applied.
+// i is the block index within the function.
+//
+func jumpThreading(f *Function, i int, b *BasicBlock) bool {
+ if i == 0 {
+ return false // don't apply to entry block
+ }
+ if b.Instrs == nil {
+ fmt.Println("empty block ", b.Name)
+ return false
+ }
+ if _, ok := b.Instrs[0].(*Jump); !ok {
+ return false // not just a jump
+ }
+ c := b.Succs[0]
+ if c == b {
+ return false // don't apply to degenerate jump-to-self.
+ }
+ if hasPhi(c) {
+ return false // not sound without more effort
+ }
+ for j, a := range b.Preds {
+ a.replaceSucc(b, c)
+
+ // If a now has two edges to c, replace its degenerate If by Jump.
+ if len(a.Succs) == 2 && a.Succs[0] == c && a.Succs[1] == c {
+ jump := new(Jump)
+ jump.SetBlock(a)
+ a.Instrs[len(a.Instrs)-1] = jump
+ a.Succs = a.Succs[:1]
+ c.removePred(b)
+ } else {
+ if j == 0 {
+ c.replacePred(b, a)
+ } else {
+ c.Preds = append(c.Preds, a)
+ }
+ }
+
+ if debugBlockOpt {
+ fmt.Fprintln(os.Stderr, "jumpThreading", a.Name, b.Name, c.Name)
+ }
+ }
+ f.Blocks[i] = nil
+ return true
+}
+
+// fuseBlocks attempts to apply the block fusion optimisation to block
+// a, in which a->b becomes ab if len(a.Succs)==len(b.Preds)==1.
+// The result is true if the optimisation was applied.
+//
+func fuseBlocks(f *Function, a *BasicBlock) bool {
+ if len(a.Succs) != 1 {
+ return false
+ }
+ b := a.Succs[0]
+ if len(b.Preds) != 1 {
+ return false
+ }
+ // Eliminate jump at end of A, then copy all of B across.
+ a.Instrs = append(a.Instrs[:len(a.Instrs)-1], b.Instrs...)
+ for _, instr := range b.Instrs {
+ instr.SetBlock(a)
+ }
+
+ // A inherits B's successors
+ a.Succs = b.Succs
+
+ // Fix up Preds links of all successors of B.
+ for _, c := range b.Succs {
+ c.replacePred(b, a)
+ }
+
+ if debugBlockOpt {
+ fmt.Fprintln(os.Stderr, "fuseBlocks", a.Name, b.Name)
+ }
+
+ // Make b unreachable. Subsequent pruning will reclaim it.
+ b.Preds = nil
+ return true
+}
+
+// optimizeBlocks() performs some simple block optimizations on a
+// completed function: dead block elimination, block fusion, jump
+// threading.
+//
+func optimizeBlocks(f *Function) {
+ // Loop until no further progress.
+ changed := true
+ for changed {
+ changed = false
+
+ if debugBlockOpt {
+ f.DumpTo(os.Stderr)
+ MustSanityCheck(f, nil)
+ }
+
+ for i, b := range f.Blocks {
+ // f.Blocks will temporarily contain nils to indicate
+ // deleted blocks; we remove them at the end.
+ if b == nil {
+ continue
+ }
+
+ // Prune unreachable blocks (including all empty blocks).
+ if prune(f, i, b) {
+ changed = true
+ continue // (b was pruned)
+ }
+
+ // Fuse blocks. b->c becomes bc.
+ if fuseBlocks(f, b) {
+ changed = true
+ }
+
+ // a->b->c becomes a->c if b contains only a Jump.
+ if jumpThreading(f, i, b) {
+ changed = true
+ continue // (b was disconnected)
+ }
+ }
+ }
+
+ // Eliminate nils from Blocks.
+ j := 0
+ for _, b := range f.Blocks {
+ if b != nil {
+ f.Blocks[j] = b
+ j++
+ }
+ }
+ // Nil out b.Blocks[j:] to aid GC.
+ for i := j; i < len(f.Blocks); i++ {
+ f.Blocks[i] = nil
+ }
+ f.Blocks = f.Blocks[:j]
+}
--- /dev/null
+package ssa
+
+// This file implements the Function and BasicBlock types.
+
+import (
+ "fmt"
+ "go/ast"
+ "go/types"
+ "io"
+ "os"
+)
+
+// addEdge adds a control-flow graph edge from from to to.
+func addEdge(from, to *BasicBlock) {
+ from.Succs = append(from.Succs, to)
+ to.Preds = append(to.Preds, from)
+}
+
+// emit appends an instruction to the current basic block.
+// If the instruction defines a Value, it is returned.
+//
+func (b *BasicBlock) emit(i Instruction) Value {
+ i.SetBlock(b)
+ b.Instrs = append(b.Instrs, i)
+ v, _ := i.(Value)
+ return v
+}
+
+// phis returns the prefix of b.Instrs containing all the block's φ-nodes.
+func (b *BasicBlock) phis() []Instruction {
+ for i, instr := range b.Instrs {
+ if _, ok := instr.(*Phi); !ok {
+ return b.Instrs[:i]
+ }
+ }
+ return nil // unreachable in well-formed blocks
+}
+
+// replacePred replaces all occurrences of p in b's predecessor list with q.
+// Ordinarily there should be at most one.
+//
+func (b *BasicBlock) replacePred(p, q *BasicBlock) {
+ for i, pred := range b.Preds {
+ if pred == p {
+ b.Preds[i] = q
+ }
+ }
+}
+
+// replaceSucc replaces all occurrences of p in b's successor list with q.
+// Ordinarily there should be at most one.
+//
+func (b *BasicBlock) replaceSucc(p, q *BasicBlock) {
+ for i, succ := range b.Succs {
+ if succ == p {
+ b.Succs[i] = q
+ }
+ }
+}
+
+// removePred removes all occurrences of p in b's
+// predecessor list and φ-nodes.
+// Ordinarily there should be at most one.
+//
+func (b *BasicBlock) removePred(p *BasicBlock) {
+ phis := b.phis()
+
+ // We must preserve edge order for φ-nodes.
+ j := 0
+ for i, pred := range b.Preds {
+ if pred != p {
+ b.Preds[j] = b.Preds[i]
+ // Strike out φ-edge too.
+ for _, instr := range phis {
+ phi := instr.(*Phi)
+ phi.Edges[j] = phi.Edges[i]
+ }
+ j++
+ }
+ }
+ // Nil out b.Preds[j:] and φ-edges[j:] to aid GC.
+ for i := j; i < len(b.Preds); i++ {
+ b.Preds[i] = nil
+ for _, instr := range phis {
+ instr.(*Phi).Edges[i] = nil
+ }
+ }
+ b.Preds = b.Preds[:j]
+ for _, instr := range phis {
+ phi := instr.(*Phi)
+ phi.Edges = phi.Edges[:j]
+ }
+}
+
+// Destinations associated with unlabelled for/switch/select stmts.
+// We push/pop one of these as we enter/leave each construct and for
+// each BranchStmt we scan for the innermost target of the right type.
+//
+type targets struct {
+ tail *targets // rest of stack
+ _break *BasicBlock
+ _continue *BasicBlock
+ _fallthrough *BasicBlock
+}
+
+// Destinations associated with a labelled block.
+// We populate these as labels are encountered in forward gotos or
+// labelled statements.
+//
+type lblock struct {
+ _goto *BasicBlock
+ _break *BasicBlock
+ _continue *BasicBlock
+}
+
+// funcSyntax holds the syntax tree for the function declaration and body.
+type funcSyntax struct {
+ recvField *ast.FieldList
+ paramFields *ast.FieldList
+ resultFields *ast.FieldList
+ body *ast.BlockStmt
+}
+
+// labelledBlock returns the branch target associated with the
+// specified label, creating it if needed.
+//
+func (f *Function) labelledBlock(label *ast.Ident) *lblock {
+ lb := f.lblocks[label.Obj]
+ if lb == nil {
+ lb = &lblock{_goto: f.newBasicBlock("label." + label.Name)}
+ f.lblocks[label.Obj] = lb
+ }
+ return lb
+}
+
+// addParam adds a (non-escaping) parameter to f.Params of the
+// specified name and type.
+//
+func (f *Function) addParam(name string, typ types.Type) *Parameter {
+ v := &Parameter{
+ Name_: name,
+ Type_: pointer(typ), // address of param
+ }
+ f.Params = append(f.Params, v)
+ return v
+}
+
+func (f *Function) addObjParam(obj types.Object) *Parameter {
+ p := f.addParam(obj.GetName(), obj.GetType())
+ f.objects[obj] = p
+ return p
+}
+
+// start initializes the function prior to generating SSA code for its body.
+// Precondition: f.Type() already set.
+//
+// If f.syntax != nil, f is a Go source function and idents must be a
+// mapping from syntactic identifiers to their canonical type objects;
+// Otherwise, idents is ignored and the usual set-up for Go source
+// functions is skipped.
+//
+func (f *Function) start(mode BuilderMode, idents map[*ast.Ident]types.Object) {
+ if mode&LogSource != 0 {
+ fmt.Fprintf(os.Stderr, "build function %s @ %s\n", f.FullName(), f.Prog.Files.Position(f.Pos))
+ }
+ f.currentBlock = f.newBasicBlock("entry")
+ f.objects = make(map[types.Object]Value) // needed for some synthetics, e.g. init
+ if f.syntax == nil {
+ return // synthetic function; no syntax tree
+ }
+ f.lblocks = make(map[*ast.Object]*lblock)
+
+ // Receiver (at most one inner iteration).
+ if f.syntax.recvField != nil {
+ for _, field := range f.syntax.recvField.List {
+ for _, n := range field.Names {
+ f.addObjParam(idents[n])
+ }
+ if field.Names == nil {
+ f.addParam(f.Signature.Recv.Name, f.Signature.Recv.Type)
+ }
+ }
+ }
+
+ // Parameters.
+ if f.syntax.paramFields != nil {
+ for _, field := range f.syntax.paramFields.List {
+ for _, n := range field.Names {
+ f.addObjParam(idents[n])
+ }
+ }
+ }
+
+ // Results.
+ if f.syntax.resultFields != nil {
+ for _, field := range f.syntax.resultFields.List {
+ // Implicit "var" decl of locals for named results.
+ for _, n := range field.Names {
+ f.results = append(f.results, f.addNamedLocal(idents[n]))
+ }
+ }
+ }
+}
+
+// finish() finalizes the function after SSA code generation of its body.
+func (f *Function) finish(mode BuilderMode) {
+ f.objects = nil
+ f.results = nil
+ f.currentBlock = nil
+ f.lblocks = nil
+ f.syntax = nil
+
+ // Remove any f.Locals that are now heap-allocated.
+ j := 0
+ for _, l := range f.Locals {
+ if !l.Heap {
+ f.Locals[j] = l
+ j++
+ }
+ }
+ // Nil out f.Locals[j:] to aid GC.
+ for i := j; i < len(f.Locals); i++ {
+ f.Locals[i] = nil
+ }
+ f.Locals = f.Locals[:j]
+
+ // Ensure all value-defining Instructions have register names.
+ // (Non-Instruction Values are named at construction.)
+ tmp := 0
+ for _, b := range f.Blocks {
+ for _, instr := range b.Instrs {
+ switch instr := instr.(type) {
+ case *Alloc:
+ // Local Allocs may already be named.
+ if instr.Name_ == "" {
+ instr.Name_ = fmt.Sprintf("t%d", tmp)
+ tmp++
+ }
+ case Value:
+ instr.(interface {
+ setNum(int)
+ }).setNum(tmp)
+ tmp++
+ }
+ }
+ }
+ optimizeBlocks(f)
+
+ if mode&LogFunctions != 0 {
+ f.DumpTo(os.Stderr)
+ }
+ if mode&SanityCheckFunctions != 0 {
+ MustSanityCheck(f, nil)
+ }
+ if mode&LogSource != 0 {
+ fmt.Fprintf(os.Stderr, "build function %s done\n", f.FullName())
+ }
+}
+
+// addNamedLocal creates a local variable, adds it to function f and
+// returns it. Its name and type are taken from obj. Subsequent
+// calls to f.lookup(obj) will return the same local.
+//
+// Precondition: f.syntax != nil (i.e. a Go source function).
+//
+func (f *Function) addNamedLocal(obj types.Object) *Alloc {
+ l := f.addLocal(obj.GetType())
+ l.Name_ = obj.GetName()
+ f.objects[obj] = l
+ return l
+}
+
+// addLocal creates an anonymous local variable of type typ, adds it
+// to function f and returns it.
+//
+func (f *Function) addLocal(typ types.Type) *Alloc {
+ v := &Alloc{Type_: pointer(typ)}
+ f.Locals = append(f.Locals, v)
+ f.emit(v)
+ return v
+}
+
+// lookup returns the address of the named variable identified by obj
+// that is local to function f or one of its enclosing functions.
+// If escaping, the reference comes from a potentially escaping pointer
+// expression and the referent must be heap-allocated.
+//
+func (f *Function) lookup(obj types.Object, escaping bool) Value {
+ if v, ok := f.objects[obj]; ok {
+ if escaping {
+ switch v := v.(type) {
+ case *Capture:
+ // TODO(adonovan): fix: we must support this case.
+ // Requires copying to a 'new' Alloc.
+ fmt.Fprintln(os.Stderr, "Error: escaping reference to Capture")
+ case *Parameter:
+ v.Heap = true
+ case *Alloc:
+ v.Heap = true
+ default:
+ panic(fmt.Sprintf("Unexpected Function.objects kind: %T", v))
+ }
+ }
+ return v // function-local var (address)
+ }
+
+ // Definition must be in an enclosing function;
+ // plumb it through intervening closures.
+ if f.Enclosing == nil {
+ panic("no Value for type.Object " + obj.GetName())
+ }
+ v := &Capture{f.Enclosing.lookup(obj, true)} // escaping
+ f.objects[obj] = v
+ f.FreeVars = append(f.FreeVars, v)
+ return v
+}
+
+// emit emits the specified instruction to function f, updating the
+// control-flow graph if required.
+//
+func (f *Function) emit(instr Instruction) Value {
+ return f.currentBlock.emit(instr)
+}
+
+// DumpTo prints to w a human readable "disassembly" of the SSA code of
+// all basic blocks of function f.
+//
+func (f *Function) DumpTo(w io.Writer) {
+ fmt.Fprintf(w, "# Name: %s\n", f.FullName())
+ fmt.Fprintf(w, "# Declared at %s\n", f.Prog.Files.Position(f.Pos))
+ fmt.Fprintf(w, "# Type: %s\n", f.Type())
+
+ if f.Enclosing != nil {
+ fmt.Fprintf(w, "# Parent: %s\n", f.Enclosing.Name())
+ }
+
+ if f.FreeVars != nil {
+ io.WriteString(w, "# Free variables:\n")
+ for i, fv := range f.FreeVars {
+ fmt.Fprintf(w, "# % 3d:\t%s %s\n", i, fv.Name(), fv.Type())
+ }
+ }
+
+ params := f.Params
+ if f.Signature.Recv != nil {
+ fmt.Fprintf(w, "func (%s) %s(", params[0].Name(), f.Name())
+ params = params[1:]
+ } else {
+ fmt.Fprintf(w, "func %s(", f.Name())
+ }
+ for i, v := range params {
+ if i > 0 {
+ io.WriteString(w, ", ")
+ }
+ io.WriteString(w, v.Name())
+ }
+ io.WriteString(w, "):\n")
+
+ for _, b := range f.Blocks {
+ if b == nil {
+ // Corrupt CFG.
+ fmt.Fprintf(w, ".nil:\n")
+ continue
+ }
+ fmt.Fprintf(w, ".%s:\t\t\t\t\t\t\t P:%d S:%d\n", b.Name, len(b.Preds), len(b.Succs))
+ if false { // CFG debugging
+ fmt.Fprintf(w, "\t# CFG: %s --> %s --> %s\n", blockNames(b.Preds), b.Name, blockNames(b.Succs))
+ }
+ for _, instr := range b.Instrs {
+ io.WriteString(w, "\t")
+ if v, ok := instr.(Value); ok {
+ l := 80 // for old time's sake.
+ // Left-align the instruction.
+ if name := v.Name(); name != "" {
+ n, _ := fmt.Fprintf(w, "%s = ", name)
+ l -= n
+ }
+ n, _ := io.WriteString(w, instr.String())
+ l -= n
+ // Right-align the type.
+ if t := v.Type(); t != nil {
+ fmt.Fprintf(w, "%*s", l-9, t)
+ }
+ } else {
+ io.WriteString(w, instr.String())
+ }
+ io.WriteString(w, "\n")
+ }
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+// newBasicBlock adds to f a new basic block with a unique name and
+// returns it. It does not automatically become the current block for
+// subsequent calls to emit.
+//
+func (f *Function) newBasicBlock(name string) *BasicBlock {
+ b := &BasicBlock{
+ Name: fmt.Sprintf("%d.%s", len(f.Blocks), name),
+ Func: f,
+ }
+ f.Blocks = append(f.Blocks, b)
+ return b
+}