diff --git a/transform/coroutines.go b/transform/coroutines.go index 2aeff43c..9a382c3a 100644 --- a/transform/coroutines.go +++ b/transform/coroutines.go @@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() { continue } - if len(fn.normalCalls) == 0 { - // No suspend points. Lower without turning it into a coroutine. + if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() { + // No suspend points or stack allocations. Lower without turning it into a coroutine. c.lowerFuncFast(fn) } else { - // There are suspend points, so it is necessary to turn this into a coroutine. + // There are suspend points or stack allocations, so it is necessary to turn this into a coroutine. c.lowerFuncCoro(fn) } } @@ -827,6 +827,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { } // Lower returns. + var postTail llvm.BasicBlock for _, ret := range fn.returns { // Get terminator instruction. terminator := ret.block.LastInstruction() @@ -886,10 +887,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { call.EraseFromParentAsInstruction() } - // Replace terminator with branch to cleanup. + // Replace terminator with a branch to the exit. + var exit llvm.BasicBlock + if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() { + // Exit through the cleanup path. + exit = cleanup + } else { + if postTail.IsNil() { + // Create a path with a suspend that never reawakens. + postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail") + c.builder.SetInsertPointAtEnd(postTail) + // %coro.save = call token @llvm.coro.save(i8* %coro.state) + save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save") + // %call.suspend = llvm.coro.suspend(token %coro.save, i1 false) + // switch i8 %call.suspend, label %suspend [i8 0, label %wakeup + // i8 1, label %cleanup] + suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend") + sw := c.builder.CreateSwitch(suspendValue, suspend, 2) + unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable") + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock) + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup) + c.builder.SetInsertPointAtEnd(unreachableBlock) + c.builder.CreateUnreachable() + } + + // Exit through a permanent suspend. + exit = postTail + } + terminator.EraseFromParentAsInstruction() c.builder.SetInsertPointAtEnd(ret.block) - c.builder.CreateBr(cleanup) + c.builder.CreateBr(exit) } // Lower regular calls. diff --git a/transform/testdata/coroutines.ll b/transform/testdata/coroutines.ll index b51abb5e..94462bbc 100644 --- a/transform/testdata/coroutines.ll +++ b/transform/testdata/coroutines.ll @@ -86,11 +86,43 @@ entry: } ; Normal function which should not be transformed. -define void @doNothing(i8*, i8*) { +define void @doNothing(i8*, i8* %parentHandle) { entry: ret void } +; Regression test: ensure that a tail call does not destroy the frame while it is still in use. +; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr. +; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack. +; After a suspend %a would be used, resulting in memory corruption. +define i8 @coroutineTailRegression(i8*, i8* %parentHandle) { +entry: + %a = alloca i8 + store i8 5, i8* %a + %val = call i8 @usePtr(i8* %a, i8* undef, i8* null) + ret i8 %val +} + +; Regression test: ensure that stack allocations alive during a suspend end up on the heap. +; This used to not be transformed to a coroutine, keeping %a on the stack. +; After a suspend %a would be used, resulting in memory corruption. +define i8 @allocaTailRegression(i8*, i8* %parentHandle) { +entry: + %a = alloca i8 + call void @sleep(i64 1000000, i8* undef, i8* null) + store i8 5, i8* %a + %val = call i8 @usePtr(i8* %a, i8* undef, i8* null) + ret i8 %val +} + +; usePtr uses a pointer after a suspend. +define i8 @usePtr(i8*, i8*, i8* %parentHandle) { +entry: + call void @sleep(i64 1000000, i8* undef, i8* null) + %val = load i8, i8* %0 + ret i8 %val +} + ; Goroutine that sleeps and does nothing. ; Should be a void tail call. define void @sleepGoroutine(i8*, i8* %parentHandle) { diff --git a/transform/testdata/coroutines.out.ll b/transform/testdata/coroutines.out.ll index af91beac..fa6eb287 100644 --- a/transform/testdata/coroutines.out.ll +++ b/transform/testdata/coroutines.out.ll @@ -45,7 +45,7 @@ entry: %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* - store i32 %0, i32* %ret.ptr.bitcast + store i32 %0, i32* %ret.ptr.bitcast, align 4 call void @sleep(i64 %1, i8* undef, i8* %parentHandle) ret i32 undef } @@ -84,7 +84,7 @@ entry: %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* - store i32 %0, i32* %ret.ptr.bitcast + store i32 %0, i32* %ret.ptr.bitcast, align 4 %ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef) call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef) %4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle) @@ -93,7 +93,7 @@ entry: define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) { entry: - %call.return = alloca i32 + %call.return = alloca i32, align 4 %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %coro.size = call i32 @llvm.coro.size.i32() %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) @@ -116,10 +116,10 @@ entry: ] wakeup: ; preds = %entry - %4 = load i32, i32* %call.return + %4 = load i32, i32* %call.return, align 4 call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast) %5 = icmp eq i32 %4, 0 - store i1 %5, i1* %task.retPtr.bitcast + store i1 %5, i1* %task.retPtr.bitcast, align 1 call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef) br label %cleanup @@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup br label %suspend } -define void @doNothing(i8* %0, i8* %1) { +define void @doNothing(i8* %0, i8* %parentHandle) { entry: ret void } +define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) { +entry: + %a = alloca i8, align 1 + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + store i8 5, i8* %a, align 1 + %coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef) + %val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle) + br label %post.tail + +suspend: ; preds = %post.tail, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %post.tail + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend + +post.tail: ; preds = %entry + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %unreachable + i8 1, label %cleanup + ] + +unreachable: ; preds = %post.tail + unreachable +} + +define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) { +entry: + %a = alloca i8, align 1 + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + call void @sleep(i64 1000000, i8* undef, i8* %parentHandle) + %coro.save1 = call token @llvm.coro.save(i8* %coro.state) + %call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false) + switch i8 %call.suspend2, label %suspend [ + i8 0, label %wakeup + i8 1, label %cleanup + ] + +wakeup: ; preds = %entry + store i8 5, i8* %a, align 1 + %1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef) + %2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle) + br label %post.tail + +suspend: ; preds = %entry, %post.tail, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %entry, %post.tail + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend + +post.tail: ; preds = %wakeup + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %unreachable + i8 1, label %cleanup + ] + +unreachable: ; preds = %post.tail + unreachable +} + +define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) { +entry: + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + call void @sleep(i64 1000000, i8* undef, i8* %parentHandle) + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %wakeup + i8 1, label %cleanup + ] + +wakeup: ; preds = %entry + %2 = load i8, i8* %0, align 1 + store i8 %2, i8* %task.retPtr, align 1 + call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + br label %cleanup + +suspend: ; preds = %entry, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %entry, %wakeup + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend +} + define void @sleepGoroutine(i8* %0, i8* %parentHandle) { %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)