transform (coroutines): fix memory corruption for tail calls that reference stack allocations

This change fixes a bug in which `alloca` memory lifetimes would not extend past the suspend of an asynchronous tail call.
This would typically manifest as memory corruption, and could happen with or without normal suspending calls within the function.
Этот коммит содержится в:
Nia Waldvogel 2021-09-15 14:08:33 -04:00 коммит произвёл Ayke
родитель a116fd0dc6
коммит ecd8c2d902
3 изменённых файлов: 188 добавлений и 12 удалений

Просмотреть файл

@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
continue
}
if len(fn.normalCalls) == 0 {
// No suspend points. Lower without turning it into a coroutine.
if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
// No suspend points or stack allocations. Lower without turning it into a coroutine.
c.lowerFuncFast(fn)
} else {
// There are suspend points, so it is necessary to turn this into a coroutine.
// There are suspend points or stack allocations, so it is necessary to turn this into a coroutine.
c.lowerFuncCoro(fn)
}
}
@ -827,6 +827,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
}
// Lower returns.
var postTail llvm.BasicBlock
for _, ret := range fn.returns {
// Get terminator instruction.
terminator := ret.block.LastInstruction()
@ -886,10 +887,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
call.EraseFromParentAsInstruction()
}
// Replace terminator with branch to cleanup.
// Replace terminator with a branch to the exit.
var exit llvm.BasicBlock
if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
// Exit through the cleanup path.
exit = cleanup
} else {
if postTail.IsNil() {
// Create a path with a suspend that never reawakens.
postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail")
c.builder.SetInsertPointAtEnd(postTail)
// %coro.save = call token @llvm.coro.save(i8* %coro.state)
save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save")
// %call.suspend = llvm.coro.suspend(token %coro.save, i1 false)
// switch i8 %call.suspend, label %suspend [i8 0, label %wakeup
// i8 1, label %cleanup]
suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend")
sw := c.builder.CreateSwitch(suspendValue, suspend, 2)
unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable")
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup)
c.builder.SetInsertPointAtEnd(unreachableBlock)
c.builder.CreateUnreachable()
}
// Exit through a permanent suspend.
exit = postTail
}
terminator.EraseFromParentAsInstruction()
c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup)
c.builder.CreateBr(exit)
}
// Lower regular calls.

34
transform/testdata/coroutines.ll предоставленный
Просмотреть файл

@ -86,11 +86,43 @@ entry:
}
; Normal function which should not be transformed.
define void @doNothing(i8*, i8*) {
define void @doNothing(i8*, i8* %parentHandle) {
entry:
ret void
}
; Regression test: ensure that a tail call does not destroy the frame while it is still in use.
; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr.
; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @coroutineTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}
; Regression test: ensure that stack allocations alive during a suspend end up on the heap.
; This used to not be transformed to a coroutine, keeping %a on the stack.
; After a suspend %a would be used, resulting in memory corruption.
define i8 @allocaTailRegression(i8*, i8* %parentHandle) {
entry:
%a = alloca i8
call void @sleep(i64 1000000, i8* undef, i8* null)
store i8 5, i8* %a
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
ret i8 %val
}
; usePtr uses a pointer after a suspend.
define i8 @usePtr(i8*, i8*, i8* %parentHandle) {
entry:
call void @sleep(i64 1000000, i8* undef, i8* null)
%val = load i8, i8* %0
ret i8 %val
}
; Goroutine that sleeps and does nothing.
; Should be a void tail call.
define void @sleepGoroutine(i8*, i8* %parentHandle) {

128
transform/testdata/coroutines.out.ll предоставленный
Просмотреть файл

@ -45,7 +45,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
call void @sleep(i64 %1, i8* undef, i8* %parentHandle)
ret i32 undef
}
@ -84,7 +84,7 @@ entry:
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
store i32 %0, i32* %ret.ptr.bitcast
store i32 %0, i32* %ret.ptr.bitcast, align 4
%ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef)
%4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle)
@ -93,7 +93,7 @@ entry:
define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) {
entry:
%call.return = alloca i32
%call.return = alloca i32, align 4
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
@ -116,10 +116,10 @@ entry:
]
wakeup: ; preds = %entry
%4 = load i32, i32* %call.return
%4 = load i32, i32* %call.return, align 4
call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast)
%5 = icmp eq i32 %4, 0
store i1 %5, i1* %task.retPtr.bitcast
store i1 %5, i1* %task.retPtr.bitcast, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup
@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup
br label %suspend
}
define void @doNothing(i8* %0, i8* %1) {
define void @doNothing(i8* %0, i8* %parentHandle) {
entry:
ret void
}
define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
store i8 5, i8* %a, align 1
%coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail
suspend: ; preds = %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef
cleanup: ; preds = %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend
post.tail: ; preds = %entry
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]
unreachable: ; preds = %post.tail
unreachable
}
define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) {
entry:
%a = alloca i8, align 1
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save1 = call token @llvm.coro.save(i8* %coro.state)
%call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false)
switch i8 %call.suspend2, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]
wakeup: ; preds = %entry
store i8 5, i8* %a, align 1
%1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
%2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
br label %post.tail
suspend: ; preds = %entry, %post.tail, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef
cleanup: ; preds = %entry, %post.tail
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend
post.tail: ; preds = %wakeup
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %unreachable
i8 1, label %cleanup
]
unreachable: ; preds = %post.tail
unreachable
}
define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) {
entry:
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%coro.size = call i32 @llvm.coro.size.i32()
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
%coro.save = call token @llvm.coro.save(i8* %coro.state)
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
switch i8 %call.suspend, label %suspend [
i8 0, label %wakeup
i8 1, label %cleanup
]
wakeup: ; preds = %entry
%2 = load i8, i8* %0, align 1
store i8 %2, i8* %task.retPtr, align 1
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
br label %cleanup
suspend: ; preds = %entry, %cleanup
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
ret i8 undef
cleanup: ; preds = %entry, %wakeup
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
br label %suspend
}
define void @sleepGoroutine(i8* %0, i8* %parentHandle) {
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)