diff --git a/compiler/compiler.go b/compiler/compiler.go index 6ead097f..6e64f021 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -419,17 +419,9 @@ func (c *Compiler) Compile(mainPath string) error { if f.LLVMFn.IsNil() { return errors.New("cannot find function: " + f.LinkName()) } - fn := f.LLVMFn - receiverType := fn.Type().ElementType().ParamTypes()[0] - if c.targetData.TypeAllocSize(receiverType) > c.targetData.TypeAllocSize(c.i8ptrType) { - // The receiver value doesn't fit in a pointer. This means that - // the interface contains a pointer to the receiver value - // instead of the value itself. Which means we have to create a - // wrapper function. - fn, err = c.wrapInterfaceInvoke(f) - if err != nil { - return err - } + fn, err := c.wrapInterfaceInvoke(f) + if err != nil { + return err } fnPtr := llvm.ConstBitCast(fn, c.i8ptrType) funcPointers = append(funcPointers, fnPtr) @@ -771,12 +763,25 @@ func (c *Compiler) getDIType(typ types.Type) (llvm.Metadata, error) { // wrapper is only needed when the interface value actually doesn't fit in a // pointer and a pointer to the value must be created. func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { + receiverType, err := c.getLLVMType(f.Params[0].Type()) + if err != nil { + return llvm.Value{}, err + } + expandedReceiverType := c.expandFormalParamType(receiverType) + + if c.targetData.TypeAllocSize(receiverType) <= c.targetData.TypeAllocSize(c.i8ptrType) && len(expandedReceiverType) == 1 { + // nothing to wrap + return f.LLVMFn, nil + } + + // create wrapper function fnType := f.LLVMFn.Type().ElementType() - paramTypes := append([]llvm.Type{c.i8ptrType}, fnType.ParamTypes()[1:]...) + paramTypes := append([]llvm.Type{c.i8ptrType}, fnType.ParamTypes()[len(expandedReceiverType):]...) wrapFnType := llvm.FunctionType(fnType.ReturnType(), paramTypes, false) wrapper := llvm.AddFunction(c.mod, f.LinkName()+"$invoke", wrapFnType) wrapper.SetLinkage(llvm.InternalLinkage) + // add debug info pos := c.ir.Program.Fset.Position(f.Pos()) difunc, err := c.attachDebugInfoRaw(f, wrapper, "$invoke", pos.Filename, pos.Line) if err != nil { @@ -784,13 +789,35 @@ func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { } c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), difunc, llvm.Metadata{}) + // set up IR builder block := c.ctx.AddBasicBlock(wrapper, "entry") c.builder.SetInsertPointAtEnd(block) - receiverType := fnType.ParamTypes()[0] - receiverPtrType := llvm.PointerType(receiverType, 0) - receiverPtr := c.builder.CreateBitCast(wrapper.Param(0), receiverPtrType, "receiver.ptr") - receiver := c.builder.CreateLoad(receiverPtr, "receiver") - params := append([]llvm.Value{receiver}, wrapper.Params()[1:]...) + + var receiverPtr llvm.Value + if c.targetData.TypeAllocSize(receiverType) > c.targetData.TypeAllocSize(c.i8ptrType) { + // The receiver is passed in using a pointer. We have to load it here + // and pass it by value to the real function. + + // Load the underlying value. + receiverPtrType := llvm.PointerType(receiverType, 0) + receiverPtr = c.builder.CreateBitCast(wrapper.Param(0), receiverPtrType, "receiver.ptr") + } else if len(expandedReceiverType) != 1 { + // The value is stored in the interface, but it is of type struct which + // is expanded to multiple parameters (e.g. {i8, i8}). So we have to + // receive the struct as parameter, expand it, and pass it on to the + // real function. + + // Cast the passed-in i8* to the struct value (using an alloca) and + // extract its values. + alloca := c.builder.CreateAlloca(c.i8ptrType, "receiver.alloca") + c.builder.CreateStore(wrapper.Param(0), alloca) + receiverPtr = c.builder.CreateBitCast(alloca, llvm.PointerType(receiverType, 0), "receiver.ptr") + } else { + panic("unreachable") + } + + receiverValue := c.builder.CreateLoad(receiverPtr, "receiver") + params := append(c.expandFormalParam(receiverValue), wrapper.Params()[1:]...) if fnType.ReturnType().TypeKind() == llvm.VoidTypeKind { c.builder.CreateCall(f.LLVMFn, params, "") c.builder.CreateRetVoid() diff --git a/testdata/interface.go b/testdata/interface.go index f24a995f..701b6e29 100644 --- a/testdata/interface.go +++ b/testdata/interface.go @@ -13,6 +13,8 @@ func main() { printItf(Number(3)) array := Array([4]uint32{1, 7, 11, 13}) printItf(array) + printItf(ArrayStruct{3, array}) + printItf(SmallPair{3, 5}) s := Stringer(thing) println("Stringer.String():", s.String()) var itf interface{} = s @@ -85,3 +87,29 @@ func (a Array) Nth(n int) uint32 { func (a Array) Print() { println("Array len:", len(a)) } + +type ArrayStruct struct { + n int + a Array +} + +func (a ArrayStruct) Nth(n int) uint32 { + return a.a[n] +} + +func (a ArrayStruct) Print() { + println("ArrayStruct.Print:", len(a.a), a.n) +} + +type SmallPair struct { + a byte + b byte +} + +func (p SmallPair) Nth(n int) uint32 { + return uint32(int(p.a)*n + int(p.b)*n) +} + +func (p SmallPair) Print() { + println("SmallPair.Print:", p.a, p.b) +} diff --git a/testdata/interface.txt b/testdata/interface.txt index f16dfbe0..2791425b 100644 --- a/testdata/interface.txt +++ b/testdata/interface.txt @@ -9,5 +9,9 @@ is *Thing: foo is Doubler: 6 is Tuple: 1 7 11 13 Array len: 4 +is Tuple: 1 7 11 13 +ArrayStruct.Print: 4 3 +is Tuple: 0 8 16 24 +SmallPair.Print: 3 5 Stringer.String(): foo Stringer.(*Thing).String(): foo