diff --git a/compiler/compiler.go b/compiler/compiler.go index bc3247e4..731d0948 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -41,29 +41,30 @@ type Config struct { type Compiler struct { Config - mod llvm.Module - ctx llvm.Context - builder llvm.Builder - dibuilder *llvm.DIBuilder - cu llvm.Metadata - difiles map[string]llvm.Metadata - ditypes map[string]llvm.Metadata - machine llvm.TargetMachine - targetData llvm.TargetData - intType llvm.Type - i8ptrType llvm.Type // for convenience - uintptrType llvm.Type - lenType llvm.Type - coroIdFunc llvm.Value - coroSizeFunc llvm.Value - coroBeginFunc llvm.Value - coroSuspendFunc llvm.Value - coroEndFunc llvm.Value - coroFreeFunc llvm.Value - initFuncs []llvm.Value - deferFuncs []*ir.Function - ctxDeferFuncs []ContextDeferFunction - ir *ir.Program + mod llvm.Module + ctx llvm.Context + builder llvm.Builder + dibuilder *llvm.DIBuilder + cu llvm.Metadata + difiles map[string]llvm.Metadata + ditypes map[string]llvm.Metadata + machine llvm.TargetMachine + targetData llvm.TargetData + intType llvm.Type + i8ptrType llvm.Type // for convenience + uintptrType llvm.Type + lenType llvm.Type + coroIdFunc llvm.Value + coroSizeFunc llvm.Value + coroBeginFunc llvm.Value + coroSuspendFunc llvm.Value + coroEndFunc llvm.Value + coroFreeFunc llvm.Value + initFuncs []llvm.Value + deferFuncs []*ir.Function + deferInvokeFuncs []InvokeDeferFunction + ctxDeferFuncs []ContextDeferFunction + ir *ir.Program } type Frame struct { @@ -93,6 +94,12 @@ type ContextDeferFunction struct { signature *types.Signature } +// A thunk for a defer that defers calling an interface method. +type InvokeDeferFunction struct { + method *types.Func + valueTypes []llvm.Type +} + func NewCompiler(pkgName string, config Config) (*Compiler, error) { if config.Triple == "" { config.Triple = llvm.DefaultTargetTriple() @@ -373,6 +380,41 @@ func (c *Compiler) Compile(mainPath string) error { c.builder.CreateRetVoid() } + // Create wrapper for deferred interface call. + for _, thunk := range c.deferInvokeFuncs { + // This function gets a single parameter which is a pointer to a struct + // (the defer frame). + // This struct starts with the values of runtime._defer, but after that + // follow the real function parameters. + // The job of this wrapper is to extract these parameters and to call + // the real function with them. + llvmFn := c.mod.NamedFunction(thunk.method.FullName() + "$defer") + llvmFn.SetLinkage(llvm.InternalLinkage) + llvmFn.SetUnnamedAddr(true) + entry := c.ctx.AddBasicBlock(llvmFn, "entry") + c.builder.SetInsertPointAtEnd(entry) + deferRawPtr := llvmFn.Param(0) + + // Get the real param type and cast to it. + deferFrameType := c.ctx.StructType(thunk.valueTypes, false) + deferFramePtr := c.builder.CreateBitCast(deferRawPtr, llvm.PointerType(deferFrameType, 0), "deferFrame") + + // Extract the params from the struct. + forwardParams := []llvm.Value{} + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + for i := range thunk.valueTypes[3:] { + gep := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), uint64(i+3), false)}, "gep") + forwardParam := c.builder.CreateLoad(gep, "param") + forwardParams = append(forwardParams, forwardParam) + } + + // Call real function (of which this is a wrapper). + fnGEP := c.builder.CreateGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(c.ctx.Int32Type(), 2, false)}, "fn.gep") + fn := c.builder.CreateLoad(fnGEP, "fn") + c.createCall(fn, forwardParams, "") + c.builder.CreateRetVoid() + } + // Create wrapper for deferred function pointer call. for _, thunk := range c.ctxDeferFuncs { // This function gets a single parameter which is a pointer to a struct @@ -1563,7 +1605,37 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { var values []llvm.Value var valueTypes []llvm.Type - if callee, ok := instr.Call.Value.(*ssa.Function); ok && !instr.Call.IsInvoke() { + if instr.Call.IsInvoke() { + // Function call on an interface. + fnPtr, args, err := c.getInvokeCall(frame, &instr.Call) + if err != nil { + return err + } + + valueTypes = []llvm.Type{llvm.PointerType(deferFuncType, 0), next.Type(), fnPtr.Type()} + for _, param := range args { + valueTypes = append(valueTypes, param.Type()) + } + + // Create a thunk. + deferName := instr.Call.Method.FullName() + "$defer" + callback := c.mod.NamedFunction(deferName) + if callback.IsNil() { + // Not found, have to add it. + callback = llvm.AddFunction(c.mod, deferName, deferFuncType) + thunk := InvokeDeferFunction{ + method: instr.Call.Method, + valueTypes: valueTypes, + } + c.deferInvokeFuncs = append(c.deferInvokeFuncs, thunk) + } + + // Collect all values to be put in the struct (starting with + // runtime._defer fields, followed by the function pointer to be + // called). + values = append([]llvm.Value{callback, next, fnPtr}, args...) + + } else if callee, ok := instr.Call.Value.(*ssa.Function); ok { // Regular function call. fn := c.ir.GetFunction(callee) @@ -1588,6 +1660,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { values = append(values, llvmParam) valueTypes = append(valueTypes, llvmParam.Type()) } + } else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok { // Immediately applied function literal with free variables. closure, err := c.parseExpr(frame, instr.Call.Value) @@ -1618,6 +1691,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { makeClosure.Fn.(*ssa.Function).Signature, } c.ctxDeferFuncs = append(c.ctxDeferFuncs, thunk) + } else { return errors.New("todo: defer on uncommon function call type") } @@ -2035,51 +2109,11 @@ func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn, con func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle llvm.Value) (llvm.Value, error) { if instr.IsInvoke() { - // Call an interface method with dynamic dispatch. - itf, err := c.parseExpr(frame, instr.Value) // interface - if err != nil { - return llvm.Value{}, err - } - - llvmFnType, err := c.getLLVMType(instr.Method.Type()) - if err != nil { - return llvm.Value{}, err - } - if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) { - // This is somewhat of a hack. - // getLLVMType() has created a closure type for us, but we don't - // actually want a closure type as an interface call can never be a - // closure call. So extract the function pointer type from the - // closure. - // This happens because somewhere the same function signature is - // used in a closure or bound method. - llvmFnType = llvmFnType.Subtypes()[1] - } - - values := []llvm.Value{ - itf, - llvm.ConstInt(c.ctx.Int16Type(), uint64(c.ir.MethodNum(instr.Method)), false), - } - fn := c.createRuntimeCall("interfaceMethod", values, "invoke.func") - fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast") - receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver") - - args := []llvm.Value{receiverValue} - for _, arg := range instr.Args { - val, err := c.parseExpr(frame, arg) - if err != nil { - return llvm.Value{}, err - } - args = append(args, val) - } - if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) { - // This function takes an extra context parameter. An interface call - // cannot also be a closure but we have to supply the nil pointer - // anyway. - args = append(args, llvm.ConstPointerNull(c.i8ptrType)) - } - // TODO: blocking methods (needs analysis) + fnCast, args, err := c.getInvokeCall(frame, instr) + if err != nil { + return llvm.Value{}, err + } return c.createCall(fnCast, args, ""), nil } @@ -2209,6 +2243,56 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } } +// getInvokeCall creates and returns the function pointer and parameters of an +// interface call. It can be used in a call or defer instruction. +func (c *Compiler) getInvokeCall(frame *Frame, instr *ssa.CallCommon) (llvm.Value, []llvm.Value, error) { + // Call an interface method with dynamic dispatch. + itf, err := c.parseExpr(frame, instr.Value) // interface + if err != nil { + return llvm.Value{}, nil, err + } + + llvmFnType, err := c.getLLVMType(instr.Method.Type()) + if err != nil { + return llvm.Value{}, nil, err + } + if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) { + // This is somewhat of a hack. + // getLLVMType() has created a closure type for us, but we don't + // actually want a closure type as an interface call can never be a + // closure call. So extract the function pointer type from the + // closure. + // This happens because somewhere the same function signature is + // used in a closure or bound method. + llvmFnType = llvmFnType.Subtypes()[1] + } + + values := []llvm.Value{ + itf, + llvm.ConstInt(c.ctx.Int16Type(), uint64(c.ir.MethodNum(instr.Method)), false), + } + fn := c.createRuntimeCall("interfaceMethod", values, "invoke.func") + fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast") + receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver") + + args := []llvm.Value{receiverValue} + for _, arg := range instr.Args { + val, err := c.parseExpr(frame, arg) + if err != nil { + return llvm.Value{}, nil, err + } + args = append(args, val) + } + if c.ir.SignatureNeedsContext(instr.Method.Type().(*types.Signature)) { + // This function takes an extra context parameter. An interface call + // cannot also be a closure but we have to supply the nil pointer + // anyway. + args = append(args, llvm.ConstPointerNull(c.i8ptrType)) + } + + return fnCast, args, nil +} + func (c *Compiler) emitBoundsCheck(frame *Frame, arrayLen, index llvm.Value, indexType types.Type) { if frame.fn.IsNoBounds() { // The //go:nobounds pragma was added to the function to avoid bounds diff --git a/testdata/calls.go b/testdata/calls.go index 93982fb8..6b4cd6eb 100644 --- a/testdata/calls.go +++ b/testdata/calls.go @@ -8,6 +8,14 @@ func (t Thing) String() string { return t.name } +func (t Thing) Print(arg string) { + println("Thing.Print:", t.name, "arg:", arg) +} + +type Printer interface { + Print(string) +} + func main() { thing := &Thing{"foo"} @@ -49,6 +57,9 @@ func testDefer() { defer deferred("...run as defer", i) i++ + var t Printer = &Thing{"foo"} + defer t.Print("bar") + println("deferring...") } diff --git a/testdata/calls.txt b/testdata/calls.txt index 902b5b4c..6998aebf 100644 --- a/testdata/calls.txt +++ b/testdata/calls.txt @@ -1,5 +1,6 @@ hello from function pointer: 5 deferring... +Thing.Print: foo arg: bar ...run as defer 3 ...run closure deferred: 4 ...run as defer 1