diff --git a/compiler/compiler.go b/compiler/compiler.go index 8b23e971..e863508d 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1338,7 +1338,10 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon) (llvm.Value, e } // This is a func value, which cannot be called directly. We have to // extract the function pointer and context first from the func value. - funcPtr, context := c.decodeFuncValue(value) + funcPtr, context, err := c.decodeFuncValue(value, instr.Value.Type().(*types.Signature)) + if err != nil { + return llvm.Value{}, err + } c.emitNilCheck(frame, funcPtr, "fpcall") return c.parseFunctionCall(frame, instr.Args, funcPtr, context, false) } @@ -1508,7 +1511,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if fn.IsExported() { return llvm.Value{}, c.makeError(expr.Pos(), "cannot use an exported function as value") } - return c.createFuncValue(fn.LLVMFn) + return c.createFuncValue(fn.LLVMFn, llvm.Undef(c.i8ptrType), fn.Signature) case *ssa.Global: value := c.ir.GetGlobal(expr).LLVMGlobal if value.IsNil() { diff --git a/compiler/func-lowering.go b/compiler/func-lowering.go new file mode 100644 index 00000000..02363f9b --- /dev/null +++ b/compiler/func-lowering.go @@ -0,0 +1,269 @@ +package compiler + +// This file lowers func values into their final form. This is necessary for +// funcValueSwitch, which needs full program analysis. + +import ( + "sort" + "strconv" + + "tinygo.org/x/go-llvm" +) + +// funcSignatureInfo keeps information about a single signature and its uses. +type funcSignatureInfo struct { + sig llvm.Value // *uint8 to identify the signature + funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature +} + +// funcWithUses keeps information about a single function used as func value and +// the assigned function ID. More commonly used functions are assigned a lower +// ID. +type funcWithUses struct { + funcPtr llvm.Value + useCount int // how often this function is used in a func value + id int // assigned ID +} + +// Slice to sort functions by their use counts, or else their name if they're +// used equally often. +type funcWithUsesList []*funcWithUses + +func (l funcWithUsesList) Len() int { return len(l) } +func (l funcWithUsesList) Less(i, j int) bool { + if l[i].useCount != l[j].useCount { + // return the reverse: we want the highest use counts sorted first + return l[i].useCount > l[j].useCount + } + iName := l[i].funcPtr.Name() + jName := l[j].funcPtr.Name() + return iName < jName +} +func (l funcWithUsesList) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +// LowerFuncValue lowers the runtime.funcValueWithSignature type and +// runtime.getFuncPtr function to their final form. +func (c *Compiler) LowerFuncValues() { + if c.funcImplementation() != funcValueSwitch { + return + } + + // Find all func values used in the program with their signatures. + funcValueWithSignaturePtr := llvm.PointerType(c.mod.GetTypeByName("runtime.funcValueWithSignature"), 0) + signatures := map[string]*funcSignatureInfo{} + for global := c.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { + if global.Type() != funcValueWithSignaturePtr { + continue + } + sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1}) + name := sig.Name() + if info, ok := signatures[name]; ok { + info.funcValueWithSignatures = append(info.funcValueWithSignatures, global) + } else { + signatures[name] = &funcSignatureInfo{ + sig: sig, + funcValueWithSignatures: []llvm.Value{global}, + } + } + } + + // Sort the signatures, for deterministic execution. + names := make([]string, 0, len(signatures)) + for name := range signatures { + names = append(names, name) + } + sort.Strings(names) + + for _, name := range names { + info := signatures[name] + functions := make(funcWithUsesList, len(info.funcValueWithSignatures)) + for i, use := range info.funcValueWithSignatures { + var useCount int + for _, use2 := range getUses(use) { + useCount += len(getUses(use2)) + } + functions[i] = &funcWithUses{ + funcPtr: llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0), + useCount: useCount, + } + } + sort.Sort(functions) + + for i, fn := range functions { + fn.id = i + 1 + for _, ptrtoint := range getUses(fn.funcPtr) { + if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { + continue + } + for _, funcValueWithSignatureConstant := range getUses(ptrtoint) { + for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) { + for _, use := range getUses(funcValueWithSignatureGlobal) { + if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { + panic("expected const ptrtoint") + } + use.ReplaceAllUsesWith(llvm.ConstInt(c.uintptrType, uint64(fn.id), false)) + } + } + } + } + } + + for _, getFuncPtrCall := range getUses(info.sig) { + if getFuncPtrCall.IsACallInst().IsNil() { + continue + } + if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" { + panic("expected all call uses to be runtime.getFuncPtr") + } + funcID := getFuncPtrCall.Operand(1) + switch len(functions) { + case 0: + // There are no functions used in a func value that implement + // this signature. The only possible value is a nil value. + for _, inttoptr := range getUses(getFuncPtrCall) { + if inttoptr.IsAIntToPtrInst().IsNil() { + panic("expected inttoptr") + } + nilptr := llvm.ConstPointerNull(inttoptr.Type()) + inttoptr.ReplaceAllUsesWith(nilptr) + inttoptr.EraseFromParentAsInstruction() + } + getFuncPtrCall.EraseFromParentAsInstruction() + case 1: + // There is exactly one function with this signature that is + // used in a func value. The func value itself can be either nil + // or this one function. + c.builder.SetInsertPointBefore(getFuncPtrCall) + zero := llvm.ConstInt(c.uintptrType, 0, false) + isnil := c.builder.CreateICmp(llvm.IntEQ, funcID, zero, "") + funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type()) + funcPtr := c.builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "") + for _, inttoptr := range getUses(getFuncPtrCall) { + if inttoptr.IsAIntToPtrInst().IsNil() { + panic("expected inttoptr") + } + inttoptr.ReplaceAllUsesWith(funcPtr) + inttoptr.EraseFromParentAsInstruction() + } + getFuncPtrCall.EraseFromParentAsInstruction() + default: + // There are multiple functions used in a func value that + // implement this signature. + // What we'll do is transform the following: + // rawPtr := runtime.getFuncPtr(fn) + // if func.rawPtr == nil { + // runtime.nilpanic() + // } + // result := func.rawPtr(...args, func.context) + // into this: + // if false { + // runtime.nilpanic() + // } + // var result // Phi + // switch fn.id { + // case 0: + // runtime.nilpanic() + // case 1: + // result = call first implementation... + // case 2: + // result = call second implementation... + // default: + // unreachable + // } + + // Remove some casts, checks, and the old call which we're going + // to replace. + var funcCall llvm.Value + for _, inttoptr := range getUses(getFuncPtrCall) { + if inttoptr.IsAIntToPtrInst().IsNil() { + panic("expected inttoptr") + } + for _, ptrUse := range getUses(inttoptr) { + if !ptrUse.IsABitCastInst().IsNil() { + for _, bitcastUse := range getUses(ptrUse) { + if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" { + panic("expected a call to runtime.isnil") + } + bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false)) + bitcastUse.EraseFromParentAsInstruction() + } + ptrUse.EraseFromParentAsInstruction() + } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr { + if !funcCall.IsNil() { + panic("multiple calls on a single runtime.getFuncPtr") + } + funcCall = ptrUse + } else { + panic("unexpected getFuncPtrCall") + } + } + } + if funcCall.IsNil() { + panic("expected exactly one call use of a runtime.getFuncPtr") + } + + // The block that cannot be reached with correct funcValues (to + // help the optimizer). + c.builder.SetInsertPointBefore(funcCall) + defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default") + c.builder.SetInsertPointAtEnd(defaultBlock) + c.builder.CreateUnreachable() + + // Create the switch. + c.builder.SetInsertPointBefore(funcCall) + sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) + + // Split right after the switch. We will need to insert a few + // basic blocks in this gap. + nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") + + // The 0 case, which is actually a nil check. + nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil") + c.builder.SetInsertPointAtEnd(nilBlock) + c.createRuntimeCall("nilpanic", nil, "") + c.builder.CreateUnreachable() + sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock) + + // Gather the list of parameters for every call we're going to + // make. + callParams := make([]llvm.Value, funcCall.OperandsCount()-1) + for i := range callParams { + callParams[i] = funcCall.Operand(i) + } + + // If the call produces a value, we need to get it using a PHI + // node. + phiBlocks := make([]llvm.BasicBlock, len(functions)) + phiValues := make([]llvm.Value, len(functions)) + for i, fn := range functions { + // Insert a switch case. + bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) + c.builder.SetInsertPointAtEnd(bb) + result := c.builder.CreateCall(fn.funcPtr, callParams, "") + c.builder.CreateBr(nextBlock) + sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb) + phiBlocks[i] = bb + phiValues[i] = result + } + // Create the PHI node so that the call result flows into the + // next block (after the split). This is only necessary when the + // call produced a value. + if funcCall.Type().TypeKind() != llvm.VoidTypeKind { + c.builder.SetInsertPointBefore(nextBlock.FirstInstruction()) + phi := c.builder.CreatePHI(funcCall.Type(), "") + phi.AddIncoming(phiValues, phiBlocks) + funcCall.ReplaceAllUsesWith(phi) + } + + // Finally, remove the old instructions. + funcCall.EraseFromParentAsInstruction() + for _, inttoptr := range getUses(getFuncPtrCall) { + inttoptr.EraseFromParentAsInstruction() + } + getFuncPtrCall.EraseFromParentAsInstruction() + } + } + } +} diff --git a/compiler/func.go b/compiler/func.go index 68db1fc8..29ceb613 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -1,10 +1,7 @@ package compiler -// This file implements function values and closures. A func value is -// implemented as a pair of pointers: {context, function pointer}, where the -// context may be a pointer to a heap-allocated struct containing the free -// variables, or it may be undef if the function being pointed to doesn't need a -// context. +// This file implements function values and closures. It may need some lowering +// in a later step, see func-lowering.go. import ( "go/types" @@ -13,14 +10,86 @@ import ( "tinygo.org/x/go-llvm" ) +type funcValueImplementation int + +const ( + funcValueNone funcValueImplementation = iota + + // A func value is implemented as a pair of pointers: + // {context, function pointer} + // where the context may be a pointer to a heap-allocated struct containing + // the free variables, or it may be undef if the function being pointed to + // doesn't need a context. The function pointer is a regular function + // pointer. + funcValueDoubleword + + // As funcValueDoubleword, but with the function pointer replaced by a + // unique ID per function signature. Function values are called by using a + // switch statement and choosing which function to call. + funcValueSwitch +) + +// funcImplementation picks an appropriate func value implementation for the +// target. +func (c *Compiler) funcImplementation() funcValueImplementation { + if c.GOARCH == "wasm" { + return funcValueSwitch + } else { + return funcValueDoubleword + } +} + // createFuncValue creates a function value from a raw function pointer with no // context. -func (c *Compiler) createFuncValue(funcPtr llvm.Value) (llvm.Value, error) { - // Closure is: {context, function pointer} - return c.ctx.ConstStruct([]llvm.Value{ - llvm.Undef(c.i8ptrType), - funcPtr, - }, false), nil +func (c *Compiler) createFuncValue(funcPtr, context llvm.Value, sig *types.Signature) (llvm.Value, error) { + var funcValueScalar llvm.Value + switch c.funcImplementation() { + case funcValueDoubleword: + // Closure is: {context, function pointer} + funcValueScalar = funcPtr + case funcValueSwitch: + sigGlobal := c.getFuncSignature(sig) + funcValueWithSignatureGlobalName := funcPtr.Name() + "$withSignature" + funcValueWithSignatureGlobal := c.mod.NamedGlobal(funcValueWithSignatureGlobalName) + if funcValueWithSignatureGlobal.IsNil() { + funcValueWithSignatureType := c.mod.GetTypeByName("runtime.funcValueWithSignature") + funcValueWithSignature := llvm.ConstNamedStruct(funcValueWithSignatureType, []llvm.Value{ + llvm.ConstPtrToInt(funcPtr, c.uintptrType), + sigGlobal, + }) + funcValueWithSignatureGlobal = llvm.AddGlobal(c.mod, funcValueWithSignatureType, funcValueWithSignatureGlobalName) + funcValueWithSignatureGlobal.SetInitializer(funcValueWithSignature) + funcValueWithSignatureGlobal.SetGlobalConstant(true) + funcValueWithSignatureGlobal.SetLinkage(llvm.InternalLinkage) + } + funcValueScalar = llvm.ConstPtrToInt(funcValueWithSignatureGlobal, c.uintptrType) + default: + panic("unimplemented func value variant") + } + funcValueType, err := c.getFuncType(sig) + if err != nil { + return llvm.Value{}, err + } + funcValue := llvm.Undef(funcValueType) + funcValue = c.builder.CreateInsertValue(funcValue, context, 0, "") + funcValue = c.builder.CreateInsertValue(funcValue, funcValueScalar, 1, "") + return funcValue, nil +} + +// getFuncSignature returns a global for identification of a particular function +// signature. It is used in runtime.funcValueWithSignature and in calls to +// getFuncPtr. +func (c *Compiler) getFuncSignature(sig *types.Signature) llvm.Value { + typeCodeName := getTypeCodeName(sig) + sigGlobalName := "reflect/types.type:" + typeCodeName + sigGlobal := c.mod.NamedGlobal(sigGlobalName) + if sigGlobal.IsNil() { + sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName) + sigGlobal.SetInitializer(llvm.Undef(c.ctx.Int8Type())) + sigGlobal.SetGlobalConstant(true) + sigGlobal.SetLinkage(llvm.InternalLinkage) + } + return sigGlobal } // extractFuncScalar returns some scalar that can be used in comparisons. It is @@ -37,19 +106,39 @@ func (c *Compiler) extractFuncContext(funcValue llvm.Value) llvm.Value { // decodeFuncValue extracts the context and the function pointer from this func // value. This may be an expensive operation. -func (c *Compiler) decodeFuncValue(funcValue llvm.Value) (funcPtr, context llvm.Value) { +func (c *Compiler) decodeFuncValue(funcValue llvm.Value, sig *types.Signature) (funcPtr, context llvm.Value, err error) { context = c.builder.CreateExtractValue(funcValue, 0, "") - funcPtr = c.builder.CreateExtractValue(funcValue, 1, "") + switch c.funcImplementation() { + case funcValueDoubleword: + funcPtr = c.builder.CreateExtractValue(funcValue, 1, "") + case funcValueSwitch: + llvmSig, err := c.getRawFuncType(sig) + if err != nil { + return llvm.Value{}, llvm.Value{}, err + } + sigGlobal := c.getFuncSignature(sig) + funcPtr = c.createRuntimeCall("getFuncPtr", []llvm.Value{funcValue, sigGlobal}, "") + funcPtr = c.builder.CreateIntToPtr(funcPtr, llvmSig, "") + default: + panic("unimplemented func value variant") + } return } // getFuncType returns the type of a func value given a signature. func (c *Compiler) getFuncType(typ *types.Signature) (llvm.Type, error) { - rawPtr, err := c.getRawFuncType(typ) - if err != nil { - return llvm.Type{}, err + switch c.funcImplementation() { + case funcValueDoubleword: + rawPtr, err := c.getRawFuncType(typ) + if err != nil { + return llvm.Type{}, err + } + return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil + case funcValueSwitch: + return c.mod.GetTypeByName("runtime.funcValue"), nil + default: + panic("unimplemented func value variant") } - return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil } // getRawFuncType returns a LLVM function pointer type for a given signature. @@ -171,19 +260,6 @@ func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.V context = contextHeapAlloc } - // Get the function signature type, which is a closure type. - // A closure is a tuple of {context, function pointer}. - typ, err := c.getFuncType(f.Signature) - if err != nil { - return llvm.Value{}, err - } - - // Create the closure, which is a struct: {context, function pointer}. - closure, err := c.getZeroValue(typ) - if err != nil { - return llvm.Value{}, err - } - closure = c.builder.CreateInsertValue(closure, f.LLVMFn, 1, "") - closure = c.builder.CreateInsertValue(closure, context, 0, "") - return closure, nil + // Create the closure. + return c.createFuncValue(f.LLVMFn, context, f.Signature) } diff --git a/compiler/interface.go b/compiler/interface.go index ad55945f..46ebee0a 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -396,14 +396,10 @@ func (c *Compiler) getInvokeCall(frame *Frame, instr *ssa.CallCommon) (llvm.Valu return llvm.Value{}, nil, err } - llvmFnType, err := c.getLLVMType(instr.Method.Type()) + llvmFnType, err := c.getRawFuncType(instr.Method.Type().(*types.Signature)) if err != nil { return llvm.Value{}, nil, err } - // getLLVMType() has created a closure type for us, but we don't actually - // want a closure type as an interface call can never be a closure call. So - // extract the function pointer type from the closure. - llvmFnType = llvmFnType.Subtypes()[1] typecode := c.builder.CreateExtractValue(itf, 0, "invoke.typecode") values := []llvm.Value{ diff --git a/compiler/optimizer.go b/compiler/optimizer.go index 0a6f8f59..dfa005da 100644 --- a/compiler/optimizer.go +++ b/compiler/optimizer.go @@ -43,6 +43,7 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro c.OptimizeStringToBytes() c.OptimizeAllocs() c.LowerInterfaces() + c.LowerFuncValues() // After interfaces are lowered, there are many more opportunities for // interprocedural optimizations. To get them to work, function @@ -76,6 +77,7 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro } else { // Must be run at any optimization level. c.LowerInterfaces() + c.LowerFuncValues() err := c.LowerGoroutines() if err != nil { return err diff --git a/src/runtime/func.go b/src/runtime/func.go new file mode 100644 index 00000000..b05c2a27 --- /dev/null +++ b/src/runtime/func.go @@ -0,0 +1,28 @@ +package runtime + +// This file implements some data types that may be useful for some +// implementations of func values. + +import ( + "unsafe" +) + +// funcValue is the underlying type of func values, depending on which func +// value representation was used. +type funcValue struct { + context unsafe.Pointer // function context, for closures and bound methods + id uintptr // ptrtoint of *funcValueWithSignature before lowering, opaque index (non-0) after lowering +} + +// funcValueWithSignature is used before the func lowering pass. +type funcValueWithSignature struct { + funcPtr uintptr // ptrtoint of the actual function pointer + signature *uint8 // pointer to identify this signature (the value is undef) +} + +// getFuncPtr is a dummy function that may be used if the func lowering pass is +// not used. It is generally too slow but may be a useful fallback to debug the +// func lowering pass. +func getFuncPtr(val funcValue, signature *uint8) uintptr { + return (*funcValueWithSignature)(unsafe.Pointer(val.id)).funcPtr +}