From 4957db89f49dfd4aba6afbb8aa4846d54b26f614 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sun, 7 Oct 2018 02:06:48 +0200 Subject: [PATCH] compiler: fix interface calls for big underlying values When the underlying value of an interface does not fit in a pointer, a pointer to the value was correctly inserted in the heap. However, the receiving method still assumed it got the underlying value instead of a pointer to it leading to a crash. This commit inserts wrapper functions for method calls on interfaces. The bug wasn't obvious as on a 64-bit system, the underlying value was almost always put directly in the interface. However, it led to a crash on the AVR platform where pointer are (usually) just 16 bits making it far more likely that underlying values cannot be directly stored in an interface. --- compiler/compiler.go | 64 +++++++++++++++++++++++++++---- testdata/interface.go | 87 ++++++++++++++++++++++++++++++++++++++++++ testdata/interface.txt | 13 +++++++ 3 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 testdata/interface.go create mode 100644 testdata/interface.txt diff --git a/compiler/compiler.go b/compiler/compiler.go index 4cac848a..6ead097f 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -419,8 +419,20 @@ func (c *Compiler) Compile(mainPath string) error { if f.LLVMFn.IsNil() { return errors.New("cannot find function: " + f.LinkName()) } - fn := llvm.ConstBitCast(f.LLVMFn, c.i8ptrType) - funcPointers = append(funcPointers, fn) + fn := f.LLVMFn + receiverType := fn.Type().ElementType().ParamTypes()[0] + if c.targetData.TypeAllocSize(receiverType) > c.targetData.TypeAllocSize(c.i8ptrType) { + // The receiver value doesn't fit in a pointer. This means that + // the interface contains a pointer to the receiver value + // instead of the value itself. Which means we have to create a + // wrapper function. + fn, err = c.wrapInterfaceInvoke(f) + if err != nil { + return err + } + } + fnPtr := llvm.ConstBitCast(fn, c.i8ptrType) + funcPointers = append(funcPointers, fnPtr) signatureNum := c.ir.MethodNum(method.Obj().(*types.Func)) signature := llvm.ConstInt(llvm.Int16Type(), uint64(signatureNum), false) signatures = append(signatures, signature) @@ -754,6 +766,42 @@ func (c *Compiler) getDIType(typ types.Type) (llvm.Metadata, error) { } } +// Wrap an interface method function pointer. The wrapper takes in a pointer to +// the underlying value, dereferences it, and calls the real method. This +// wrapper is only needed when the interface value actually doesn't fit in a +// pointer and a pointer to the value must be created. +func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { + fnType := f.LLVMFn.Type().ElementType() + paramTypes := append([]llvm.Type{c.i8ptrType}, fnType.ParamTypes()[1:]...) + wrapFnType := llvm.FunctionType(fnType.ReturnType(), paramTypes, false) + wrapper := llvm.AddFunction(c.mod, f.LinkName()+"$invoke", wrapFnType) + wrapper.SetLinkage(llvm.InternalLinkage) + + pos := c.ir.Program.Fset.Position(f.Pos()) + difunc, err := c.attachDebugInfoRaw(f, wrapper, "$invoke", pos.Filename, pos.Line) + if err != nil { + return llvm.Value{}, err + } + c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), difunc, llvm.Metadata{}) + + block := c.ctx.AddBasicBlock(wrapper, "entry") + c.builder.SetInsertPointAtEnd(block) + receiverType := fnType.ParamTypes()[0] + receiverPtrType := llvm.PointerType(receiverType, 0) + receiverPtr := c.builder.CreateBitCast(wrapper.Param(0), receiverPtrType, "receiver.ptr") + receiver := c.builder.CreateLoad(receiverPtr, "receiver") + params := append([]llvm.Value{receiver}, wrapper.Params()[1:]...) + if fnType.ReturnType().TypeKind() == llvm.VoidTypeKind { + c.builder.CreateCall(f.LLVMFn, params, "") + c.builder.CreateRetVoid() + } else { + ret := c.builder.CreateCall(f.LLVMFn, params, "ret") + c.builder.CreateRet(ret) + } + + return wrapper, nil +} + func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) { frame := &Frame{ fn: f, @@ -816,7 +864,7 @@ func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) { } if c.Debug && f.Synthetic == "package initializer" { - difunc, err := c.attachDebugInfoRaw(f, "", 0) + difunc, err := c.attachDebugInfoRaw(f, f.LLVMFn, "", "", 0) if err != nil { return nil, err } @@ -835,10 +883,10 @@ func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) { func (c *Compiler) attachDebugInfo(f *ir.Function) (llvm.Metadata, error) { pos := c.ir.Program.Fset.Position(f.Syntax().Pos()) - return c.attachDebugInfoRaw(f, pos.Filename, pos.Line) + return c.attachDebugInfoRaw(f, f.LLVMFn, "", pos.Filename, pos.Line) } -func (c *Compiler) attachDebugInfoRaw(f *ir.Function, filename string, line int) (llvm.Metadata, error) { +func (c *Compiler) attachDebugInfoRaw(f *ir.Function, llvmFn llvm.Value, suffix, filename string, line int) (llvm.Metadata, error) { if _, ok := c.difiles[filename]; !ok { dir, file := filepath.Split(filename) if dir != "" { @@ -862,8 +910,8 @@ func (c *Compiler) attachDebugInfoRaw(f *ir.Function, filename string, line int) Flags: 0, // ? }) difunc := c.dibuilder.CreateFunction(c.difiles[filename], llvm.DIFunction{ - Name: f.RelString(nil), - LinkageName: f.LinkName(), + Name: f.RelString(nil) + suffix, + LinkageName: f.LinkName() + suffix, File: c.difiles[filename], Line: line, Type: diFuncType, @@ -873,7 +921,7 @@ func (c *Compiler) attachDebugInfoRaw(f *ir.Function, filename string, line int) Flags: llvm.FlagPrototyped, Optimized: true, }) - f.LLVMFn.SetSubprogram(difunc) + llvmFn.SetSubprogram(difunc) return difunc, nil } diff --git a/testdata/interface.go b/testdata/interface.go new file mode 100644 index 00000000..f24a995f --- /dev/null +++ b/testdata/interface.go @@ -0,0 +1,87 @@ +package main + +func main() { + thing := &Thing{"foo"} + println("thing:", thing.String()) + thing.Print() + printItf(5) + printItf(byte('x')) + printItf("foo") + printItf(*thing) + printItf(thing) + printItf(Stringer(thing)) + printItf(Number(3)) + array := Array([4]uint32{1, 7, 11, 13}) + printItf(array) + s := Stringer(thing) + println("Stringer.String():", s.String()) + var itf interface{} = s + println("Stringer.(*Thing).String():", itf.(Stringer).String()) +} + +func printItf(val interface{}) { + switch val := val.(type) { + case Doubler: + println("is Doubler:", val.Double()) + case Tuple: + println("is Tuple:", val.Nth(0), val.Nth(1), val.Nth(2), val.Nth(3)) + val.Print() + case int: + println("is int:", val) + case byte: + println("is byte:", val) + case string: + println("is string:", val) + case Thing: + println("is Thing:", val.String()) + case *Thing: + println("is *Thing:", val.String()) + case Foo: + println("is Foo:", val) + default: + println("is ?") + } +} + +type Thing struct { + name string +} + +func (t Thing) String() string { + return t.name +} + +func (t Thing) Print() { + println("Thing.Print:", t.name) +} + +type Stringer interface { + String() string +} + +type Foo int + +type Number int + +func (n Number) Double() int { + return int(n) * 2 +} + +type Doubler interface { + Double() int +} + +type Tuple interface { + Nth(int) uint32 + Print() +} + +type Array [4]uint32 + +func (a Array) Nth(n int) uint32 { + return a[n] +} + +func (a Array) Print() { + println("Array len:", len(a)) +} diff --git a/testdata/interface.txt b/testdata/interface.txt new file mode 100644 index 00000000..f16dfbe0 --- /dev/null +++ b/testdata/interface.txt @@ -0,0 +1,13 @@ +thing: foo +Thing.Print: foo +is int: 5 +is byte: 120 +is string: foo +is Thing: foo +is *Thing: foo +is *Thing: foo +is Doubler: 6 +is Tuple: 1 7 11 13 +Array len: 4 +Stringer.String(): foo +Stringer.(*Thing).String(): foo