diff --git a/compiler/compiler.go b/compiler/compiler.go index 1da298b5..d04f883a 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -68,8 +68,9 @@ type Compiler struct { type Frame struct { fn *ir.Function - locals map[ssa.Value]llvm.Value // local variables - blocks map[*ssa.BasicBlock]llvm.BasicBlock + locals map[ssa.Value]llvm.Value // local variables + blockEntries map[*ssa.BasicBlock]llvm.BasicBlock // a *ssa.BasicBlock may be split up + blockExits map[*ssa.BasicBlock]llvm.BasicBlock // these are the exit blocks currentBlock *ssa.BasicBlock phis []Phi blocking bool @@ -917,10 +918,11 @@ func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) { frame := &Frame{ - fn: f, - locals: make(map[ssa.Value]llvm.Value), - blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), - blocking: c.ir.IsBlocking(f), + fn: f, + locals: make(map[ssa.Value]llvm.Value), + blockEntries: make(map[*ssa.BasicBlock]llvm.BasicBlock), + blockExits: make(map[*ssa.BasicBlock]llvm.BasicBlock), + blocking: c.ir.IsBlocking(f), } var retType llvm.Type @@ -1346,13 +1348,14 @@ func (c *Compiler) parseFunc(frame *Frame) error { // Pre-create all basic blocks in the function. for _, block := range frame.fn.DomPreorder() { llvmBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, block.Comment) - frame.blocks[block] = llvmBlock + frame.blockEntries[block] = llvmBlock + frame.blockExits[block] = llvmBlock } if frame.blocking { frame.cleanupBlock = c.ctx.AddBasicBlock(frame.fn.LLVMFn, "task.cleanup") frame.suspendBlock = c.ctx.AddBasicBlock(frame.fn.LLVMFn, "task.suspend") } - entryBlock := frame.blocks[frame.fn.Blocks[0]] + entryBlock := frame.blockEntries[frame.fn.Blocks[0]] c.builder.SetInsertPointAtEnd(entryBlock) // Load function parameters @@ -1479,9 +1482,9 @@ func (c *Compiler) parseFunc(frame *Frame) error { // Fill blocks with instructions. for _, block := range frame.fn.DomPreorder() { if c.DumpSSA { - fmt.Printf("%s:\n", block.Comment) + fmt.Printf("%d: %s:\n", block.Index, block.Comment) } - c.builder.SetInsertPointAtEnd(frame.blocks[block]) + c.builder.SetInsertPointAtEnd(frame.blockEntries[block]) frame.currentBlock = block for _, instr := range block.Instrs { if _, ok := instr.(*ssa.DebugRef); ok { @@ -1512,7 +1515,7 @@ func (c *Compiler) parseFunc(frame *Frame) error { if err != nil { return err } - llvmBlock := frame.blocks[block.Preds[i]] + llvmBlock := frame.blockExits[block.Preds[i]] phi.llvm.AddIncoming([]llvm.Value{llvmVal}, []llvm.BasicBlock{llvmBlock}) } } @@ -1651,12 +1654,12 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { return err } block := instr.Block() - blockThen := frame.blocks[block.Succs[0]] - blockElse := frame.blocks[block.Succs[1]] + blockThen := frame.blockEntries[block.Succs[0]] + blockElse := frame.blockEntries[block.Succs[1]] c.builder.CreateCondBr(cond, blockThen, blockElse) return nil case *ssa.Jump: - blockJump := frame.blocks[instr.Block().Succs[0]] + blockJump := frame.blockEntries[instr.Block().Succs[0]] c.builder.CreateBr(blockJump) return nil case *ssa.MapUpdate: @@ -2793,7 +2796,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { prevBlock := c.builder.GetInsertBlock() okBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "typeassert.ok") nextBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "typeassert.next") - frame.blocks[frame.currentBlock] = nextBlock // adjust outgoing block for phi nodes + frame.blockExits[frame.currentBlock] = nextBlock // adjust outgoing block for phi nodes c.builder.CreateCondBr(commaOk, okBlock, nextBlock) // Retrieve the value from the interface if the type assert was diff --git a/testdata/interface.go b/testdata/interface.go index 63e3d5f2..e931934a 100644 --- a/testdata/interface.go +++ b/testdata/interface.go @@ -20,6 +20,8 @@ func main() { println("Stringer.String():", s.String()) var itf interface{} = s println("Stringer.(*Thing).String():", itf.(Stringer).String()) + + println("nested switch:", nestedSwitch('v', 3)) } func printItf(val interface{}) { @@ -46,6 +48,17 @@ func printItf(val interface{}) { } } +func nestedSwitch(verb rune, arg interface{}) bool { + switch verb { + case 'v', 's': + switch arg.(type) { + case int: + return true + } + } + return false +} + type Thing struct { name string } diff --git a/testdata/interface.txt b/testdata/interface.txt index c9bce6be..3f398974 100644 --- a/testdata/interface.txt +++ b/testdata/interface.txt @@ -16,3 +16,4 @@ is Tuple: 0 8 16 24 SmallPair.Print: 3 5 Stringer.String(): foo Stringer.(*Thing).String(): foo +nested switch: true