From ecd8c2d902498c9b3c39417b9e52e8844adcb1cd Mon Sep 17 00:00:00 2001 From: Nia Waldvogel Date: Wed, 15 Sep 2021 14:08:33 -0400 Subject: [PATCH] transform (coroutines): fix memory corruption for tail calls that reference stack allocations This change fixes a bug in which `alloca` memory lifetimes would not extend past the suspend of an asynchronous tail call. This would typically manifest as memory corruption, and could happen with or without normal suspending calls within the function. --- transform/coroutines.go | 38 ++++++-- transform/testdata/coroutines.ll | 34 ++++++- transform/testdata/coroutines.out.ll | 128 +++++++++++++++++++++++++-- 3 files changed, 188 insertions(+), 12 deletions(-) diff --git a/transform/coroutines.go b/transform/coroutines.go index 2aeff43c..9a382c3a 100644 --- a/transform/coroutines.go +++ b/transform/coroutines.go @@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() { continue } - if len(fn.normalCalls) == 0 { - // No suspend points. Lower without turning it into a coroutine. + if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() { + // No suspend points or stack allocations. Lower without turning it into a coroutine. c.lowerFuncFast(fn) } else { - // There are suspend points, so it is necessary to turn this into a coroutine. + // There are suspend points or stack allocations, so it is necessary to turn this into a coroutine. c.lowerFuncCoro(fn) } } @@ -827,6 +827,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { } // Lower returns. + var postTail llvm.BasicBlock for _, ret := range fn.returns { // Get terminator instruction. terminator := ret.block.LastInstruction() @@ -886,10 +887,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { call.EraseFromParentAsInstruction() } - // Replace terminator with branch to cleanup. + // Replace terminator with a branch to the exit. + var exit llvm.BasicBlock + if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() { + // Exit through the cleanup path. + exit = cleanup + } else { + if postTail.IsNil() { + // Create a path with a suspend that never reawakens. + postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail") + c.builder.SetInsertPointAtEnd(postTail) + // %coro.save = call token @llvm.coro.save(i8* %coro.state) + save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save") + // %call.suspend = llvm.coro.suspend(token %coro.save, i1 false) + // switch i8 %call.suspend, label %suspend [i8 0, label %wakeup + // i8 1, label %cleanup] + suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend") + sw := c.builder.CreateSwitch(suspendValue, suspend, 2) + unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable") + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock) + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup) + c.builder.SetInsertPointAtEnd(unreachableBlock) + c.builder.CreateUnreachable() + } + + // Exit through a permanent suspend. + exit = postTail + } + terminator.EraseFromParentAsInstruction() c.builder.SetInsertPointAtEnd(ret.block) - c.builder.CreateBr(cleanup) + c.builder.CreateBr(exit) } // Lower regular calls. diff --git a/transform/testdata/coroutines.ll b/transform/testdata/coroutines.ll index b51abb5e..94462bbc 100644 --- a/transform/testdata/coroutines.ll +++ b/transform/testdata/coroutines.ll @@ -86,11 +86,43 @@ entry: } ; Normal function which should not be transformed. -define void @doNothing(i8*, i8*) { +define void @doNothing(i8*, i8* %parentHandle) { entry: ret void } +; Regression test: ensure that a tail call does not destroy the frame while it is still in use. +; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr. +; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack. +; After a suspend %a would be used, resulting in memory corruption. +define i8 @coroutineTailRegression(i8*, i8* %parentHandle) { +entry: + %a = alloca i8 + store i8 5, i8* %a + %val = call i8 @usePtr(i8* %a, i8* undef, i8* null) + ret i8 %val +} + +; Regression test: ensure that stack allocations alive during a suspend end up on the heap. +; This used to not be transformed to a coroutine, keeping %a on the stack. +; After a suspend %a would be used, resulting in memory corruption. +define i8 @allocaTailRegression(i8*, i8* %parentHandle) { +entry: + %a = alloca i8 + call void @sleep(i64 1000000, i8* undef, i8* null) + store i8 5, i8* %a + %val = call i8 @usePtr(i8* %a, i8* undef, i8* null) + ret i8 %val +} + +; usePtr uses a pointer after a suspend. +define i8 @usePtr(i8*, i8*, i8* %parentHandle) { +entry: + call void @sleep(i64 1000000, i8* undef, i8* null) + %val = load i8, i8* %0 + ret i8 %val +} + ; Goroutine that sleeps and does nothing. ; Should be a void tail call. define void @sleepGoroutine(i8*, i8* %parentHandle) { diff --git a/transform/testdata/coroutines.out.ll b/transform/testdata/coroutines.out.ll index af91beac..fa6eb287 100644 --- a/transform/testdata/coroutines.out.ll +++ b/transform/testdata/coroutines.out.ll @@ -45,7 +45,7 @@ entry: %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* - store i32 %0, i32* %ret.ptr.bitcast + store i32 %0, i32* %ret.ptr.bitcast, align 4 call void @sleep(i64 %1, i8* undef, i8* %parentHandle) ret i32 undef } @@ -84,7 +84,7 @@ entry: %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* - store i32 %0, i32* %ret.ptr.bitcast + store i32 %0, i32* %ret.ptr.bitcast, align 4 %ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef) call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef) %4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle) @@ -93,7 +93,7 @@ entry: define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) { entry: - %call.return = alloca i32 + %call.return = alloca i32, align 4 %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %coro.size = call i32 @llvm.coro.size.i32() %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) @@ -116,10 +116,10 @@ entry: ] wakeup: ; preds = %entry - %4 = load i32, i32* %call.return + %4 = load i32, i32* %call.return, align 4 call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast) %5 = icmp eq i32 %4, 0 - store i1 %5, i1* %task.retPtr.bitcast + store i1 %5, i1* %task.retPtr.bitcast, align 1 call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef) br label %cleanup @@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup br label %suspend } -define void @doNothing(i8* %0, i8* %1) { +define void @doNothing(i8* %0, i8* %parentHandle) { entry: ret void } +define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) { +entry: + %a = alloca i8, align 1 + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + store i8 5, i8* %a, align 1 + %coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef) + %val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle) + br label %post.tail + +suspend: ; preds = %post.tail, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %post.tail + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend + +post.tail: ; preds = %entry + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %unreachable + i8 1, label %cleanup + ] + +unreachable: ; preds = %post.tail + unreachable +} + +define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) { +entry: + %a = alloca i8, align 1 + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + call void @sleep(i64 1000000, i8* undef, i8* %parentHandle) + %coro.save1 = call token @llvm.coro.save(i8* %coro.state) + %call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false) + switch i8 %call.suspend2, label %suspend [ + i8 0, label %wakeup + i8 1, label %cleanup + ] + +wakeup: ; preds = %entry + store i8 5, i8* %a, align 1 + %1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef) + %2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle) + br label %post.tail + +suspend: ; preds = %entry, %post.tail, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %entry, %post.tail + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend + +post.tail: ; preds = %wakeup + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %unreachable + i8 1, label %cleanup + ] + +unreachable: ; preds = %post.tail + unreachable +} + +define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) { +entry: + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + call void @sleep(i64 1000000, i8* undef, i8* %parentHandle) + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %wakeup + i8 1, label %cleanup + ] + +wakeup: ; preds = %entry + %2 = load i8, i8* %0, align 1 + store i8 %2, i8* %task.retPtr, align 1 + call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef) + br label %cleanup + +suspend: ; preds = %entry, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i8 undef + +cleanup: ; preds = %entry, %wakeup + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend +} + define void @sleepGoroutine(i8* %0, i8* %parentHandle) { %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)