diff --git a/compiler/compiler.go b/compiler/compiler.go index dfe3dd77..0e459be3 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -40,24 +40,22 @@ const tinygoPath = "github.com/tinygo-org/tinygo" var functionsUsedInTransforms = []string{ "runtime.alloc", "runtime.free", - "runtime.scheduler", "runtime.nilPanic", } -var taskFunctionsUsedInTransforms = []string{ - "runtime.startGoroutine", -} +var taskFunctionsUsedInTransforms = []string{} var coroFunctionsUsedInTransforms = []string{ - "runtime.avrSleep", - "runtime.getFakeCoroutine", - "runtime.setTaskStatePtr", - "runtime.getTaskStatePtr", - "runtime.activateTask", - "runtime.noret", - "runtime.getParentHandle", - "runtime.getCoroutine", - "runtime.llvmCoroRefHolder", + "internal/task.start", + "internal/task.Pause", + "internal/task.fake", + "internal/task.Current", + "internal/task.createTask", + "(*internal/task.Task).setState", + "(*internal/task.Task).returnTo", + "(*internal/task.Task).returnCurrent", + "(*internal/task.Task).setReturnPtr", + "(*internal/task.Task).getReturnPtr", } type Compiler struct { @@ -162,6 +160,7 @@ func (c *Compiler) Module() llvm.Module { func (c *Compiler) getFunctionsUsedInTransforms() []string { fnused := functionsUsedInTransforms switch c.Scheduler() { + case "none": case "coroutines": fnused = append(append([]string{}, fnused...), coroFunctionsUsedInTransforms...) case "tasks": @@ -218,7 +217,7 @@ func (c *Compiler) Compile(mainPath string) []error { path = path[len(tinygoPath+"/src/"):] } switch path { - case "machine", "os", "reflect", "runtime", "runtime/interrupt", "runtime/volatile", "sync", "testing", "internal/reflectlite": + case "machine", "os", "reflect", "runtime", "runtime/interrupt", "runtime/volatile", "sync", "testing", "internal/reflectlite", "internal/task": return path default: if strings.HasPrefix(path, "device/") || strings.HasPrefix(path, "examples/") { @@ -1095,7 +1094,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) { funcPtr, context := c.decodeFuncValue(c.getValue(frame, instr.Call.Value), instr.Call.Value.Type().(*types.Signature)) params = append(params, context) // context parameter switch c.Scheduler() { - case "coroutines": + case "none", "coroutines": // There are no additional parameters needed for the goroutine start operation. case "tasks": // Add the function pointer as a parameter to start the goroutine. diff --git a/compiler/func.go b/compiler/func.go index 14f1416b..c7c9f9ff 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -35,7 +35,7 @@ func (c *Compiler) funcImplementation() funcValueImplementation { // Always pick the switch implementation, as it allows the use of blocking // inside a function that is used as a func value. switch c.Scheduler() { - case "coroutines": + case "none", "coroutines": return funcValueSwitch case "tasks": return funcValueDoubleword diff --git a/compiler/goroutine-lowering.go b/compiler/goroutine-lowering.go deleted file mode 100644 index 4d2b7212..00000000 --- a/compiler/goroutine-lowering.go +++ /dev/null @@ -1,1023 +0,0 @@ -package compiler - -// This file implements lowering for the goroutine scheduler. There are two -// scheduler implementations, one based on tasks (like RTOSes and the main Go -// runtime) and one based on a coroutine compiler transformation. The task based -// implementation requires very little work from the compiler but is not very -// portable (in particular, it is very hard if not impossible to support on -// WebAssembly). The coroutine based one requires a lot of work by the compiler -// to implement, but can run virtually anywhere with a single scheduler -// implementation. -// -// The below description is for the coroutine based scheduler. -// -// This file lowers goroutine pseudo-functions into coroutines scheduled by a -// scheduler at runtime. It uses coroutine support in LLVM for this -// transformation: https://llvm.org/docs/Coroutines.html -// -// For example, take the following code: -// -// func main() { -// go foo() -// time.Sleep(2 * time.Second) -// println("some other operation") -// i := bar() -// println("done", *i) -// } -// -// func foo() { -// for { -// println("foo!") -// time.Sleep(time.Second) -// } -// } -// -// func bar() *int { -// time.Sleep(time.Second) -// println("blocking operation completed) -// return new(int) -// } -// -// It is transformed by the IR generator in compiler.go into the following -// pseudo-Go code: -// -// func main() { -// fn := runtime.makeGoroutine(foo) -// fn() -// time.Sleep(2 * time.Second) -// println("some other operation") -// i := bar() // imagine an 'await' keyword in front of this call -// println("done", *i) -// } -// -// func foo() { -// for { -// println("foo!") -// time.Sleep(time.Second) -// } -// } -// -// func bar() *int { -// time.Sleep(time.Second) -// println("blocking operation completed) -// return new(int) -// } -// -// The pass in this file transforms this code even further, to the following -// async/await style pseudocode: -// -// func main(parent) { -// hdl := llvm.makeCoroutine() -// foo(nil) // do not pass the parent coroutine: this is an independent goroutine -// runtime.sleepTask(hdl, 2 * time.Second) // ask the scheduler to re-activate this coroutine at the right time -// llvm.suspend(hdl) // suspend point -// println("some other operation") -// var i *int // allocate space on the stack for the return value -// runtime.setTaskStatePtr(hdl, &i) // store return value alloca in our coroutine promise -// bar(hdl) // await, pass a continuation (hdl) to bar -// llvm.suspend(hdl) // suspend point, wait for the callee to re-activate -// println("done", *i) -// runtime.activateTask(parent) // re-activate the parent (nop, there is no parent) -// } -// -// func foo(parent) { -// hdl := llvm.makeCoroutine() -// for { -// println("foo!") -// runtime.sleepTask(hdl, time.Second) // ask the scheduler to re-activate this coroutine at the right time -// llvm.suspend(hdl) // suspend point -// } -// } -// -// func bar(parent) { -// hdl := llvm.makeCoroutine() -// runtime.sleepTask(hdl, time.Second) // ask the scheduler to re-activate this coroutine at the right time -// llvm.suspend(hdl) // suspend point -// println("blocking operation completed) -// runtime.activateTask(parent) // re-activate the parent coroutine before returning -// } -// -// The real LLVM code is more complicated, but this is the general idea. -// -// The LLVM coroutine passes will then process this file further transforming -// these three functions into coroutines. Most of the actual work is done by the -// scheduler, which runs in the background scheduling all coroutines. - -import ( - "fmt" - "strings" - - "github.com/tinygo-org/tinygo/compiler/llvmutil" - "tinygo.org/x/go-llvm" -) - -// setting this to true will cause the compiler to spew tons of information about coroutine transformations -// this can be useful when debugging coroutine lowering or looking for potential missed optimizations -const coroDebug = false - -type asyncFunc struct { - taskHandle llvm.Value - cleanupBlock llvm.BasicBlock - suspendBlock llvm.BasicBlock -} - -// LowerGoroutines performs some IR transformations necessary to support -// goroutines. It does something different based on whether it uses the -// coroutine or the tasks implementation of goroutines, and whether goroutines -// are necessary at all. -func (c *Compiler) LowerGoroutines() error { - switch c.Scheduler() { - case "coroutines": - return c.lowerCoroutines() - case "tasks": - return c.lowerTasks() - default: - panic("unknown scheduler type") - } -} - -// lowerTasks starts the main goroutine and then runs the scheduler. -// This is enough compiler-level transformation for the task-based scheduler. -func (c *Compiler) lowerTasks() error { - uses := getUses(c.mod.NamedFunction("runtime.callMain")) - if len(uses) != 1 || uses[0].IsACallInst().IsNil() { - panic("expected exactly 1 call of runtime.callMain, check the entry point") - } - mainCall := uses[0] - - realMain := c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path() + ".main") - if len(getUses(c.mod.NamedFunction("runtime.startGoroutine"))) != 0 || len(getUses(c.mod.NamedFunction("runtime.yield"))) != 0 { - // Program needs a scheduler. Start main.main as a goroutine and start - // the scheduler. - realMainWrapper := c.createGoroutineStartWrapper(realMain) - c.builder.SetInsertPointBefore(mainCall) - zero := llvm.ConstInt(c.uintptrType, 0, false) - c.createRuntimeCall("startGoroutine", []llvm.Value{realMainWrapper, zero}, "") - c.createRuntimeCall("scheduler", nil, "") - } else { - // Program doesn't need a scheduler. Call main.main directly. - c.builder.SetInsertPointBefore(mainCall) - params := []llvm.Value{ - llvm.Undef(c.i8ptrType), // unused context parameter - llvm.Undef(c.i8ptrType), // unused coroutine handle - } - c.createCall(realMain, params, "") - } - mainCall.EraseFromParentAsInstruction() - - // main.main was set to external linkage during IR construction. Set it to - // internal linkage to enable interprocedural optimizations. - realMain.SetLinkage(llvm.InternalLinkage) - - return nil -} - -// lowerCoroutines transforms the IR into one where all blocking functions are -// turned into goroutines and blocking calls into await calls. It also makes -// sure that the first coroutine is started and the coroutine scheduler will be -// run. -func (c *Compiler) lowerCoroutines() error { - needsScheduler, err := c.markAsyncFunctions() - if err != nil { - return err - } - - uses := getUses(c.mod.NamedFunction("runtime.callMain")) - if len(uses) != 1 || uses[0].IsACallInst().IsNil() { - panic("expected exactly 1 call of runtime.callMain, check the entry point") - } - mainCall := uses[0] - - // Replace call of runtime.callMain() with a real call to main.main(), - // optionally followed by a call to runtime.scheduler(). - c.builder.SetInsertPointBefore(mainCall) - realMain := c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path() + ".main") - var ph llvm.Value - if needsScheduler { - ph = c.createRuntimeCall("getFakeCoroutine", []llvm.Value{}, "") - } else { - ph = llvm.Undef(c.i8ptrType) - } - c.builder.CreateCall(realMain, []llvm.Value{llvm.Undef(c.i8ptrType), ph}, "") - if needsScheduler { - c.createRuntimeCall("scheduler", nil, "") - } - mainCall.EraseFromParentAsInstruction() - - if !needsScheduler { - go_scheduler := c.mod.NamedFunction("go_scheduler") - if !go_scheduler.IsNil() { - // This is the WebAssembly backend. - // There is no need to export the go_scheduler function, but it is - // still exported. Make sure it is optimized away. - go_scheduler.SetLinkage(llvm.InternalLinkage) - } - } else { - // Eliminate unnecessary fake coroutines. - // This is necessary to prevent infinite recursion in runtime.getFakeCoroutine. - c.eliminateFakeCoroutines() - } - - // main.main was set to external linkage during IR construction. Set it to - // internal linkage to enable interprocedural optimizations. - realMain.SetLinkage(llvm.InternalLinkage) - - return nil -} - -func coroDebugPrintln(s ...interface{}) { - if coroDebug { - fmt.Println(s...) - } -} - -// markAsyncFunctions does the bulk of the work of lowering goroutines. It -// determines whether a scheduler is needed, and if it is, it transforms -// blocking operations into goroutines and blocking calls into await calls. -// -// It does the following operations: -// * Find all blocking functions. -// * Determine whether a scheduler is necessary. If not, it skips the -// following operations. -// * Transform call instructions into await calls. -// * Transform return instructions into final suspends. -// * Set up the coroutine frames for async functions. -// * Transform blocking calls into their async equivalents. -func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) { - var worklist []llvm.Value - - yield := c.mod.NamedFunction("runtime.yield") - if !yield.IsNil() { - worklist = append(worklist, yield) - } - - if len(worklist) == 0 { - // There are no blocking operations, so no need to transform anything. - return false, c.lowerMakeGoroutineCalls(false) - } - - // Find all async functions. - // Keep reducing this worklist by marking a function as recursively async - // from the worklist and pushing all its parents that are non-async. - // This is somewhat similar to a worklist in a mark-sweep garbage collector: - // the work items are then grey objects. - asyncFuncs := make(map[llvm.Value]*asyncFunc) - asyncList := make([]llvm.Value, 0, 4) - for len(worklist) != 0 { - // Pick the topmost. - f := worklist[len(worklist)-1] - worklist = worklist[:len(worklist)-1] - if _, ok := asyncFuncs[f]; ok { - continue // already processed - } - if f.Name() == "resume" { - continue - } - // Add to set of async functions. - asyncFuncs[f] = &asyncFunc{} - asyncList = append(asyncList, f) - - // Add all callees to the worklist. - for _, use := range getUses(f) { - if use.IsConstant() && use.Opcode() == llvm.PtrToInt { - for _, call := range getUses(use) { - if call.IsACallInst().IsNil() || call.CalledValue().Name() != "runtime.makeGoroutine" { - return false, errorAt(call, "async function incorrectly used in ptrtoint, expected runtime.makeGoroutine") - } - } - // This is a go statement. Do not mark the parent as async, as - // starting a goroutine is not a blocking operation. - continue - } - if use.IsConstant() && use.Opcode() == llvm.BitCast { - // Not sure why this const bitcast is here but as long as it - // has no uses it can be ignored, I guess? - // I think it was created for the runtime.isnil check but - // somehow wasn't removed when all these checks are removed. - if len(getUses(use)) == 0 { - continue - } - } - if use.IsACallInst().IsNil() { - // Not a call instruction. Maybe a store to a global? In any - // case, this requires support for async calls across function - // pointers which is not yet supported. - at := use - if use.IsAInstruction().IsNil() { - // The use might not be an instruction (for example, in the - // case of a const bitcast). Fall back to reporting the - // location of the function instead. - at = f - } - return false, errorAt(at, "async function "+f.Name()+" used as function pointer") - } - parent := use.InstructionParent().Parent() - for i := 0; i < use.OperandsCount()-1; i++ { - if use.Operand(i) == f { - return false, errorAt(use, "async function "+f.Name()+" used as function pointer") - } - } - worklist = append(worklist, parent) - } - } - - // Check whether a scheduler is needed. - makeGoroutine := c.mod.NamedFunction("runtime.makeGoroutine") - if strings.HasPrefix(c.Triple(), "avr") { - needsScheduler = false - getCoroutine := c.mod.NamedFunction("runtime.getCoroutine") - for _, inst := range getUses(getCoroutine) { - inst.ReplaceAllUsesWith(llvm.Undef(inst.Type())) - inst.EraseFromParentAsInstruction() - } - yield := c.mod.NamedFunction("runtime.yield") - for _, inst := range getUses(yield) { - inst.EraseFromParentAsInstruction() - } - sleep := c.mod.NamedFunction("time.Sleep") - for _, inst := range getUses(sleep) { - c.builder.SetInsertPointBefore(inst) - c.createRuntimeCall("avrSleep", []llvm.Value{inst.Operand(0)}, "") - inst.EraseFromParentAsInstruction() - } - } else { - // Only use a scheduler when an async goroutine is started. When the - // goroutine is not async (does not do any blocking operation), no - // scheduler is necessary as it can be called directly. - for _, use := range getUses(makeGoroutine) { - // Input param must be const ptrtoint of function. - ptrtoint := use.Operand(0) - if !ptrtoint.IsConstant() || ptrtoint.Opcode() != llvm.PtrToInt { - panic("expected const ptrtoint operand of runtime.makeGoroutine") - } - goroutine := ptrtoint.Operand(0) - if goroutine.Name() == "runtime.fakeCoroutine" { - continue - } - if _, ok := asyncFuncs[goroutine]; ok { - needsScheduler = true - break - } - } - if _, ok := asyncFuncs[c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path()+".main")]; ok { - needsScheduler = true - } - } - - if !needsScheduler { - // on wasm, we may still have calls to deadlock - // replace these with an abort - abort := c.mod.NamedFunction("runtime.abort") - if deadlock := c.mod.NamedFunction("runtime.deadlock"); !deadlock.IsNil() { - deadlock.ReplaceAllUsesWith(abort) - } - - // No scheduler is needed. Do not transform all functions here. - // However, make sure that all go calls (which are all non-async) are - // transformed into regular calls. - return false, c.lowerMakeGoroutineCalls(false) - } - - if noret := c.mod.NamedFunction("runtime.noret"); noret.IsNil() { - panic("missing noret") - } - - // replace indefinitely blocking yields - getCoroutine := c.mod.NamedFunction("runtime.getCoroutine") - coroDebugPrintln("replace indefinitely blocking yields") - nonReturning := map[llvm.Value]bool{} - for _, f := range asyncList { - if f == yield { - continue - } - coroDebugPrintln("scanning", f.Name()) - - var callsAsyncNotYield bool - var callsYield bool - var getsCoroutine bool - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !inst.IsACallInst().IsNil() { - callee := inst.CalledValue() - if callee == yield { - callsYield = true - } else if callee == getCoroutine { - getsCoroutine = true - } else if _, ok := asyncFuncs[callee]; ok { - callsAsyncNotYield = true - } - } - } - } - - coroDebugPrintln("result", f.Name(), callsYield, getsCoroutine, callsAsyncNotYield) - - if callsYield && !getsCoroutine && !callsAsyncNotYield { - coroDebugPrintln("optimizing", f.Name()) - // calls yield without registering for a wakeup - // this actually could otherwise wake up, but only in the case of really messed up undefined behavior - // so everything after a yield is unreachable, so we can just inject a fake return - delQueue := []llvm.Value{} - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - var broken bool - - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !broken && !inst.IsACallInst().IsNil() && inst.CalledValue() == yield { - coroDebugPrintln("broke", f.Name(), bb.AsValue().Name()) - broken = true - c.builder.SetInsertPointBefore(inst) - c.createRuntimeCall("noret", []llvm.Value{}, "") - if f.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { - c.builder.CreateRetVoid() - } else { - c.builder.CreateRet(llvm.Undef(f.Type().ElementType().ReturnType())) - } - } - if broken { - if inst.Type().TypeKind() != llvm.VoidTypeKind { - inst.ReplaceAllUsesWith(llvm.Undef(inst.Type())) - } - delQueue = append(delQueue, inst) - } - } - if !broken { - coroDebugPrintln("did not break", f.Name(), bb.AsValue().Name()) - } - } - - for _, v := range delQueue { - v.EraseFromParentAsInstruction() - } - - nonReturning[f] = true - } - } - - // convert direct calls into an async call followed by a yield operation - coroDebugPrintln("convert direct calls into an async call followed by a yield operation") - for _, f := range asyncList { - if f == yield { - continue - } - coroDebugPrintln("scanning", f.Name()) - - // Rewrite async calls - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !inst.IsACallInst().IsNil() { - callee := inst.CalledValue() - if _, ok := asyncFuncs[callee]; !ok || callee == yield { - continue - } - - uses := getUses(inst) - next := llvm.NextInstruction(inst) - switch { - case nonReturning[callee]: - // callee blocks forever - coroDebugPrintln("optimizing indefinitely blocking call", f.Name(), callee.Name()) - - // never calls getCoroutine - coroutine handle is irrelevant - inst.SetOperand(inst.OperandsCount()-2, llvm.Undef(c.i8ptrType)) - - // insert return - c.builder.SetInsertPointBefore(next) - c.createRuntimeCall("noret", []llvm.Value{}, "") - var retInst llvm.Value - if f.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { - retInst = c.builder.CreateRetVoid() - } else { - retInst = c.builder.CreateRet(llvm.Undef(f.Type().ElementType().ReturnType())) - } - - // delete everything after return - for next := llvm.NextInstruction(retInst); !next.IsNil(); next = llvm.NextInstruction(retInst) { - if next.Type().TypeKind() != llvm.VoidTypeKind { - next.ReplaceAllUsesWith(llvm.Undef(next.Type())) - } - next.EraseFromParentAsInstruction() - } - - continue - case next.IsAReturnInst().IsNil(): - // not a return instruction - coroDebugPrintln("not a return instruction", f.Name(), callee.Name()) - case callee.Type().ElementType().ReturnType() != f.Type().ElementType().ReturnType(): - // return types do not match - coroDebugPrintln("return types do not match", f.Name(), callee.Name()) - case callee.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: - fallthrough - case next.Operand(0) == inst: - // async tail call optimization - just pass parent handle - coroDebugPrintln("doing async tail call opt", f.Name()) - - // insert before call - c.builder.SetInsertPointBefore(inst) - - // get parent handle - parentHandle := c.createRuntimeCall("getParentHandle", []llvm.Value{}, "") - - // pass parent handle directly into function - inst.SetOperand(inst.OperandsCount()-2, parentHandle) - - if callee.Type().ElementType().ReturnType().TypeKind() != llvm.VoidTypeKind { - // delete return value - uses[0].SetOperand(0, llvm.Undef(callee.Type().ElementType().ReturnType())) - } - - c.builder.SetInsertPointBefore(next) - c.createRuntimeCall("yield", []llvm.Value{}, "") - c.createRuntimeCall("noret", []llvm.Value{}, "") - - continue - } - - coroDebugPrintln("inserting regular call", f.Name(), callee.Name()) - c.builder.SetInsertPointBefore(inst) - - // insert call to getCoroutine, this will be lowered later - coro := c.createRuntimeCall("getCoroutine", []llvm.Value{}, "") - - // provide coroutine handle to function - inst.SetOperand(inst.OperandsCount()-2, coro) - - // Allocate space for the return value. - var retvalAlloca llvm.Value - if callee.Type().ElementType().ReturnType().TypeKind() != llvm.VoidTypeKind { - // allocate return value buffer - retvalAlloca = llvmutil.CreateInstructionAlloca(c.builder, c.mod, callee.Type().ElementType().ReturnType(), inst, "coro.retvalAlloca") - - // call before function - c.builder.SetInsertPointBefore(inst) - - // cast buffer pointer to *i8 - data := c.builder.CreateBitCast(retvalAlloca, c.i8ptrType, "") - - // set state pointer to return value buffer so it can be written back - c.createRuntimeCall("setTaskStatePtr", []llvm.Value{coro, data}, "") - } - - // insert yield after starting function - c.builder.SetInsertPointBefore(llvm.NextInstruction(inst)) - yieldCall := c.createRuntimeCall("yield", []llvm.Value{}, "") - - if !retvalAlloca.IsNil() && !inst.FirstUse().IsNil() { - // Load the return value from the alloca. - // The callee has written the return value to it. - c.builder.SetInsertPointBefore(llvm.NextInstruction(yieldCall)) - retval := c.builder.CreateLoad(retvalAlloca, "coro.retval") - inst.ReplaceAllUsesWith(retval) - } - } - } - } - } - - // ditch unnecessary tail yields - coroDebugPrintln("ditch unnecessary tail yields") - noret := c.mod.NamedFunction("runtime.noret") - for _, f := range asyncList { - if f == yield { - continue - } - coroDebugPrintln("scanning", f.Name()) - - // we can only ditch a yield if we can ditch all yields - var yields []llvm.Value - var canDitch bool - scanYields: - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if inst.IsACallInst().IsNil() || inst.CalledValue() != yield { - continue - } - - yields = append(yields, inst) - - // we can only ditch the yield if the next instruction is a void return *or* noret - next := llvm.NextInstruction(inst) - ditchable := false - switch { - case !next.IsACallInst().IsNil() && next.CalledValue() == noret: - coroDebugPrintln("ditching yield with noret", f.Name()) - ditchable = true - case !next.IsAReturnInst().IsNil() && f.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: - coroDebugPrintln("ditching yield with void return", f.Name()) - ditchable = true - case !next.IsAReturnInst().IsNil(): - coroDebugPrintln("not ditching because return is not void", f.Name(), f.Type().ElementType().ReturnType().String()) - default: - coroDebugPrintln("not ditching", f.Name()) - } - if !ditchable { - // unditchable yield - canDitch = false - break scanYields - } - - // ditchable yield - canDitch = true - } - } - - if canDitch { - coroDebugPrintln("ditching all in", f.Name()) - for _, inst := range yields { - if !llvm.NextInstruction(inst).IsAReturnInst().IsNil() { - // insert noret - coroDebugPrintln("insering noret", f.Name()) - c.builder.SetInsertPointBefore(inst) - c.createRuntimeCall("noret", []llvm.Value{}, "") - } - - // delete original yield - inst.EraseFromParentAsInstruction() - } - } - } - - // generate return reactivations - coroDebugPrintln("generate return reactivations") - for _, f := range asyncList { - if f == yield { - continue - } - coroDebugPrintln("scanning", f.Name()) - - var retPtr llvm.Value - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - block: - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - switch { - case !inst.IsACallInst().IsNil() && inst.CalledValue() == noret: - // does not return normally - skip this basic block - coroDebugPrintln("noret found - skipping", f.Name(), bb.AsValue().Name()) - break block - case !inst.IsAReturnInst().IsNil(): - // return instruction - rewrite to reactivation - coroDebugPrintln("adding return reactivation", f.Name(), bb.AsValue().Name()) - if f.Type().ElementType().ReturnType().TypeKind() != llvm.VoidTypeKind { - // returns something - if retPtr.IsNil() { - coroDebugPrintln("adding return pointer get", f.Name()) - - // get return pointer in entry block - c.builder.SetInsertPointBefore(f.EntryBasicBlock().FirstInstruction()) - parentHandle := c.createRuntimeCall("getParentHandle", []llvm.Value{}, "") - ptr := c.createRuntimeCall("getTaskStatePtr", []llvm.Value{parentHandle}, "") - retPtr = c.builder.CreateBitCast(ptr, llvm.PointerType(f.Type().ElementType().ReturnType(), 0), "retPtr") - } - - coroDebugPrintln("adding return store", f.Name(), bb.AsValue().Name()) - - // store result into return pointer - c.builder.SetInsertPointBefore(inst) - c.builder.CreateStore(inst.Operand(0), retPtr) - - // delete return value - inst.SetOperand(0, llvm.Undef(f.Type().ElementType().ReturnType())) - } - - // insert reactivation call - c.builder.SetInsertPointBefore(inst) - parentHandle := c.createRuntimeCall("getParentHandle", []llvm.Value{}, "") - c.createRuntimeCall("activateTask", []llvm.Value{parentHandle}, "") - - // mark as noret - c.builder.SetInsertPointBefore(inst) - c.createRuntimeCall("noret", []llvm.Value{}, "") - break block - - // DO NOT ERASE THE RETURN!!!!!!! - } - } - } - } - - // Create a few LLVM intrinsics for coroutine support. - - coroIdType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.ctx.Int32Type(), c.i8ptrType, c.i8ptrType, c.i8ptrType}, false) - coroIdFunc := llvm.AddFunction(c.mod, "llvm.coro.id", coroIdType) - - coroSizeType := llvm.FunctionType(c.ctx.Int32Type(), nil, false) - coroSizeFunc := llvm.AddFunction(c.mod, "llvm.coro.size.i32", coroSizeType) - - coroBeginType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.ctx.TokenType(), c.i8ptrType}, false) - coroBeginFunc := llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType) - - coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false) - coroSuspendFunc := llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType) - - coroEndType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType, c.ctx.Int1Type()}, false) - coroEndFunc := llvm.AddFunction(c.mod, "llvm.coro.end", coroEndType) - - coroFreeType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.ctx.TokenType(), c.i8ptrType}, false) - coroFreeFunc := llvm.AddFunction(c.mod, "llvm.coro.free", coroFreeType) - - // split blocks and add LLVM coroutine intrinsics - coroDebugPrintln("split blocks and add LLVM coroutine intrinsics") - for _, f := range asyncList { - if f == yield { - continue - } - - // find calls to yield - var yieldCalls []llvm.Value - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !inst.IsACallInst().IsNil() && inst.CalledValue() == yield { - yieldCalls = append(yieldCalls, inst) - } - } - } - - if len(yieldCalls) == 0 { - // no yields - we do not have to LLVM-ify this - coroDebugPrintln("skipping", f.Name()) - deleteQueue := []llvm.Value{} - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !inst.IsACallInst().IsNil() && inst.CalledValue() == getCoroutine { - // no seperate local task - replace getCoroutine with getParentHandle - c.builder.SetInsertPointBefore(inst) - inst.ReplaceAllUsesWith(c.createRuntimeCall("getParentHandle", []llvm.Value{}, "")) - deleteQueue = append(deleteQueue, inst) - } - } - } - for _, v := range deleteQueue { - v.EraseFromParentAsInstruction() - } - continue - } - - coroDebugPrintln("converting", f.Name()) - - // get frame data to mess with - frame := asyncFuncs[f] - - // add basic blocks to put cleanup and suspend code - frame.cleanupBlock = c.ctx.AddBasicBlock(f, "task.cleanup") - frame.suspendBlock = c.ctx.AddBasicBlock(f, "task.suspend") - - // at start of function - c.builder.SetInsertPointBefore(f.EntryBasicBlock().FirstInstruction()) - taskState := c.builder.CreateAlloca(c.getLLVMRuntimeType("taskState"), "task.state") - stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8") - - // get LLVM-assigned coroutine ID - id := c.builder.CreateCall(coroIdFunc, []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - stateI8, - llvm.ConstNull(c.i8ptrType), - llvm.ConstNull(c.i8ptrType), - }, "task.token") - - // allocate buffer for task struct - size := c.builder.CreateCall(coroSizeFunc, nil, "task.size") - if c.targetData.TypeAllocSize(size.Type()) > c.targetData.TypeAllocSize(c.uintptrType) { - size = c.builder.CreateTrunc(size, c.uintptrType, "task.size.uintptr") - } else if c.targetData.TypeAllocSize(size.Type()) < c.targetData.TypeAllocSize(c.uintptrType) { - size = c.builder.CreateZExt(size, c.uintptrType, "task.size.uintptr") - } - data := c.createRuntimeCall("alloc", []llvm.Value{size}, "task.data") - if c.NeedsStackObjects() { - c.trackPointer(data) - } - - // invoke llvm.coro.begin intrinsic and save task pointer - frame.taskHandle = c.builder.CreateCall(coroBeginFunc, []llvm.Value{id, data}, "task.handle") - - // Coroutine cleanup. Free resources associated with this coroutine. - c.builder.SetInsertPointAtEnd(frame.cleanupBlock) - mem := c.builder.CreateCall(coroFreeFunc, []llvm.Value{id, frame.taskHandle}, "task.data.free") - c.createRuntimeCall("free", []llvm.Value{mem}, "") - c.builder.CreateBr(frame.suspendBlock) - - // Coroutine suspend. A call to llvm.coro.suspend() will branch here. - c.builder.SetInsertPointAtEnd(frame.suspendBlock) - c.builder.CreateCall(coroEndFunc, []llvm.Value{frame.taskHandle, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "unused") - returnType := f.Type().ElementType().ReturnType() - if returnType.TypeKind() == llvm.VoidTypeKind { - c.builder.CreateRetVoid() - } else { - c.builder.CreateRet(llvm.Undef(returnType)) - } - - for _, inst := range yieldCalls { - // Replace call to yield with a suspension of the coroutine. - c.builder.SetInsertPointBefore(inst) - continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{ - llvm.ConstNull(c.ctx.TokenType()), - llvm.ConstInt(c.ctx.Int1Type(), 0, false), - }, "") - wakeup := llvmutil.SplitBasicBlock(c.builder, inst, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.wakeup") - c.builder.SetInsertPointBefore(inst) - sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2) - sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup) - sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock) - inst.EraseFromParentAsInstruction() - } - ditchQueue := []llvm.Value{} - for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - if !inst.IsACallInst().IsNil() && inst.CalledValue() == getCoroutine { - // replace getCoroutine calls with the task handle - inst.ReplaceAllUsesWith(frame.taskHandle) - ditchQueue = append(ditchQueue, inst) - } - if !inst.IsACallInst().IsNil() && inst.CalledValue() == noret { - // replace tail yield with jump to cleanup, otherwise we end up with undefined behavior - c.builder.SetInsertPointBefore(inst) - c.builder.CreateBr(frame.cleanupBlock) - ditchQueue = append(ditchQueue, inst, llvm.NextInstruction(inst)) - } - } - } - for _, v := range ditchQueue { - v.EraseFromParentAsInstruction() - } - } - - // check for leftover calls to getCoroutine - if uses := getUses(getCoroutine); len(uses) > 0 { - useNames := make([]string, 0, len(uses)) - for _, u := range uses { - if u.InstructionParent().Parent().Name() == "runtime.llvmCoroRefHolder" { - continue - } - useNames = append(useNames, u.InstructionParent().Parent().Name()) - } - if len(useNames) > 0 { - panic("bad use of getCoroutine: " + strings.Join(useNames, ",")) - } - } - - // rewrite calls to getParentHandle - for _, inst := range getUses(c.mod.NamedFunction("runtime.getParentHandle")) { - f := inst.InstructionParent().Parent() - var parentHandle llvm.Value - parentHandle = f.LastParam() - if parentHandle.IsNil() || parentHandle.Name() != "parentHandle" { - // sanity check - panic("trying to make exported function async: " + f.Name()) - } - inst.ReplaceAllUsesWith(parentHandle) - inst.EraseFromParentAsInstruction() - } - - // ditch invalid function attributes - bads := []llvm.Value{c.mod.NamedFunction("runtime.setTaskStatePtr")} - for _, f := range append(bads, asyncList...) { - // These properties were added by the functionattrs pass. Remove - // them, because now we start using the parameter. - // https://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes - for _, kind := range []string{"nocapture", "readnone"} { - kindID := llvm.AttributeKindID(kind) - n := f.ParamsCount() - for i := 0; i <= n; i++ { - f.RemoveEnumAttributeAtIndex(i, kindID) - } - } - } - - // eliminate noret - for _, inst := range getUses(noret) { - inst.EraseFromParentAsInstruction() - } - - return true, c.lowerMakeGoroutineCalls(true) -} - -// Lower runtime.makeGoroutine calls to regular call instructions. This is done -// after the regular goroutine transformations. The started goroutines are -// either non-blocking (in which case they can be called directly) or blocking, -// in which case they will ask the scheduler themselves to be rescheduled. -func (c *Compiler) lowerMakeGoroutineCalls(sched bool) error { - // The following Go code: - // go startedGoroutine() - // - // Is translated to the following during IR construction, to preserve the - // fact that this function should be called as a new goroutine. - // %0 = call i8* @runtime.makeGoroutine(i8* bitcast (void (i8*, i8*)* @main.startedGoroutine to i8*), i8* undef, i8* null) - // %1 = bitcast i8* %0 to void (i8*, i8*)* - // call void %1(i8* undef, i8* undef) - // - // This function rewrites it to a direct call: - // call void @main.startedGoroutine(i8* undef, i8* null) - - makeGoroutine := c.mod.NamedFunction("runtime.makeGoroutine") - for _, goroutine := range getUses(makeGoroutine) { - ptrtointIn := goroutine.Operand(0) - origFunc := ptrtointIn.Operand(0) - uses := getUses(goroutine) - if len(uses) != 1 || uses[0].IsAIntToPtrInst().IsNil() { - return errorAt(makeGoroutine, "expected exactly 1 inttoptr use of runtime.makeGoroutine") - } - inttoptrOut := uses[0] - uses = getUses(inttoptrOut) - if len(uses) != 1 || uses[0].IsACallInst().IsNil() { - return errorAt(inttoptrOut, "expected exactly 1 call use of runtime.makeGoroutine bitcast") - } - realCall := uses[0] - - // Create call instruction. - var params []llvm.Value - for i := 0; i < realCall.OperandsCount()-1; i++ { - params = append(params, realCall.Operand(i)) - } - c.builder.SetInsertPointBefore(realCall) - if !sched { - params[len(params)-1] = llvm.Undef(c.i8ptrType) - } else { - params[len(params)-1] = c.createRuntimeCall("getFakeCoroutine", []llvm.Value{}, "") // parent coroutine handle (must not be nil) - } - c.builder.CreateCall(origFunc, params, "") - realCall.EraseFromParentAsInstruction() - inttoptrOut.EraseFromParentAsInstruction() - goroutine.EraseFromParentAsInstruction() - } - - if !sched && len(getUses(c.mod.NamedFunction("runtime.getFakeCoroutine"))) > 0 { - panic("getFakeCoroutine used without scheduler") - } - - return nil -} - -// internalArgumentValue finds the LLVM value inside the function which corresponds to the provided argument of the provided call. -func (c *Compiler) internalArgumentValue(call llvm.Value, arg llvm.Value) llvm.Value { - n := call.OperandsCount() - for i := 0; i < n; i++ { - if call.Operand(i) == arg { - return call.CalledValue().Param(i) - } - } - panic("no corresponding argument") -} - -// specialCoroFuncs are functions in the runtime which accept coroutines as arguments but act as a no-op if these are nil. -// Calls to these functions do not require a fake coroutine. -var specialCoroFuncs = map[string]bool{ - "runtime.runqueuePushBack": true, - "runtime.activateTask": true, -} - -// isCoroNecessary checks if a coroutine pointer value must be non-nil for the program to function. -// This returns true if replacing a fake coroutine value with nil will result in equivalent behavior. -func (c *Compiler) isCoroNecessary(coro llvm.Value, scanned map[llvm.Value]struct{}) (necessary bool) { - // avoid infinite recursion - if _, ok := scanned[coro]; ok { - return false - } - scanned[coro] = struct{}{} - - for use := coro.FirstUse(); !use.IsNil(); use = use.NextUse() { - user := use.User() - switch { - case !user.IsACallInst().IsNil(): - switch { - case !user.CalledValue().IsConstant(): - // This is passed into an unknown function, so we do not know what is happening to it. - coroDebugPrintln("found unoptimizable dynamic call") - return true - case specialCoroFuncs[user.CalledValue().Name()]: - // Pushing nil to the runqueue is valid and acts as a no-op. - // This use does not require a non-nil coroutine. - case c.isCoroNecessary(c.internalArgumentValue(user, coro), scanned): - // The function we called depends on the coroutine value being non-nil. - coroDebugPrintln("call to function depending on non-nil coroutine") - return true - default: - // This call does not depend upon a non-nil coroutine. - } - default: - if coroDebug { - fmt.Printf("unoptimizable usage of coroutine in %q: ", user.InstructionParent().Parent().Name()) - user.Dump() - fmt.Println() - } - return true - } - } - - // Nothing we found needed this coroutine value. - return false -} - -// eliminateFakeCoroutines replaces unnecessary calls to runtime.getFakeCoroutine. -// This is not considered an optimization, because it is necessary to avoid infinite recursion inside of runtime.getFakeCoroutine. -func (c *Compiler) eliminateFakeCoroutines() { - coroDebugPrintln("eliminating fake coroutines") - for _, v := range getUses(c.mod.NamedFunction("runtime.getFakeCoroutine")) { - if !c.isCoroNecessary(v, map[llvm.Value]struct{}{}) { - // This use of a fake coroutine is not necessary. - coroDebugPrintln("eliminating fake coroutine for", getUses(v)[0].CalledValue().Name()) - v.ReplaceAllUsesWith(llvm.ConstNull(c.i8ptrType)) - v.EraseFromParentAsInstruction() - } else { - // This use of a fake coroutine is necessary. - coroDebugPrintln("failed to eliminate fake coroutine for", getUses(v)[0].CalledValue().Name()) - } - } -} diff --git a/compiler/goroutine.go b/compiler/goroutine.go index 972aaed4..2abb8bb0 100644 --- a/compiler/goroutine.go +++ b/compiler/goroutine.go @@ -13,24 +13,17 @@ import "tinygo.org/x/go-llvm" // // Because a go statement doesn't return anything, return undef. func (c *Compiler) emitStartGoroutine(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + paramBundle := c.emitPointerPack(params) + var callee llvm.Value switch c.Scheduler() { - case "tasks": - paramBundle := c.emitPointerPack(params) - paramBundle = c.builder.CreatePtrToInt(paramBundle, c.uintptrType, "") - - calleeValue := c.createGoroutineStartWrapper(funcPtr) - c.createRuntimeCall("startGoroutine", []llvm.Value{calleeValue, paramBundle}, "") + case "none", "tasks": + callee = c.createGoroutineStartWrapper(funcPtr) case "coroutines": - // We roundtrip through runtime.makeGoroutine as a signal (to find these - // calls) and to break any optimizations LLVM will try to do: they are - // invalid if we called this as a regular function to be updated later. - calleeValue := c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "") - calleeValue = c.createRuntimeCall("makeGoroutine", []llvm.Value{calleeValue}, "") - calleeValue = c.builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "") - c.createCall(calleeValue, append(params, llvm.ConstPointerNull(c.i8ptrType)), "") + callee = c.builder.CreatePtrToInt(funcPtr, c.uintptrType, "") default: panic("unreachable") } + c.createCall(c.mod.NamedFunction("internal/task.start"), []llvm.Value{callee, paramBundle, llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)}, "") return llvm.Undef(funcPtr.Type().ElementType().ReturnType()) } diff --git a/compiler/optimizer.go b/compiler/optimizer.go index 60395edc..5402fa18 100644 --- a/compiler/optimizer.go +++ b/compiler/optimizer.go @@ -31,6 +31,9 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) []er } } + // Replace callMain placeholder with actual main function. + c.mod.NamedFunction("runtime.callMain").ReplaceAllUsesWith(c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path() + ".main")) + // Run function passes for each function. funcPasses := llvm.NewFunctionPassManagerForModule(c.mod) defer funcPasses.Dispose() @@ -92,25 +95,41 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) []er } } - err := c.LowerGoroutines() - if err != nil { - return []error{err} - } } else { // Must be run at any optimization level. transform.LowerInterfaces(c.mod) if c.funcImplementation() == funcValueSwitch { transform.LowerFuncValues(c.mod) } - err := c.LowerGoroutines() - if err != nil { - return []error{err} - } errs := transform.LowerInterrupts(c.mod) if len(errs) > 0 { return errs } } + + // Lower async implementations. + switch c.Scheduler() { + case "coroutines": + // Lower async as coroutines. + err := transform.LowerCoroutines(c.mod, c.NeedsStackObjects()) + if err != nil { + return []error{err} + } + case "tasks": + // No transformations necessary. + case "none": + // Check for any goroutine starts. + if start := c.mod.NamedFunction("internal/task.start"); !start.IsNil() && len(getUses(start)) > 0 { + errs := []error{} + for _, call := range getUses(start) { + errs = append(errs, errorAt(call, "attempted to start a goroutine without a scheduler")) + } + return errs + } + default: + return []error{errors.New("invalid scheduler")} + } + if c.VerifyIR() { if errs := c.checkModule(); errs != nil { return errs diff --git a/ir/passes.go b/ir/passes.go index 0fa7a4fd..55390cad 100644 --- a/ir/passes.go +++ b/ir/passes.go @@ -69,10 +69,11 @@ func (p *Program) SimpleDCE() { main := p.mainPkg.Members["main"].(*ssa.Function) runtimePkg := p.Program.ImportedPackage("runtime") mathPkg := p.Program.ImportedPackage("math") + taskPkg := p.Program.ImportedPackage("internal/task") p.GetFunction(main).flag = true worklist := []*ssa.Function{main} for _, f := range p.Functions { - if f.exported || f.Synthetic == "package initializer" || f.Pkg == runtimePkg || (f.Pkg == mathPkg && f.Pkg != nil) { + if f.exported || f.Synthetic == "package initializer" || f.Pkg == runtimePkg || f.Pkg == taskPkg || (f.Pkg == mathPkg && f.Pkg != nil) { if f.flag { continue } diff --git a/src/internal/task/queue.go b/src/internal/task/queue.go new file mode 100644 index 00000000..c86bc596 --- /dev/null +++ b/src/internal/task/queue.go @@ -0,0 +1,98 @@ +package task + +const asserts = false + +// Queue is a FIFO container of tasks. +// The zero value is an empty queue. +type Queue struct { + head, tail *Task +} + +// Push a task onto the queue. +func (q *Queue) Push(t *Task) { + if asserts && t.Next != nil { + panic("runtime: pushing a task to a queue with a non-nil Next pointer") + } + if q.tail != nil { + q.tail.Next = t + } + q.tail = t + t.Next = nil + if q.head == nil { + q.head = t + } +} + +// Pop a task off of the queue. +func (q *Queue) Pop() *Task { + t := q.head + if t == nil { + return nil + } + q.head = t.Next + if q.tail == t { + q.tail = nil + } + t.Next = nil + return t +} + +// Append pops the contents of another queue and pushes them onto the end of this queue. +func (q *Queue) Append(other *Queue) { + if q.head == nil { + q.head = other.head + } else { + q.tail.Next = other.head + } + q.tail = other.tail + other.head, other.tail = nil, nil +} + +// Stack is a LIFO container of tasks. +// The zero value is an empty stack. +// This is slightly cheaper than a queue, so it can be preferable when strict ordering is not necessary. +type Stack struct { + top *Task +} + +// Push a task onto the stack. +func (s *Stack) Push(t *Task) { + if asserts && t.Next != nil { + panic("runtime: pushing a task to a stack with a non-nil Next pointer") + } + s.top, t.Next = t, s.top +} + +// Pop a task off of the stack. +func (s *Stack) Pop() *Task { + t := s.top + if t != nil { + s.top = t.Next + } + t.Next = nil + return t +} + +// tail follows the chain of tasks. +// If t is nil, returns nil. +// Otherwise, returns the task in the chain where the Next field is nil. +func (t *Task) tail() *Task { + if t == nil { + return nil + } + for t.Next != nil { + t = t.Next + } + return t +} + +// Queue moves the contents of the stack into a queue. +// Elements can be popped from the queue in the same order that they would be popped from the stack. +func (s *Stack) Queue() Queue { + head := s.top + s.top = nil + return Queue{ + head: head, + tail: head.tail(), + } +} diff --git a/src/internal/task/task.go b/src/internal/task/task.go new file mode 100644 index 00000000..57b29eb3 --- /dev/null +++ b/src/internal/task/task.go @@ -0,0 +1,20 @@ +package task + +import ( + "unsafe" +) + +// Task is a state of goroutine for scheduling purposes. +type Task struct { + // Next is a field which can be used to make a linked list of tasks. + Next *Task + + // Ptr is a field which can be used for storing a pointer. + Ptr unsafe.Pointer + + // Data is a field which can be used for storing state information. + Data uint + + // state is the underlying running state of the task. + state state +} diff --git a/src/internal/task/task_coroutine.go b/src/internal/task/task_coroutine.go new file mode 100644 index 00000000..374bfe11 --- /dev/null +++ b/src/internal/task/task_coroutine.go @@ -0,0 +1,97 @@ +// +build scheduler.coroutines + +package task + +import ( + "unsafe" +) + +// rawState is an underlying coroutine state exposed by llvm.coro. +// This matches *i8 in LLVM. +type rawState uint8 + +//go:export llvm.coro.resume +func (s *rawState) resume() + +type state struct{ *rawState } + +//go:export llvm.coro.noop +func noopState() *rawState + +// Resume the task until it pauses or completes. +func (t *Task) Resume() { + t.state.resume() +} + +// setState is used by the compiler to set the state of the function at the beginning of a function call. +// Returns the state of the caller. +func (t *Task) setState(s *rawState) *rawState { + caller := t.state + t.state = state{s} + return caller.rawState +} + +// returnTo is used by the compiler to return to the state of the caller. +func (t *Task) returnTo(parent *rawState) { + t.state = state{parent} + t.returnCurrent() +} + +// returnCurrent is used by the compiler to return to the state of the caller in a case where the state is not replaced. +func (t *Task) returnCurrent() { + scheduleTask(t) +} + +//go:linkname scheduleTask runtime.runqueuePushBack +func scheduleTask(*Task) + +// setReturnPtr is used by the compiler to store the return buffer into the task. +// This buffer is where the return value of a function that is about to be called will be stored. +func (t *Task) setReturnPtr(buf unsafe.Pointer) { + t.Ptr = buf +} + +// getReturnPtr is used by the compiler to get the return buffer stored into the task. +// This is called at the beginning of an async function, and the return is stored into this buffer immediately before resuming the caller. +func (t *Task) getReturnPtr() unsafe.Pointer { + return t.Ptr +} + +// createTask returns a new task struct initialized with a no-op state. +func createTask() *Task { + return &Task{ + state: state{noopState()}, + } +} + +// start invokes a function in a new goroutine. Calls to this are inserted by the compiler. +// The created goroutine starts running immediately. +// This is implemented inside the compiler. +func start(fn uintptr, args unsafe.Pointer) + +// Current returns the current active task. +// This is implemented inside the compiler. +func Current() *Task + +// Pause suspends the current running task. +// This is implemented inside the compiler. +func Pause() + +type taskHolder interface { + setState(*rawState) *rawState + returnTo(*rawState) + returnCurrent() + setReturnPtr(unsafe.Pointer) + getReturnPtr() unsafe.Pointer +} + +// If there are no direct references to the task methods, they will not be discovered by the compiler, and this will trigger a compiler error. +// Instantiating this interface forces discovery of these methods. +var _ = taskHolder((*Task)(nil)) + +func fake() { + // Hack to ensure intrinsics are discovered. + Current() + go func() {}() + Pause() +} diff --git a/src/internal/task/task_none.go b/src/internal/task/task_none.go new file mode 100644 index 00000000..31835892 --- /dev/null +++ b/src/internal/task/task_none.go @@ -0,0 +1,29 @@ +// +build scheduler.none + +package task + +import "unsafe" + +//go:linkname runtimePanic runtime.runtimePanic +func runtimePanic(str string) + +func Pause() { + runtimePanic("scheduler is disabled") +} + +func Current() *Task { + runtimePanic("scheduler is disabled") + return nil +} + +//go:noinline +func start(fn uintptr, args unsafe.Pointer) { + // The compiler will error if this is reachable. + runtimePanic("scheduler is disabled") +} + +type state struct{} + +func (t *Task) Resume() { + runtimePanic("scheduler is disabled") +} diff --git a/src/internal/task/task_stack.go b/src/internal/task/task_stack.go new file mode 100644 index 00000000..0bd2f4c2 --- /dev/null +++ b/src/internal/task/task_stack.go @@ -0,0 +1,74 @@ +// +build scheduler.tasks + +package task + +import "unsafe" + +//go:linkname runtimePanic runtime.runtimePanic +func runtimePanic(str string) + +// Stack canary, to detect a stack overflow. The number is a random number +// generated by random.org. The bit fiddling dance is necessary because +// otherwise Go wouldn't allow the cast to a smaller integer size. +const stackCanary = uintptr(uint64(0x670c1333b83bf575) & uint64(^uintptr(0))) + +// state is a structure which holds a reference to the state of the task. +// When the task is suspended, the registers are stored onto the stack and the stack pointer is stored into sp. +type state struct { + // sp is the stack pointer of the saved state. + // When the task is inactive, the saved registers are stored at the top of the stack. + sp uintptr + + // canaryPtr points to the top word of the stack (the lowest address). + // This is used to detect stack overflows. + // When initializing the goroutine, the stackCanary constant is stored there. + // If the stack overflowed, the word will likely no longer equal stackCanary. + canaryPtr *uintptr +} + +// currentTask is the current running task, or nil if currently in the scheduler. +var currentTask *Task + +// Current returns the current active task. +func Current() *Task { + return currentTask +} + +// Pause suspends the current task and returns to the scheduler. +// This function may only be called when running on a goroutine stack, not when running on the system stack or in an interrupt. +func Pause() { + // Check whether the canary (the lowest address of the stack) is still + // valid. If it is not, a stack overflow has occured. + if *currentTask.state.canaryPtr != stackCanary { + runtimePanic("goroutine stack overflow") + } + currentTask.state.pause() +} + +// Resume the task until it pauses or completes. +// This may only be called from the scheduler. +func (t *Task) Resume() { + currentTask = t + t.state.resume() + currentTask = nil +} + +// initialize the state and prepare to call the specified function with the specified argument bundle. +func (s *state) initialize(fn uintptr, args unsafe.Pointer) { + // Create a stack. + stack := make([]uintptr, stackSize/unsafe.Sizeof(uintptr(0))) + + // Invoke architecture-specific initialization. + s.archInit(stack, fn, args) +} + +//go:linkname runqueuePushBack runtime.runqueuePushBack +func runqueuePushBack(*Task) + +// start creates and starts a new goroutine with the given function and arguments. +// The new goroutine is scheduled to run later. +func start(fn uintptr, args unsafe.Pointer) { + t := &Task{} + t.state.initialize(fn, args) + runqueuePushBack(t) +} diff --git a/src/internal/task/task_stack_cortexm.go b/src/internal/task/task_stack_cortexm.go new file mode 100644 index 00000000..b91e5f04 --- /dev/null +++ b/src/internal/task/task_stack_cortexm.go @@ -0,0 +1,83 @@ +// +build scheduler.tasks, cortexm + +package task + +import "unsafe" + +const stackSize = 1024 + +// calleeSavedRegs is the list of registers that must be saved and restored when +// switching between tasks. Also see scheduler_cortexm.S that relies on the +// exact layout of this struct. +type calleeSavedRegs struct { + r4 uintptr + r5 uintptr + r6 uintptr + r7 uintptr + r8 uintptr + r9 uintptr + r10 uintptr + r11 uintptr + + pc uintptr +} + +// registers gets a pointer to the registers stored at the top of the stack. +func (s *state) registers() *calleeSavedRegs { + return (*calleeSavedRegs)(unsafe.Pointer(s.sp)) +} + +// startTask is a small wrapper function that sets up the first (and only) +// argument to the new goroutine and makes sure it is exited when the goroutine +// finishes. +//go:extern tinygo_startTask +var startTask [0]uint8 + +// archInit runs architecture-specific setup for the goroutine startup. +func (s *state) archInit(stack []uintptr, fn uintptr, args unsafe.Pointer) { + // Set up the stack canary, a random number that should be checked when + // switching from the task back to the scheduler. The stack canary pointer + // points to the first word of the stack. If it has changed between now and + // the next stack switch, there was a stack overflow. + s.canaryPtr = &stack[0] + *s.canaryPtr = stackCanary + + // Store the initial sp for the startTask function (implemented in assembly). + s.sp = uintptr(unsafe.Pointer(&stack[uintptr(len(stack))-(unsafe.Sizeof(calleeSavedRegs{})/unsafe.Sizeof(uintptr(0)))])) + + // Initialize the registers. + // These will be popped off of the stack on the first resume of the goroutine. + r := s.registers() + + // Start the function at tinygo_startTask (defined in src/runtime/scheduler_cortexm.S). + // This assembly code calls a function (passed in r4) with a single argument (passed in r5). + // After the function returns, it calls Pause(). + r.pc = uintptr(unsafe.Pointer(&startTask)) + + // Pass the function to call in r4. + // This function is a compiler-generated wrapper which loads arguments out of a struct pointer. + // See createGoroutineStartWrapper (defined in compiler/goroutine.go) for more information. + r.r4 = fn + + // Pass the pointer to the arguments struct in r5. + r.r5 = uintptr(args) +} + +func (s *state) resume() { + switchToTask(s.sp) +} + +//export tinygo_switchToTask +func switchToTask(uintptr) + +//export tinygo_switchToScheduler +func switchToScheduler(*uintptr) + +func (s *state) pause() { + switchToScheduler(&s.sp) +} + +//export tinygo_pause +func pause() { + Pause() +} diff --git a/src/runtime/chan.go b/src/runtime/chan.go index a4cb4748..95243136 100644 --- a/src/runtime/chan.go +++ b/src/runtime/chan.go @@ -24,6 +24,7 @@ package runtime // element of the receiving coroutine and setting the 'comma-ok' value to false. import ( + "internal/task" "unsafe" ) @@ -46,7 +47,7 @@ type channelBlockedList struct { // If this channel operation is not part of a select, then the pointer field of the state holds the data buffer. // If this channel operation is part of a select, then the pointer field of the state holds the recieve buffer. // If this channel operation is a receive, then the data field should be set to zero when resuming due to channel closure. - t *task + t *task.Task // s is a pointer to the channel select state corresponding to this operation. // This will be nil if and only if this channel operation is not part of a select statement. @@ -141,24 +142,24 @@ func (ch *channel) resumeRX(ok bool) unsafe.Pointer { b, ch.blocked = ch.blocked, ch.blocked.next // get destination pointer - dst := b.t.state().ptr + dst := b.t.Ptr if !ok { // the result value is zero memzero(dst, ch.elementSize) - b.t.state().data = 0 + b.t.Data = 0 } if b.s != nil { // tell the select op which case resumed - b.t.state().ptr = unsafe.Pointer(b.s) + b.t.Ptr = unsafe.Pointer(b.s) // detach associated operations b.detach() } // push task onto runqueue - runqueuePushBack(b.t) + runqueue.Push(b.t) return dst } @@ -171,21 +172,21 @@ func (ch *channel) resumeTX() unsafe.Pointer { b, ch.blocked = ch.blocked, ch.blocked.next // get source pointer - src := b.t.state().ptr + src := b.t.Ptr if b.s != nil { // use state's source pointer src = b.s.value // tell the select op which case resumed - b.t.state().ptr = unsafe.Pointer(b.s) + b.t.Ptr = unsafe.Pointer(b.s) // detach associated operations b.detach() } // push task onto runqueue - runqueuePushBack(b.t) + runqueue.Push(b.t) return src } @@ -424,17 +425,16 @@ func chanSend(ch *channel, value unsafe.Pointer) { } // wait for reciever - sender := getCoroutine() + sender := task.Current() ch.state = chanStateSend - senderState := sender.state() - senderState.ptr = value + sender.Ptr = value ch.blocked = &channelBlockedList{ next: ch.blocked, t: sender, } chanDebug(ch) - yield() - senderState.ptr = nil + task.Pause() + sender.Ptr = nil } // chanRecv receives a single value over a channel. @@ -454,18 +454,17 @@ func chanRecv(ch *channel, value unsafe.Pointer) bool { } // wait for a value - receiver := getCoroutine() + receiver := task.Current() ch.state = chanStateRecv - receiverState := receiver.state() - receiverState.ptr, receiverState.data = value, 1 + receiver.Ptr, receiver.Data = value, 1 ch.blocked = &channelBlockedList{ next: ch.blocked, t: receiver, } chanDebug(ch) - yield() - ok := receiverState.data == 1 - receiverState.ptr, receiverState.data = nil, 0 + task.Pause() + ok := receiver.Data == 1 + receiver.Ptr, receiver.Data = nil, 0 return ok } @@ -515,7 +514,7 @@ func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, ops []channelB for i, v := range states { ops[i] = channelBlockedList{ next: v.ch.blocked, - t: getCoroutine(), + t: task.Current(), s: &states[i], allSelectOps: ops, } @@ -547,14 +546,15 @@ func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, ops []channelB } // expose rx buffer - getCoroutine().state().ptr = recvbuf - getCoroutine().state().data = 1 + t := task.Current() + t.Ptr = recvbuf + t.Data = 1 // wait for one case to fire - yield() + task.Pause() // figure out which one fired and return the ok value - return (uintptr(getCoroutine().state().ptr) - uintptr(unsafe.Pointer(&states[0]))) / unsafe.Sizeof(chanSelectState{}), getCoroutine().state().data != 0 + return (uintptr(t.Ptr) - uintptr(unsafe.Pointer(&states[0]))) / unsafe.Sizeof(chanSelectState{}), t.Data != 0 } // tryChanSelect is like chanSelect, but it does a non-blocking select operation. diff --git a/src/runtime/runtime.go b/src/runtime/runtime.go index 13f3ab3c..ae9ab665 100644 --- a/src/runtime/runtime.go +++ b/src/runtime/runtime.go @@ -10,17 +10,8 @@ const Compiler = "tinygo" // package. func initAll() -// A function call to this function is replaced with one of the following, -// depending on whether the scheduler is necessary: -// -// Without scheduler: -// -// main.main() -// -// With scheduler: -// -// main.main() -// scheduler() +// callMain is a placeholder for the program main function. +// All references to this are replaced with references to the program main function by the compiler. func callMain() func GOMAXPROCS(n int) int { diff --git a/src/runtime/runtime_arm7tdmi.go b/src/runtime/runtime_arm7tdmi.go index cff5bca3..cf2ff681 100644 --- a/src/runtime/runtime_arm7tdmi.go +++ b/src/runtime/runtime_arm7tdmi.go @@ -40,7 +40,10 @@ func main() { initAll() // Compiler-generated call to main.main(). - callMain() + go callMain() + + // Run the scheduler. + scheduler() } func preinit() { diff --git a/src/runtime/runtime_atsamd21.go b/src/runtime/runtime_atsamd21.go index 440cd148..a3aa6c97 100644 --- a/src/runtime/runtime_atsamd21.go +++ b/src/runtime/runtime_atsamd21.go @@ -17,7 +17,8 @@ type timeUnit int64 func main() { preinit() initAll() - callMain() + go callMain() + scheduler() abort() } diff --git a/src/runtime/runtime_atsamd51.go b/src/runtime/runtime_atsamd51.go index 7cbcad04..23d91109 100644 --- a/src/runtime/runtime_atsamd51.go +++ b/src/runtime/runtime_atsamd51.go @@ -16,7 +16,8 @@ type timeUnit int64 func main() { preinit() initAll() - callMain() + go callMain() + scheduler() abort() } diff --git a/src/runtime/runtime_cortexm.go b/src/runtime/runtime_cortexm.go index 147a3b39..04b34039 100644 --- a/src/runtime/runtime_cortexm.go +++ b/src/runtime/runtime_cortexm.go @@ -40,29 +40,6 @@ func preinit() { } } -// calleeSavedRegs is the list of registers that must be saved and restored when -// switching between tasks. Also see scheduler_cortexm.S that relies on the -// exact layout of this struct. -type calleeSavedRegs struct { - r4 uintptr - r5 uintptr - r6 uintptr - r7 uintptr - r8 uintptr - r9 uintptr - r10 uintptr - r11 uintptr -} - -// prepareStartTask stores fn and args in some callee-saved registers that can -// then be used by the startTask function (implemented in assembly) to set up -// the initial stack pointer and initial argument with the pointer to the object -// with the goroutine start arguments. -func (r *calleeSavedRegs) prepareStartTask(fn, args uintptr) { - r.r4 = fn - r.r5 = args -} - func abort() { // disable all interrupts arm.DisableInterrupts() diff --git a/src/runtime/runtime_cortexm_qemu.go b/src/runtime/runtime_cortexm_qemu.go index f1f95caa..973c20ac 100644 --- a/src/runtime/runtime_cortexm_qemu.go +++ b/src/runtime/runtime_cortexm_qemu.go @@ -21,7 +21,8 @@ var timestamp timeUnit func main() { preinit() initAll() - callMain() + go callMain() + scheduler() arm.SemihostingCall(arm.SemihostingReportException, arm.SemihostingApplicationExit) abort() } diff --git a/src/runtime/runtime_fe310.go b/src/runtime/runtime_fe310.go index 92cb9ad1..88295c9c 100644 --- a/src/runtime/runtime_fe310.go +++ b/src/runtime/runtime_fe310.go @@ -52,7 +52,8 @@ func main() { preinit() initPeripherals() initAll() - callMain() + go callMain() + scheduler() abort() } diff --git a/src/runtime/runtime_nrf.go b/src/runtime/runtime_nrf.go index 50d3e5f2..9b01ded7 100644 --- a/src/runtime/runtime_nrf.go +++ b/src/runtime/runtime_nrf.go @@ -22,7 +22,8 @@ func main() { systemInit() preinit() initAll() - callMain() + go callMain() + scheduler() abort() } diff --git a/src/runtime/runtime_stm32.go b/src/runtime/runtime_stm32.go index 0717e7cf..11e686f0 100644 --- a/src/runtime/runtime_stm32.go +++ b/src/runtime/runtime_stm32.go @@ -8,6 +8,7 @@ type timeUnit int64 func main() { preinit() initAll() - callMain() + go callMain() + scheduler() abort() } diff --git a/src/runtime/runtime_unix.go b/src/runtime/runtime_unix.go index 63d78917..a88f243f 100644 --- a/src/runtime/runtime_unix.go +++ b/src/runtime/runtime_unix.go @@ -52,7 +52,10 @@ func main() int { initAll() // Compiler-generated call to main.main(). - callMain() + go callMain() + + // Run scheduler. + scheduler() // For libc compatibility. return 0 diff --git a/src/runtime/runtime_wasm.go b/src/runtime/runtime_wasm.go index 09216cc3..c01fa932 100644 --- a/src/runtime/runtime_wasm.go +++ b/src/runtime/runtime_wasm.go @@ -21,7 +21,8 @@ func fd_write(id uint32, iovs *wasiIOVec, iovs_len uint, nwritten *uint) (errno //export _start func _start() { initAll() - callMain() + go callMain() + scheduler() } // Using global variables to avoid heap allocation. @@ -50,7 +51,9 @@ func setEventHandler(fn func()) { //go:export resume func resume() { - handleEvent() + go func() { + handleEvent() + }() } //go:export go_scheduler diff --git a/src/runtime/scheduler.go b/src/runtime/scheduler.go index 7dc2b790..49cf8e40 100644 --- a/src/runtime/scheduler.go +++ b/src/runtime/scheduler.go @@ -14,27 +14,16 @@ package runtime // of the coroutine-based scheduler, it is the coroutine pointer (a *i8 in // LLVM). -import "unsafe" +import ( + "internal/task" +) const schedulerDebug = false -// State of a task. Internally represented as: -// -// {i8* next, i8* ptr, i32/i64 data} -type taskState struct { - next *task - ptr unsafe.Pointer - data uint -} - // Queues used by the scheduler. -// -// TODO: runqueueFront can be removed by making the run queue a circular linked -// list. The runqueueBack will simply refer to the front in the 'next' pointer. var ( - runqueueFront *task - runqueueBack *task - sleepQueue *task + runqueue task.Queue + sleepQueue *task.Task sleepQueueBaseTime timeUnit ) @@ -46,14 +35,14 @@ func scheduleLog(msg string) { } // Simple logging with a task pointer, for debugging. -func scheduleLogTask(msg string, t *task) { +func scheduleLogTask(msg string, t *task.Task) { if schedulerDebug { println("---", msg, t) } } // Simple logging with a channel and task pointer. -func scheduleLogChan(msg string, ch *channel, t *task) { +func scheduleLogChan(msg string, ch *channel, t *task.Task) { if schedulerDebug { println("---", msg, ch, t) } @@ -67,7 +56,7 @@ func scheduleLogChan(msg string, ch *channel, t *task) { //go:noinline func deadlock() { // call yield without requesting a wakeup - yield() + task.Pause() panic("unreachable") } @@ -80,122 +69,20 @@ func Goexit() { deadlock() } -// unblock unblocks a task and returns the next value -func unblock(t *task) *task { - state := t.state() - next := state.next - state.next = nil - activateTask(t) - return next -} - -// unblockChain unblocks the next task on the stack/queue, returning it -// also updates the chain, putting the next element into the chain pointer -// if the chain is used as a queue, tail is used as a pointer to the final insertion point -// if the chain is used as a stack, tail should be nil -func unblockChain(chain **task, tail ***task) *task { - t := *chain - if t == nil { - return nil - } - *chain = unblock(t) - if tail != nil && *chain == nil { - *tail = chain - } - return t -} - -// dropChain drops a task from the given stack or queue -// if the chain is used as a queue, tail is used as a pointer to the field containing a pointer to the next insertion point -// if the chain is used as a stack, tail should be nil -func dropChain(t *task, chain **task, tail ***task) { - for c := chain; *c != nil; c = &((*c).state().next) { - if *c == t { - next := (*c).state().next - if next == nil && tail != nil { - *tail = c - } - *c = next - return - } - } - panic("runtime: task not in chain") -} - -// Pause the current task for a given time. -//go:linkname sleep time.Sleep -func sleep(duration int64) { - addSleepTask(getCoroutine(), duration) - yield() -} - -func avrSleep(duration int64) { - sleepTicks(timeUnit(duration / tickMicros)) -} - -// Add a non-queued task to the run queue. -// -// This is a compiler intrinsic, and is called from a callee to reactivate the -// caller. -func activateTask(t *task) { - if t == nil { - return - } - scheduleLogTask(" set runnable:", t) - runqueuePushBack(t) -} - -// getTaskStateData is a helper function to get the current .data field of the -// goroutine state. -//go:inline -func getTaskStateData(t *task) uint { - return t.state().data -} - -// Add this task to the end of the run queue. May also destroy the task if it's -// done. -func runqueuePushBack(t *task) { - if schedulerDebug { - scheduleLogTask(" pushing back:", t) - if t.state().next != nil { - panic("runtime: runqueuePushBack: expected next task to be nil") - } - } - if runqueueBack == nil { // empty runqueue - runqueueBack = t - runqueueFront = t - } else { - lastTaskState := runqueueBack.state() - lastTaskState.next = t - runqueueBack = t - } -} - -// Get a task from the front of the run queue. Returns nil if there is none. -func runqueuePopFront() *task { - t := runqueueFront - if t == nil { - return nil - } - state := t.state() - runqueueFront = state.next - if runqueueFront == nil { - // Runqueue is empty now. - runqueueBack = nil - } - state.next = nil - return t +// Add this task to the end of the run queue. +func runqueuePushBack(t *task.Task) { + runqueue.Push(t) } // Add this task to the sleep queue, assuming its state is set to sleeping. -func addSleepTask(t *task, duration int64) { +func addSleepTask(t *task.Task, duration int64) { if schedulerDebug { println(" set sleep:", t, uint(duration/tickMicros)) - if t.state().next != nil { + if t.Next != nil { panic("runtime: addSleepTask: expected next task to be nil") } } - t.state().data = uint(duration / tickMicros) // TODO: longer durations + t.Data = uint(duration / tickMicros) // TODO: longer durations now := ticks() if sleepQueue == nil { scheduleLog(" -> sleep new queue") @@ -206,20 +93,20 @@ func addSleepTask(t *task, duration int64) { // Add to sleep queue. q := &sleepQueue - for ; *q != nil; q = &((*q).state()).next { - if t.state().data < (*q).state().data { + for ; *q != nil; q = &(*q).Next { + if t.Data < (*q).Data { // this will finish earlier than the next - insert here break } else { // this will finish later - adjust delay - t.state().data -= (*q).state().data + t.Data -= (*q).Data } } if *q != nil { // cut delay time between this sleep task and the next - (*q).state().data -= t.state().data + (*q).Data -= t.Data } - t.state().next = *q + t.Next = *q *q = t } @@ -236,17 +123,16 @@ func scheduler() { // Add tasks that are done sleeping to the end of the runqueue so they // will be executed soon. - if sleepQueue != nil && now-sleepQueueBaseTime >= timeUnit(sleepQueue.state().data) { + if sleepQueue != nil && now-sleepQueueBaseTime >= timeUnit(sleepQueue.Data) { t := sleepQueue scheduleLogTask(" awake:", t) - state := t.state() - sleepQueueBaseTime += timeUnit(state.data) - sleepQueue = state.next - state.next = nil - runqueuePushBack(t) + sleepQueueBaseTime += timeUnit(t.Data) + sleepQueue = t.Next + t.Next = nil + runqueue.Push(t) } - t := runqueuePopFront() + t := runqueue.Pop() if t == nil { if sleepQueue == nil { // No more tasks to execute. @@ -256,11 +142,11 @@ func scheduler() { scheduleLog(" no tasks left!") return } - timeLeft := timeUnit(sleepQueue.state().data) - (now - sleepQueueBaseTime) + timeLeft := timeUnit(sleepQueue.Data) - (now - sleepQueueBaseTime) if schedulerDebug { println(" sleeping...", sleepQueue, uint(timeLeft)) - for t := sleepQueue; t != nil; t = t.state().next { - println(" task sleeping:", t, timeUnit(t.state().data)) + for t := sleepQueue; t != nil; t = t.Next { + println(" task sleeping:", t, timeUnit(t.Data)) } } sleepTicks(timeLeft) @@ -275,11 +161,11 @@ func scheduler() { // Run the given task. scheduleLogTask(" run:", t) - t.resume() + t.Resume() } } func Gosched() { - runqueuePushBack(getCoroutine()) - yield() + runqueue.Push(task.Current()) + task.Pause() } diff --git a/src/runtime/scheduler_any.go b/src/runtime/scheduler_any.go new file mode 100644 index 00000000..d541dc6a --- /dev/null +++ b/src/runtime/scheduler_any.go @@ -0,0 +1,12 @@ +// +build !scheduler.none + +package runtime + +import "internal/task" + +// Pause the current task for a given time. +//go:linkname sleep time.Sleep +func sleep(duration int64) { + addSleepTask(task.Current(), duration) + task.Pause() +} diff --git a/src/runtime/scheduler_coroutines.go b/src/runtime/scheduler_coroutines.go index 1d0996d9..5e871a64 100644 --- a/src/runtime/scheduler_coroutines.go +++ b/src/runtime/scheduler_coroutines.go @@ -2,103 +2,8 @@ package runtime -// This file implements the Go scheduler using coroutines. -// A goroutine contains a whole stack. A coroutine is just a single function. -// How do we use coroutines for goroutines, then? -// * Every function that contains a blocking call (like sleep) is marked -// blocking, and all it's parents (callers) are marked blocking as well -// transitively until the root (main.main or a go statement). -// * A blocking function that calls a non-blocking function is called as -// usual. -// * A blocking function that calls a blocking function passes its own -// coroutine handle as a parameter to the subroutine. When the subroutine -// returns, it will re-insert the parent into the scheduler. -// Note that we use the type 'task' to refer to a coroutine, for compatibility -// with the task-based scheduler. A task type here does not represent the whole -// task, but just the topmost coroutine. For most of the scheduler, this -// difference doesn't matter. -// -// For more background on coroutines in LLVM: -// https://llvm.org/docs/Coroutines.html - -import "unsafe" - -// A coroutine instance, wrapped here to provide some type safety. The value -// must not be used directly, it is meant to be used as an opaque *i8 in LLVM. -type task uint8 - -//go:export llvm.coro.resume -func (t *task) resume() - -//go:export llvm.coro.destroy -func (t *task) destroy() - -//go:export llvm.coro.done -func (t *task) done() bool - -//go:export llvm.coro.promise -func (t *task) _promise(alignment int32, from bool) unsafe.Pointer - -// Get the state belonging to a task. -func (t *task) state() *taskState { - return (*taskState)(t._promise(int32(unsafe.Alignof(taskState{})), false)) -} - -func makeGoroutine(uintptr) uintptr - -// Compiler stub to get the current goroutine. Calls to this function are -// removed in the goroutine lowering pass. -func getCoroutine() *task - -// setTaskStatePtr is a helper function to set the current .ptr field of a -// coroutine promise. -func setTaskStatePtr(t *task, value unsafe.Pointer) { - t.state().ptr = value -} - -// getTaskStatePtr is a helper function to get the current .ptr field from a -// coroutine promise. -func getTaskStatePtr(t *task) unsafe.Pointer { - if t == nil { - blockingPanic() - } - return t.state().ptr -} - -// yield suspends execution of the current goroutine -// any wakeups must be configured before calling yield -func yield() - // getSystemStackPointer returns the current stack pointer of the system stack. // This is always the current stack pointer. func getSystemStackPointer() uintptr { return getCurrentStackPointer() } - -func fakeCoroutine(dst **task) { - *dst = getCoroutine() - for { - yield() - } -} - -func getFakeCoroutine() *task { - // this isnt defined behavior, but this is what our implementation does - // this is really a horrible hack - var t *task - go fakeCoroutine(&t) - - // the first line of fakeCoroutine will have completed by now - return t -} - -// noret is a placeholder that can be used to indicate that an async function is not going to directly return here -func noret() - -func getParentHandle() *task - -func llvmCoroRefHolder() { - noret() - getParentHandle() - getCoroutine() -} diff --git a/src/runtime/scheduler_cortexm.S b/src/runtime/scheduler_cortexm.S index 6d83ba4f..121992e7 100644 --- a/src/runtime/scheduler_cortexm.S +++ b/src/runtime/scheduler_cortexm.S @@ -17,7 +17,7 @@ tinygo_startTask: blx r4 // After return, exit this goroutine. This is a tail call. - bl runtime.yield + bl tinygo_pause .section .text.tinygo_getSystemStackPointer .global tinygo_getSystemStackPointer @@ -35,24 +35,23 @@ tinygo_getSystemStackPointer: .global tinygo_switchToScheduler .type tinygo_switchToScheduler, %function tinygo_switchToScheduler: - // r0 = oldTask *task + // r0 = sp *uintptr // Currently on the task stack (SP=PSP). We need to store the position on // the stack where the in-use registers will be stored. mov r1, sp subs r1, #36 - str r1, [r0, #36] + str r1, [r0] b tinygo_swapTask .global tinygo_switchToTask .type tinygo_switchToTask, %function tinygo_switchToTask: - // r0 = newTask *task + // r0 = sp uintptr // Currently on the scheduler stack (SP=MSP). We'll have to update the PSP, // and then we can invoke swapTask. - ldr r0, [r0, #36] msr PSP, r0 // Continue executing in the swapTask function, which swaps the stack diff --git a/src/runtime/scheduler_none.go b/src/runtime/scheduler_none.go new file mode 100644 index 00000000..222867dd --- /dev/null +++ b/src/runtime/scheduler_none.go @@ -0,0 +1,14 @@ +// +build scheduler.none + +package runtime + +//go:linkname sleep time.Sleep +func sleep(duration int64) { + sleepTicks(timeUnit(duration / tickMicros)) +} + +// getSystemStackPointer returns the current stack pointer of the system stack. +// This is always the current stack pointer. +func getSystemStackPointer() uintptr { + return getCurrentStackPointer() +} diff --git a/src/runtime/scheduler_tasks.go b/src/runtime/scheduler_tasks.go index b2142892..9be63d5d 100644 --- a/src/runtime/scheduler_tasks.go +++ b/src/runtime/scheduler_tasks.go @@ -2,110 +2,6 @@ package runtime -import "unsafe" - -const stackSize = 1024 - -// Stack canary, to detect a stack overflow. The number is a random number -// generated by random.org. The bit fiddling dance is necessary because -// otherwise Go wouldn't allow the cast to a smaller integer size. -const stackCanary = uintptr(uint64(0x670c1333b83bf575) & uint64(^uintptr(0))) - -var ( - currentTask *task // currently running goroutine, or nil -) - -// This type points to the bottom of the goroutine stack and contains some state -// that must be kept with the task. The last field is a canary, which is -// necessary to make sure that no stack overflow occured when switching tasks. -type task struct { - // The order of fields in this structs must be kept in sync with assembly! - calleeSavedRegs - pc uintptr - sp uintptr - taskState - canaryPtr *uintptr // used to detect stack overflows -} - -// getCoroutine returns the currently executing goroutine. It is used as an -// intrinsic when compiling channel operations, but is not necessary with the -// task-based scheduler. -//go:inline -func getCoroutine() *task { - return currentTask -} - -// state is a small helper that returns the task state, and is provided for -// compatibility with the coroutine implementation. -//go:inline -func (t *task) state() *taskState { - return &t.taskState -} - -// resume is a small helper that resumes this task until this task switches back -// to the scheduler. -func (t *task) resume() { - currentTask = t - switchToTask(t) - currentTask = nil -} - -// switchToScheduler saves the current state on the stack, saves the current -// stack pointer in the task, and switches to the scheduler. It must only be -// called when actually running on this task. -// When it returns, the scheduler has switched back to this task (for example, -// after a blocking operation completed). -//export tinygo_switchToScheduler -func switchToScheduler(t *task) - -// switchToTask switches from the scheduler to the task. It must only be called -// from the scheduler. -// When this function returns, the task just yielded control back to the -// scheduler. -//export tinygo_switchToTask -func switchToTask(t *task) - -// startTask is a small wrapper function that sets up the first (and only) -// argument to the new goroutine and makes sure it is exited when the goroutine -// finishes. -//go:extern tinygo_startTask -var startTask [0]uint8 - -// startGoroutine starts a new goroutine with the given function pointer and -// argument. It creates a new goroutine stack, prepares it for execution, and -// adds it to the runqueue. -func startGoroutine(fn, args uintptr) { - stack := alloc(stackSize) - t := (*task)(unsafe.Pointer(uintptr(stack) + stackSize - unsafe.Sizeof(task{}))) - - // Set up the stack canary, a random number that should be checked when - // switching from the task back to the scheduler. The stack canary pointer - // points to the first word of the stack. If it has changed between now and - // the next stack switch, there was a stack overflow. - t.canaryPtr = (*uintptr)(unsafe.Pointer(stack)) - *t.canaryPtr = stackCanary - - // Store the initial sp/pc for the startTask function (implemented in - // assembly). - t.sp = uintptr(stack) + stackSize - unsafe.Sizeof(task{}) - t.pc = uintptr(unsafe.Pointer(&startTask)) - t.prepareStartTask(fn, args) - scheduleLogTask(" start goroutine:", t) - runqueuePushBack(t) -} - -// yield suspends execution of the current goroutine -// any wakeups must be configured before calling yield -//export runtime.yield -func yield() { - // Check whether the canary (the lowest address of the stack) is still - // valid. If it is not, a stack overflow has occured. - if *currentTask.canaryPtr != stackCanary { - runtimePanic("goroutine stack overflow") - } - switchToScheduler(currentTask) -} - // getSystemStackPointer returns the current stack pointer of the system stack. // This is not necessarily the same as the current stack pointer. //export tinygo_getSystemStackPointer diff --git a/targets/avr.json b/targets/avr.json index 00ab6378..d3f97a5b 100644 --- a/targets/avr.json +++ b/targets/avr.json @@ -5,6 +5,7 @@ "compiler": "avr-gcc", "gc": "conservative", "linker": "avr-gcc", + "scheduler": "none", "ldflags": [ "-T", "targets/avr.ld", "-Wl,--gc-sections" diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 62d62ab7..77e14d0e 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -47,7 +47,7 @@ func main() { time.Sleep(2 * time.Millisecond) x = 1 }() - time.Sleep(time.Second/2) + time.Sleep(time.Second / 2) println("closure go call result:", x) time.Sleep(2 * time.Millisecond) diff --git a/transform/coroutines.go b/transform/coroutines.go new file mode 100644 index 00000000..5df2d746 --- /dev/null +++ b/transform/coroutines.go @@ -0,0 +1,929 @@ +package transform + +// This file lowers asynchronous functions and goroutine starts when using the coroutines scheduler. +// This is accomplished by inserting LLVM intrinsics which are used in order to save the states of functions. + +import ( + "errors" + "github.com/tinygo-org/tinygo/compiler/llvmutil" + "strconv" + "tinygo.org/x/go-llvm" +) + +// LowerCoroutines turns async functions into coroutines. +// This must be run with the coroutines scheduler. +// +// Before this pass, goroutine starts are expressed as a call to an intrinsic called "internal/task.start". +// This intrinsic accepts the function pointer and a pointer to a struct containing the function's arguments. +// +// Before this pass, an intrinsic called "internal/task.Pause" is used to express suspensions of the current goroutine. +// +// This pass first accumulates a list of blocking functions. +// A function is considered "blocking" if it calls "internal/task.Pause" or any other blocking function. +// +// Blocking calls are implemented by turning blocking functions into a coroutine. +// The body of each blocking function is modified to start a new coroutine, and to return after the first suspend. +// After calling a blocking function, the caller coroutine suspends. +// The caller also provides a buffer to store the return value into. +// When a blocking function returns, the return value is written into this buffer and then the caller is queued to run. +// +// Goroutine starts which invoke non-blocking functions are implemented as direct calls. +// Goroutine starts are replaced with the creation of a new task data structure followed by a call to the start of the blocking function. +// The task structure is populated with a "noop" coroutine before invoking the blocking function. +// When the blocking function returns, it resumes this "noop" coroutine which does nothing. +// The goroutine starter is able to continue after the first suspend of the started goroutine. +// +// The transformation of a function to a coroutine is accomplished using LLVM's coroutines system (https://llvm.org/docs/Coroutines.html). +// The simplest implementation of a coroutine inserts a suspend point after every blocking call. +// +// Transforming blocking functions into coroutines and calls into suspend points is extremely expensive. +// In many cases, a blocking call is followed immediately by a function terminator (a return or an "unreachable" instruction). +// This is a blocking "tail call". +// In a non-returning tail call (call to a non-returning function, such as an infinite loop), the coroutine can exit without any extra work. +// In a returning tail call, the returned value must either be the return of the call or a value known before the call. +// If the return value of the caller is the return of the callee, the coroutine can exit without any extra work and the tailed call will instead return to the caller of the caller. +// If the return value is known in advance, this result can be stored into the parent's return buffer before the call so that a suspend is unnecessary. +// If the callee returns an unnecessary value, a return buffer can be allocated on the heap so that it will outlive the coroutine. +// +// In the implementation of time.Sleep, the current task is pushed onto a timer queue and then suspended. +// Since the only suspend point is a call to "internal/task.Pause" followed by a return, there is no need to transform this into a coroutine. +// This generalizes to all blocking functions in which all suspend points can be elided. +// This optimization saves a substantial amount of binary size. +func LowerCoroutines(mod llvm.Module, needStackSlots bool) error { + ctx := mod.Context() + + builder := ctx.NewBuilder() + defer builder.Dispose() + + target := llvm.NewTargetData(mod.DataLayout()) + defer target.Dispose() + + pass := &coroutineLoweringPass{ + mod: mod, + ctx: ctx, + builder: builder, + target: target, + } + + err := pass.load() + if err != nil { + return err + } + + // Supply task operands to async calls. + pass.supplyTaskOperands() + + // Analyze async returns. + pass.returnAnalysisPass() + + // Categorize async calls. + pass.categorizeCalls() + + // Lower async functions. + pass.lowerFuncsPass() + + // Lower calls to internal/task.Current. + pass.lowerCurrent() + + // Lower goroutine starts. + pass.lowerStartsPass() + + // Fix annotations on async call params. + pass.fixAnnotations() + + if needStackSlots { + // Set up garbage collector tracking of tasks at start. + err = pass.trackGoroutines() + if err != nil { + return err + } + } + + return nil +} + +// asyncFunc is a metadata container for an asynchronous function. +type asyncFunc struct { + // fn is the underlying function pointer. + fn llvm.Value + + // rawTask is the parameter where the task pointer is passed in. + rawTask llvm.Value + + // callers is a set of all functions which call this async function. + callers map[llvm.Value]struct{} + + // returns is a map of terminal basic blocks to their return kinds. + returns map[llvm.BasicBlock]returnKind + + // calls is the set of all calls in the asyncFunc. + // normalCalls is the set of all intermideate suspending calls in the asyncFunc. + // tailCalls is the set of all tail calls in the asyncFunc. + calls, normalCalls, tailCalls map[llvm.Value]struct{} +} + +// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler. +type coroutineLoweringPass struct { + mod llvm.Module + ctx llvm.Context + builder llvm.Builder + target llvm.TargetData + + // asyncFuncs is a map of all asyncFuncs. + // The map keys are function pointers. + asyncFuncs map[llvm.Value]*asyncFunc + + // calls is a slice of all of the async calls in the module. + calls []llvm.Value + + i8ptr llvm.Type + + // memory management functions from the runtime + alloc, free llvm.Value + + // coroutine intrinsics + start, pause, current llvm.Value + setState, setRetPtr, getRetPtr, returnTo, returnCurrent llvm.Value + createTask llvm.Value + + // llvm.coro intrinsics + coroId, coroSize, coroBegin, coroSuspend, coroEnd, coroFree, coroSave llvm.Value +} + +// findAsyncFuncs finds all asynchronous functions. +// A function is considered asynchronous if it calls an asynchronous function or intrinsic. +func (c *coroutineLoweringPass) findAsyncFuncs() { + asyncs := map[llvm.Value]*asyncFunc{} + calls := []llvm.Value{} + + // Use a breadth-first search to find all async functions. + worklist := []llvm.Value{c.pause} + for len(worklist) > 0 { + // Pop a function off the worklist. + fn := worklist[len(worklist)-1] + worklist = worklist[:len(worklist)-1] + + // Get task pointer argument. + task := fn.LastParam() + if fn != c.pause && (task.IsNil() || task.Name() != "parentHandle") { + panic("trying to make exported function async: " + fn.Name()) + } + + // Search all uses of the function while collecting callers. + callers := map[llvm.Value]struct{}{} + for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() { + user := use.User() + if user.IsACallInst().IsNil() { + // User is not a call instruction, so this is irrelevant. + continue + } + if user.CalledValue() != fn { + // Not the called value. + continue + } + + // Add to calls list. + calls = append(calls, user) + + // Get the caller. + caller := user.InstructionParent().Parent() + + // Add as caller. + callers[caller] = struct{}{} + + if _, ok := asyncs[caller]; ok { + // Already marked caller as async. + continue + } + + // Mark the caller as async. + // Use nil as a temporary value. It will be replaced later. + asyncs[caller] = nil + + // Put the caller on the worklist. + worklist = append(worklist, caller) + } + + asyncs[fn] = &asyncFunc{ + fn: fn, + rawTask: task, + callers: callers, + } + } + + c.asyncFuncs = asyncs + c.calls = calls +} + +func (c *coroutineLoweringPass) load() error { + // Find memory management functions from the runtime. + c.alloc = c.mod.NamedFunction("runtime.alloc") + if c.alloc.IsNil() { + return ErrMissingIntrinsic{"runtime.alloc"} + } + c.free = c.mod.NamedFunction("runtime.free") + if c.free.IsNil() { + return ErrMissingIntrinsic{"runtime.free"} + } + + // Find intrinsics. + c.pause = c.mod.NamedFunction("internal/task.Pause") + if c.pause.IsNil() { + return ErrMissingIntrinsic{"internal/task.Pause"} + } + c.start = c.mod.NamedFunction("internal/task.start") + if c.start.IsNil() { + return ErrMissingIntrinsic{"internal/task.start"} + } + c.current = c.mod.NamedFunction("internal/task.Current") + if c.current.IsNil() { + return ErrMissingIntrinsic{"internal/task.Current"} + } + c.setState = c.mod.NamedFunction("(*internal/task.Task).setState") + if c.setState.IsNil() { + return ErrMissingIntrinsic{"(*internal/task.Task).setState"} + } + c.setRetPtr = c.mod.NamedFunction("(*internal/task.Task).setReturnPtr") + if c.setRetPtr.IsNil() { + return ErrMissingIntrinsic{"(*internal/task.Task).setReturnPtr"} + } + c.getRetPtr = c.mod.NamedFunction("(*internal/task.Task).getReturnPtr") + if c.getRetPtr.IsNil() { + return ErrMissingIntrinsic{"(*internal/task.Task).getReturnPtr"} + } + c.returnTo = c.mod.NamedFunction("(*internal/task.Task).returnTo") + if c.returnTo.IsNil() { + return ErrMissingIntrinsic{"(*internal/task.Task).returnTo"} + } + c.returnCurrent = c.mod.NamedFunction("(*internal/task.Task).returnCurrent") + if c.returnCurrent.IsNil() { + return ErrMissingIntrinsic{"(*internal/task.Task).returnCurrent"} + } + c.createTask = c.mod.NamedFunction("internal/task.createTask") + if c.createTask.IsNil() { + return ErrMissingIntrinsic{"internal/task.createTask"} + } + + // Find async functions. + c.findAsyncFuncs() + + // Get i8* type. + c.i8ptr = llvm.PointerType(c.ctx.Int8Type(), 0) + + // Build LLVM coroutine intrinsic. + coroIdType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.ctx.Int32Type(), c.i8ptr, c.i8ptr, c.i8ptr}, false) + c.coroId = llvm.AddFunction(c.mod, "llvm.coro.id", coroIdType) + + sizeT := c.alloc.Param(0).Type() + coroSizeType := llvm.FunctionType(sizeT, nil, false) + c.coroSize = llvm.AddFunction(c.mod, "llvm.coro.size.i"+strconv.Itoa(sizeT.IntTypeWidth()), coroSizeType) + + coroBeginType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false) + c.coroBegin = llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType) + + coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false) + c.coroSuspend = llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType) + + coroEndType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptr, c.ctx.Int1Type()}, false) + c.coroEnd = llvm.AddFunction(c.mod, "llvm.coro.end", coroEndType) + + coroFreeType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false) + c.coroFree = llvm.AddFunction(c.mod, "llvm.coro.free", coroFreeType) + + coroSaveType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.i8ptr}, false) + c.coroSave = llvm.AddFunction(c.mod, "llvm.coro.save", coroSaveType) + + return nil +} + +// lowerStartSync lowers a goroutine start of a synchronous function to a synchronous call. +func (c *coroutineLoweringPass) lowerStartSync(start llvm.Value) { + c.builder.SetInsertPointBefore(start) + + // Get function to call. + fn := start.Operand(0).Operand(0) + + // Create the list of params for the call. + paramTypes := fn.Type().ElementType().ParamTypes() + params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1]) + params = append(params, llvm.Undef(c.i8ptr)) + + // Generate call to function. + c.builder.CreateCall(fn, params, "") + + // Remove start call. + start.EraseFromParentAsInstruction() +} + +// supplyTaskOperands fills in the task operands of async calls. +func (c *coroutineLoweringPass) supplyTaskOperands() { + var curCalls []llvm.Value + for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() { + curCalls = append(curCalls, use.User()) + } + for _, call := range append(curCalls, c.calls...) { + c.builder.SetInsertPointBefore(call) + task := c.asyncFuncs[call.InstructionParent().Parent()].rawTask + call.SetOperand(call.OperandsCount()-2, task) + } +} + +// returnKind is a classification of a type of function terminator. +type returnKind uint8 + +const ( + // returnNormal is a terminator that returns a value normally from a function. + returnNormal returnKind = iota + + // returnVoid is a terminator that exits normally without returning a value. + returnVoid + + // returnVoidTail is a terminator which is a tail call to a void-returning function in a void-returning function. + returnVoidTail + + // returnTail is a terinator which is a tail call to a value-returning function where the value is returned by the callee. + returnTail + + // returnDeadTail is a terminator which is a call to a non-returning asynchronous function. + returnDeadTail + + // returnAlternateTail is a terminator which is a tail call to a value-returning function where a previously acquired value is returned by the callee. + returnAlternateTail + + // returnDitchedTail is a terminator which is a tail call to a value-returning function, where the callee returns void. + returnDitchedTail + + // returnDelayedValue is a terminator in which a void-returning tail call is followed by a return of a previous value. + returnDelayedValue +) + +// isAsyncCall returns whether the specified call is async. +func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool { + _, ok := c.asyncFuncs[call.CalledValue()] + return ok +} + +// analyzeFuncReturns analyzes and classifies the returns of a function. +func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { + returns := map[llvm.BasicBlock]returnKind{} + if fn.fn == c.pause { + // Skip pause. + fn.returns = returns + return + } + + for _, bb := range fn.fn.BasicBlocks() { + last := bb.LastInstruction() + switch last.InstructionOpcode() { + case llvm.Ret: + // Check if it is a void return. + isVoid := fn.fn.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind + + // Analyze previous instruction. + prev := llvm.PrevInstruction(last) + switch { + case prev.IsNil(): + fallthrough + case prev.IsACallInst().IsNil(): + fallthrough + case !c.isAsyncCall(prev): + // This is not any form of asynchronous tail call. + if isVoid { + returns[bb] = returnVoid + } else { + returns[bb] = returnNormal + } + case isVoid: + if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { + // This is a tail call to a void-returning function from a function with a void return. + returns[bb] = returnVoidTail + } else { + // This is a tail call to a value-returning function from a function with a void return. + // The returned value will be ditched. + returns[bb] = returnDitchedTail + } + case last.Operand(0) == prev: + // This is a regular tail call. The return of the callee is returned to the parent. + returns[bb] = returnTail + case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: + // This is a tail call that returns a previous value after waiting on a void function. + returns[bb] = returnDelayedValue + default: + // This is a tail call that returns a value that is available before the function call. + returns[bb] = returnAlternateTail + } + case llvm.Unreachable: + prev := llvm.PrevInstruction(last) + + if prev.IsNil() || prev.IsACallInst().IsNil() || !c.isAsyncCall(prev) { + // This unreachable instruction does not behave as an asynchronous return. + continue + } + + // This is an asyncnhronous tail call to function that does not return. + returns[bb] = returnDeadTail + } + } + + fn.returns = returns +} + +// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions. +func (c *coroutineLoweringPass) returnAnalysisPass() { + for _, async := range c.asyncFuncs { + c.analyzeFuncReturns(async) + } +} + +// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers. +func (c *coroutineLoweringPass) categorizeCalls() { + // Sort calls into their respective callers. + for _, async := range c.asyncFuncs { + async.calls = map[llvm.Value]struct{}{} + } + for _, call := range c.calls { + c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{} + } + + // Seperate regular and tail calls. + for _, async := range c.asyncFuncs { + // Find all tail calls (of any kind). + tails := map[llvm.Value]struct{}{} + for ret, kind := range async.returns { + switch kind { + case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: + // This is a tail return. The previous instruction is a tail call. + tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{} + } + } + + // Find all regular calls. + regulars := map[llvm.Value]struct{}{} + for call := range async.calls { + if _, ok := tails[call]; ok { + // This is a tail call. + continue + } + + regulars[call] = struct{}{} + } + + async.tailCalls = tails + async.normalCalls = regulars + } +} + +// lowerFuncsPass lowers all functions, turning them into coroutines if necessary. +func (c *coroutineLoweringPass) lowerFuncsPass() { + for _, fn := range c.asyncFuncs { + if fn.fn == c.pause { + // Skip. It is an intrinsic. + continue + } + + if len(fn.normalCalls) == 0 { + // No suspend points. Lower without turning it into a coroutine. + c.lowerFuncFast(fn) + } else { + // There are suspend points, so it is necessary to turn this into a coroutine. + c.lowerFuncCoro(fn) + } + } +} + +func (async *asyncFunc) hasValueStoreReturn() bool { + for _, kind := range async.returns { + switch kind { + case returnNormal, returnAlternateTail, returnDelayedValue: + return true + } + } + + return false +} + +// heapAlloc creates a heap allocation large enough to hold the supplied type. +// The allocation is returned as a raw i8* pointer. +// This allocation is not automatically tracked by the garbage collector, and should thus be stored into a tracked memory object immediately. +func (c *coroutineLoweringPass) heapAlloc(t llvm.Type, name string) llvm.Value { + sizeT := c.alloc.FirstParam().Type() + size := llvm.ConstInt(sizeT, c.target.TypeAllocSize(t), false) + return c.builder.CreateCall(c.alloc, []llvm.Value{size, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, name) +} + +// lowerFuncFast lowers an async function that has no suspend points. +func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) { + // Get return type. + retType := fn.fn.Type().ElementType().ReturnType() + + // Get task value. + c.insertPointAfterAllocas(fn.fn) + task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task") + + // Get return pointer if applicable. + var rawRetPtr, retPtr llvm.Value + if fn.hasValueStoreReturn() { + rawRetPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "ret.ptr") + retType = fn.fn.Type().ElementType().ReturnType() + retPtr = c.builder.CreateBitCast(rawRetPtr, llvm.PointerType(retType, 0), "ret.ptr.bitcast") + } + + // Lower returns. + for ret, kind := range fn.returns { + // Get terminator. + terminator := ret.LastInstruction() + + // Get tail call if applicable. + var call llvm.Value + switch kind { + case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: + call = llvm.PrevInstruction(terminator) + } + + switch kind { + case returnNormal: + c.builder.SetInsertPointBefore(terminator) + + // Store value into return pointer. + c.builder.CreateStore(terminator.Operand(0), retPtr) + + // Resume caller. + c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + + // Erase return argument. + terminator.SetOperand(0, llvm.Undef(retType)) + case returnVoid: + c.builder.SetInsertPointBefore(terminator) + + // Resume caller. + c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnVoidTail: + // Nothing to do. There is already a tail call followed by a void return. + case returnTail: + // Erase return argument. + terminator.SetOperand(0, llvm.Undef(retType)) + case returnDeadTail: + // Replace unreachable with immediate return, without resuming the caller. + c.builder.SetInsertPointBefore(terminator) + if retType.TypeKind() == llvm.VoidTypeKind { + c.builder.CreateRetVoid() + } else { + c.builder.CreateRet(llvm.Undef(retType)) + } + terminator.EraseFromParentAsInstruction() + case returnAlternateTail: + c.builder.SetInsertPointBefore(call) + + // Store return value. + c.builder.CreateStore(terminator.Operand(0), retPtr) + + // Heap-allocate a return buffer for the discarded return. + alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + + // Erase return argument. + terminator.SetOperand(0, llvm.Undef(retType)) + case returnDitchedTail: + c.builder.SetInsertPointBefore(call) + + // Heap-allocate a return buffer for the discarded return. + ditchBuf := c.heapAlloc(call.Type(), "ret.ditch") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnDelayedValue: + c.builder.SetInsertPointBefore(call) + + // Store value into return pointer. + c.builder.CreateStore(terminator.Operand(0), retPtr) + + // Erase return argument. + terminator.SetOperand(0, llvm.Undef(retType)) + } + + // Delete call if it is a pause, because it has already been lowered. + if !call.IsNil() && call.CalledValue() == c.pause { + call.EraseFromParentAsInstruction() + } + } +} + +// insertPointAfterAllocas sets the insert point of the builder to be immediately after the last alloca in the entry block. +func (c *coroutineLoweringPass) insertPointAfterAllocas(fn llvm.Value) { + inst := fn.EntryBasicBlock().FirstInstruction() + for !inst.IsAAllocaInst().IsNil() { + inst = llvm.NextInstruction(inst) + } + c.builder.SetInsertPointBefore(inst) +} + +// lowerCallReturn lowers the return value of an async call by creating a return buffer and loading the returned value from it. +func (c *coroutineLoweringPass) lowerCallReturn(caller *asyncFunc, call llvm.Value) { + // Get return type. + retType := call.Type() + if retType.TypeKind() == llvm.VoidTypeKind { + // Void return. Nothing to do. + return + } + + // Create alloca for return buffer. + alloca := llvmutil.CreateInstructionAlloca(c.builder, c.mod, retType, call, "call.return") + + // Store new return buffer into task before call. + c.builder.SetInsertPointBefore(call) + task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), caller.rawTask}, "call.task") + retPtr := c.builder.CreateBitCast(alloca, c.i8ptr, "call.return.bitcast") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retPtr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + + // Load return value after call. + c.builder.SetInsertPointBefore(llvm.NextInstruction(call)) + ret := c.builder.CreateLoad(alloca, "call.return.load") + + // Replace call value with loaded return. + call.ReplaceAllUsesWith(ret) +} + +// lowerFuncCoro transforms an async function into a coroutine by lowering async operations to `llvm.coro` intrinsics. +// See https://llvm.org/docs/Coroutines.html for more information on these intrinsics. +func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { + returnType := fn.fn.Type().ElementType().ReturnType() + + // Prepare coroutine state. + c.insertPointAfterAllocas(fn.fn) + // %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + coroId := c.builder.CreateCall(c.coroId, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstNull(c.i8ptr), + llvm.ConstNull(c.i8ptr), + llvm.ConstNull(c.i8ptr), + }, "coro.id") + // %coro.size = call i32 @llvm.coro.size.i32() + coroSize := c.builder.CreateCall(c.coroSize, []llvm.Value{}, "coro.size") + // %coro.alloc = call i8* runtime.alloc(i32 %coro.size) + coroAlloc := c.builder.CreateCall(c.alloc, []llvm.Value{coroSize, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.alloc") + // %coro.state = call noalias i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + coroState := c.builder.CreateCall(c.coroBegin, []llvm.Value{coroId, coroAlloc}, "coro.state") + // Store state into task. + task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task") + parentState := c.builder.CreateCall(c.setState, []llvm.Value{task, coroState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.state.parent") + // Get return pointer if needed. + var retPtr llvm.Value + if fn.hasValueStoreReturn() { + retPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.retPtr") + retPtr = c.builder.CreateBitCast(retPtr, llvm.PointerType(fn.fn.Type().ElementType().ReturnType(), 0), "task.retPtr.bitcast") + } + + // Build suspend block. + // This is executed when the coroutine is about to suspend. + suspend := c.ctx.AddBasicBlock(fn.fn, "suspend") + c.builder.SetInsertPointAtEnd(suspend) + // %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + c.builder.CreateCall(c.coroEnd, []llvm.Value{coroState, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "unused") + // Insert return. + if returnType.TypeKind() == llvm.VoidTypeKind { + c.builder.CreateRetVoid() + } else { + c.builder.CreateRet(llvm.Undef(returnType)) + } + + // Build cleanup block. + // This is executed before the function returns in order to clean up resources. + cleanup := c.ctx.AddBasicBlock(fn.fn, "cleanup") + c.builder.SetInsertPointAtEnd(cleanup) + // %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + coroMemFree := c.builder.CreateCall(c.coroFree, []llvm.Value{coroId, coroState}, "coro.memFree") + // call i8* runtime.free(i8* %coro.memFree) + c.builder.CreateCall(c.free, []llvm.Value{coroMemFree, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + // Branch to suspend block. + c.builder.CreateBr(suspend) + + // Restore old state before tail calls. + for call := range fn.tailCalls { + if fn.returns[call.InstructionParent()] == returnDeadTail { + // Callee never returns, so the state restore is ineffectual. + continue + } + + c.builder.SetInsertPointBefore(call) + c.builder.CreateCall(c.setState, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.state.restore") + } + + // Lower returns. + for ret, kind := range fn.returns { + // Get terminator instruction. + terminator := ret.LastInstruction() + + // Get tail call if applicable. + var call llvm.Value + switch kind { + case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: + call = llvm.PrevInstruction(terminator) + } + + switch kind { + case returnNormal: + c.builder.SetInsertPointBefore(terminator) + + // Store value into return pointer. + c.builder.CreateStore(terminator.Operand(0), retPtr) + + // Resume caller. + c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnVoid: + c.builder.SetInsertPointBefore(terminator) + + // Resume caller. + c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnVoidTail, returnTail, returnDeadTail: + // Nothing to do. + case returnAlternateTail: + c.builder.SetInsertPointBefore(call) + + // Store return value. + c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) + + // Heap-allocate a return buffer for the discarded return. + alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnDitchedTail: + c.builder.SetInsertPointBefore(call) + + // Heap-allocate a return buffer for the discarded return. + ditchBuf := c.heapAlloc(call.Type(), "ret.ditch") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + case returnDelayedValue: + c.builder.SetInsertPointBefore(call) + + // Store return value. + c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) + } + + // Delete call if it is a pause, because it has already been lowered. + if !call.IsNil() && call.CalledValue() == c.pause { + call.EraseFromParentAsInstruction() + } + + // Replace terminator with branch to cleanup. + terminator.EraseFromParentAsInstruction() + c.builder.SetInsertPointAtEnd(ret) + c.builder.CreateBr(cleanup) + } + + // Lower regular calls. + for call := range fn.normalCalls { + // Lower return value of call. + c.lowerCallReturn(fn, call) + + // Get originating basic block. + bb := call.InstructionParent() + + // Split block. + wakeup := llvmutil.SplitBasicBlock(c.builder, call, llvm.NextBasicBlock(bb), "wakeup") + + // Insert suspension and switch. + c.builder.SetInsertPointAtEnd(bb) + // %coro.save = call token @llvm.coro.save(i8* %coro.state) + save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save") + // %call.suspend = llvm.coro.suspend(token %coro.save, i1 false) + // switch i8 %call.suspend, label %suspend [i8 0, label %wakeup + // i8 1, label %cleanup] + suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend") + sw := c.builder.CreateSwitch(suspendValue, suspend, 2) + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup) + sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup) + + // Delete call if it is a pause, because it has already been lowered. + if call.CalledValue() == c.pause { + call.EraseFromParentAsInstruction() + } + } +} + +// lowerCurrent lowers calls to internal/task.Current to bitcasts. +func (c *coroutineLoweringPass) lowerCurrent() error { + taskType := c.current.Type().ElementType().ReturnType() + deleteQueue := []llvm.Value{} + for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() { + // Get user. + user := use.User() + + if user.IsACallInst().IsNil() || user.CalledValue() != c.current { + return errorAt(user, "unexpected non-call use of task.Current") + } + + // Replace with bitcast. + c.builder.SetInsertPointBefore(user) + raw := user.Operand(1) + if !raw.IsAUndefValue().IsNil() || raw.IsNull() { + return errors.New("undefined task") + } + task := c.builder.CreateBitCast(raw, taskType, "task.current") + user.ReplaceAllUsesWith(task) + deleteQueue = append(deleteQueue, user) + } + + // Delete calls. + for _, inst := range deleteQueue { + inst.EraseFromParentAsInstruction() + } + + return nil +} + +// lowerStart lowers a goroutine start into a task creation and call or a synchronous call. +func (c *coroutineLoweringPass) lowerStart(start llvm.Value) { + c.builder.SetInsertPointBefore(start) + + // Get function to call. + fn := start.Operand(0).Operand(0) + + if _, ok := c.asyncFuncs[fn]; !ok { + // Turn into synchronous call. + c.lowerStartSync(start) + return + } + + // Create the list of params for the call. + paramTypes := fn.Type().ElementType().ParamTypes() + params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1]) + + // Create task. + task := c.builder.CreateCall(c.createTask, []llvm.Value{llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "start.task") + rawTask := c.builder.CreateBitCast(task, c.i8ptr, "start.task.bitcast") + params = append(params, rawTask) + + // Generate a return buffer if necessary. + returnType := fn.Type().ElementType().ReturnType() + if returnType.TypeKind() == llvm.VoidTypeKind { + // No return buffer necessary for a void return. + } else { + // Check for any undead returns. + var undead bool + for _, kind := range c.asyncFuncs[fn].returns { + if kind != returnDeadTail { + // This return results in a value being eventually stored. + undead = true + break + } + } + if undead { + // The function stores a value into a return buffer, so we need to create one. + retBuf := c.heapAlloc(returnType, "ret.ditch") + c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + } + } + + // Generate call to function. + c.builder.CreateCall(fn, params, "") + + // Erase start call. + start.EraseFromParentAsInstruction() +} + +// lowerStartsPass lowers all goroutine starts. +func (c *coroutineLoweringPass) lowerStartsPass() { + starts := []llvm.Value{} + for use := c.start.FirstUse(); !use.IsNil(); use = use.NextUse() { + starts = append(starts, use.User()) + } + for _, start := range starts { + c.lowerStart(start) + } +} + +func (c *coroutineLoweringPass) fixAnnotations() { + for f := range c.asyncFuncs { + // These properties were added by the functionattrs pass. Remove + // them, because now we start using the parameter. + // https://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes + for _, kind := range []string{"nocapture", "readnone"} { + kindID := llvm.AttributeKindID(kind) + n := f.ParamsCount() + for i := 0; i <= n; i++ { + f.RemoveEnumAttributeAtIndex(i, kindID) + } + } + } +} + +// trackGoroutines adds runtime.trackPointer calls to track goroutine starts and data. +func (c *coroutineLoweringPass) trackGoroutines() error { + trackPointer := c.mod.NamedFunction("runtime.trackPointer") + if trackPointer.IsNil() { + return ErrMissingIntrinsic{"runtime.trackPointer"} + } + + trackFunctions := []llvm.Value{c.createTask, c.setState, c.getRetPtr} + for _, fn := range trackFunctions { + for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() { + call := use.User() + + c.builder.SetInsertPointBefore(llvm.NextInstruction(call)) + ptr := call + if ptr.Type() != c.i8ptr { + ptr = c.builder.CreateBitCast(call, c.i8ptr, "") + } + c.builder.CreateCall(trackPointer, []llvm.Value{ptr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") + } + } + + return nil +} diff --git a/transform/errors.go b/transform/errors.go index 226ead28..73ed0970 100644 --- a/transform/errors.go +++ b/transform/errors.go @@ -46,3 +46,12 @@ func getPosition(val llvm.Value) token.Position { return token.Position{} } } + +// ErrMissingIntrinsic is an error indicating that a required intrinsic was not found in the module. +type ErrMissingIntrinsic struct { + Name string +} + +func (err ErrMissingIntrinsic) Error() string { + return "missing intrinsic: " + err.Name +} diff --git a/transform/func-lowering.go b/transform/func-lowering.go index fa6eb0e4..9687cc76 100644 --- a/transform/func-lowering.go +++ b/transform/func-lowering.go @@ -181,29 +181,15 @@ func LowerFuncValues(mod llvm.Module) { // Remove some casts, checks, and the old call which we're going // to replace. for _, callIntPtr := range getUses(getFuncPtrCall) { - if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "runtime.makeGoroutine" { - // Special case for runtime.makeGoroutine. - for _, inttoptr := range getUses(callIntPtr) { - if inttoptr.IsAIntToPtrInst().IsNil() { - panic("expected a inttoptr") - } - for _, use := range getUses(inttoptr) { - addFuncLoweringSwitch(mod, builder, funcID, use, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { - // The function lowering switch code passes in a parent handle value. - // Set the parent handle to null here because it is irrelevant to goroutine starts. - i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) - params[len(params)-1] = llvm.ConstPointerNull(i8ptrType) - calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "") - makeGoroutine := mod.NamedFunction("runtime.makeGoroutine") - calleeValue = builder.CreateCall(makeGoroutine, []llvm.Value{calleeValue, llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") - calleeValue = builder.CreateIntToPtr(calleeValue, funcPtr.Type(), "") - builder.CreateCall(calleeValue, params, "") - return llvm.Value{} // void so no return value - }, functions) - use.EraseFromParentAsInstruction() - } - inttoptr.EraseFromParentAsInstruction() - } + if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" { + // Special case for goroutine starts. + addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) + calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "") + start := mod.NamedFunction("internal/task.start") + builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") + return llvm.Value{} // void so no return value + }, functions) callIntPtr.EraseFromParentAsInstruction() continue } diff --git a/transform/goroutine_test.go b/transform/goroutine_test.go new file mode 100644 index 00000000..e63b4923 --- /dev/null +++ b/transform/goroutine_test.go @@ -0,0 +1,16 @@ +package transform + +import ( + "testing" + "tinygo.org/x/go-llvm" +) + +func TestGoroutineLowering(t *testing.T) { + t.Parallel() + testTransform(t, "testdata/coroutines", func(mod llvm.Module) { + err := LowerCoroutines(mod, false) + if err != nil { + panic(err) + } + }) +} diff --git a/transform/testdata/coroutines.ll b/transform/testdata/coroutines.ll new file mode 100644 index 00000000..f2ace5a0 --- /dev/null +++ b/transform/testdata/coroutines.ll @@ -0,0 +1,120 @@ +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" +target triple = "armv7m-none-eabi" + +%"internal/task.state" = type { i8* } +%"internal/task.Task" = type { %"internal/task.Task", i8*, i32, %"internal/task.state" } + +declare void @"internal/task.start"(i32, i8*, i8*, i8*) +declare void @"internal/task.Pause"(i8*, i8*) + +declare void @runtime.scheduler(i8*, i8*) + +declare i8* @runtime.alloc(i32, i8*, i8*) +declare void @runtime.free(i8*, i8*, i8*) + +declare %"internal/task.Task"* @"internal/task.Current"(i8*, i8*) + +declare i8* @"(*internal/task.Task).setState"(%"internal/task.Task"*, i8*, i8*, i8*) +declare void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"*, i8*, i8*, i8*) +declare i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"*, i8*, i8*) +declare void @"(*internal/task.Task).returnTo"(%"internal/task.Task"*, i8*, i8*, i8*) +declare void @"(*internal/task.Task).returnCurrent"(%"internal/task.Task"*, i8*, i8*) +declare %"internal/task.Task"* @"internal/task.createTask"(i8*, i8*) + +declare void @callMain(i8*, i8*) + +; Test a simple sleep-like scenario. +declare void @enqueueTimer(%"internal/task.Task"*, i64, i8*, i8*) + +define void @sleep(i64, i8*, i8* %parentHandle) { +entry: + %2 = call %"internal/task.Task"* @"internal/task.Current"(i8* undef, i8* null) + call void @enqueueTimer(%"internal/task.Task"* %2, i64 %0, i8* undef, i8* null) + call void @"internal/task.Pause"(i8* undef, i8* null) + ret void +} + +; Test a delayed value return. +define i32 @delayedValue(i32, i64, i8*, i8* %parentHandle) { +entry: + call void @sleep(i64 %1, i8* undef, i8* null) + ret i32 %0 +} + +; Test a deadlocking async func. +define void @deadlock(i8*, i8* %parentHandle) { +entry: + call void @"internal/task.Pause"(i8* undef, i8* null) + unreachable +} + +; Test a regular tail call. +define i32 @tail(i32, i64, i8*, i8* %parentHandle) { +entry: + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* null) + ret i32 %3 +} + +; Test a ditching tail call. +define void @ditchTail(i32, i64, i8*, i8* %parentHandle) { +entry: + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* null) + ret void +} + +; Test a void tail call. +define void @voidTail(i32, i64, i8*, i8* %parentHandle) { +entry: + call void @ditchTail(i32 %0, i64 %1, i8* undef, i8* null) + ret void +} + +; Test a tail call returning an alternate value. +define i32 @alternateTail(i32, i32, i64, i8*, i8* %parentHandle) { +entry: + %4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* null) + ret i32 %0 +} + +; Test a normal return from a coroutine. +; This must be turned into a coroutine. +define i1 @coroutine(i32, i64, i8*, i8* %parentHandle) { +entry: + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* null) + %4 = icmp eq i32 %3, 0 + ret i1 %4 +} + +; Normal function which should not be transformed. +define void @doNothing(i8*, i8*) { +entry: + ret void +} + +; Goroutine that sleeps and does nothing. +; Should be a void tail call. +define void @sleepGoroutine(i8*, i8* %parentHandle) { + call void @sleep(i64 1000000, i8* undef, i8* null) + ret void +} + +; Program main function. +define void @progMain(i8*, i8* %parentHandle) { +entry: + ; Call a sync func in a goroutine. + call void @"internal/task.start"(i32 ptrtoint (void (i8*, i8*)* @doNothing to i32), i8* undef, i8* undef, i8* null) + ; Call an async func in a goroutine. + call void @"internal/task.start"(i32 ptrtoint (void (i8*, i8*)* @sleepGoroutine to i32), i8* undef, i8* undef, i8* null) + ; Sleep a bit. + call void @sleep(i64 2000000, i8* undef, i8* null) + ; Done. + ret void +} + +; Entrypoint of runtime. +define void @main() { +entry: + call void @"internal/task.start"(i32 ptrtoint (void (i8*, i8*)* @progMain to i32), i8* undef, i8* undef, i8* null) + call void @runtime.scheduler(i8* undef, i8* null) + ret void +} diff --git a/transform/testdata/coroutines.out.ll b/transform/testdata/coroutines.out.ll new file mode 100644 index 00000000..ec03650b --- /dev/null +++ b/transform/testdata/coroutines.out.ll @@ -0,0 +1,176 @@ +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" +target triple = "armv7m-none-eabi" + +%"internal/task.Task" = type { %"internal/task.Task", i8*, i32, %"internal/task.state" } +%"internal/task.state" = type { i8* } + +declare void @"internal/task.start"(i32, i8*, i8*, i8*) +declare void @"internal/task.Pause"(i8*, i8*) + +declare void @runtime.scheduler(i8*, i8*) + +declare i8* @runtime.alloc(i32, i8*, i8*) +declare void @runtime.free(i8*, i8*, i8*) + +declare %"internal/task.Task"* @"internal/task.Current"(i8*, i8*) + +declare i8* @"(*internal/task.Task).setState"(%"internal/task.Task"*, i8*, i8*, i8*) + +declare void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"*, i8*, i8*, i8*) +declare i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"*, i8*, i8*) + +declare void @"(*internal/task.Task).returnTo"(%"internal/task.Task"*, i8*, i8*, i8*) +declare void @"(*internal/task.Task).returnCurrent"(%"internal/task.Task"*, i8*, i8*) + +declare %"internal/task.Task"* @"internal/task.createTask"(i8*, i8*) + +declare void @callMain(i8*, i8*) + +declare void @enqueueTimer(%"internal/task.Task"*, i64, i8*, i8*) +define void @sleep(i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.current1 = bitcast i8* %parentHandle to %"internal/task.Task"* + call void @enqueueTimer(%"internal/task.Task"* %task.current1, i64 %0, i8* undef, i8* null) + ret void +} + +define i32 @delayedValue(i32, i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* + store i32 %0, i32* %ret.ptr.bitcast + call void @sleep(i64 %1, i8* undef, i8* %parentHandle) + ret i32 undef +} + +define void @deadlock(i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + ret void +} + +define i32 @tail(i32, i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* %parentHandle) + ret i32 undef +} + +define void @ditchTail(i32, i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %ret.ditch = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.ditch, i8* undef, i8* undef) + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* %parentHandle) + ret void +} + +define void @voidTail(i32, i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + call void @ditchTail(i32 %0, i64 %1, i8* undef, i8* %parentHandle) + ret void +} + +define i32 @alternateTail(i32, i32, i64, i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef) + %ret.ptr.bitcast = bitcast i8* %ret.ptr to i32* + store i32 %0, i32* %ret.ptr.bitcast + %ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef) + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef) + %4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle) + ret i32 undef +} + +define i1 @coroutine(i32, i64, i8*, i8* %parentHandle) { +entry: + %call.return = alloca i32 + %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %coro.size = call i32 @llvm.coro.size.i32() + %coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef) + %coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) + %task.current2 = bitcast i8* %parentHandle to %"internal/task.Task"* + %task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current2, i8* %coro.state, i8* undef, i8* undef) + %task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current2, i8* undef, i8* undef) + %task.retPtr.bitcast = bitcast i8* %task.retPtr to i1* + %call.return.bitcast = bitcast i32* %call.return to i8* + call void @llvm.lifetime.start.p0i8(i64 4, i8* %call.return.bitcast) + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + %call.return.bitcast1 = bitcast i32* %call.return to i8* + call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %call.return.bitcast1, i8* undef, i8* undef) + %3 = call i32 @delayedValue(i32 %0, i64 %1, i8* undef, i8* %parentHandle) + %coro.save = call token @llvm.coro.save(i8* %coro.state) + %call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false) + switch i8 %call.suspend, label %suspend [ + i8 0, label %wakeup + i8 1, label %cleanup + ] + +wakeup: ; preds = %entry + %4 = load i32, i32* %call.return + call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast) + %5 = icmp eq i32 %4, 0 + store i1 %5, i1* %task.retPtr.bitcast + call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef) + br label %cleanup + +suspend: ; preds = %entry, %cleanup + %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) + ret i1 undef + +cleanup: ; preds = %entry, %wakeup + %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) + call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef) + br label %suspend +} + +define void @doNothing(i8*, i8*) { +entry: + ret void +} + +define void @sleepGoroutine(i8*, i8* %parentHandle) { + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + call void @sleep(i64 1000000, i8* undef, i8* %parentHandle) + ret void +} + +define void @progMain(i8*, i8* %parentHandle) { +entry: + %task.current = bitcast i8* %parentHandle to %"internal/task.Task"* + call void @doNothing(i8* undef, i8* undef) + %start.task = call %"internal/task.Task"* @"internal/task.createTask"(i8* undef, i8* undef) + %start.task.bitcast = bitcast %"internal/task.Task"* %start.task to i8* + call void @sleepGoroutine(i8* undef, i8* %start.task.bitcast) + call void @sleep(i64 2000000, i8* undef, i8* %parentHandle) + ret void +} + +define void @main() { +entry: + %start.task = call %"internal/task.Task"* @"internal/task.createTask"(i8* undef, i8* undef) + %start.task.bitcast = bitcast %"internal/task.Task"* %start.task to i8* + call void @progMain(i8* undef, i8* %start.task.bitcast) + call void @runtime.scheduler(i8* undef, i8* null) + ret void +} + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #0 +declare i32 @llvm.coro.size.i32() #1 +declare i8* @llvm.coro.begin(token, i8* writeonly) #2 +declare i8 @llvm.coro.suspend(token, i1) #2 +declare i1 @llvm.coro.end(i8*, i1) #2 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #0 +declare token @llvm.coro.save(i8*) #2 + +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #3 +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #3 + +attributes #0 = { argmemonly nounwind readonly } +attributes #1 = { nounwind readnone } +attributes #2 = { nounwind } +attributes #3 = { argmemonly nounwind } diff --git a/transform/testdata/func-lowering.ll b/transform/testdata/func-lowering.ll index 8bcc0a40..b9aea6c0 100644 --- a/transform/testdata/func-lowering.ll +++ b/transform/testdata/func-lowering.ll @@ -15,7 +15,7 @@ target triple = "wasm32-unknown-unknown-wasm" declare i32 @runtime.getFuncPtr(i8*, i32, %runtime.typecodeID*, i8*, i8*) -declare i32 @runtime.makeGoroutine(i32, i8*, i8*) +declare void @"internal/task.start"(i32, i8*, i8*, i8*) declare void @runtime.nilPanic(i8*, i8*) @@ -71,13 +71,10 @@ fpcall.next: ret void } -; Special case for runtime.makeGoroutine. +; Special case for internal/task.start. define void @sleepFuncValue(i8*, i32, i8* nocapture readnone %context, i8* nocapture readnone %parentHandle) { entry: %2 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{basic:int}{}", i8* undef, i8* null) - %3 = call i32 @runtime.makeGoroutine(i32 %2, i8* undef, i8* null) - %4 = inttoptr i32 %3 to void (i32, i8*, i8*)* - call void %4(i32 8, i8* %0, i8* null) + call void @"internal/task.start"(i32 %2, i8* null, i8* undef, i8* null) ret void } - diff --git a/transform/testdata/func-lowering.out.ll b/transform/testdata/func-lowering.out.ll index 01d66835..af63fdc1 100644 --- a/transform/testdata/func-lowering.out.ll +++ b/transform/testdata/func-lowering.out.ll @@ -15,7 +15,7 @@ target triple = "wasm32-unknown-unknown-wasm" declare i32 @runtime.getFuncPtr(i8*, i32, %runtime.typecodeID*, i8*, i8*) -declare i32 @runtime.makeGoroutine(i32, i8*, i8*) +declare void @"internal/task.start"(i32, i8*, i8*, i8*) declare void @runtime.nilPanic(i8*, i8*) @@ -102,15 +102,11 @@ func.nil: unreachable func.call1: - %2 = call i32 @runtime.makeGoroutine(i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), i8* undef, i8* null) - %3 = inttoptr i32 %2 to void (i32, i8*, i8*)* - call void %3(i32 8, i8* %0, i8* null) + call void @"internal/task.start"(i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), i8* null, i8* undef, i8* null) br label %func.next func.call2: - %4 = call i32 @runtime.makeGoroutine(i32 ptrtoint (void (i32, i8*, i8*)* @"main$2" to i32), i8* undef, i8* null) - %5 = inttoptr i32 %4 to void (i32, i8*, i8*)* - call void %5(i32 8, i8* %0, i8* null) + call void @"internal/task.start"(i32 ptrtoint (void (i32, i8*, i8*)* @"main$2" to i32), i8* null, i8* undef, i8* null) br label %func.next func.next: