diff --git a/tgo.go b/tgo.go index 2716d05f..830845eb 100644 --- a/tgo.go +++ b/tgo.go @@ -34,6 +34,7 @@ type Compiler struct { ctx llvm.Context builder llvm.Builder machine llvm.TargetMachine + targetData llvm.TargetData intType llvm.Type stringLenType llvm.Type stringType llvm.Type @@ -73,6 +74,7 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) { return nil, err } c.machine = target.CreateTargetMachine(triple, "", "", llvm.CodeGenLevelDefault, llvm.RelocDefault, llvm.CodeModelDefault) + c.targetData = c.machine.CreateTargetData() c.mod = llvm.NewModule(pkgName) c.ctx = c.mod.Context() @@ -269,6 +271,26 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } } +func (c *Compiler) getTypeWidth(typ types.Type) (int, error) { + switch typ := typ.(type) { + case *types.Basic: + if typ.Kind() == types.UnsafePointer { + return c.targetData.PointerSize(), nil + } + llvmType, err := c.getLLVMType(typ) + if err != nil { + return 0, err + } + return llvmType.IntTypeWidth(), nil + case *types.Named: + return c.getTypeWidth(typ.Underlying()) + case *types.Pointer: + return c.targetData.PointerSize(), nil + default: + return 0, errors.New("todo: type width") + } +} + func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value { if _, ok := c.itfTypeNumbers[typ]; !ok { num := uint64(len(c.itfTypes)) @@ -716,7 +738,7 @@ func (c *Compiler) parseConvert(frame *Frame, expr *ssa.Convert) (llvm.Value, er return value, nil } - typeFrom, err := c.getLLVMType(expr.X.Type()) + sizeFrom, err := c.getTypeWidth(expr.X.Type()) if err != nil { return llvm.Value{}, err } @@ -724,11 +746,15 @@ func (c *Compiler) parseConvert(frame *Frame, expr *ssa.Convert) (llvm.Value, er if err != nil { return llvm.Value{}, err } - sizeFrom := typeFrom.IntTypeWidth() - sizeTo := typeTo.IntTypeWidth() + sizeTo, err := c.getTypeWidth(expr.Type()) + if err != nil { + return llvm.Value{}, err + } - if sizeFrom >= sizeTo { - return c.builder.CreateTruncOrBitCast(value, typeTo, ""), nil + if sizeFrom > sizeTo { + return c.builder.CreateTrunc(value, typeTo, ""), nil + } else if sizeFrom == sizeTo { + return c.builder.CreateBitCast(value, typeTo, ""), nil } else { // sizeFrom < sizeTo: extend switch typ := expr.X.Type().(type) { // typeFrom case *types.Basic: