transform (coroutines): remove map iteration from coroutine lowering pass

The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output.
This change removes many of the maps and ensures that the transformations are deterministic.
Этот коммит содержится в:
Jaden Weiss 2020-04-10 09:51:08 -04:00 коммит произвёл Ayke
родитель 3862d6e8a2
коммит bb5f7534e5

Просмотреть файл

@ -115,13 +115,22 @@ type asyncFunc struct {
// callers is a set of all functions which call this async function. // callers is a set of all functions which call this async function.
callers map[llvm.Value]struct{} callers map[llvm.Value]struct{}
// returns is a map of terminal basic blocks to their return kinds. // returns is a list of returns in the function, along with metadata.
returns map[llvm.BasicBlock]returnKind returns []asyncReturn
// calls is the set of all calls in the asyncFunc. // calls is a list of all calls in the asyncFunc.
// normalCalls is the set of all intermideate suspending calls in the asyncFunc. // normalCalls is a list of all intermideate suspending calls in the asyncFunc.
// tailCalls is the set of all tail calls in the asyncFunc. // tailCalls is a list of all tail calls in the asyncFunc.
calls, normalCalls, tailCalls map[llvm.Value]struct{} calls, normalCalls, tailCalls []llvm.Value
}
// asyncReturn is a metadata container for a return from an asynchronous function.
type asyncReturn struct {
// block is the basic block terminated by the return.
block llvm.BasicBlock
// kind is the kind of the return.
kind returnKind
} }
// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler. // coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
// The map keys are function pointers. // The map keys are function pointers.
asyncFuncs map[llvm.Value]*asyncFunc asyncFuncs map[llvm.Value]*asyncFunc
asyncFuncsOrdered []*asyncFunc
// calls is a slice of all of the async calls in the module. // calls is a slice of all of the async calls in the module.
calls []llvm.Value calls []llvm.Value
@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
// A function is considered asynchronous if it calls an asynchronous function or intrinsic. // A function is considered asynchronous if it calls an asynchronous function or intrinsic.
func (c *coroutineLoweringPass) findAsyncFuncs() { func (c *coroutineLoweringPass) findAsyncFuncs() {
asyncs := map[llvm.Value]*asyncFunc{} asyncs := map[llvm.Value]*asyncFunc{}
asyncsOrdered := []llvm.Value{}
calls := []llvm.Value{} calls := []llvm.Value{}
// Use a breadth-first search to find all async functions. // Use a breadth-first search to find all async functions.
worklist := []llvm.Value{c.pause} worklist := []llvm.Value{c.pause}
for len(worklist) > 0 { for len(worklist) > 0 {
// Pop a function off the worklist. // Pop a function off the worklist.
fn := worklist[len(worklist)-1] fn := worklist[0]
worklist = worklist[:len(worklist)-1] worklist = worklist[1:]
// Get task pointer argument. // Get task pointer argument.
task := fn.LastParam() task := fn.LastParam()
@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
// Mark the caller as async. // Mark the caller as async.
// Use nil as a temporary value. It will be replaced later. // Use nil as a temporary value. It will be replaced later.
asyncs[caller] = nil asyncs[caller] = nil
asyncsOrdered = append(asyncsOrdered, caller)
// Put the caller on the worklist. // Put the caller on the worklist.
worklist = append(worklist, caller) worklist = append(worklist, caller)
@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
} }
} }
// Flip the order of the async functions so that the top ones are lowered first.
for i := 0; i < len(asyncsOrdered)/2; i++ {
asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
}
// Map the elements of asyncsOrdered to *asyncFunc.
asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
for i, v := range asyncsOrdered {
asyncFuncsOrdered[i] = asyncs[v]
}
c.asyncFuncs = asyncs c.asyncFuncs = asyncs
c.asyncFuncsOrdered = asyncFuncsOrdered
c.calls = calls c.calls = calls
} }
@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
// analyzeFuncReturns analyzes and classifies the returns of a function. // analyzeFuncReturns analyzes and classifies the returns of a function.
func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
returns := map[llvm.BasicBlock]returnKind{} returns := []asyncReturn{}
if fn.fn == c.pause { if fn.fn == c.pause {
// Skip pause. // Skip pause.
fn.returns = returns fn.returns = returns
@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
case !c.isAsyncCall(prev): case !c.isAsyncCall(prev):
// This is not any form of asynchronous tail call. // This is not any form of asynchronous tail call.
if isVoid { if isVoid {
returns[bb] = returnVoid returns = append(returns, asyncReturn{
block: bb,
kind: returnVoid,
})
} else { } else {
returns[bb] = returnNormal returns = append(returns, asyncReturn{
block: bb,
kind: returnNormal,
})
} }
case isVoid: case isVoid:
if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
// This is a tail call to a void-returning function from a function with a void return. // This is a tail call to a void-returning function from a function with a void return.
returns[bb] = returnVoidTail returns = append(returns, asyncReturn{
block: bb,
kind: returnVoidTail,
})
} else { } else {
// This is a tail call to a value-returning function from a function with a void return. // This is a tail call to a value-returning function from a function with a void return.
// The returned value will be ditched. // The returned value will be ditched.
returns[bb] = returnDitchedTail returns = append(returns, asyncReturn{
block: bb,
kind: returnDitchedTail,
})
} }
case last.Operand(0) == prev: case last.Operand(0) == prev:
// This is a regular tail call. The return of the callee is returned to the parent. // This is a regular tail call. The return of the callee is returned to the parent.
returns[bb] = returnTail returns = append(returns, asyncReturn{
block: bb,
kind: returnTail,
})
case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
// This is a tail call that returns a previous value after waiting on a void function. // This is a tail call that returns a previous value after waiting on a void function.
returns[bb] = returnDelayedValue returns = append(returns, asyncReturn{
block: bb,
kind: returnDelayedValue,
})
default: default:
// This is a tail call that returns a value that is available before the function call. // This is a tail call that returns a value that is available before the function call.
returns[bb] = returnAlternateTail returns = append(returns, asyncReturn{
block: bb,
kind: returnAlternateTail,
})
} }
case llvm.Unreachable: case llvm.Unreachable:
prev := llvm.PrevInstruction(last) prev := llvm.PrevInstruction(last)
@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
} }
// This is an asyncnhronous tail call to function that does not return. // This is an asyncnhronous tail call to function that does not return.
returns[bb] = returnDeadTail returns = append(returns, asyncReturn{
block: bb,
kind: returnDeadTail,
})
} }
} }
@ -451,7 +500,7 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions. // returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
func (c *coroutineLoweringPass) returnAnalysisPass() { func (c *coroutineLoweringPass) returnAnalysisPass() {
for _, async := range c.asyncFuncs { for _, async := range c.asyncFuncsOrdered {
c.analyzeFuncReturns(async) c.analyzeFuncReturns(async)
} }
} }
@ -459,38 +508,37 @@ func (c *coroutineLoweringPass) returnAnalysisPass() {
// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers. // categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
func (c *coroutineLoweringPass) categorizeCalls() { func (c *coroutineLoweringPass) categorizeCalls() {
// Sort calls into their respective callers. // Sort calls into their respective callers.
for _, async := range c.asyncFuncs {
async.calls = map[llvm.Value]struct{}{}
}
for _, call := range c.calls { for _, call := range c.calls {
c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{} caller := c.asyncFuncs[call.InstructionParent().Parent()]
caller.calls = append(caller.calls, call)
} }
// Seperate regular and tail calls. // Seperate regular and tail calls.
for _, async := range c.asyncFuncs { for _, async := range c.asyncFuncsOrdered {
// Find all tail calls (of any kind). // Search returns for tail calls.
tails := map[llvm.Value]struct{}{} tails := map[llvm.Value]struct{}{}
for ret, kind := range async.returns { for _, ret := range async.returns {
switch kind { switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
// This is a tail return. The previous instruction is a tail call. // This is a tail return. The previous instruction is a tail call.
tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{} tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
} }
} }
// Find all regular calls. // Seperate tail calls and regular calls.
regulars := map[llvm.Value]struct{}{} normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
for call := range async.calls { for _, call := range async.calls {
if _, ok := tails[call]; ok { if _, ok := tails[call]; ok {
// This is a tail call. // This is a tail call.
continue tailCalls = append(tailCalls, call)
} else {
// This is a regular call.
normalCalls = append(normalCalls, call)
}
} }
regulars[call] = struct{}{} async.normalCalls = normalCalls
} async.tailCalls = tailCalls
async.tailCalls = tails
async.normalCalls = regulars
} }
} }
@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
} }
func (async *asyncFunc) hasValueStoreReturn() bool { func (async *asyncFunc) hasValueStoreReturn() bool {
for _, kind := range async.returns { for _, ret := range async.returns {
switch kind { switch ret.kind {
case returnNormal, returnAlternateTail, returnDelayedValue: case returnNormal, returnAlternateTail, returnDelayedValue:
return true return true
} }
@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
} }
// Lower returns. // Lower returns.
for ret, kind := range fn.returns { for _, ret := range fn.returns {
// Get terminator. // Get terminator.
terminator := ret.LastInstruction() terminator := ret.block.LastInstruction()
// Get tail call if applicable. // Get tail call if applicable.
var call llvm.Value var call llvm.Value
switch kind { switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator) call = llvm.PrevInstruction(terminator)
} }
switch kind { switch ret.kind {
case returnNormal: case returnNormal:
c.builder.SetInsertPointBefore(terminator) c.builder.SetInsertPointBefore(terminator)
@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.CreateBr(suspend) c.builder.CreateBr(suspend)
// Restore old state before tail calls. // Restore old state before tail calls.
for call := range fn.tailCalls { for _, call := range fn.tailCalls {
if fn.returns[call.InstructionParent()] == returnDeadTail { if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
// Callee never returns, so the state restore is ineffectual. // Callee never returns, so the state restore is ineffectual.
continue continue
} }
@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
} }
// Lower returns. // Lower returns.
for ret, kind := range fn.returns { for _, ret := range fn.returns {
// Get terminator instruction. // Get terminator instruction.
terminator := ret.LastInstruction() terminator := ret.block.LastInstruction()
// Get tail call if applicable. // Get tail call if applicable.
var call llvm.Value var call llvm.Value
switch kind { switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator) call = llvm.PrevInstruction(terminator)
} }
switch kind { switch ret.kind {
case returnNormal: case returnNormal:
c.builder.SetInsertPointBefore(terminator) c.builder.SetInsertPointBefore(terminator)
@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call) c.builder.SetInsertPointBefore(call)
// Store return value. // Store return value.
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) c.builder.CreateStore(terminator.Operand(0), retPtr)
// Heap-allocate a return buffer for the discarded return. // Heap-allocate a return buffer for the discarded return.
alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call) c.builder.SetInsertPointBefore(call)
// Store return value. // Store return value.
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) c.builder.CreateStore(terminator.Operand(0), retPtr)
} }
// Delete call if it is a pause, because it has already been lowered. // Delete call if it is a pause, because it has already been lowered.
@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
// Replace terminator with branch to cleanup. // Replace terminator with branch to cleanup.
terminator.EraseFromParentAsInstruction() terminator.EraseFromParentAsInstruction()
c.builder.SetInsertPointAtEnd(ret) c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup) c.builder.CreateBr(cleanup)
} }
// Lower regular calls. // Lower regular calls.
for call := range fn.normalCalls { for _, call := range fn.normalCalls {
// Lower return value of call. // Lower return value of call.
c.lowerCallReturn(fn, call) c.lowerCallReturn(fn, call)
@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
} else { } else {
// Check for any undead returns. // Check for any undead returns.
var undead bool var undead bool
for _, kind := range c.asyncFuncs[fn].returns { for _, ret := range c.asyncFuncs[fn].returns {
if kind != returnDeadTail { if ret.kind != returnDeadTail {
// This return results in a value being eventually stored. // This return results in a value being eventually stored.
undead = true undead = true
break break