transform (coroutines): remove map iteration from coroutine lowering pass
The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output. This change removes many of the maps and ensures that the transformations are deterministic.
Этот коммит содержится в:
родитель
3862d6e8a2
коммит
bb5f7534e5
1 изменённых файлов: 101 добавлений и 53 удалений
|
@ -115,13 +115,22 @@ type asyncFunc struct {
|
|||
// callers is a set of all functions which call this async function.
|
||||
callers map[llvm.Value]struct{}
|
||||
|
||||
// returns is a map of terminal basic blocks to their return kinds.
|
||||
returns map[llvm.BasicBlock]returnKind
|
||||
// returns is a list of returns in the function, along with metadata.
|
||||
returns []asyncReturn
|
||||
|
||||
// calls is the set of all calls in the asyncFunc.
|
||||
// normalCalls is the set of all intermideate suspending calls in the asyncFunc.
|
||||
// tailCalls is the set of all tail calls in the asyncFunc.
|
||||
calls, normalCalls, tailCalls map[llvm.Value]struct{}
|
||||
// calls is a list of all calls in the asyncFunc.
|
||||
// normalCalls is a list of all intermideate suspending calls in the asyncFunc.
|
||||
// tailCalls is a list of all tail calls in the asyncFunc.
|
||||
calls, normalCalls, tailCalls []llvm.Value
|
||||
}
|
||||
|
||||
// asyncReturn is a metadata container for a return from an asynchronous function.
|
||||
type asyncReturn struct {
|
||||
// block is the basic block terminated by the return.
|
||||
block llvm.BasicBlock
|
||||
|
||||
// kind is the kind of the return.
|
||||
kind returnKind
|
||||
}
|
||||
|
||||
// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
|
||||
|
@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
|
|||
// The map keys are function pointers.
|
||||
asyncFuncs map[llvm.Value]*asyncFunc
|
||||
|
||||
asyncFuncsOrdered []*asyncFunc
|
||||
|
||||
// calls is a slice of all of the async calls in the module.
|
||||
calls []llvm.Value
|
||||
|
||||
|
@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
|
|||
// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
|
||||
func (c *coroutineLoweringPass) findAsyncFuncs() {
|
||||
asyncs := map[llvm.Value]*asyncFunc{}
|
||||
asyncsOrdered := []llvm.Value{}
|
||||
calls := []llvm.Value{}
|
||||
|
||||
// Use a breadth-first search to find all async functions.
|
||||
worklist := []llvm.Value{c.pause}
|
||||
for len(worklist) > 0 {
|
||||
// Pop a function off the worklist.
|
||||
fn := worklist[len(worklist)-1]
|
||||
worklist = worklist[:len(worklist)-1]
|
||||
fn := worklist[0]
|
||||
worklist = worklist[1:]
|
||||
|
||||
// Get task pointer argument.
|
||||
task := fn.LastParam()
|
||||
|
@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
|
|||
// Mark the caller as async.
|
||||
// Use nil as a temporary value. It will be replaced later.
|
||||
asyncs[caller] = nil
|
||||
asyncsOrdered = append(asyncsOrdered, caller)
|
||||
|
||||
// Put the caller on the worklist.
|
||||
worklist = append(worklist, caller)
|
||||
|
@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
|
|||
}
|
||||
}
|
||||
|
||||
// Flip the order of the async functions so that the top ones are lowered first.
|
||||
for i := 0; i < len(asyncsOrdered)/2; i++ {
|
||||
asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
|
||||
}
|
||||
|
||||
// Map the elements of asyncsOrdered to *asyncFunc.
|
||||
asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
|
||||
for i, v := range asyncsOrdered {
|
||||
asyncFuncsOrdered[i] = asyncs[v]
|
||||
}
|
||||
|
||||
c.asyncFuncs = asyncs
|
||||
c.asyncFuncsOrdered = asyncFuncsOrdered
|
||||
c.calls = calls
|
||||
}
|
||||
|
||||
|
@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
|
|||
|
||||
// analyzeFuncReturns analyzes and classifies the returns of a function.
|
||||
func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
|
||||
returns := map[llvm.BasicBlock]returnKind{}
|
||||
returns := []asyncReturn{}
|
||||
if fn.fn == c.pause {
|
||||
// Skip pause.
|
||||
fn.returns = returns
|
||||
|
@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
|
|||
case !c.isAsyncCall(prev):
|
||||
// This is not any form of asynchronous tail call.
|
||||
if isVoid {
|
||||
returns[bb] = returnVoid
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnVoid,
|
||||
})
|
||||
} else {
|
||||
returns[bb] = returnNormal
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnNormal,
|
||||
})
|
||||
}
|
||||
case isVoid:
|
||||
if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
|
||||
// This is a tail call to a void-returning function from a function with a void return.
|
||||
returns[bb] = returnVoidTail
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnVoidTail,
|
||||
})
|
||||
} else {
|
||||
// This is a tail call to a value-returning function from a function with a void return.
|
||||
// The returned value will be ditched.
|
||||
returns[bb] = returnDitchedTail
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnDitchedTail,
|
||||
})
|
||||
}
|
||||
case last.Operand(0) == prev:
|
||||
// This is a regular tail call. The return of the callee is returned to the parent.
|
||||
returns[bb] = returnTail
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnTail,
|
||||
})
|
||||
case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
|
||||
// This is a tail call that returns a previous value after waiting on a void function.
|
||||
returns[bb] = returnDelayedValue
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnDelayedValue,
|
||||
})
|
||||
default:
|
||||
// This is a tail call that returns a value that is available before the function call.
|
||||
returns[bb] = returnAlternateTail
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnAlternateTail,
|
||||
})
|
||||
}
|
||||
case llvm.Unreachable:
|
||||
prev := llvm.PrevInstruction(last)
|
||||
|
@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
|
|||
}
|
||||
|
||||
// This is an asyncnhronous tail call to function that does not return.
|
||||
returns[bb] = returnDeadTail
|
||||
returns = append(returns, asyncReturn{
|
||||
block: bb,
|
||||
kind: returnDeadTail,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -451,7 +500,7 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
|
|||
|
||||
// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
|
||||
func (c *coroutineLoweringPass) returnAnalysisPass() {
|
||||
for _, async := range c.asyncFuncs {
|
||||
for _, async := range c.asyncFuncsOrdered {
|
||||
c.analyzeFuncReturns(async)
|
||||
}
|
||||
}
|
||||
|
@ -459,38 +508,37 @@ func (c *coroutineLoweringPass) returnAnalysisPass() {
|
|||
// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
|
||||
func (c *coroutineLoweringPass) categorizeCalls() {
|
||||
// Sort calls into their respective callers.
|
||||
for _, async := range c.asyncFuncs {
|
||||
async.calls = map[llvm.Value]struct{}{}
|
||||
}
|
||||
for _, call := range c.calls {
|
||||
c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{}
|
||||
caller := c.asyncFuncs[call.InstructionParent().Parent()]
|
||||
caller.calls = append(caller.calls, call)
|
||||
}
|
||||
|
||||
// Seperate regular and tail calls.
|
||||
for _, async := range c.asyncFuncs {
|
||||
// Find all tail calls (of any kind).
|
||||
for _, async := range c.asyncFuncsOrdered {
|
||||
// Search returns for tail calls.
|
||||
tails := map[llvm.Value]struct{}{}
|
||||
for ret, kind := range async.returns {
|
||||
switch kind {
|
||||
for _, ret := range async.returns {
|
||||
switch ret.kind {
|
||||
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
|
||||
// This is a tail return. The previous instruction is a tail call.
|
||||
tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{}
|
||||
tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Find all regular calls.
|
||||
regulars := map[llvm.Value]struct{}{}
|
||||
for call := range async.calls {
|
||||
// Seperate tail calls and regular calls.
|
||||
normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
|
||||
for _, call := range async.calls {
|
||||
if _, ok := tails[call]; ok {
|
||||
// This is a tail call.
|
||||
continue
|
||||
tailCalls = append(tailCalls, call)
|
||||
} else {
|
||||
// This is a regular call.
|
||||
normalCalls = append(normalCalls, call)
|
||||
}
|
||||
|
||||
regulars[call] = struct{}{}
|
||||
}
|
||||
|
||||
async.tailCalls = tails
|
||||
async.normalCalls = regulars
|
||||
async.normalCalls = normalCalls
|
||||
async.tailCalls = tailCalls
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
|
|||
}
|
||||
|
||||
func (async *asyncFunc) hasValueStoreReturn() bool {
|
||||
for _, kind := range async.returns {
|
||||
switch kind {
|
||||
for _, ret := range async.returns {
|
||||
switch ret.kind {
|
||||
case returnNormal, returnAlternateTail, returnDelayedValue:
|
||||
return true
|
||||
}
|
||||
|
@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
|
|||
}
|
||||
|
||||
// Lower returns.
|
||||
for ret, kind := range fn.returns {
|
||||
for _, ret := range fn.returns {
|
||||
// Get terminator.
|
||||
terminator := ret.LastInstruction()
|
||||
terminator := ret.block.LastInstruction()
|
||||
|
||||
// Get tail call if applicable.
|
||||
var call llvm.Value
|
||||
switch kind {
|
||||
switch ret.kind {
|
||||
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
|
||||
call = llvm.PrevInstruction(terminator)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
switch ret.kind {
|
||||
case returnNormal:
|
||||
c.builder.SetInsertPointBefore(terminator)
|
||||
|
||||
|
@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
|
|||
c.builder.CreateBr(suspend)
|
||||
|
||||
// Restore old state before tail calls.
|
||||
for call := range fn.tailCalls {
|
||||
if fn.returns[call.InstructionParent()] == returnDeadTail {
|
||||
for _, call := range fn.tailCalls {
|
||||
if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
|
||||
// Callee never returns, so the state restore is ineffectual.
|
||||
continue
|
||||
}
|
||||
|
@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
|
|||
}
|
||||
|
||||
// Lower returns.
|
||||
for ret, kind := range fn.returns {
|
||||
for _, ret := range fn.returns {
|
||||
// Get terminator instruction.
|
||||
terminator := ret.LastInstruction()
|
||||
terminator := ret.block.LastInstruction()
|
||||
|
||||
// Get tail call if applicable.
|
||||
var call llvm.Value
|
||||
switch kind {
|
||||
switch ret.kind {
|
||||
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
|
||||
call = llvm.PrevInstruction(terminator)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
switch ret.kind {
|
||||
case returnNormal:
|
||||
c.builder.SetInsertPointBefore(terminator)
|
||||
|
||||
|
@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
|
|||
c.builder.SetInsertPointBefore(call)
|
||||
|
||||
// Store return value.
|
||||
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
|
||||
c.builder.CreateStore(terminator.Operand(0), retPtr)
|
||||
|
||||
// Heap-allocate a return buffer for the discarded return.
|
||||
alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
|
||||
|
@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
|
|||
c.builder.SetInsertPointBefore(call)
|
||||
|
||||
// Store return value.
|
||||
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
|
||||
c.builder.CreateStore(terminator.Operand(0), retPtr)
|
||||
}
|
||||
|
||||
// Delete call if it is a pause, because it has already been lowered.
|
||||
|
@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
|
|||
|
||||
// Replace terminator with branch to cleanup.
|
||||
terminator.EraseFromParentAsInstruction()
|
||||
c.builder.SetInsertPointAtEnd(ret)
|
||||
c.builder.SetInsertPointAtEnd(ret.block)
|
||||
c.builder.CreateBr(cleanup)
|
||||
}
|
||||
|
||||
// Lower regular calls.
|
||||
for call := range fn.normalCalls {
|
||||
for _, call := range fn.normalCalls {
|
||||
// Lower return value of call.
|
||||
c.lowerCallReturn(fn, call)
|
||||
|
||||
|
@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
|
|||
} else {
|
||||
// Check for any undead returns.
|
||||
var undead bool
|
||||
for _, kind := range c.asyncFuncs[fn].returns {
|
||||
if kind != returnDeadTail {
|
||||
for _, ret := range c.asyncFuncs[fn].returns {
|
||||
if ret.kind != returnDeadTail {
|
||||
// This return results in a value being eventually stored.
|
||||
undead = true
|
||||
break
|
||||
|
|
Загрузка…
Создание таблицы
Сослаться в новой задаче