diff --git a/compiler/goroutine-lowering.go b/compiler/goroutine-lowering.go index 7e7be7ad..7ebb194f 100644 --- a/compiler/goroutine-lowering.go +++ b/compiler/goroutine-lowering.go @@ -213,6 +213,10 @@ func (c *Compiler) lowerCoroutines() error { // still exported. Make sure it is optimized away. go_scheduler.SetLinkage(llvm.InternalLinkage) } + } else { + // Eliminate unnecessary fake coroutines. + // This is necessary to prevent infinite recursion in runtime.getFakeCoroutine. + c.eliminateFakeCoroutines() } // main.main was set to external linkage during IR construction. Set it to @@ -917,7 +921,7 @@ func (c *Compiler) lowerMakeGoroutineCalls(sched bool) error { params = append(params, realCall.Operand(i)) } c.builder.SetInsertPointBefore(realCall) - if (!sched) || goroutine.InstructionParent().Parent() == c.mod.NamedFunction("runtime.getFakeCoroutine") { + if !sched { params[len(params)-1] = llvm.Undef(c.i8ptrType) } else { params[len(params)-1] = c.createRuntimeCall("getFakeCoroutine", []llvm.Value{}, "") // parent coroutine handle (must not be nil) @@ -934,3 +938,80 @@ func (c *Compiler) lowerMakeGoroutineCalls(sched bool) error { return nil } + +// internalArgumentValue finds the LLVM value inside the function which corresponds to the provided argument of the provided call. +func (c *Compiler) internalArgumentValue(call llvm.Value, arg llvm.Value) llvm.Value { + n := call.OperandsCount() + for i := 0; i < n; i++ { + if call.Operand(i) == arg { + return call.CalledValue().Param(i) + } + } + panic("no corresponding argument") +} + +// specialCoroFuncs are functions in the runtime which accept coroutines as arguments but act as a no-op if these are nil. +// Calls to these functions do not require a fake coroutine. +var specialCoroFuncs = map[string]bool{ + "runtime.runqueuePushBack": true, + "runtime.activateTask": true, +} + +// isCoroNecessary checks if a coroutine pointer value must be non-nil for the program to function. +// This returns true if replacing a fake coroutine value with nil will result in equivalent behavior. +func (c *Compiler) isCoroNecessary(coro llvm.Value, scanned map[llvm.Value]struct{}) (necessary bool) { + // avoid infinite recursion + if _, ok := scanned[coro]; ok { + return false + } + scanned[coro] = struct{}{} + + for use := coro.FirstUse(); !use.IsNil(); use = use.NextUse() { + user := use.User() + switch { + case !user.IsACallInst().IsNil(): + switch { + case !user.CalledValue().IsConstant(): + // This is passed into an unknown function, so we do not know what is happening to it. + coroDebugPrintln("found unoptimizable dynamic call") + return true + case specialCoroFuncs[user.CalledValue().Name()]: + // Pushing nil to the runqueue is valid and acts as a no-op. + // This use does not require a non-nil coroutine. + case c.isCoroNecessary(c.internalArgumentValue(user, coro), scanned): + // The function we called depends on the coroutine value being non-nil. + coroDebugPrintln("call to function depending on non-nil coroutine") + return true + default: + // This call does not depend upon a non-nil coroutine. + } + default: + if coroDebug { + fmt.Printf("unoptimizable usage of coroutine in %q: ", user.InstructionParent().Parent().Name()) + user.Dump() + fmt.Println() + } + return true + } + } + + // Nothing we found needed this coroutine value. + return false +} + +// eliminateFakeCoroutines replaces unnecessary calls to runtime.getFakeCoroutine. +// This is not considered an optimization, because it is necessary to avoid infinite recursion inside of runtime.getFakeCoroutine. +func (c *Compiler) eliminateFakeCoroutines() { + coroDebugPrintln("eliminating fake coroutines") + for _, v := range getUses(c.mod.NamedFunction("runtime.getFakeCoroutine")) { + if !c.isCoroNecessary(v, map[llvm.Value]struct{}{}) { + // This use of a fake coroutine is not necessary. + coroDebugPrintln("eliminating fake coroutine for", getUses(v)[0].CalledValue().Name()) + v.ReplaceAllUsesWith(llvm.ConstNull(c.i8ptrType)) + v.EraseFromParentAsInstruction() + } else { + // This use of a fake coroutine is necessary. + coroDebugPrintln("failed to eliminate fake coroutine for", getUses(v)[0].CalledValue().Name()) + } + } +}