From 82be43f4e68a99159e9e2c83b317cebc6a80ee71 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Mon, 22 Oct 2018 14:06:51 +0200 Subject: [PATCH] compiler: implement deferring of immediately-applied closures This is a common operation: freevar := ... defer func() { println("I am deferred:", freevar) }() The function is thus an immediately applied closure. Only this form is currently supported, support for regular (fat) function pointers should be trivial to add but is not currently implemented as it wasn't necessary to get fmt to compile. --- compiler/calls.go | 21 +++++- compiler/compiler.go | 172 ++++++++++++++++++++++++++++++++++--------- testdata/calls.go | 61 +++++++++++++++ testdata/calls.txt | 8 ++ 4 files changed, 228 insertions(+), 34 deletions(-) create mode 100644 testdata/calls.go create mode 100644 testdata/calls.txt diff --git a/compiler/calls.go b/compiler/calls.go index 050d6f19..05cba1c3 100644 --- a/compiler/calls.go +++ b/compiler/calls.go @@ -23,9 +23,28 @@ import ( // argument list. // * Blocking functions have a coroutine pointer prepended to the argument // list, see src/runtime/scheduler.go for details. +// +// Some further notes: +// * Function pointers are lowered to either a raw function pointer or a +// closure struct: { i8*, function pointer } +// The function pointer type depends on whether the exact same signature is +// used anywhere else in the program for a call that needs a context +// (closures, bound methods). If it isn't, it is lowered to a raw function +// pointer. +// * Defer statements are implemented by transforming the function in the +// following way: +// * Creating an alloca in the entry block that contains a pointer +// (initially null) to the linked list of defer frames. +// * Every time a defer statement is executed, a new defer frame is +// created using alloca with a pointer to the previous defer frame, and +// the head pointer in the entry block is replaced with a pointer to +// this defer frame. +// * On return, runtime.rundefers is called which calls all deferred +// functions from the head of the linked list until it has gone through +// all defer frames. // The maximum number of arguments that can be expanded from a single struct. If -// a struct contains more fields, it is passed as value. +// a struct contains more fields, it is passed as a struct without expanding. const MaxFieldsPerParam = 3 // Shortcut: create a call to runtime. with the given arguments. diff --git a/compiler/compiler.go b/compiler/compiler.go index 347ff1bf..7e4319ae 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -62,6 +62,7 @@ type Compiler struct { coroFreeFunc llvm.Value initFuncs []llvm.Value deferFuncs []*ir.Function + ctxDeferFuncs []ContextDeferFunction ir *ir.Program } @@ -84,6 +85,13 @@ type Phi struct { llvm llvm.Value } +// A thunk for a defer that defers calling a function pointer with context. +type ContextDeferFunction struct { + fn llvm.Value + deferStruct []llvm.Type + signature *types.Signature +} + func NewCompiler(pkgName string, config Config) (*Compiler, error) { if config.Triple == "" { config.Triple = llvm.DefaultTargetTriple() @@ -325,12 +333,14 @@ func (c *Compiler) Compile(mainPath string) error { // Create deferred function wrappers. for _, fn := range c.deferFuncs { - // This function gets a single parameter which is a pointer to a struct. + // This function gets a single parameter which is a pointer to a struct + // (the defer frame). // This struct starts with the values of runtime._defer, but after that - // follow the real parameters. + // follow the real function parameters. // The job of this wrapper is to extract these parameters and to call // the real function with them. llvmFn := c.mod.NamedFunction(fn.LinkName() + "$defer") + llvmFn.SetLinkage(llvm.InternalLinkage) entry := c.ctx.AddBasicBlock(llvmFn, "entry") c.builder.SetInsertPointAtEnd(entry) deferRawPtr := llvmFn.Param(0) @@ -344,14 +354,14 @@ func (c *Compiler) Compile(mainPath string) error { } valueTypes = append(valueTypes, llvmType) } - contextType := c.ctx.StructType(valueTypes, false) - contextPtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(contextType, 0), "context") + deferFrameType := c.ctx.StructType(valueTypes, false) + deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "deferFrame") // Extract the params from the struct. forwardParams := []llvm.Value{} zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) for i := range fn.Params { - gep := c.builder.CreateGEP(contextPtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep") + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+2), false)}, "gep") forwardParam := c.builder.CreateLoad(gep, "param") forwardParams = append(forwardParams, forwardParam) } @@ -361,6 +371,64 @@ func (c *Compiler) Compile(mainPath string) error { c.builder.CreateRetVoid() } + // Create wrapper for deferred function pointer call. + for _, thunk := range c.ctxDeferFuncs { + // This function gets a single parameter which is a pointer to a struct + // (the defer frame). + // This struct starts with the values of runtime._defer, but after that + // follows the closure and then the real parameters. + // The job of this wrapper is to extract this closure and these + // parameters and to call the function pointer with them. + llvmFn := thunk.fn + llvmFn.SetLinkage(llvm.InternalLinkage) + entry := c.ctx.AddBasicBlock(llvmFn, "entry") + // TODO: set the debug location - perhaps the location of the rundefers + // call? + c.builder.SetInsertPointAtEnd(entry) + deferRawPtr := llvmFn.Param(0) + + // Get the real param type and cast to it. + deferFrameType := c.ctx.StructType(thunk.deferStruct, false) + deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "defer.frame") + + // Extract the params from the struct. + forwardParams := []llvm.Value{} + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + for i := 3; i < len(thunk.deferStruct); i++ { + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false)}, "") + forwardParam := c.builder.CreateLoad(gep, "param") + forwardParams = append(forwardParams, forwardParam) + } + + // Extract the closure from the struct. + fpGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{ + zero, + llvm.ConstInt(c.ctx.Int32Type(), 2, false), + llvm.ConstInt(c.ctx.Int32Type(), 1, false), + }, "closure.fp.ptr") + fp := c.builder.CreateLoad(fpGEP, "closure.fp") + contextGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{ + zero, + llvm.ConstInt(c.ctx.Int32Type(), 2, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }, "closure.context.ptr") + context := c.builder.CreateLoad(contextGEP, "closure.context") + forwardParams = append(forwardParams, context) + + // Cast the function pointer in the closure to the correct function + // pointer type. + closureType, err := c.getLLVMType(thunk.signature) + if err != nil { + return err + } + fpType := closureType.StructElementTypes()[1] + fpCast := c.builder.CreateBitCast(fp, fpType, "closure.fp.cast") + + // Call real function (of which this is a wrapper). + c.createCall(fpCast, forwardParams, "") + c.builder.CreateRetVoid() + } + // After all packages are imported, add a synthetic initializer function // that calls the initializer of each package. initFn := c.ir.GetFunction(c.ir.Program.ImportedPackage("runtime").Members["initAll"].(*ssa.Function)) @@ -1327,6 +1395,7 @@ func (c *Compiler) parseFunc(frame *Frame) error { panic("free variables on function without context") } context := frame.fn.LLVMFn.LastParam() + context.SetName("context") // Determine the context type. It's a struct containing all variables. freeVarTypes := make([]llvm.Type, 0, len(frame.fn.FreeVars)) @@ -1469,51 +1538,86 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { case *ssa.DebugRef: return nil // ignore case *ssa.Defer: - if _, ok := instr.Call.Value.(*ssa.Function); !ok || instr.Call.IsInvoke() { - return errors.New("todo: non-direct function calls in defer") - } - fn := c.ir.GetFunction(instr.Call.Value.(*ssa.Function)) - // The pointer to the previous defer struct, which we will replace to // make a linked list. next := c.builder.CreateLoad(frame.deferPtr, "defer.next") - // Try to find the wrapper $defer function. - deferName := fn.LinkName() + "$defer" - callback := c.mod.NamedFunction(deferName) - if callback.IsNil() { - // Not found, have to add it. - deferFuncType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{next.Type()}, false) - callback = llvm.AddFunction(c.mod, deferName, deferFuncType) - c.deferFuncs = append(c.deferFuncs, fn) - } + deferFuncType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{next.Type()}, false) - // Collect all values to be put in the struct (starting with - // runtime._defer fields). - values := []llvm.Value{callback, next} - valueTypes := []llvm.Type{callback.Type(), next.Type()} - for _, param := range instr.Call.Args { - llvmParam, err := c.parseExpr(frame, param) + var values []llvm.Value + var valueTypes []llvm.Type + if callee, ok := instr.Call.Value.(*ssa.Function); ok && !instr.Call.IsInvoke() { + // Regular function call. + fn := c.ir.GetFunction(callee) + + // Try to find the wrapper $defer function. + deferName := fn.LinkName() + "$defer" + callback := c.mod.NamedFunction(deferName) + if callback.IsNil() { + // Not found, have to add it. + callback = llvm.AddFunction(c.mod, deferName, deferFuncType) + c.deferFuncs = append(c.deferFuncs, fn) + } + + // Collect all values to be put in the struct (starting with + // runtime._defer fields). + values = []llvm.Value{callback, next} + valueTypes = []llvm.Type{callback.Type(), next.Type()} + for _, param := range instr.Call.Args { + llvmParam, err := c.parseExpr(frame, param) + if err != nil { + return err + } + values = append(values, llvmParam) + valueTypes = append(valueTypes, llvmParam.Type()) + } + } else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok { + // Immediately applied function literal with free variables. + closure, err := c.parseExpr(frame, instr.Call.Value) if err != nil { return err } - values = append(values, llvmParam) - valueTypes = append(valueTypes, llvmParam.Type()) + + // Hopefully, LLVM will merge equivalent functions. + deferName := frame.fn.LinkName() + "$fpdefer" + callback := llvm.AddFunction(c.mod, deferName, deferFuncType) + + // Collect all values to be put in the struct (starting with + // runtime._defer fields, followed by the closure). + values = []llvm.Value{callback, next, closure} + valueTypes = []llvm.Type{callback.Type(), next.Type(), closure.Type()} + for _, param := range instr.Call.Args { + llvmParam, err := c.parseExpr(frame, param) + if err != nil { + return err + } + values = append(values, llvmParam) + valueTypes = append(valueTypes, llvmParam.Type()) + } + + thunk := ContextDeferFunction{ + callback, + valueTypes, + makeClosure.Fn.(*ssa.Function).Signature, + } + c.ctxDeferFuncs = append(c.ctxDeferFuncs, thunk) + } else { + return errors.New("todo: defer on uncommon function call type") } - // Make a struct out of it. - contextType := c.ctx.StructType(valueTypes, false) - context, err := c.getZeroValue(contextType) + // Make a struct out of the collected values to put in the defer frame. + deferFrameType := c.ctx.StructType(valueTypes, false) + deferFrame, err := c.getZeroValue(deferFrameType) if err != nil { return err } for i, value := range values { - context = c.builder.CreateInsertValue(context, value, i, "") + deferFrame = c.builder.CreateInsertValue(deferFrame, value, i, "") } // Put this struct in an alloca. - alloca := c.builder.CreateAlloca(contextType, "defer.alloca") - c.builder.CreateStore(context, alloca) + alloca := c.builder.CreateAlloca(deferFrameType, "defer.alloca") + c.builder.CreateStore(deferFrame, alloca) // Push it on top of the linked list by replacing deferPtr. allocaCast := c.builder.CreateBitCast(alloca, next.Type(), "defer.alloca.cast") @@ -2354,6 +2458,8 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { } case *ssa.MakeClosure: + // A closure returns a function pointer with context: + // {context, fp} return c.parseMakeClosure(frame, expr) case *ssa.MakeInterface: diff --git a/testdata/calls.go b/testdata/calls.go new file mode 100644 index 00000000..93982fb8 --- /dev/null +++ b/testdata/calls.go @@ -0,0 +1,61 @@ +package main + +type Thing struct { + name string +} + +func (t Thing) String() string { + return t.name +} + +func main() { + thing := &Thing{"foo"} + + // function pointers + runFunc(hello, 5) // must be indirect to avoid obvious inlining + + // deferred functions + testDefer() + + // Take a bound method and use it as a function pointer. + // This function pointer needs a context pointer. + testBound(thing.String) + + // closures + func() { + println("thing inside closure:", thing.String()) + }() + runFunc(func(i int) { + println("inside fp closure:", thing.String(), i) + }, 3) +} + +func runFunc(f func(int), arg int) { + f(arg) +} + +func hello(n int) { + println("hello from function pointer:", n) +} + +func testDefer() { + i := 1 + defer deferred("...run as defer", i) + i++ + defer func() { + println("...run closure deferred:", i) + }() + i++ + defer deferred("...run as defer", i) + i++ + + println("deferring...") +} + +func deferred(msg string, i int) { + println(msg, i) +} + +func testBound(f func() string) { + println("bound method:", f()) +} diff --git a/testdata/calls.txt b/testdata/calls.txt new file mode 100644 index 00000000..902b5b4c --- /dev/null +++ b/testdata/calls.txt @@ -0,0 +1,8 @@ +hello from function pointer: 5 +deferring... +...run as defer 3 +...run closure deferred: 4 +...run as defer 1 +bound method: foo +thing inside closure: foo +inside fp closure: foo 3