From 86ab03c9998dca53ec80dc6c2ef644da13fdbdf1 Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Fri, 18 Oct 2019 19:19:46 -0400 Subject: [PATCH] fix miscompile of static goroutine calls to closures --- compiler/compiler.go | 26 ++++++++++++++++---------- compiler/func-lowering.go | 10 +++++++++- compiler/goroutine.go | 9 ++++++--- testdata/coroutines.go | 10 ++++++++++ testdata/coroutines.txt | 1 + 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index 89925804..1787a5e7 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1080,28 +1080,34 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) { // Static callee is known. This makes it easier to start a new // goroutine. calleeFn := c.ir.GetFunction(callee) - if !calleeFn.IsExported() && c.selectScheduler() != "tasks" { - // For coroutine scheduling, this is only required when calling - // an external function. - // For tasks, because all params are stored in a single object, - // no unnecessary parameters should be stored anyway. - params = append(params, llvm.Undef(c.i8ptrType)) // context parameter - params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle + var context llvm.Value + switch value := instr.Call.Value.(type) { + case *ssa.Function: + // Goroutine call is regular function call. No context is necessary. + context = llvm.Undef(c.i8ptrType) + case *ssa.MakeClosure: + // A goroutine call on a func value, but the callee is trivial to find. For + // example: immediately applied functions. + funcValue := c.getValue(frame, value) + context = c.extractFuncContext(funcValue) + default: + panic("StaticCallee returned an unexpected value") } + params = append(params, context) // context parameter c.emitStartGoroutine(calleeFn.LLVMFn, params) } else if !instr.Call.IsInvoke() { // This is a function pointer. // At the moment, two extra params are passed to the newly started // goroutine: // * The function context, for closures. - // * The parent handle (for coroutines) or the function pointer - // itself (for tasks). + // * The function pointer (for tasks). funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature)) params = append(params, context) // context parameter switch c.selectScheduler() { case "coroutines": - params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle + // There are no additional parameters needed for the goroutine start operation. case "tasks": + // Add the function pointer as a parameter to start the goroutine. params = append(params, funcPtr) default: panic("unknown scheduler type") diff --git a/compiler/func-lowering.go b/compiler/func-lowering.go index f1af3ab3..a3123b59 100644 --- a/compiler/func-lowering.go +++ b/compiler/func-lowering.go @@ -98,6 +98,10 @@ func (c *Compiler) LowerFuncValues() { continue } for _, funcValueWithSignatureConstant := range getUses(ptrtoint) { + if !funcValueWithSignatureConstant.IsACallInst().IsNil() && funcValueWithSignatureConstant.CalledValue().Name() == "runtime.makeGoroutine" { + // makeGoroutine calls are handled seperately + continue + } for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) { for _, use := range getUses(funcValueWithSignatureGlobal) { if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { @@ -182,7 +186,11 @@ func (c *Compiler) LowerFuncValues() { panic("expected a inttoptr") } for _, use := range getUses(inttoptr) { - c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions) + c.addFuncLoweringSwitch(funcID, use, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + // The function lowering switch code passes in a parent handle value. + // Strip the parent handle off here because it is irrelevant to goroutine starts. + return c.emitStartGoroutine(funcPtr, params[:len(params)-1]) + }, functions) use.EraseFromParentAsInstruction() } inttoptr.EraseFromParentAsInstruction() diff --git a/compiler/goroutine.go b/compiler/goroutine.go index 59c5a383..23979357 100644 --- a/compiler/goroutine.go +++ b/compiler/goroutine.go @@ -7,6 +7,9 @@ import "tinygo.org/x/go-llvm" // emitStartGoroutine starts a new goroutine with the provided function pointer // and parameters. +// In general, you should pass all regular parameters plus the context parameter. +// There is one exception: the task-based scheduler needs to have the function +// pointer passed in as a parameter too in addition to the context. // // Because a go statement doesn't return anything, return undef. func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) llvm.Value { @@ -24,7 +27,7 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "") calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "") calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "") - c.createCall(calleeValue, params, "") + c.createCall(calleeValue, append(params, llvm.ConstPointerNull(c.i8ptrType)), "") default: panic("unreachable") } @@ -74,8 +77,8 @@ func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { // Create the list of params for the call. paramTypes := fn.Type().ElementType().ParamTypes() - params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2]) - params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)) + params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-1]) + params = append(params, llvm.Undef(c.i8ptrType)) // Create the call. c.builder.CreateCall(fn, params, "") diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 47d02598..62d62ab7 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -41,6 +41,16 @@ func main() { }) time.Sleep(2 * time.Millisecond) + + var x int + go func() { + time.Sleep(2 * time.Millisecond) + x = 1 + }() + time.Sleep(time.Second/2) + println("closure go call result:", x) + + time.Sleep(2 * time.Millisecond) } func sub() { diff --git a/testdata/coroutines.txt b/testdata/coroutines.txt index eea625ff..e296f8e0 100644 --- a/testdata/coroutines.txt +++ b/testdata/coroutines.txt @@ -13,3 +13,4 @@ done with non-blocking goroutine async interface method call slept inside func pointer 8 slept inside closure, with value: 20 8 +closure go call result: 1