diff --git a/compiler/interface-lowering.go b/compiler/interface-lowering.go index cab5ee33..3df1c644 100644 --- a/compiler/interface-lowering.go +++ b/compiler/interface-lowering.go @@ -527,11 +527,40 @@ func (p *lowerInterfacesPass) replaceInvokeWithCall(use llvm.Value, typ *typeInf } inttoptr := inttoptrs[0] function := typ.getMethod(signature).function - if inttoptr.Type() != function.Type() { - p.builder.SetInsertPointBefore(use) - function = p.builder.CreateBitCast(function, inttoptr.Type(), "") + if inttoptr.Type() == function.Type() { + // Easy case: the types are the same. Simply replace the inttoptr + // result (which is directly called) with the actual function. + inttoptr.ReplaceAllUsesWith(function) + } else { + // Harder case: the type is not actually the same. Go through each call + // (of which there should be only one), extract the receiver params for + // this call and replace the call with a direct call to the target + // function. + for _, call := range getUses(inttoptr) { + if call.IsACallInst().IsNil() || call.CalledValue() != inttoptr { + panic("expected the inttoptr to be called as a method, this is not a method call") + } + operands := make([]llvm.Value, call.OperandsCount()-1) + for i := range operands { + operands[i] = call.Operand(i) + } + paramTypes := function.Type().ElementType().ParamTypes() + receiverParamTypes := paramTypes[:len(paramTypes)-(len(operands)-1)] + methodParamTypes := paramTypes[len(paramTypes)-(len(operands)-1):] + for i, methodParamType := range methodParamTypes { + if methodParamType != operands[i+1].Type() { + panic("expected method call param type and function param type to be the same") + } + } + p.builder.SetInsertPointBefore(call) + receiverParams := p.emitPointerUnpack(operands[0], receiverParamTypes) + result := p.builder.CreateCall(function, append(receiverParams, operands[1:]...), "") + if result.Type().TypeKind() != llvm.VoidTypeKind { + call.ReplaceAllUsesWith(result) + } + call.EraseFromParentAsInstruction() + } } - inttoptr.ReplaceAllUsesWith(function) inttoptr.EraseFromParentAsInstruction() use.EraseFromParentAsInstruction() } diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 68a1ae24..0247c862 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -21,6 +21,10 @@ func main() { go nowait() time.Sleep(time.Millisecond) println("done with non-blocking goroutine") + + var printer Printer + printer = &myPrinter{} + printer.Print() } func sub() { @@ -38,3 +42,15 @@ func wait() { func nowait() { println("non-blocking goroutine") } + +type Printer interface { + Print() +} + +type myPrinter struct{ +} + +func (i *myPrinter) Print() { + time.Sleep(time.Millisecond) + println("async interface method call") +} diff --git a/testdata/coroutines.txt b/testdata/coroutines.txt index d5e2d74f..2e3db6df 100644 --- a/testdata/coroutines.txt +++ b/testdata/coroutines.txt @@ -9,3 +9,4 @@ wait: end waiting non-blocking goroutine done with non-blocking goroutine +async interface method call