Implement closures and bound methods

Этот коммит содержится в:
Ayke van Laethem 2018-09-02 03:13:39 +02:00
родитель 58b853bbef
коммит 58c87329d4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: E97FF5335DFDFDED
5 изменённых файлов: 357 добавлений и 49 удалений

Просмотреть файл

@ -48,13 +48,14 @@ Currently supported features:
* slices (partially)
* maps (very rough, unfinished)
* defer (only in trivial cases)
* closures
* bound methods
Not yet supported:
* complex numbers
* garbage collection
* recover
* closures
* channels
* introspection (if it ever gets implemented)
* ...

Просмотреть файл

@ -214,6 +214,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
c.ir.SimpleDCE() // remove most dead code
c.ir.AnalyseCallgraph() // set up callgraph
c.ir.AnalyseInterfaceConversions() // determine which types are converted to an interface
c.ir.AnalyseFunctionPointers() // determine which function pointer signatures need context
c.ir.AnalyseBlockingRecursive() // make all parents of blocking calls blocking (transitively)
c.ir.AnalyseGoCalls() // check whether we need a scheduler
@ -542,8 +543,18 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
}
paramTypes = append(paramTypes, subType)
}
// make a function pointer of it
return llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), 0), nil
var ptr llvm.Type
if c.ir.SignatureNeedsContext(typ) {
// make a closure type (with a function pointer type inside):
// {context, funcptr}
paramTypes = append(paramTypes, c.i8ptrType)
ptr = llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), 0)
ptr = c.ctx.StructType([]llvm.Type{c.i8ptrType, ptr}, false)
} else {
// make a simple function pointer
ptr = llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), 0)
}
return ptr, nil
case *types.Slice:
elemType, err := c.getLLVMType(typ.Elem())
if err != nil {
@ -703,6 +714,12 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
frame.params[param] = i
}
if c.ir.FunctionNeedsContext(f) {
// This function gets an extra parameter: the context pointer (for
// closures and bound methods). Add it as an extra paramter here.
paramTypes = append(paramTypes, c.i8ptrType)
}
fnType := llvm.FunctionType(retType, paramTypes, false)
name := f.LinkName()
@ -781,7 +798,13 @@ func (c *Compiler) getInterpretedValue(value Value) (llvm.Value, error) {
}
return getZeroValue(llvmType)
}
return c.ir.GetFunction(value.Elem).llvmFn, nil
fn := c.ir.GetFunction(value.Elem)
ptr := fn.llvmFn
if c.ir.SignatureNeedsContext(fn.fn.Signature) {
// Create closure value: {context, function pointer}
ptr = llvm.ConstStruct([]llvm.Value{llvm.ConstPointerNull(c.i8ptrType), ptr}, false)
}
return ptr, nil
case *GlobalValue:
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
@ -1001,6 +1024,54 @@ func (c *Compiler) parseFunc(frame *Frame) error {
frame.locals[param] = llvmParam
}
// Load free variables from the context. This is a closure (or bound
// method).
if len(frame.fn.fn.FreeVars) != 0 {
if !c.ir.FunctionNeedsContext(frame.fn) {
panic("free variables on function without context")
}
c.builder.SetInsertPointAtEnd(frame.blocks[frame.fn.fn.Blocks[0]])
context := frame.fn.llvmFn.Param(len(frame.fn.fn.Params))
// Determine the context type. It's a struct containing all variables.
freeVarTypes := make([]llvm.Type, 0, len(frame.fn.fn.FreeVars))
for _, freeVar := range frame.fn.fn.FreeVars {
typ, err := c.getLLVMType(freeVar.Type())
if err != nil {
return err
}
freeVarTypes = append(freeVarTypes, typ)
}
contextType := llvm.StructType(freeVarTypes, false)
// Get a correctly-typed pointer to the context.
contextAlloc := llvm.Value{}
if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) {
// Context stored directly in pointer. Load it using an alloca.
contextRawAlloc := c.builder.CreateAlloca(llvm.PointerType(c.i8ptrType, 0), "")
contextRawValue := c.builder.CreateBitCast(context, llvm.PointerType(c.i8ptrType, 0), "")
c.builder.CreateStore(contextRawValue, contextRawAlloc)
contextAlloc = c.builder.CreateBitCast(contextRawAlloc, llvm.PointerType(contextType, 0), "")
} else {
// Context stored in the heap. Bitcast the passed-in pointer to the
// correct pointer type.
contextAlloc = c.builder.CreateBitCast(context, llvm.PointerType(contextType, 0), "")
}
// Load each free variable from the context.
// A free variable is always a pointer when this is a closure, but it
// can be another type when it is a wrapper for a bound method (these
// wrappers are generated by the ssa package).
for i, freeVar := range frame.fn.fn.FreeVars {
indices := []llvm.Value{
llvm.ConstInt(llvm.Int32Type(), 0, false),
llvm.ConstInt(llvm.Int32Type(), uint64(i), false),
}
gep := c.builder.CreateInBoundsGEP(contextAlloc, indices, "")
frame.locals[freeVar] = c.builder.CreateLoad(gep, "")
}
}
if frame.blocking {
// Coroutine initialization.
c.builder.SetInsertPointAtEnd(frame.blocks[frame.fn.fn.Blocks[0]])
@ -1372,7 +1443,7 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string)
}
}
func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) {
func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn, context llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) {
var params []llvm.Value
if blocking {
if parentHandle.IsNil() {
@ -1391,6 +1462,12 @@ func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn llvm
params = append(params, val)
}
if !context.IsNil() {
// This function takes a context parameter.
// Add it to the end of the parameter list.
params = append(params, context)
}
if frame.blocking && llvmFn.Name() == "runtime.Sleep" {
// Set task state to TASK_STATE_SLEEP and set the duration.
c.builder.CreateCall(c.mod.NamedFunction("runtime.sleepTask"), []llvm.Value{frame.taskHandle, params[0]}, "")
@ -1443,10 +1520,22 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
if err != nil {
return llvm.Value{}, err
}
llvmFnType, err := c.getLLVMType(instr.Method.Type())
if err != nil {
return llvm.Value{}, err
}
if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) {
// This is somewhat of a hack.
// getLLVMType() has created a closure type for us, but we don't
// actually want a closure type as an interface call can never be a
// closure call. So extract the function pointer type from the
// closure.
// This happens because somewhere the same function signature is
// used in a closure or bound method.
llvmFnType = llvmFnType.Subtypes()[1]
}
values := []llvm.Value{
itf,
llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.MethodNum(instr.Method)), false),
@ -1454,6 +1543,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
fn := c.builder.CreateCall(c.mod.NamedFunction("runtime.interfaceMethod"), values, "invoke.func")
fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast")
receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver")
args := []llvm.Value{receiverValue}
for _, arg := range instr.Args {
val, err := c.parseExpr(frame, arg)
@ -1462,16 +1552,21 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
}
args = append(args, val)
}
if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) {
// This function takes an extra context parameter. An interface call
// cannot also be a closure but we have to supply the nil pointer
// anyway.
args = append(args, llvm.ConstPointerNull(c.i8ptrType))
}
// TODO: blocking methods (needs analysis)
return c.builder.CreateCall(fnCast, args, ""), nil
}
// Regular function, builtin, or function pointer.
switch call := instr.Value.(type) {
case *ssa.Builtin:
return c.parseBuiltin(frame, instr.Args, call.Name())
case *ssa.Function:
if call.Name() == "Asm" && len(instr.Args) == 1 {
// Try to call the function directly for trivially static calls.
fn := instr.StaticCallee()
if fn != nil {
if fn.Name() == "Asm" && len(instr.Args) == 1 {
// Magic function: insert inline assembly instead of calling it.
if named, ok := instr.Args[0].Type().(*types.Named); ok && named.Obj().Name() == "__asm" {
fnType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{}, false)
@ -1480,20 +1575,52 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
return c.builder.CreateCall(target, nil, ""), nil
}
}
targetFunc := c.ir.GetFunction(call)
name := targetFunc.LinkName()
llvmFn := c.mod.NamedFunction(name)
if llvmFn.IsNil() {
return llvm.Value{}, errors.New("undefined function: " + name)
targetFunc := c.ir.GetFunction(fn)
if targetFunc.llvmFn.IsNil() {
return llvm.Value{}, errors.New("undefined function: " + targetFunc.LinkName())
}
return c.parseFunctionCall(frame, instr.Args, llvmFn, targetFunc.blocking, parentHandle)
var context llvm.Value
if c.ir.FunctionNeedsContext(targetFunc) {
// This function call is to a (potential) closure, not a regular
// function. See whether it is a closure and if so, call it as such.
// Else, supply a dummy nil pointer as the last parameter.
var err error
if mkClosure, ok := instr.Value.(*ssa.MakeClosure); ok {
// closure is {context, function pointer}
closure, err := c.parseExpr(frame, mkClosure)
if err != nil {
return llvm.Value{}, err
}
context = c.builder.CreateExtractValue(closure, 0, "")
} else {
context, err = getZeroValue(c.i8ptrType)
if err != nil {
return llvm.Value{}, err
}
}
}
return c.parseFunctionCall(frame, instr.Args, targetFunc.llvmFn, context, targetFunc.blocking, parentHandle)
}
// Builtin or function pointer.
switch call := instr.Value.(type) {
case *ssa.Builtin:
return c.parseBuiltin(frame, instr.Args, call.Name())
default: // function pointer
value, err := c.parseExpr(frame, instr.Value)
if err != nil {
return llvm.Value{}, err
}
// TODO: blocking function pointers (needs analysis)
return c.parseFunctionCall(frame, instr.Args, value, false, parentHandle)
var context llvm.Value
if c.ir.SignatureNeedsContext(instr.Signature()) {
// 'value' is a closure, not a raw function pointer.
// Extract the function pointer and the context pointer.
// closure: {context, function pointer}
context = c.builder.CreateExtractValue(value, 0, "")
value = c.builder.CreateExtractValue(value, 1, "")
}
return c.parseFunctionCall(frame, instr.Args, value, context, false, parentHandle)
}
}
@ -1583,7 +1710,17 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
}
return c.builder.CreateGEP(val, indices, ""), nil
case *ssa.Function:
return c.mod.NamedFunction(c.ir.GetFunction(expr).LinkName()), nil
fn := c.ir.GetFunction(expr)
ptr := fn.llvmFn
if c.ir.FunctionNeedsContext(fn) {
// Create closure for function pointer.
// Closure is: {context, function pointer}
ptr = llvm.ConstStruct([]llvm.Value{
llvm.ConstPointerNull(c.i8ptrType),
ptr,
}, false)
}
return ptr, nil
case *ssa.Global:
if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use.
@ -1731,6 +1868,10 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
default:
panic("unknown lookup type: " + expr.String())
}
case *ssa.MakeClosure:
return c.parseMakeClosure(frame, expr)
case *ssa.MakeInterface:
val, err := c.parseExpr(frame, expr.X)
if err != nil {
@ -2249,6 +2390,85 @@ func (c *Compiler) parseConvert(typeFrom, typeTo types.Type, value llvm.Value) (
}
}
func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.Value, error) {
if len(expr.Bindings) == 0 {
panic("unexpected: MakeClosure without bound variables")
}
f := c.ir.GetFunction(expr.Fn.(*ssa.Function))
if !c.ir.FunctionNeedsContext(f) {
// Maybe AnalyseFunctionPointers didn't run?
panic("MakeClosure on function signature without context")
}
// Collect all bound variables.
boundVars := make([]llvm.Value, 0, len(expr.Bindings))
boundVarTypes := make([]llvm.Type, 0, len(expr.Bindings))
for _, binding := range expr.Bindings {
// The context stores the bound variables.
llvmBoundVar, err := c.parseExpr(frame, binding)
if err != nil {
return llvm.Value{}, err
}
boundVars = append(boundVars, llvmBoundVar)
boundVarTypes = append(boundVarTypes, llvmBoundVar.Type())
}
contextType := llvm.StructType(boundVarTypes, false)
// Allocate memory for the context.
contextAlloc := llvm.Value{}
contextHeapAlloc := llvm.Value{}
if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) {
// Context fits in a pointer - e.g. when it is a pointer. Store it
// directly in the stack after a convert.
// Because contextType is a struct and we have to cast it to a *i8,
// store it in an alloca first for bitcasting (store+bitcast+load).
contextAlloc = c.builder.CreateAlloca(contextType, "")
} else {
// Context is bigger than a pointer, so allocate it on the heap.
size := c.targetData.TypeAllocSize(contextType)
sizeValue := llvm.ConstInt(c.uintptrType, size, false)
contextHeapAlloc = c.builder.CreateCall(c.allocFunc, []llvm.Value{sizeValue}, "")
contextAlloc = c.builder.CreateBitCast(contextHeapAlloc, llvm.PointerType(contextType, 0), "")
}
// Store all bound variables in the alloca or heap pointer.
for i, boundVar := range boundVars {
indices := []llvm.Value{
llvm.ConstInt(llvm.Int32Type(), 0, false),
llvm.ConstInt(llvm.Int32Type(), uint64(i), false),
}
gep := c.builder.CreateInBoundsGEP(contextAlloc, indices, "")
c.builder.CreateStore(boundVar, gep)
}
context := llvm.Value{}
if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) {
// Load value (as *i8) from the alloca.
contextAlloc = c.builder.CreateBitCast(contextAlloc, llvm.PointerType(c.i8ptrType, 0), "")
context = c.builder.CreateLoad(contextAlloc, "")
} else {
// Get the original heap allocation pointer, which already is an
// *i8.
context = contextHeapAlloc
}
// Get the function signature type, which is a closure type.
// A closure is a tuple of {context, function pointer}.
typ, err := c.getLLVMType(f.fn.Signature)
if err != nil {
return llvm.Value{}, err
}
// Create the closure, which is a struct: {context, function pointer}.
closure, err := getZeroValue(typ)
if err != nil {
return llvm.Value{}, err
}
closure = c.builder.CreateInsertValue(closure, f.llvmFn, 1, "")
closure = c.builder.CreateInsertValue(closure, context, 0, "")
return closure, nil
}
func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, isConst bool) (llvm.Value, error) {
var itfValue llvm.Value
size := c.targetData.TypeAllocSize(val.Type())

10
ir.go
Просмотреть файл

@ -25,9 +25,10 @@ type Program struct {
NamedTypes []*NamedType
needsScheduler bool
goCalls []*ssa.Go
typesWithMethods map[string]*InterfaceType
typesWithoutMethods map[string]int
typesWithMethods map[string]*InterfaceType // see AnalyseInterfaceConversions
typesWithoutMethods map[string]int // see AnalyseInterfaceConversions
methodSignatureNames map[string]int
fpWithContext map[string]struct{} // see AnalyseFunctionPointers
}
// Function or method.
@ -37,8 +38,9 @@ type Function struct {
linkName string
blocking bool
flag bool // used by dead code elimination
addressTaken bool // used as function pointer, calculated by AnalyseFunctionPointers
parents []*Function // calculated by AnalyseCallgraph
children []*Function
children []*Function // calculated by AnalyseCallgraph
}
// Global variable, possibly constant.
@ -117,9 +119,9 @@ func (p *Program) AddPackage(pkg *ssa.Package) {
func (p *Program) addFunction(ssaFn *ssa.Function) {
f := &Function{fn: ssaFn}
f.parsePragmas()
p.Functions = append(p.Functions, f)
p.functionMap[ssaFn] = f
f.parsePragmas()
for _, anon := range ssaFn.AnonFuncs {
p.addFunction(anon)

110
passes.go
Просмотреть файл

@ -5,46 +5,55 @@ import (
"golang.org/x/tools/go/ssa"
)
// This function implements several optimization passes (analysis + transform)
// to optimize code in SSA form before it is compiled to LLVM IR. It is based on
// This file implements several optimization passes (analysis + transform) to
// optimize code in SSA form before it is compiled to LLVM IR. It is based on
// the IR defined in ir.go.
// Make a readable version of the method signature (including the function name,
// Make a readable version of a method signature (including the function name,
// excluding the receiver name). This string is used internally to match
// interfaces and to call the correct method on an interface. Examples:
//
// String() string
// Read([]byte) (int, error)
func MethodName(method *types.Func) string {
sig := method.Type().(*types.Signature)
name := method.Name()
func MethodSignature(method *types.Func) string {
return method.Name() + Signature(method.Type().(*types.Signature))
}
// Make a readable version of a function (pointer) signature. This string is
// used internally to match signatures (like in AnalyseFunctionPointers).
// Examples:
//
// () string
// (string, int) (int, error)
func Signature(sig *types.Signature) string {
s := ""
if sig.Params().Len() == 0 {
name += "()"
s += "()"
} else {
name += "("
s += "("
for i := 0; i < sig.Params().Len(); i++ {
if i > 0 {
name += ", "
s += ", "
}
name += sig.Params().At(i).Type().String()
s += sig.Params().At(i).Type().String()
}
name += ")"
s += ")"
}
if sig.Results().Len() == 0 {
// keep as-is
} else if sig.Results().Len() == 1 {
name += " " + sig.Results().At(0).Type().String()
s += " " + sig.Results().At(0).Type().String()
} else {
name += " ("
s += " ("
for i := 0; i < sig.Results().Len(); i++ {
if i > 0 {
name += ", "
s += ", "
}
name += sig.Results().At(i).Type().String()
s += sig.Results().At(i).Type().String()
}
name += ")"
s += ")"
}
return name
return s
}
// Fill in parents of all functions.
@ -111,7 +120,7 @@ func (p *Program) AnalyseInterfaceConversions() {
Methods: make(map[string]*types.Selection),
}
for _, sel := range methods {
name := MethodName(sel.Obj().(*types.Func))
name := MethodSignature(sel.Obj().(*types.Func))
t.Methods[name] = sel
}
p.typesWithMethods[name] = t
@ -124,6 +133,49 @@ func (p *Program) AnalyseInterfaceConversions() {
}
}
// Analyse which function pointer signatures need a context parameter.
// This makes calling function pointers more efficient.
func (p *Program) AnalyseFunctionPointers() {
// Clear, if AnalyseFunctionPointers has been called before.
p.fpWithContext = map[string]struct{}{}
for _, f := range p.Functions {
for _, block := range f.fn.Blocks {
for _, instr := range block.Instrs {
switch instr := instr.(type) {
case ssa.CallInstruction:
for _, arg := range instr.Common().Args {
switch arg := arg.(type) {
case *ssa.Function:
f := p.GetFunction(arg)
f.addressTaken = true
}
}
case *ssa.DebugRef:
default:
// For anything that isn't a call...
for _, operand := range instr.Operands(nil) {
if operand == nil || *operand == nil || isCGoInternal((*operand).Name()) {
continue
}
switch operand := (*operand).(type) {
case *ssa.Function:
f := p.GetFunction(operand)
f.addressTaken = true
}
}
}
switch instr := instr.(type) {
case *ssa.MakeClosure:
fn := instr.Fn.(*ssa.Function)
sig := Signature(fn.Signature)
p.fpWithContext[sig] = struct{}{}
}
}
}
}
}
// Analyse which functions are recursively blocking.
//
// Depends on AnalyseCallgraph.
@ -233,6 +285,12 @@ func (p *Program) SimpleDCE() {
switch operand := (*operand).(type) {
case *ssa.Function:
f := p.GetFunction(operand)
if f == nil {
// FIXME HACK: this function should have been
// discovered already. It is not for bound methods.
p.addFunction(operand)
f = p.GetFunction(operand)
}
if !f.flag {
f.flag = true
worklist = append(worklist, operand)
@ -306,11 +364,11 @@ func (p *Program) TypeNum(typ types.Type) (int, bool) {
// MethodNum returns the numeric ID of this method, to be used in method lookups
// on interfaces for example.
func (p *Program) MethodNum(method *types.Func) int {
name := MethodName(method)
name := MethodSignature(method)
if _, ok := p.methodSignatureNames[name]; !ok {
p.methodSignatureNames[name] = len(p.methodSignatureNames)
}
return p.methodSignatureNames[MethodName(method)]
return p.methodSignatureNames[MethodSignature(method)]
}
// The start index of the first dynamic type that has methods.
@ -330,3 +388,15 @@ func (p *Program) AllDynamicTypes() []*InterfaceType {
}
return l
}
func (p *Program) FunctionNeedsContext(f *Function) bool {
if !f.addressTaken {
return false
}
return p.SignatureNeedsContext(f.fn.Signature)
}
func (p *Program) SignatureNeedsContext(sig *types.Signature) bool {
_, needsContext := p.fpWithContext[Signature(sig)]
return needsContext
}

Просмотреть файл

@ -27,15 +27,18 @@ func main() {
println("sumrange(100) =", sumrange(100))
println("strlen foo:", strlen("foo"))
// map
m := map[string]int{"answer": 42, "foo": 3}
readMap(m, "answer")
readMap(testmap, "data")
// slice
foo := []int{1, 2, 4, 5}
println("len/cap foo:", len(foo), cap(foo))
println("foo[3]:", foo[3])
println("sum foo:", sum(foo))
// interfaces, pointers
thing := &Thing{"foo"}
println("thing:", thing.String())
printItf(5)
@ -47,8 +50,16 @@ func main() {
s := Stringer(thing)
println("Stringer.String():", s.String())
// unusual calls
runFunc(hello, 5) // must be indirect to avoid obvious inlining
testDefer()
testBound(thing.String)
func() {
println("thing inside closure:", thing.String()) //, len(foo))
}()
runFunc(func(i int) {
println("inside fp closure:", thing.String(), i)
}, 3)
// test library functions
println("lower to upper char:", 'h', "->", unicode.ToUpper('h'))
@ -70,6 +81,10 @@ func deferred(msg string, i int) {
println(msg, i)
}
func testBound(f func() string) {
println("bound method:", f())
}
func readMap(m map[string]int, key string) {
println("map length:", len(m))
println("map read:", key, "=", m[key])