diff --git a/compiler/compiler.go b/compiler/compiler.go index 0cd9fe91..818abe93 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -62,26 +62,27 @@ type Compiler struct { coroEndFunc llvm.Value coroFreeFunc llvm.Value initFuncs []llvm.Value - deferFuncs []*ir.Function - deferInvokeFuncs []InvokeDeferFunction - ctxDeferFuncs []ContextDeferFunction interfaceInvokeWrappers []interfaceInvokeWrapper ir *ir.Program } type Frame struct { - fn *ir.Function - locals map[ssa.Value]llvm.Value // local variables - blockEntries map[*ssa.BasicBlock]llvm.BasicBlock // a *ssa.BasicBlock may be split up - blockExits map[*ssa.BasicBlock]llvm.BasicBlock // these are the exit blocks - currentBlock *ssa.BasicBlock - phis []Phi - blocking bool - taskHandle llvm.Value - cleanupBlock llvm.BasicBlock - suspendBlock llvm.BasicBlock - deferPtr llvm.Value - difunc llvm.Metadata + fn *ir.Function + locals map[ssa.Value]llvm.Value // local variables + blockEntries map[*ssa.BasicBlock]llvm.BasicBlock // a *ssa.BasicBlock may be split up + blockExits map[*ssa.BasicBlock]llvm.BasicBlock // these are the exit blocks + currentBlock *ssa.BasicBlock + phis []Phi + blocking bool + taskHandle llvm.Value + cleanupBlock llvm.BasicBlock + suspendBlock llvm.BasicBlock + deferPtr llvm.Value + difunc llvm.Metadata + allDeferFuncs []interface{} + deferFuncs map[*ir.Function]int + deferInvokeFuncs map[string]int + deferClosureFuncs map[*ir.Function]int } type Phi struct { @@ -342,12 +343,6 @@ func (c *Compiler) Compile(mainPath string) error { } } - // Create thunks for deferred functions. - err = c.finalizeDefers() - if err != nil { - return err - } - // Define the already declared functions that wrap methods for use in // interfaces. for _, state := range c.interfaceInvokeWrappers { diff --git a/compiler/defer.go b/compiler/defer.go index 44fe9045..db8f5769 100644 --- a/compiler/defer.go +++ b/compiler/defer.go @@ -1,30 +1,35 @@ package compiler -// This file implements the 'defer' keyword in Go. See src/runtime/defer.go for -// details. +// This file implements the 'defer' keyword in Go. +// Defer statements are implemented by transforming the function in the +// following way: +// * Creating an alloca in the entry block that contains a pointer (initially +// null) to the linked list of defer frames. +// * Every time a defer statement is executed, a new defer frame is created +// using alloca with a pointer to the previous defer frame, and the head +// pointer in the entry block is replaced with a pointer to this defer +// frame. +// * On return, runtime.rundefers is called which calls all deferred functions +// from the head of the linked list until it has gone through all defer +// frames. import ( "go/types" "github.com/aykevl/go-llvm" + "github.com/aykevl/tinygo/ir" "golang.org/x/tools/go/ssa" ) -// A thunk for a defer that defers calling a function pointer with context. -type ContextDeferFunction struct { - fn llvm.Value - deferStruct []llvm.Type - signature *types.Signature -} - -// A thunk for a defer that defers calling an interface method. -type InvokeDeferFunction struct { - method *types.Func - valueTypes []llvm.Type -} - -// deferInitFunc sets up this function for future deferred calls. +// deferInitFunc sets up this function for future deferred calls. It must be +// called from within the entry block when this function contains deferred +// calls. func (c *Compiler) deferInitFunc(frame *Frame) { + // Some setup. + frame.deferFuncs = make(map[*ir.Function]int) + frame.deferInvokeFuncs = make(map[string]int) + frame.deferClosureFuncs = make(map[*ir.Function]int) + // Create defer list pointer. deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0) frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr") @@ -38,57 +43,50 @@ func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error { // make a linked list. next := c.builder.CreateLoad(frame.deferPtr, "defer.next") - deferFuncType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{next.Type()}, false) - var values []llvm.Value - var valueTypes []llvm.Type + valueTypes := []llvm.Type{c.uintptrType, next.Type()} if instr.Call.IsInvoke() { - // Function call on an interface. - fnPtr, args, err := c.getInvokeCall(frame, &instr.Call) + // Method call on an interface. + + // Get callback type number. + methodName := instr.Call.Method.FullName() + if _, ok := frame.deferInvokeFuncs[methodName]; !ok { + frame.deferInvokeFuncs[methodName] = len(frame.allDeferFuncs) + frame.allDeferFuncs = append(frame.allDeferFuncs, &instr.Call) + } + callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferInvokeFuncs[methodName]), false) + + // Collect all values to be put in the struct (starting with + // runtime._defer fields, followed by the call parameters). + itf, err := c.parseExpr(frame, instr.Call.Value) // interface if err != nil { return err } - - valueTypes = []llvm.Type{llvm.PointerType(deferFuncType, 0), next.Type(), fnPtr.Type()} - for _, param := range args { - valueTypes = append(valueTypes, param.Type()) - } - - // Create a thunk. - deferName := instr.Call.Method.FullName() + "$defer" - callback := c.mod.NamedFunction(deferName) - if callback.IsNil() { - // Not found, have to add it. - callback = llvm.AddFunction(c.mod, deferName, deferFuncType) - thunk := InvokeDeferFunction{ - method: instr.Call.Method, - valueTypes: valueTypes, + receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver") + values = []llvm.Value{callback, next, receiverValue} + valueTypes = append(valueTypes, c.i8ptrType) + for _, arg := range instr.Call.Args { + val, err := c.parseExpr(frame, arg) + if err != nil { + return err } - c.deferInvokeFuncs = append(c.deferInvokeFuncs, thunk) + values = append(values, val) + valueTypes = append(valueTypes, val.Type()) } - // Collect all values to be put in the struct (starting with - // runtime._defer fields, followed by the function pointer to be - // called). - values = append([]llvm.Value{callback, next, fnPtr}, args...) - } else if callee, ok := instr.Call.Value.(*ssa.Function); ok { // Regular function call. fn := c.ir.GetFunction(callee) - // Try to find the wrapper $defer function. - deferName := fn.LinkName() + "$defer" - callback := c.mod.NamedFunction(deferName) - if callback.IsNil() { - // Not found, have to add it. - callback = llvm.AddFunction(c.mod, deferName, deferFuncType) - c.deferFuncs = append(c.deferFuncs, fn) + if _, ok := frame.deferFuncs[fn]; !ok { + frame.deferFuncs[fn] = len(frame.allDeferFuncs) + frame.allDeferFuncs = append(frame.allDeferFuncs, fn) } + callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferFuncs[fn]), false) // Collect all values to be put in the struct (starting with // runtime._defer fields). values = []llvm.Value{callback, next} - valueTypes = []llvm.Type{callback.Type(), next.Type()} for _, param := range instr.Call.Args { llvmParam, err := c.parseExpr(frame, param) if err != nil { @@ -100,19 +98,29 @@ func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error { } else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok { // Immediately applied function literal with free variables. + + // Extract the context from the closure. We won't need the function + // pointer. + // TODO: ignore this closure entirely and put pointers to the free + // variables directly in the defer struct, avoiding a memory allocation. closure, err := c.parseExpr(frame, instr.Call.Value) if err != nil { return err } + context := c.builder.CreateExtractValue(closure, 0, "") - // Hopefully, LLVM will merge equivalent functions. - deferName := frame.fn.LinkName() + "$fpdefer" - callback := llvm.AddFunction(c.mod, deferName, deferFuncType) + // Get the callback number. + fn := c.ir.GetFunction(makeClosure.Fn.(*ssa.Function)) + if _, ok := frame.deferClosureFuncs[fn]; !ok { + frame.deferClosureFuncs[fn] = len(frame.allDeferFuncs) + frame.allDeferFuncs = append(frame.allDeferFuncs, makeClosure) + } + callback := llvm.ConstInt(c.uintptrType, uint64(frame.deferClosureFuncs[fn]), false) // Collect all values to be put in the struct (starting with - // runtime._defer fields, followed by the closure). - values = []llvm.Value{callback, next, closure} - valueTypes = []llvm.Type{callback.Type(), next.Type(), closure.Type()} + // runtime._defer fields, followed by all parameters including the + // context pointer). + values = []llvm.Value{callback, next} for _, param := range instr.Call.Args { llvmParam, err := c.parseExpr(frame, param) if err != nil { @@ -121,13 +129,8 @@ func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error { values = append(values, llvmParam) valueTypes = append(valueTypes, llvmParam.Type()) } - - thunk := ContextDeferFunction{ - callback, - valueTypes, - makeClosure.Fn.(*ssa.Function).Signature, - } - c.ctxDeferFuncs = append(c.ctxDeferFuncs, thunk) + values = append(values, context) + valueTypes = append(valueTypes, context.Type()) } else { return c.makeError(instr.Pos(), "todo: defer on uncommon function call type") @@ -155,147 +158,172 @@ func (c *Compiler) emitDefer(frame *Frame, instr *ssa.Defer) error { // emitRunDefers emits code to run all deferred functions. func (c *Compiler) emitRunDefers(frame *Frame) error { + // Add a loop like the following: + // for stack != nil { + // _stack := stack + // stack = stack.next + // switch _stack.callback { + // case 0: + // // run first deferred call + // case 1: + // // run second deferred call + // // etc. + // default: + // unreachable + // } + // } + + // Create loop. + loophead := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loophead") + loop := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loop") + unreachable := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.default") + end := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.end") + c.builder.CreateBr(loophead) + + // Create loop head: + // for stack != nil { + c.builder.SetInsertPointAtEnd(loophead) deferData := c.builder.CreateLoad(frame.deferPtr, "") - c.createRuntimeCall("rundefers", []llvm.Value{deferData}, "") - return nil -} + stackIsNil := c.builder.CreateICmp(llvm.IntEQ, deferData, llvm.ConstPointerNull(deferData.Type()), "stackIsNil") + c.builder.CreateCondBr(stackIsNil, end, loop) -// finalizeDefers creates thunks for deferred functions. -func (c *Compiler) finalizeDefers() error { - // Create deferred function wrappers. - for _, fn := range c.deferFuncs { - // This function gets a single parameter which is a pointer to a struct - // (the defer frame). - // This struct starts with the values of runtime._defer, but after that - // follow the real function parameters. - // The job of this wrapper is to extract these parameters and to call - // the real function with them. - llvmFn := c.mod.NamedFunction(fn.LinkName() + "$defer") - llvmFn.SetLinkage(llvm.InternalLinkage) - llvmFn.SetUnnamedAddr(true) - entry := c.ctx.AddBasicBlock(llvmFn, "entry") - c.builder.SetInsertPointAtEnd(entry) - deferRawPtr := llvmFn.Param(0) + // Create loop body: + // _stack := stack + // stack = stack.next + // switch stack.callback { + c.builder.SetInsertPointAtEnd(loop) + nextStackGEP := c.builder.CreateGEP(deferData, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 1, false), // .next field + }, "stack.next.gep") + nextStack := c.builder.CreateLoad(nextStackGEP, "stack.next") + c.builder.CreateStore(nextStack, frame.deferPtr) + gep := c.builder.CreateGEP(deferData, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), // .callback field + }, "callback.gep") + callback := c.builder.CreateLoad(gep, "callback") + sw := c.builder.CreateSwitch(callback, unreachable, len(frame.allDeferFuncs)) - // Get the real param type and cast to it. - valueTypes := []llvm.Type{llvmFn.Type(), llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)} - for _, param := range fn.Params { - llvmType, err := c.getLLVMType(param.Type()) + for i, callback := range frame.allDeferFuncs { + // Create switch case, for example: + // case 0: + // // run first deferred call + block := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.callback") + sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(i), false), block) + c.builder.SetInsertPointAtEnd(block) + switch callback := callback.(type) { + case *ssa.CallCommon: + // Call on an interface value. + if !callback.IsInvoke() { + panic("expected an invoke call, not a direct call") + } + + // Get the real defer struct type and cast to it. + valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0), c.i8ptrType} + for _, arg := range callback.Args { + llvmType, err := c.getLLVMType(arg.Type()) + if err != nil { + return err + } + valueTypes = append(valueTypes, llvmType) + } + deferFrameType := c.ctx.StructType(valueTypes, false) + deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame") + + // Extract the params from the struct (including receiver). + forwardParams := []llvm.Value{} + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + for i := 2; i < len(valueTypes); i++ { + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "gep") + forwardParam := c.builder.CreateLoad(gep, "param") + forwardParams = append(forwardParams, forwardParam) + } + + if c.ir.SignatureNeedsContext(callback.Method.Type().(*types.Signature)) { + // This function takes an extra context parameter. An interface call + // cannot also be a closure but we have to supply the parameter + // anyway for platforms with a strict calling convention. + forwardParams = append(forwardParams, llvm.Undef(c.i8ptrType)) + } + + fnPtr, _, err := c.getInvokeCall(frame, callback) if err != nil { return err } - valueTypes = append(valueTypes, llvmType) - } - deferFrameType := c.ctx.StructType(valueTypes, false) - deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "deferFrame") + c.createCall(fnPtr, forwardParams, "") - // Extract the params from the struct. - forwardParams := []llvm.Value{} - zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) - for i := range fn.Params { - gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep") - forwardParam := c.builder.CreateLoad(gep, "param") - forwardParams = append(forwardParams, forwardParam) + case *ir.Function: + // Direct call. + + // Get the real defer struct type and cast to it. + valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)} + for _, param := range callback.Params { + llvmType, err := c.getLLVMType(param.Type()) + if err != nil { + return err + } + valueTypes = append(valueTypes, llvmType) + } + deferFrameType := c.ctx.StructType(valueTypes, false) + deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame") + + // Extract the params from the struct. + forwardParams := []llvm.Value{} + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + for i := range callback.Params { + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep") + forwardParam := c.builder.CreateLoad(gep, "param") + forwardParams = append(forwardParams, forwardParam) + } + + // Call real function. + c.createCall(callback.LLVMFn, forwardParams, "") + + case *ssa.MakeClosure: + // Get the real defer struct type and cast to it. + fn := c.ir.GetFunction(callback.Fn.(*ssa.Function)) + valueTypes := []llvm.Type{c.uintptrType, llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)} + params := fn.Signature.Params() + for i := 0; i < params.Len(); i++ { + llvmType, err := c.getLLVMType(params.At(i).Type()) + if err != nil { + return err + } + valueTypes = append(valueTypes, llvmType) + } + valueTypes = append(valueTypes, c.i8ptrType) // closure + deferFrameType := c.ctx.StructType(valueTypes, false) + deferFramePtr := c.builder.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame") + + // Extract the params from the struct. + forwardParams := []llvm.Value{} + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + for i := 2; i < len(valueTypes); i++ { + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "") + forwardParam := c.builder.CreateLoad(gep, "param") + forwardParams = append(forwardParams, forwardParam) + } + + // Call deferred function. + c.createCall(fn.LLVMFn, forwardParams, "") + + default: + panic("unknown deferred function type") } - // Call real function (of which this is a wrapper). - c.createCall(fn.LLVMFn, forwardParams, "") - c.builder.CreateRetVoid() + // Branch back to the start of the loop. + c.builder.CreateBr(loophead) } - // Create wrapper for deferred interface call. - for _, thunk := range c.deferInvokeFuncs { - // This function gets a single parameter which is a pointer to a struct - // (the defer frame). - // This struct starts with the values of runtime._defer, but after that - // follow the real function parameters. - // The job of this wrapper is to extract these parameters and to call - // the real function with them. - llvmFn := c.mod.NamedFunction(thunk.method.FullName() + "$defer") - llvmFn.SetLinkage(llvm.InternalLinkage) - llvmFn.SetUnnamedAddr(true) - entry := c.ctx.AddBasicBlock(llvmFn, "entry") - c.builder.SetInsertPointAtEnd(entry) - deferRawPtr := llvmFn.Param(0) - - // Get the real param type and cast to it. - deferFrameType := c.ctx.StructType(thunk.valueTypes, false) - deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "deferFrame") - - // Extract the params from the struct. - forwardParams := []llvm.Value{} - zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) - for i := range thunk.valueTypes[3:] { - gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+3), false)}, "gep") - forwardParam := c.builder.CreateLoad(gep, "param") - forwardParams = append(forwardParams, forwardParam) - } - - // Call real function (of which this is a wrapper). - fnGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), 2, false)}, "fn.gep") - fn := c.builder.CreateLoad(fnGEP, "fn") - c.createCall(fn, forwardParams, "") - c.builder.CreateRetVoid() - } - - // Create wrapper for deferred function pointer call. - for _, thunk := range c.ctxDeferFuncs { - // This function gets a single parameter which is a pointer to a struct - // (the defer frame). - // This struct starts with the values of runtime._defer, but after that - // follows the closure and then the real parameters. - // The job of this wrapper is to extract this closure and these - // parameters and to call the function pointer with them. - llvmFn := thunk.fn - llvmFn.SetLinkage(llvm.InternalLinkage) - llvmFn.SetUnnamedAddr(true) - entry := c.ctx.AddBasicBlock(llvmFn, "entry") - // TODO: set the debug location - perhaps the location of the rundefers - // call? - c.builder.SetInsertPointAtEnd(entry) - deferRawPtr := llvmFn.Param(0) - - // Get the real param type and cast to it. - deferFrameType := c.ctx.StructType(thunk.deferStruct, false) - deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "defer.frame") - - // Extract the params from the struct. - forwardParams := []llvm.Value{} - zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) - for i := 3; i < len(thunk.deferStruct); i++ { - gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "") - forwardParam := c.builder.CreateLoad(gep, "param") - forwardParams = append(forwardParams, forwardParam) - } - - // Extract the closure from the struct. - fpGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{ - zero, - llvm.ConstInt(c.ctx.Int32Type(), 2, false), - llvm.ConstInt(c.ctx.Int32Type(), 1, false), - }, "closure.fp.ptr") - fp := c.builder.CreateLoad(fpGEP, "closure.fp") - contextGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{ - zero, - llvm.ConstInt(c.ctx.Int32Type(), 2, false), - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - }, "closure.context.ptr") - context := c.builder.CreateLoad(contextGEP, "closure.context") - forwardParams = append(forwardParams, context) - - // Cast the function pointer in the closure to the correct function - // pointer type. - closureType, err := c.getLLVMType(thunk.signature) - if err != nil { - return err - } - fpType := closureType.StructElementTypes()[1] - fpCast := c.builder.CreateBitCast(fp, fpType, "closure.fp.cast") - - // Call real function (of which this is a wrapper). - c.createCall(fpCast, forwardParams, "") - c.builder.CreateRetVoid() - } + // Create default unreachable block: + // default: + // unreachable + // } + c.builder.SetInsertPointAtEnd(unreachable) + c.builder.CreateUnreachable() + // End of loop. + c.builder.SetInsertPointAtEnd(end) return nil } diff --git a/src/runtime/defer.go b/src/runtime/defer.go index 3d8cafbf..941cafd6 100644 --- a/src/runtime/defer.go +++ b/src/runtime/defer.go @@ -1,29 +1,9 @@ package runtime -// Defer statements are implemented by transforming the function in the -// following way: -// * Creating an alloca in the entry block that contains a pointer (initially -// null) to the linked list of defer frames. -// * Every time a defer statement is executed, a new defer frame is created -// using alloca with a pointer to the previous defer frame, and the head -// pointer in the entry block is replaced with a pointer to this defer -// frame. -// * On return, runtime.rundefers is called which calls all deferred functions -// from the head of the linked list until it has gone through all defer -// frames. - -import "unsafe" - -type deferContext unsafe.Pointer +// Some helper types for the defer statement. +// See compiler/defer.go for details. type _defer struct { - callback func(*_defer) + callback uintptr // callback number next *_defer } - -func rundefers(stack *_defer) { - for stack != nil { - stack.callback(stack) - stack = stack.next - } -}