compiler: lower func values to switch + direct call

This has several advantages, among them:
  - Many passes (heap-to-stack, dead arg elimination, inlining) do not
    work with function pointer calls. Making them normal function calls
    improves their effectiveness.
  - Goroutine lowering to LLVM coroutines does not currently support
    function pointers. By eliminating function pointers, coroutine
    lowering gets support for them for free.
    This is especially useful for WebAssembly.
Because of the second point, this work is currently only enabled for the
WebAssembly target.
Этот коммит содержится в:
Ayke van Laethem 2019-04-15 15:46:26 +02:00 коммит произвёл Ron Evans
родитель 1460877c28
коммит 2a0a7722f9
6 изменённых файлов: 413 добавлений и 39 удалений

Просмотреть файл

@ -1338,7 +1338,10 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon) (llvm.Value, e
}
// This is a func value, which cannot be called directly. We have to
// extract the function pointer and context first from the func value.
funcPtr, context := c.decodeFuncValue(value)
funcPtr, context, err := c.decodeFuncValue(value, instr.Value.Type().(*types.Signature))
if err != nil {
return llvm.Value{}, err
}
c.emitNilCheck(frame, funcPtr, "fpcall")
return c.parseFunctionCall(frame, instr.Args, funcPtr, context, false)
}
@ -1508,7 +1511,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
if fn.IsExported() {
return llvm.Value{}, c.makeError(expr.Pos(), "cannot use an exported function as value")
}
return c.createFuncValue(fn.LLVMFn)
return c.createFuncValue(fn.LLVMFn, llvm.Undef(c.i8ptrType), fn.Signature)
case *ssa.Global:
value := c.ir.GetGlobal(expr).LLVMGlobal
if value.IsNil() {

269
compiler/func-lowering.go Обычный файл
Просмотреть файл

@ -0,0 +1,269 @@
package compiler
// This file lowers func values into their final form. This is necessary for
// funcValueSwitch, which needs full program analysis.
import (
"sort"
"strconv"
"tinygo.org/x/go-llvm"
)
// funcSignatureInfo keeps information about a single signature and its uses.
type funcSignatureInfo struct {
sig llvm.Value // *uint8 to identify the signature
funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature
}
// funcWithUses keeps information about a single function used as func value and
// the assigned function ID. More commonly used functions are assigned a lower
// ID.
type funcWithUses struct {
funcPtr llvm.Value
useCount int // how often this function is used in a func value
id int // assigned ID
}
// Slice to sort functions by their use counts, or else their name if they're
// used equally often.
type funcWithUsesList []*funcWithUses
func (l funcWithUsesList) Len() int { return len(l) }
func (l funcWithUsesList) Less(i, j int) bool {
if l[i].useCount != l[j].useCount {
// return the reverse: we want the highest use counts sorted first
return l[i].useCount > l[j].useCount
}
iName := l[i].funcPtr.Name()
jName := l[j].funcPtr.Name()
return iName < jName
}
func (l funcWithUsesList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
// LowerFuncValue lowers the runtime.funcValueWithSignature type and
// runtime.getFuncPtr function to their final form.
func (c *Compiler) LowerFuncValues() {
if c.funcImplementation() != funcValueSwitch {
return
}
// Find all func values used in the program with their signatures.
funcValueWithSignaturePtr := llvm.PointerType(c.mod.GetTypeByName("runtime.funcValueWithSignature"), 0)
signatures := map[string]*funcSignatureInfo{}
for global := c.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
if global.Type() != funcValueWithSignaturePtr {
continue
}
sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1})
name := sig.Name()
if info, ok := signatures[name]; ok {
info.funcValueWithSignatures = append(info.funcValueWithSignatures, global)
} else {
signatures[name] = &funcSignatureInfo{
sig: sig,
funcValueWithSignatures: []llvm.Value{global},
}
}
}
// Sort the signatures, for deterministic execution.
names := make([]string, 0, len(signatures))
for name := range signatures {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
info := signatures[name]
functions := make(funcWithUsesList, len(info.funcValueWithSignatures))
for i, use := range info.funcValueWithSignatures {
var useCount int
for _, use2 := range getUses(use) {
useCount += len(getUses(use2))
}
functions[i] = &funcWithUses{
funcPtr: llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0),
useCount: useCount,
}
}
sort.Sort(functions)
for i, fn := range functions {
fn.id = i + 1
for _, ptrtoint := range getUses(fn.funcPtr) {
if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
continue
}
for _, funcValueWithSignatureConstant := range getUses(ptrtoint) {
for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) {
for _, use := range getUses(funcValueWithSignatureGlobal) {
if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
panic("expected const ptrtoint")
}
use.ReplaceAllUsesWith(llvm.ConstInt(c.uintptrType, uint64(fn.id), false))
}
}
}
}
}
for _, getFuncPtrCall := range getUses(info.sig) {
if getFuncPtrCall.IsACallInst().IsNil() {
continue
}
if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" {
panic("expected all call uses to be runtime.getFuncPtr")
}
funcID := getFuncPtrCall.Operand(1)
switch len(functions) {
case 0:
// There are no functions used in a func value that implement
// this signature. The only possible value is a nil value.
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
nilptr := llvm.ConstPointerNull(inttoptr.Type())
inttoptr.ReplaceAllUsesWith(nilptr)
inttoptr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
case 1:
// There is exactly one function with this signature that is
// used in a func value. The func value itself can be either nil
// or this one function.
c.builder.SetInsertPointBefore(getFuncPtrCall)
zero := llvm.ConstInt(c.uintptrType, 0, false)
isnil := c.builder.CreateICmp(llvm.IntEQ, funcID, zero, "")
funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type())
funcPtr := c.builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "")
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
inttoptr.ReplaceAllUsesWith(funcPtr)
inttoptr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
default:
// There are multiple functions used in a func value that
// implement this signature.
// What we'll do is transform the following:
// rawPtr := runtime.getFuncPtr(fn)
// if func.rawPtr == nil {
// runtime.nilpanic()
// }
// result := func.rawPtr(...args, func.context)
// into this:
// if false {
// runtime.nilpanic()
// }
// var result // Phi
// switch fn.id {
// case 0:
// runtime.nilpanic()
// case 1:
// result = call first implementation...
// case 2:
// result = call second implementation...
// default:
// unreachable
// }
// Remove some casts, checks, and the old call which we're going
// to replace.
var funcCall llvm.Value
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
for _, ptrUse := range getUses(inttoptr) {
if !ptrUse.IsABitCastInst().IsNil() {
for _, bitcastUse := range getUses(ptrUse) {
if bitcastUse.IsACallInst().IsNil() || bitcastUse.CalledValue().Name() != "runtime.isnil" {
panic("expected a call to runtime.isnil")
}
bitcastUse.ReplaceAllUsesWith(llvm.ConstInt(c.ctx.Int1Type(), 0, false))
bitcastUse.EraseFromParentAsInstruction()
}
ptrUse.EraseFromParentAsInstruction()
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == inttoptr {
if !funcCall.IsNil() {
panic("multiple calls on a single runtime.getFuncPtr")
}
funcCall = ptrUse
} else {
panic("unexpected getFuncPtrCall")
}
}
}
if funcCall.IsNil() {
panic("expected exactly one call use of a runtime.getFuncPtr")
}
// The block that cannot be reached with correct funcValues (to
// help the optimizer).
c.builder.SetInsertPointBefore(funcCall)
defaultBlock := llvm.AddBasicBlock(funcCall.InstructionParent().Parent(), "func.default")
c.builder.SetInsertPointAtEnd(defaultBlock)
c.builder.CreateUnreachable()
// Create the switch.
c.builder.SetInsertPointBefore(funcCall)
sw := c.builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
// Split right after the switch. We will need to insert a few
// basic blocks in this gap.
nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
// The 0 case, which is actually a nil check.
nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil")
c.builder.SetInsertPointAtEnd(nilBlock)
c.createRuntimeCall("nilpanic", nil, "")
c.builder.CreateUnreachable()
sw.AddCase(llvm.ConstInt(c.uintptrType, 0, false), nilBlock)
// Gather the list of parameters for every call we're going to
// make.
callParams := make([]llvm.Value, funcCall.OperandsCount()-1)
for i := range callParams {
callParams[i] = funcCall.Operand(i)
}
// If the call produces a value, we need to get it using a PHI
// node.
phiBlocks := make([]llvm.BasicBlock, len(functions))
phiValues := make([]llvm.Value, len(functions))
for i, fn := range functions {
// Insert a switch case.
bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
c.builder.SetInsertPointAtEnd(bb)
result := c.builder.CreateCall(fn.funcPtr, callParams, "")
c.builder.CreateBr(nextBlock)
sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(fn.id), false), bb)
phiBlocks[i] = bb
phiValues[i] = result
}
// Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the
// call produced a value.
if funcCall.Type().TypeKind() != llvm.VoidTypeKind {
c.builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := c.builder.CreatePHI(funcCall.Type(), "")
phi.AddIncoming(phiValues, phiBlocks)
funcCall.ReplaceAllUsesWith(phi)
}
// Finally, remove the old instructions.
funcCall.EraseFromParentAsInstruction()
for _, inttoptr := range getUses(getFuncPtrCall) {
inttoptr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
}
}
}
}

Просмотреть файл

@ -1,10 +1,7 @@
package compiler
// This file implements function values and closures. A func value is
// implemented as a pair of pointers: {context, function pointer}, where the
// context may be a pointer to a heap-allocated struct containing the free
// variables, or it may be undef if the function being pointed to doesn't need a
// context.
// This file implements function values and closures. It may need some lowering
// in a later step, see func-lowering.go.
import (
"go/types"
@ -13,14 +10,86 @@ import (
"tinygo.org/x/go-llvm"
)
type funcValueImplementation int
const (
funcValueNone funcValueImplementation = iota
// A func value is implemented as a pair of pointers:
// {context, function pointer}
// where the context may be a pointer to a heap-allocated struct containing
// the free variables, or it may be undef if the function being pointed to
// doesn't need a context. The function pointer is a regular function
// pointer.
funcValueDoubleword
// As funcValueDoubleword, but with the function pointer replaced by a
// unique ID per function signature. Function values are called by using a
// switch statement and choosing which function to call.
funcValueSwitch
)
// funcImplementation picks an appropriate func value implementation for the
// target.
func (c *Compiler) funcImplementation() funcValueImplementation {
if c.GOARCH == "wasm" {
return funcValueSwitch
} else {
return funcValueDoubleword
}
}
// createFuncValue creates a function value from a raw function pointer with no
// context.
func (c *Compiler) createFuncValue(funcPtr llvm.Value) (llvm.Value, error) {
// Closure is: {context, function pointer}
return c.ctx.ConstStruct([]llvm.Value{
llvm.Undef(c.i8ptrType),
funcPtr,
}, false), nil
func (c *Compiler) createFuncValue(funcPtr, context llvm.Value, sig *types.Signature) (llvm.Value, error) {
var funcValueScalar llvm.Value
switch c.funcImplementation() {
case funcValueDoubleword:
// Closure is: {context, function pointer}
funcValueScalar = funcPtr
case funcValueSwitch:
sigGlobal := c.getFuncSignature(sig)
funcValueWithSignatureGlobalName := funcPtr.Name() + "$withSignature"
funcValueWithSignatureGlobal := c.mod.NamedGlobal(funcValueWithSignatureGlobalName)
if funcValueWithSignatureGlobal.IsNil() {
funcValueWithSignatureType := c.mod.GetTypeByName("runtime.funcValueWithSignature")
funcValueWithSignature := llvm.ConstNamedStruct(funcValueWithSignatureType, []llvm.Value{
llvm.ConstPtrToInt(funcPtr, c.uintptrType),
sigGlobal,
})
funcValueWithSignatureGlobal = llvm.AddGlobal(c.mod, funcValueWithSignatureType, funcValueWithSignatureGlobalName)
funcValueWithSignatureGlobal.SetInitializer(funcValueWithSignature)
funcValueWithSignatureGlobal.SetGlobalConstant(true)
funcValueWithSignatureGlobal.SetLinkage(llvm.InternalLinkage)
}
funcValueScalar = llvm.ConstPtrToInt(funcValueWithSignatureGlobal, c.uintptrType)
default:
panic("unimplemented func value variant")
}
funcValueType, err := c.getFuncType(sig)
if err != nil {
return llvm.Value{}, err
}
funcValue := llvm.Undef(funcValueType)
funcValue = c.builder.CreateInsertValue(funcValue, context, 0, "")
funcValue = c.builder.CreateInsertValue(funcValue, funcValueScalar, 1, "")
return funcValue, nil
}
// getFuncSignature returns a global for identification of a particular function
// signature. It is used in runtime.funcValueWithSignature and in calls to
// getFuncPtr.
func (c *Compiler) getFuncSignature(sig *types.Signature) llvm.Value {
typeCodeName := getTypeCodeName(sig)
sigGlobalName := "reflect/types.type:" + typeCodeName
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
if sigGlobal.IsNil() {
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
sigGlobal.SetInitializer(llvm.Undef(c.ctx.Int8Type()))
sigGlobal.SetGlobalConstant(true)
sigGlobal.SetLinkage(llvm.InternalLinkage)
}
return sigGlobal
}
// extractFuncScalar returns some scalar that can be used in comparisons. It is
@ -37,19 +106,39 @@ func (c *Compiler) extractFuncContext(funcValue llvm.Value) llvm.Value {
// decodeFuncValue extracts the context and the function pointer from this func
// value. This may be an expensive operation.
func (c *Compiler) decodeFuncValue(funcValue llvm.Value) (funcPtr, context llvm.Value) {
func (c *Compiler) decodeFuncValue(funcValue llvm.Value, sig *types.Signature) (funcPtr, context llvm.Value, err error) {
context = c.builder.CreateExtractValue(funcValue, 0, "")
funcPtr = c.builder.CreateExtractValue(funcValue, 1, "")
switch c.funcImplementation() {
case funcValueDoubleword:
funcPtr = c.builder.CreateExtractValue(funcValue, 1, "")
case funcValueSwitch:
llvmSig, err := c.getRawFuncType(sig)
if err != nil {
return llvm.Value{}, llvm.Value{}, err
}
sigGlobal := c.getFuncSignature(sig)
funcPtr = c.createRuntimeCall("getFuncPtr", []llvm.Value{funcValue, sigGlobal}, "")
funcPtr = c.builder.CreateIntToPtr(funcPtr, llvmSig, "")
default:
panic("unimplemented func value variant")
}
return
}
// getFuncType returns the type of a func value given a signature.
func (c *Compiler) getFuncType(typ *types.Signature) (llvm.Type, error) {
rawPtr, err := c.getRawFuncType(typ)
if err != nil {
return llvm.Type{}, err
switch c.funcImplementation() {
case funcValueDoubleword:
rawPtr, err := c.getRawFuncType(typ)
if err != nil {
return llvm.Type{}, err
}
return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil
case funcValueSwitch:
return c.mod.GetTypeByName("runtime.funcValue"), nil
default:
panic("unimplemented func value variant")
}
return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil
}
// getRawFuncType returns a LLVM function pointer type for a given signature.
@ -171,19 +260,6 @@ func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.V
context = contextHeapAlloc
}
// Get the function signature type, which is a closure type.
// A closure is a tuple of {context, function pointer}.
typ, err := c.getFuncType(f.Signature)
if err != nil {
return llvm.Value{}, err
}
// Create the closure, which is a struct: {context, function pointer}.
closure, err := c.getZeroValue(typ)
if err != nil {
return llvm.Value{}, err
}
closure = c.builder.CreateInsertValue(closure, f.LLVMFn, 1, "")
closure = c.builder.CreateInsertValue(closure, context, 0, "")
return closure, nil
// Create the closure.
return c.createFuncValue(f.LLVMFn, context, f.Signature)
}

Просмотреть файл

@ -396,14 +396,10 @@ func (c *Compiler) getInvokeCall(frame *Frame, instr *ssa.CallCommon) (llvm.Valu
return llvm.Value{}, nil, err
}
llvmFnType, err := c.getLLVMType(instr.Method.Type())
llvmFnType, err := c.getRawFuncType(instr.Method.Type().(*types.Signature))
if err != nil {
return llvm.Value{}, nil, err
}
// getLLVMType() has created a closure type for us, but we don't actually
// want a closure type as an interface call can never be a closure call. So
// extract the function pointer type from the closure.
llvmFnType = llvmFnType.Subtypes()[1]
typecode := c.builder.CreateExtractValue(itf, 0, "invoke.typecode")
values := []llvm.Value{

Просмотреть файл

@ -43,6 +43,7 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro
c.OptimizeStringToBytes()
c.OptimizeAllocs()
c.LowerInterfaces()
c.LowerFuncValues()
// After interfaces are lowered, there are many more opportunities for
// interprocedural optimizations. To get them to work, function
@ -76,6 +77,7 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro
} else {
// Must be run at any optimization level.
c.LowerInterfaces()
c.LowerFuncValues()
err := c.LowerGoroutines()
if err != nil {
return err

28
src/runtime/func.go Обычный файл
Просмотреть файл

@ -0,0 +1,28 @@
package runtime
// This file implements some data types that may be useful for some
// implementations of func values.
import (
"unsafe"
)
// funcValue is the underlying type of func values, depending on which func
// value representation was used.
type funcValue struct {
context unsafe.Pointer // function context, for closures and bound methods
id uintptr // ptrtoint of *funcValueWithSignature before lowering, opaque index (non-0) after lowering
}
// funcValueWithSignature is used before the func lowering pass.
type funcValueWithSignature struct {
funcPtr uintptr // ptrtoint of the actual function pointer
signature *uint8 // pointer to identify this signature (the value is undef)
}
// getFuncPtr is a dummy function that may be used if the func lowering pass is
// not used. It is generally too slow but may be a useful fallback to debug the
// func lowering pass.
func getFuncPtr(val funcValue, signature *uint8) uintptr {
return (*funcValueWithSignature)(unsafe.Pointer(val.id)).funcPtr
}