compiler: add support for 'go' on func values

This commit allows starting a new goroutine directly from a func value,
not just when the static callee is known.

This is necessary to support the whole time package, not just the
commonly used subset that was compiled with the SimpleDCE pass enabled.
Этот коммит содержится в:
Ayke van Laethem 2019-08-15 15:52:26 +02:00 коммит произвёл Ron Evans
родитель e4fc3bb66a
коммит bbc3046687
6 изменённых файлов: 238 добавлений и 131 удалений

Просмотреть файл

@ -1060,32 +1060,47 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
case *ssa.Defer: case *ssa.Defer:
c.emitDefer(frame, instr) c.emitDefer(frame, instr)
case *ssa.Go: case *ssa.Go:
if instr.Call.IsInvoke() {
c.addError(instr.Pos(), "todo: go on method receiver")
return
}
callee := instr.Call.StaticCallee()
if callee == nil {
c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)")
return
}
calleeFn := c.ir.GetFunction(callee)
// Get all function parameters to pass to the goroutine. // Get all function parameters to pass to the goroutine.
var params []llvm.Value var params []llvm.Value
for _, param := range instr.Call.Args { for _, param := range instr.Call.Args {
params = append(params, c.getValue(frame, param)) params = append(params, c.getValue(frame, param))
} }
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
// For coroutine scheduling, this is only required when calling an
// external function.
// For tasks, because all params are stored in a single object, no
// unnecessary parameters should be stored anyway.
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle
}
// Start a new goroutine.
if callee := instr.Call.StaticCallee(); callee != nil {
// Static callee is known. This makes it easier to start a new
// goroutine.
calleeFn := c.ir.GetFunction(callee)
if !calleeFn.IsExported() && c.selectScheduler() != "tasks" {
// For coroutine scheduling, this is only required when calling
// an external function.
// For tasks, because all params are stored in a single object,
// no unnecessary parameters should be stored anyway.
params = append(params, llvm.Undef(c.i8ptrType)) // context parameter
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
}
c.emitStartGoroutine(calleeFn.LLVMFn, params) c.emitStartGoroutine(calleeFn.LLVMFn, params)
} else if !instr.Call.IsInvoke() {
// This is a function pointer.
// At the moment, two extra params are passed to the newly started
// goroutine:
// * The function context, for closures.
// * The parent handle (for coroutines) or the function pointer
// itself (for tasks).
funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
params = append(params, context) // context parameter
switch c.selectScheduler() {
case "coroutines":
params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle
case "tasks":
params = append(params, funcPtr)
default:
panic("unknown scheduler type")
}
c.emitStartGoroutine(funcPtr, params)
} else {
c.addError(instr.Pos(), "todo: go on interface call")
}
case *ssa.If: case *ssa.If:
cond := c.getValue(frame, instr.Cond) cond := c.getValue(frame, instr.Cond)
block := instr.Block() block := instr.Block()

Просмотреть файл

@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() {
// There are multiple functions used in a func value that // There are multiple functions used in a func value that
// implement this signature. // implement this signature.
// What we'll do is transform the following: // What we'll do is transform the following:
// rawPtr := runtime.getFuncPtr(fn) // rawPtr := runtime.getFuncPtr(func.ptr)
// if func.rawPtr == nil { // if rawPtr == nil {
// runtime.nilPanic() // runtime.nilPanic()
// } // }
// result := func.rawPtr(...args, func.context) // result := rawPtr(...args, func.context)
// into this: // into this:
// if false { // if false {
// runtime.nilPanic() // runtime.nilPanic()
@ -175,48 +175,75 @@ func (c *Compiler) LowerFuncValues() {
// Remove some casts, checks, and the old call which we're going // Remove some casts, checks, and the old call which we're going
// to replace. // to replace.
var funcCall llvm.Value for _, callIntPtr := range getUses(getFuncPtrCall) {
for _, inttoptr := range getUses(getFuncPtrCall) { if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" {
for _, inttoptr := range getUses(callIntPtr) {
if inttoptr.IsAIntToPtrInst().IsNil() { if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected a inttoptr")
}
for _, use := range getUses(inttoptr) {
c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions)
use.EraseFromParentAsInstruction()
}
inttoptr.EraseFromParentAsInstruction()
}
callIntPtr.EraseFromParentAsInstruction()
continue
}
if callIntPtr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr") panic("expected inttoptr")
} }
for _, ptrUse := range getUses(inttoptr) { for _, ptrUse := range getUses(callIntPtr) {
if !ptrUse.IsABitCastInst().IsNil() { if !ptrUse.IsABitCastInst().IsNil() {
for _, bitcastUse := range getUses(ptrUse) { for _, bitcastUse := range getUses(ptrUse) {
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" { if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() {
panic("expected a call to runtime.isnil") panic("expected a call instruction")
} }
switch bitcastUse.CalledValue().Name() {
case "runtime.isnil":
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false)) bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
bitcastUse.EraseFromParentAsInstruction() bitcastUse.EraseFromParentAsInstruction()
default:
panic("expected a call to runtime.isnil")
} }
ptrUse.EraseFromParentAsInstruction()
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
if !funcCall.IsNil() {
panic("multiple calls on a single runtime.getFuncPtr")
} }
funcCall = ptrUse } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
return c.builder.CreateCall(funcPtr, params, "")
}, functions)
} else { } else {
panic("unexpected getFuncPtrCall") panic("unexpected getFuncPtrCall")
} }
ptrUse.EraseFromParentAsInstruction()
}
callIntPtr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
} }
} }
if funcCall.IsNil() {
panic("expected exactly one call use of a runtime.getFuncPtr")
} }
}
// The block that cannot be reached with correct funcValues (to // addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
// help the optimizer). // to the newly created direct calls. The funcID is the number to switch on,
c.builder.SetInsertPointBefore(funcCall) // call is the call instruction to replace, and createCall is the callback that
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default") // actually creates the new call. By changing createCall to something other than
// c.builder.CreateCall, instead of calling a function it can start a new
// goroutine for example.
func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
// The block that cannot be reached with correct funcValues (to help the
// optimizer).
c.builder.SetInsertPointBefore(call)
defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
c.builder.SetInsertPointAtEnd(defaultBlock) c.builder.SetInsertPointAtEnd(defaultBlock)
c.builder.CreateUnreachable() c.builder.CreateUnreachable()
// Create the switch. // Create the switch.
c.builder.SetInsertPointBefore(funcCall) c.builder.SetInsertPointBefore(call)
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
// Split right after the switch. We will need to insert a few // Split right after the switch. We will need to insert a few basic blocks
// basic blocks in this gap. // in this gap.
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
// The 0 case, which is actually a nil check. // The 0 case, which is actually a nil check.
@ -226,11 +253,10 @@ func (c *Compiler) LowerFuncValues() {
c.builder.CreateUnreachable() c.builder.CreateUnreachable()
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock) sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
// Gather the list of parameters for every call we're going to // Gather the list of parameters for every call we're going to make.
// make. callParams := make([]llvm.Value, call.OperandsCount()-1)
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
for i := range callParams { for i := range callParams {
callParams[i] = funcCall.Operand(i) callParams[i] = call.Operand(i)
} }
// If the call produces a value, we need to get it using a PHI // If the call produces a value, we need to get it using a PHI
@ -241,7 +267,7 @@ func (c *Compiler) LowerFuncValues() {
// Insert a switch case. // Insert a switch case.
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
c.builder.SetInsertPointAtEnd(bb) c.builder.SetInsertPointAtEnd(bb)
result := c.builder.CreateCall(fn.funcPtr, callParams, "") result := createCall(fn.funcPtr, callParams)
c.builder.CreateBr(nextBlock) c.builder.CreateBr(nextBlock)
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb) sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
phiBlocks[i] = bb phiBlocks[i] = bb
@ -250,20 +276,10 @@ func (c *Compiler) LowerFuncValues() {
// Create the PHI node so that the call result flows into the // Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the // next block (after the split). This is only necessary when the
// call produced a value. // call produced a value.
if funcCall.Type().TypeKind() != llvm.VoidTypeKind { if call.Type().TypeKind() != llvm.VoidTypeKind {
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction()) c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := c.builder.CreatePHI(funcCall.Type(), "") phi := c.builder.CreatePHI(call.Type(), "")
phi.AddIncoming(phiValues, phiBlocks) phi.AddIncoming(phiValues, phiBlocks)
funcCall.ReplaceAllUsesWith(phi) call.ReplaceAllUsesWith(phi)
}
// Finally, remove the old instructions.
funcCall.EraseFromParentAsInstruction()
for _, inttoptr := range getUses(getFuncPtrCall) {
inttoptr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
}
}
} }
} }

Просмотреть файл

@ -32,10 +32,15 @@ const (
// funcImplementation picks an appropriate func value implementation for the // funcImplementation picks an appropriate func value implementation for the
// target. // target.
func (c *Compiler) funcImplementation() funcValueImplementation { func (c *Compiler) funcImplementation() funcValueImplementation {
if c.GOARCH == "wasm" { // Always pick the switch implementation, as it allows the use of blocking
// inside a function that is used as a func value.
switch c.selectScheduler() {
case "coroutines":
return funcValueSwitch return funcValueSwitch
} else { case "tasks":
return funcValueDoubleword return funcValueDoubleword
default:
panic("unknown scheduler type")
} }
} }

Просмотреть файл

@ -18,14 +18,12 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l
calleeValue := c.createGoroutineStartWrapper(funcPtr) calleeValue := c.createGoroutineStartWrapper(funcPtr)
c.createRuntimeCall("startGoroutine", []llvm.Value{calleeValue, paramBundle}, "") c.createRuntimeCall("startGoroutine", []llvm.Value{calleeValue, paramBundle}, "")
case "coroutines": case "coroutines":
// Mark this function as a 'go' invocation and break invalid // We roundtrip through runtime.makeGoroutine as a signal (to find these
// interprocedural optimizations. For example, heap-to-stack // calls) and to break any optimizations LLVM will try to do: they are
// transformations are not sound as goroutines can outlive their parent. // invalid if we called this as a regular function to be updated later.
calleeType := funcPtr.Type()
calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "") calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "")
calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "") calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "")
calleeValue = c.builder.CreateIntToPtr(calleeValue, calleeType, "") calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "")
c.createCall(calleeValue, params, "") c.createCall(calleeValue, params, "")
default: default:
panic("unreachable") panic("unreachable")
@ -52,13 +50,12 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l
// ignores the return value because newly started goroutines do not have a // ignores the return value because newly started goroutines do not have a
// return value. // return value.
func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value {
if fn.IsAFunction().IsNil() { var wrapper llvm.Value
panic("todo: goroutine start wrapper for func value")
}
if !fn.IsAFunction().IsNil() {
// See whether this wrapper has already been created. If so, return it. // See whether this wrapper has already been created. If so, return it.
name := fn.Name() name := fn.Name()
wrapper := c.mod.NamedFunction(name + "$gowrapper") wrapper = c.mod.NamedFunction(name + "$gowrapper")
if !wrapper.IsNil() { if !wrapper.IsNil() {
return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "") return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "")
} }
@ -74,10 +71,65 @@ func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value {
wrapper.SetUnnamedAddr(true) wrapper.SetUnnamedAddr(true)
entry := llvm.AddBasicBlock(wrapper, "entry") entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry) c.builder.SetInsertPointAtEnd(entry)
// Create the list of params for the call.
paramTypes := fn.Type().ElementType().ParamTypes() paramTypes := fn.Type().ElementType().ParamTypes()
params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2]) params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2])
params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)) params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType))
// Create the call.
c.builder.CreateCall(fn, params, "") c.builder.CreateCall(fn, params, "")
} else {
// For a function pointer like this:
//
// var funcPtr func(x, y int) int
//
// A wrapper like the following is created:
//
// func .gowrapper(ptr *unsafe.Pointer) {
// args := (*struct{
// x, y int
// fn func(x, y int) int
// })(ptr)
// args.fn(x, y)
// }
//
// With a bit of luck, identical wrapper functions like these can be
// merged into one.
// Save the current position in the IR builder.
currentBlock := c.builder.GetInsertBlock()
defer c.builder.SetInsertPointAtEnd(currentBlock)
// Create the wrapper.
wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false)
wrapper = llvm.AddFunction(c.mod, ".gowrapper", wrapperType)
wrapper.SetLinkage(llvm.InternalLinkage)
wrapper.SetUnnamedAddr(true)
entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry)
// Get the list of parameters, with the extra parameters at the end.
paramTypes := fn.Type().ElementType().ParamTypes()
paramTypes[len(paramTypes)-1] = fn.Type() // the last element is the function pointer
params := c.emitPointerUnpack(wrapper.Param(0), paramTypes)
// Get the function pointer.
fnPtr := params[len(params)-1]
// Ignore the last param, which isn't used anymore.
// TODO: avoid this extra "parent handle" parameter in most functions.
params[len(params)-1] = llvm.Undef(c.i8ptrType)
// Create the call.
c.builder.CreateCall(fnPtr, params, "")
}
// Finish the function. Every basic block must end in a terminator, and
// because goroutines never return a value we can simply return void.
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
// Return a ptrtoint of the wrapper, not the function itself.
return c.builder.CreatePtrToInt(wrapper, c.uintptrType, "") return c.builder.CreatePtrToInt(wrapper, c.uintptrType, "")
} }

19
testdata/coroutines.go предоставленный
Просмотреть файл

@ -28,6 +28,19 @@ func main() {
var printer Printer var printer Printer
printer = &myPrinter{} printer = &myPrinter{}
printer.Print() printer.Print()
sleepFuncValue(func(x int) {
time.Sleep(1 * time.Millisecond)
println("slept inside func pointer", x)
})
time.Sleep(1 * time.Millisecond)
n := 20
sleepFuncValue(func(x int) {
time.Sleep(1 * time.Millisecond)
println("slept inside closure, with value:", n, x)
})
time.Sleep(2 * time.Millisecond)
} }
func sub() { func sub() {
@ -47,6 +60,10 @@ func delayedValue() int {
return 42 return 42
} }
func sleepFuncValue(fn func(int)) {
go fn(8)
}
func nowait() { func nowait() {
println("non-blocking goroutine") println("non-blocking goroutine")
} }
@ -55,7 +72,7 @@ type Printer interface {
Print() Print()
} }
type myPrinter struct{ type myPrinter struct {
} }
func (i *myPrinter) Print() { func (i *myPrinter) Print() {

2
testdata/coroutines.txt предоставленный
Просмотреть файл

@ -11,3 +11,5 @@ value produced after some time: 42
non-blocking goroutine non-blocking goroutine
done with non-blocking goroutine done with non-blocking goroutine
async interface method call async interface method call
slept inside func pointer 8
slept inside closure, with value: 20 8