diff --git a/compiler/compiler.go b/compiler/compiler.go index ef2fd416..adf3fef5 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1060,32 +1060,47 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) { case *ssa.Defer: c.emitDefer(frame, instr) case *ssa.Go: - if instr.Call.IsInvoke() { - c.addError(instr.Pos(), "todo: go on method receiver") - return - } - callee := instr.Call.StaticCallee() - if callee == nil { - c.addError(instr.Pos(), "todo: go on non-direct function (function pointer, etc.)") - return - } - calleeFn := c.ir.GetFunction(callee) - // Get all function parameters to pass to the goroutine. var params []llvm.Value for _, param := range instr.Call.Args { params = append(params, c.getValue(frame, param)) } - if !calleeFn.IsExported() && c.selectScheduler() != "tasks" { - // For coroutine scheduling, this is only required when calling an - // external function. - // For tasks, because all params are stored in a single object, no - // unnecessary parameters should be stored anyway. - params = append(params, llvm.Undef(c.i8ptrType)) // context parameter - params = append(params, llvm.Undef(c.i8ptrType)) // parent coroutine handle - } - c.emitStartGoroutine(calleeFn.LLVMFn, params) + // Start a new goroutine. + if callee := instr.Call.StaticCallee(); callee != nil { + // Static callee is known. This makes it easier to start a new + // goroutine. + calleeFn := c.ir.GetFunction(callee) + if !calleeFn.IsExported() && c.selectScheduler() != "tasks" { + // For coroutine scheduling, this is only required when calling + // an external function. + // For tasks, because all params are stored in a single object, + // no unnecessary parameters should be stored anyway. + params = append(params, llvm.Undef(c.i8ptrType)) // context parameter + params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle + } + c.emitStartGoroutine(calleeFn.LLVMFn, params) + } else if !instr.Call.IsInvoke() { + // This is a function pointer. + // At the moment, two extra params are passed to the newly started + // goroutine: + // * The function context, for closures. + // * The parent handle (for coroutines) or the function pointer + // itself (for tasks). + funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature)) + params = append(params, context) // context parameter + switch c.selectScheduler() { + case "coroutines": + params = append(params, llvm.ConstPointerNull(c.i8ptrType)) // parent coroutine handle + case "tasks": + params = append(params, funcPtr) + default: + panic("unknown scheduler type") + } + c.emitStartGoroutine(funcPtr, params) + } else { + c.addError(instr.Pos(), "todo: go on interface call") + } case *ssa.If: cond := c.getValue(frame, instr.Cond) block := instr.Block() diff --git a/compiler/func-lowering.go b/compiler/func-lowering.go index 2d67e2a5..9abb2963 100644 --- a/compiler/func-lowering.go +++ b/compiler/func-lowering.go @@ -152,11 +152,11 @@ func (c *Compiler) LowerFuncValues() { // There are multiple functions used in a func value that // implement this signature. // What we'll do is transform the following: - // rawPtr := runtime.getFuncPtr(fn) - // if func.rawPtr == nil { + // rawPtr := runtime.getFuncPtr(func.ptr) + // if rawPtr == nil { // runtime.nilPanic() // } - // result := func.rawPtr(...args, func.context) + // result := rawPtr(...args, func.context) // into this: // if false { // runtime.nilPanic() @@ -175,95 +175,111 @@ func (c *Compiler) LowerFuncValues() { // Remove some casts, checks, and the old call which we're going // to replace. - var funcCall llvm.Value - for _, inttoptr := range getUses(getFuncPtrCall) { - if inttoptr.IsAIntToPtrInst().IsNil() { + for _, callIntPtr := range getUses(getFuncPtrCall) { + if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" { + for _, inttoptr := range getUses(callIntPtr) { + if inttoptr.IsAIntToPtrInst().IsNil() { + panic("expected a inttoptr") + } + for _, use := range getUses(inttoptr) { + c.addFuncLoweringSwitch(funcID, use, c.emitStartGoroutine, functions) + use.EraseFromParentAsInstruction() + } + inttoptr.EraseFromParentAsInstruction() + } + callIntPtr.EraseFromParentAsInstruction() + continue + } + if callIntPtr.IsAIntToPtrInst().IsNil() { panic("expected inttoptr") } - for _, ptrUse := range getUses(inttoptr) { + for _, ptrUse := range getUses(callIntPtr) { if !ptrUse.IsABitCastInst().IsNil() { for _, bitcastUse := range getUses(ptrUse) { - if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" { + if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().IsAFunction().IsNil() { + panic("expected a call instruction") + } + switch bitcastUse.CalledValue().Name() { + case "runtime.isnil": + bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false)) + bitcastUse.EraseFromParentAsInstruction() + default: panic("expected a call to runtime.isnil") } - bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false)) - bitcastUse.EraseFromParentAsInstruction() } - ptrUse.EraseFromParentAsInstruction() - } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr { - if !funcCall.IsNil() { - panic("multiple calls on a single runtime.getFuncPtr") - } - funcCall = ptrUse + } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr { + c.addFuncLoweringSwitch(funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + return c.builder.CreateCall(funcPtr, params, "") + }, functions) } else { panic("unexpected getFuncPtrCall") } + ptrUse.EraseFromParentAsInstruction() } - } - if funcCall.IsNil() { - panic("expected exactly one call use of a runtime.getFuncPtr") - } - - // The block that cannot be reached with correct funcValues (to - // help the optimizer). - c.builder.SetInsertPointBefore(funcCall) - defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default") - c.builder.SetInsertPointAtEnd(defaultBlock) - c.builder.CreateUnreachable() - - // Create the switch. - c.builder.SetInsertPointBefore(funcCall) - sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) - - // Split right after the switch. We will need to insert a few - // basic blocks in this gap. - nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") - - // The 0 case, which is actually a nil check. - nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil") - c.builder.SetInsertPointAtEnd(nilBlock) - c.createRuntimeCall("nilPanic", nil, "") - c.builder.CreateUnreachable() - sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock) - - // Gather the list of parameters for every call we're going to - // make. - callParams := make([]llvm.Value, funcCall.OperandsCount()-1) - for i := range callParams { - callParams[i] = funcCall.Operand(i) - } - - // If the call produces a value, we need to get it using a PHI - // node. - phiBlocks := make([]llvm.BasicBlock, len(functions)) - phiValues := make([]llvm.Value, len(functions)) - for i, fn := range functions { - // Insert a switch case. - bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) - c.builder.SetInsertPointAtEnd(bb) - result := c.builder.CreateCall(fn.funcPtr, callParams, "") - c.builder.CreateBr(nextBlock) - sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb) - phiBlocks[i] = bb - phiValues[i] = result - } - // Create the PHI node so that the call result flows into the - // next block (after the split). This is only necessary when the - // call produced a value. - if funcCall.Type().TypeKind() != llvm.VoidTypeKind { - c.builder.SetInsertPointBefore(nextBlock.FirstInstruction()) - phi := c.builder.CreatePHI(funcCall.Type(), "") - phi.AddIncoming(phiValues, phiBlocks) - funcCall.ReplaceAllUsesWith(phi) - } - - // Finally, remove the old instructions. - funcCall.EraseFromParentAsInstruction() - for _, inttoptr := range getUses(getFuncPtrCall) { - inttoptr.EraseFromParentAsInstruction() + callIntPtr.EraseFromParentAsInstruction() } getFuncPtrCall.EraseFromParentAsInstruction() } } } } + +// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls +// to the newly created direct calls. The funcID is the number to switch on, +// call is the call instruction to replace, and createCall is the callback that +// actually creates the new call. By changing createCall to something other than +// c.builder.CreateCall, instead of calling a function it can start a new +// goroutine for example. +func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) { + // The block that cannot be reached with correct funcValues (to help the + // optimizer). + c.builder.SetInsertPointBefore(call) + defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default") + c.builder.SetInsertPointAtEnd(defaultBlock) + c.builder.CreateUnreachable() + + // Create the switch. + c.builder.SetInsertPointBefore(call) + sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) + + // Split right after the switch. We will need to insert a few basic blocks + // in this gap. + nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") + + // The 0 case, which is actually a nil check. + nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil") + c.builder.SetInsertPointAtEnd(nilBlock) + c.createRuntimeCall("nilPanic", nil, "") + c.builder.CreateUnreachable() + sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock) + + // Gather the list of parameters for every call we're going to make. + callParams := make([]llvm.Value, call.OperandsCount()-1) + for i := range callParams { + callParams[i] = call.Operand(i) + } + + // If the call produces a value, we need to get it using a PHI + // node. + phiBlocks := make([]llvm.BasicBlock, len(functions)) + phiValues := make([]llvm.Value, len(functions)) + for i, fn := range functions { + // Insert a switch case. + bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) + c.builder.SetInsertPointAtEnd(bb) + result := createCall(fn.funcPtr, callParams) + c.builder.CreateBr(nextBlock) + sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb) + phiBlocks[i] = bb + phiValues[i] = result + } + // Create the PHI node so that the call result flows into the + // next block (after the split). This is only necessary when the + // call produced a value. + if call.Type().TypeKind() != llvm.VoidTypeKind { + c.builder.SetInsertPointBefore(nextBlock.FirstInstruction()) + phi := c.builder.CreatePHI(call.Type(), "") + phi.AddIncoming(phiValues, phiBlocks) + call.ReplaceAllUsesWith(phi) + } +} diff --git a/compiler/func.go b/compiler/func.go index d9879d74..df364b13 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -32,10 +32,15 @@ const ( // funcImplementation picks an appropriate func value implementation for the // target. func (c *Compiler) funcImplementation() funcValueImplementation { - if c.GOARCH == "wasm" { + // Always pick the switch implementation, as it allows the use of blocking + // inside a function that is used as a func value. + switch c.selectScheduler() { + case "coroutines": return funcValueSwitch - } else { + case "tasks": return funcValueDoubleword + default: + panic("unknown scheduler type") } } diff --git a/compiler/goroutine.go b/compiler/goroutine.go index 85c5e7d8..9c8376c5 100644 --- a/compiler/goroutine.go +++ b/compiler/goroutine.go @@ -18,14 +18,12 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l calleeValue := c.createGoroutineStartWrapper(funcPtr) c.createRuntimeCall("startGoroutine", []llvm.Value{calleeValue, paramBundle}, "") case "coroutines": - // Mark this function as a 'go' invocation and break invalid - // interprocedural optimizations. For example, heap-to-stack - // transformations are not sound as goroutines can outlive their parent. - calleeType := funcPtr.Type() + // We roundtrip through runtime.makeGoroutine as a signal (to find these + // calls) and to break any optimizations LLVM will try to do: they are + // invalid if we called this as a regular function to be updated later. calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "") calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "") - calleeValue = c.builder.CreateIntToPtr(calleeValue, calleeType, "") - + calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "") c.createCall(calleeValue, params, "") default: panic("unreachable") @@ -52,32 +50,86 @@ func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) l // ignores the return value because newly started goroutines do not have a // return value. func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { - if fn.IsAFunction().IsNil() { - panic("todo: goroutine start wrapper for func value") + var wrapper llvm.Value + + if !fn.IsAFunction().IsNil() { + // See whether this wrapper has already been created. If so, return it. + name := fn.Name() + wrapper = c.mod.NamedFunction(name + "$gowrapper") + if !wrapper.IsNil() { + return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "") + } + + // Save the current position in the IR builder. + currentBlock := c.builder.GetInsertBlock() + defer c.builder.SetInsertPointAtEnd(currentBlock) + + // Create the wrapper. + wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false) + wrapper = llvm.AddFunction(c.mod, name+"$gowrapper", wrapperType) + wrapper.SetLinkage(llvm.PrivateLinkage) + wrapper.SetUnnamedAddr(true) + entry := llvm.AddBasicBlock(wrapper, "entry") + c.builder.SetInsertPointAtEnd(entry) + + // Create the list of params for the call. + paramTypes := fn.Type().ElementType().ParamTypes() + params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2]) + params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)) + + // Create the call. + c.builder.CreateCall(fn, params, "") + + } else { + // For a function pointer like this: + // + // var funcPtr func(x, y int) int + // + // A wrapper like the following is created: + // + // func .gowrapper(ptr *unsafe.Pointer) { + // args := (*struct{ + // x, y int + // fn func(x, y int) int + // })(ptr) + // args.fn(x, y) + // } + // + // With a bit of luck, identical wrapper functions like these can be + // merged into one. + + // Save the current position in the IR builder. + currentBlock := c.builder.GetInsertBlock() + defer c.builder.SetInsertPointAtEnd(currentBlock) + + // Create the wrapper. + wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false) + wrapper = llvm.AddFunction(c.mod, ".gowrapper", wrapperType) + wrapper.SetLinkage(llvm.InternalLinkage) + wrapper.SetUnnamedAddr(true) + entry := llvm.AddBasicBlock(wrapper, "entry") + c.builder.SetInsertPointAtEnd(entry) + + // Get the list of parameters, with the extra parameters at the end. + paramTypes := fn.Type().ElementType().ParamTypes() + paramTypes[len(paramTypes)-1] = fn.Type() // the last element is the function pointer + params := c.emitPointerUnpack(wrapper.Param(0), paramTypes) + + // Get the function pointer. + fnPtr := params[len(params)-1] + + // Ignore the last param, which isn't used anymore. + // TODO: avoid this extra "parent handle" parameter in most functions. + params[len(params)-1] = llvm.Undef(c.i8ptrType) + + // Create the call. + c.builder.CreateCall(fnPtr, params, "") } - // See whether this wrapper has already been created. If so, return it. - name := fn.Name() - wrapper := c.mod.NamedFunction(name + "$gowrapper") - if !wrapper.IsNil() { - return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "") - } - - // Save the current position in the IR builder. - currentBlock := c.builder.GetInsertBlock() - defer c.builder.SetInsertPointAtEnd(currentBlock) - - // Create the wrapper. - wrapperType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.i8ptrType}, false) - wrapper = llvm.AddFunction(c.mod, name+"$gowrapper", wrapperType) - wrapper.SetLinkage(llvm.PrivateLinkage) - wrapper.SetUnnamedAddr(true) - entry := llvm.AddBasicBlock(wrapper, "entry") - c.builder.SetInsertPointAtEnd(entry) - paramTypes := fn.Type().ElementType().ParamTypes() - params := c.emitPointerUnpack(wrapper.Param(0), paramTypes[:len(paramTypes)-2]) - params = append(params, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)) - c.builder.CreateCall(fn, params, "") + // Finish the function. Every basic block must end in a terminator, and + // because goroutines never return a value we can simply return void. c.builder.CreateRetVoid() + + // Return a ptrtoint of the wrapper, not the function itself. return c.builder.CreatePtrToInt(wrapper, c.uintptrType, "") } diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 082e3f63..47d02598 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -28,6 +28,19 @@ func main() { var printer Printer printer = &myPrinter{} printer.Print() + + sleepFuncValue(func(x int) { + time.Sleep(1 * time.Millisecond) + println("slept inside func pointer", x) + }) + time.Sleep(1 * time.Millisecond) + n := 20 + sleepFuncValue(func(x int) { + time.Sleep(1 * time.Millisecond) + println("slept inside closure, with value:", n, x) + }) + + time.Sleep(2 * time.Millisecond) } func sub() { @@ -47,6 +60,10 @@ func delayedValue() int { return 42 } +func sleepFuncValue(fn func(int)) { + go fn(8) +} + func nowait() { println("non-blocking goroutine") } @@ -55,7 +72,7 @@ type Printer interface { Print() } -type myPrinter struct{ +type myPrinter struct { } func (i *myPrinter) Print() { diff --git a/testdata/coroutines.txt b/testdata/coroutines.txt index b4f3c111..eea625ff 100644 --- a/testdata/coroutines.txt +++ b/testdata/coroutines.txt @@ -11,3 +11,5 @@ value produced after some time: 42 non-blocking goroutine done with non-blocking goroutine async interface method call +slept inside func pointer 8 +slept inside closure, with value: 20 8