compiler: add support for 'go' on func values

This commit allows starting a new goroutine directly from a func value,
not just when the static callee is known.

This is necessary to support the whole time package, not just the
commonly used subset that was compiled with the SimpleDCE pass enabled.
Этот коммит содержится в:
Ayke van Laethem 2019-08-15 15:52:26 +02:00 коммит произвёл Ron Evans
родитель e4fc3bb66a
коммит bbc3046687
6 изменённых файлов: 238 добавлений и 131 удалений

Просмотреть файл

@ -1060,32 +1060,47 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
case *ssa.Defer:
c.emitDefer(frame, instr)
case *ssa.Go:
if instr.Call.IsInvoke() {
c.addError(instr.Pos(), "todo: go on method receiver")
return
}
callee := instr.Call.StaticCallee()
if callee == nil {
c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)")
return
}
calleeFn := c.ir.GetFunction(callee)
// Get all function parameters to pass to the goroutine.
var params []llvm.Value
for _, param := range instr.Call.Args {
params = append(params, c.getValue(frame, param))
}
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
// For coroutine scheduling, this is only required when calling an
// external function.
// For tasks, because all params are stored in a single object, no
// unnecessary parameters should be stored anyway.
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle
}
c.emitStartGoroutine(calleeFn.LLVMFn, params)
// Start a new goroutine.
if callee := instr.Call.StaticCallee(); callee != nil {
// Static callee is known. This makes it easier to start a new
// goroutine.
calleeFn := c.ir.GetFunction(callee)
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
// For coroutine scheduling, this is only required when calling
// an external function.
// For tasks, because all params are stored in a single object,
// no unnecessary parameters should be stored anyway.
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
}
c.emitStartGoroutine(calleeFn.LLVMFn, params)
} else if !instr.Call.IsInvoke() {
// This is a function pointer.
// At the moment, two extra params are passed to the newly started
// goroutine:
// * The function context, for closures.
// * The parent handle (for coroutines) or the function pointer
// itself (for tasks).
funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
params = append(params, context) // context parameter
switch c.selectScheduler() {
case "coroutines":
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
case "tasks":
params = append(params, funcPtr)
default:
panic("unknown scheduler type")
}
c.emitStartGoroutine(funcPtr, params)
} else {
c.addError(instr.Pos(), "todo: go on interface call")
}
case *ssa.If:
cond := c.getValue(frame, instr.Cond)
block := instr.Block()

Просмотреть файл

@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() {
// There are multiple functions used in a func value that
// implement this signature.
// What we'll do is transform the following:
// rawPtr := runtime.getFuncPtr(fn)
// if func.rawPtr == nil {
// rawPtr := runtime.getFuncPtr(func.ptr)
// if rawPtr == nil {
// runtime.nilPanic()
// }
// result := func.rawPtr(...args, func.context)
// result := rawPtr(...args, func.context)
// into this:
// if false {
// runtime.nilPanic()
@ -175,95 +175,111 @@ func (c *Compiler) LowerFuncValues() {
// Remove some casts, checks, and the old call which we're going
// to replace.
var funcCall llvm.Value
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
for _, callIntPtr := range getUses(getFuncPtrCall) {
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" {
for _, inttoptr := range getUses(callIntPtr) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected a inttoptr")
}
for _, use := range getUses(inttoptr) {
c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions)
use.EraseFromParentAsInstruction()
}
inttoptr.EraseFromParentAsInstruction()
}
callIntPtr.EraseFromParentAsInstruction()
continue
}
if callIntPtr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
for _, ptrUse := range getUses(inttoptr) {
for _, ptrUse := range getUses(callIntPtr) {
if !ptrUse.IsABitCastInst().IsNil() {
for _, bitcastUse := range getUses(ptrUse) {
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
panic("expected a call instruction")
}
switch bitcastUse.CalledValue().Name() {
case "runtime.isnil":
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
bitcastUse.EraseFromParentAsInstruction()
default:
panic("expected a call to runtime.isnil")
}
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
bitcastUse.EraseFromParentAsInstruction()
}
ptrUse.EraseFromParentAsInstruction()
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
if !funcCall.IsNil() {
panic("multiple calls on a single runtime.getFuncPtr")
}
funcCall = ptrUse
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
return c.builder.CreateCall(funcPtr, params, "")
}, functions)
} else {
panic("unexpected getFuncPtrCall")
}
ptrUse.EraseFromParentAsInstruction()
}
}
if funcCall.IsNil() {
panic("expected exactly one call use of a runtime.getFuncPtr")
}
// The block that cannot be reached with correct funcValues (to
// help the optimizer).
c.builder.SetInsertPointBefore(funcCall)
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
c.builder.SetInsertPointAtEnd(defaultBlock)
c.builder.CreateUnreachable()
// Create the switch.
c.builder.SetInsertPointBefore(funcCall)
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
// Split right after the switch. We will need to insert a few
// basic blocks in this gap.
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
// The 0 case, which is actually a nil check.
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
c.builder.SetInsertPointAtEnd(nilBlock)
c.createRuntimeCall("nilPanic", nil, "")
c.builder.CreateUnreachable()
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
// Gather the list of parameters for every call we're going to
// make.
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
for i := range callParams {
callParams[i] = funcCall.Operand(i)
}
// If the call produces a value, we need to get it using a PHI
// node.
phiBlocks := make([]llvm.BasicBlock, len(functions))
phiValues := make([]llvm.Value, len(functions))
for i, fn := range functions {
// Insert a switch case.
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
c.builder.SetInsertPointAtEnd(bb)
result := c.builder.CreateCall(fn.funcPtr, callParams, "")
c.builder.CreateBr(nextBlock)
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
phiBlocks[i] = bb
phiValues[i] = result
}
// Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the
// call produced a value.
if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := c.builder.CreatePHI(funcCall.Type(), "")
phi.AddIncoming(phiValues, phiBlocks)
funcCall.ReplaceAllUsesWith(phi)
}
// Finally, remove the old instructions.
funcCall.EraseFromParentAsInstruction()
for _, inttoptr := range getUses(getFuncPtrCall) {
inttoptr.EraseFromParentAsInstruction()
callIntPtr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
}
}
}
}
// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
// to the newly created direct calls. The funcID is the number to switch on,
// call is the call instruction to replace, and createCall is the callback that
// actually creates the new call. By changing createCall to something other than
// c.builder.CreateCall, instead of calling a function it can start a new
// goroutine for example.
func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
// The block that cannot be reached with correct funcValues (to help the
// optimizer).
c.builder.SetInsertPointBefore(call)
defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
c.builder.SetInsertPointAtEnd(defaultBlock)
c.builder.CreateUnreachable()
// Create the switch.
c.builder.SetInsertPointBefore(call)
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
// Split right after the switch. We will need to insert a few basic blocks
// in this gap.
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
// The 0 case, which is actually a nil check.
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
c.builder.SetInsertPointAtEnd(nilBlock)
c.createRuntimeCall("nilPanic", nil, "")
c.builder.CreateUnreachable()
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
// Gather the list of parameters for every call we're going to make.
callParams := make([]llvm.Value, call.OperandsCount()-1)
for i := range callParams {
callParams[i] = call.Operand(i)
}
// If the call produces a value, we need to get it using a PHI
// node.
phiBlocks := make([]llvm.BasicBlock, len(functions))
phiValues := make([]llvm.Value, len(functions))
for i, fn := range functions {
// Insert a switch case.
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
c.builder.SetInsertPointAtEnd(bb)
result := createCall(fn.funcPtr, callParams)
c.builder.CreateBr(nextBlock)
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
phiBlocks[i] = bb
phiValues[i] = result
}
// Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the
// call produced a value.
if call.Type().TypeKind() != llvm.VoidTypeKind {
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := c.builder.CreatePHI(call.Type(), "")
phi.AddIncoming(phiValues, phiBlocks)
call.ReplaceAllUsesWith(phi)
}
}

Просмотреть файл

@ -32,10 +32,15 @@ const (
// funcImplementation picks an appropriate func value implementation for the
// target.
func (c *Compiler) funcImplementation() funcValueImplementation {
if c.GOARCH == "wasm" {
// Always pick the switch implementation, as it allows the use of blocking
// inside a function that is used as a func value.
switch c.selectScheduler() {
case "coroutines":
return funcValueSwitch
} else {
case "tasks":
return funcValueDoubleword
default:
panic("unknown scheduler type")
}
}

Просмотреть файл

@ -18,14 +18,12 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l
calleeValue := c.createGoroutineStartWrapper(funcPtr)
c.createRuntimeCall("startGoroutine", []llvm.Value{calleeValue, paramBundle}, "")
case "coroutines":
// Mark this function as a 'go' invocation and break invalid
// interprocedural optimizations. For example, heap-to-stack
// transformations are not sound as goroutines can outlive their parent.
calleeType := funcPtr.Type()
// We roundtrip through runtime.makeGoroutine as a signal (to find these
// calls) and to break any optimizations LLVM will try to do: they are
// invalid if we called this as a regular function to be updated later.
calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "")
calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "")
calleeValue = c.builder.CreateIntToPtr(calleeValue, calleeType, "")
calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "")
c.createCall(calleeValue, params, "")
default:
panic("unreachable")
@ -52,32 +50,86 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l
// ignores the return value because newly started goroutines do not have a
// return value.
func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value {
if fn.IsAFunction().IsNil() {
panic("todo: goroutine start wrapper for func value")
var wrapper llvm.Value
if !fn.IsAFunction().IsNil() {
// See whether this wrapper has already been created. If so, return it.
name := fn.Name()
wrapper = c.mod.NamedFunction(name + "$gowrapper")
if !wrapper.IsNil() {
return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "")
}
// Save the current position in the IR builder.
currentBlock := c.builder.GetInsertBlock()
defer c.builder.SetInsertPointAtEnd(currentBlock)
// Create the wrapper.
wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false)
wrapper = llvm.AddFunction(c.mod, name+"$gowrapper", wrapperType)
wrapper.SetLinkage(llvm.PrivateLinkage)
wrapper.SetUnnamedAddr(true)
entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry)
// Create the list of params for the call.
paramTypes := fn.Type().ElementType().ParamTypes()
params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2])
params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType))
// Create the call.
c.builder.CreateCall(fn, params, "")
} else {
// For a function pointer like this:
//
// var funcPtr func(x, y int) int
//
// A wrapper like the following is created:
//
// func .gowrapper(ptr *unsafe.Pointer) {
// args := (*struct{
// x, y int
// fn func(x, y int) int
// })(ptr)
// args.fn(x, y)
// }
//
// With a bit of luck, identical wrapper functions like these can be
// merged into one.
// Save the current position in the IR builder.
currentBlock := c.builder.GetInsertBlock()
defer c.builder.SetInsertPointAtEnd(currentBlock)
// Create the wrapper.
wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false)
wrapper = llvm.AddFunction(c.mod, ".gowrapper", wrapperType)
wrapper.SetLinkage(llvm.InternalLinkage)
wrapper.SetUnnamedAddr(true)
entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry)
// Get the list of parameters, with the extra parameters at the end.
paramTypes := fn.Type().ElementType().ParamTypes()
paramTypes[len(paramTypes)-1] = fn.Type() // the last element is the function pointer
params := c.emitPointerUnpack(wrapper.Param(0), paramTypes)
// Get the function pointer.
fnPtr := params[len(params)-1]
// Ignore the last param, which isn't used anymore.
// TODO: avoid this extra "parent handle" parameter in most functions.
params[len(params)-1] = llvm.Undef(c.i8ptrType)
// Create the call.
c.builder.CreateCall(fnPtr, params, "")
}
// See whether this wrapper has already been created. If so, return it.
name := fn.Name()
wrapper := c.mod.NamedFunction(name + "$gowrapper")
if !wrapper.IsNil() {
return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "")
}
// Save the current position in the IR builder.
currentBlock := c.builder.GetInsertBlock()
defer c.builder.SetInsertPointAtEnd(currentBlock)
// Create the wrapper.
wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false)
wrapper = llvm.AddFunction(c.mod, name+"$gowrapper", wrapperType)
wrapper.SetLinkage(llvm.PrivateLinkage)
wrapper.SetUnnamedAddr(true)
entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry)
paramTypes := fn.Type().ElementType().ParamTypes()
params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2])
params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType))
c.builder.CreateCall(fn, params, "")
// Finish the function. Every basic block must end in a terminator, and
// because goroutines never return a value we can simply return void.
c.builder.CreateRetVoid()
// Return a ptrtoint of the wrapper, not the function itself.
return c.builder.CreatePtrToInt(wrapper, c.uintptrType, "")
}

19
testdata/coroutines.go предоставленный
Просмотреть файл

@ -28,6 +28,19 @@ func main() {
var printer Printer
printer = &myPrinter{}
printer.Print()
sleepFuncValue(func(x int) {
time.Sleep(1 * time.Millisecond)
println("slept inside func pointer", x)
})
time.Sleep(1 * time.Millisecond)
n := 20
sleepFuncValue(func(x int) {
time.Sleep(1 * time.Millisecond)
println("slept inside closure, with value:", n, x)
})
time.Sleep(2 * time.Millisecond)
}
func sub() {
@ -47,6 +60,10 @@ func delayedValue() int {
return 42
}
func sleepFuncValue(fn func(int)) {
go fn(8)
}
func nowait() {
println("non-blocking goroutine")
}
@ -55,7 +72,7 @@ type Printer interface {
Print()
}
type myPrinter struct{
type myPrinter struct {
}
func (i *myPrinter) Print() {

2
testdata/coroutines.txt предоставленный
Просмотреть файл

@ -11,3 +11,5 @@ value produced after some time: 42
non-blocking goroutine
done with non-blocking goroutine
async interface method call
slept inside func pointer 8
slept inside closure, with value: 20 8