diff --git a/tgo.go b/tgo.go index ab75cddb..ab53ab5d 100644 --- a/tgo.go +++ b/tgo.go @@ -48,6 +48,7 @@ type Compiler struct { printbyteFunc llvm.Value printspaceFunc llvm.Value printnlFunc llvm.Value + memsetIntrinsic llvm.Value itfTypeNumbers map[types.Type]uint64 itfTypes []types.Type } @@ -107,6 +108,16 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) { printnlType := llvm.FunctionType(llvm.VoidType(), nil, false) c.printnlFunc = llvm.AddFunction(c.mod, "runtime.printnl", printnlType) + // Intrinsic functions + memsetType := llvm.FunctionType( + llvm.VoidType(), []llvm.Type{ + llvm.PointerType(llvm.Int8Type(), 0), + llvm.Int8Type(), + llvm.Int32Type(), + llvm.Int1Type(), + }, false) + c.memsetIntrinsic = llvm.AddFunction(c.mod, "llvm.memset.p0i8.i32", memsetType) + return c, nil } @@ -244,6 +255,12 @@ func (c *Compiler) Parse(pkgName string) error { func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { fmt.Println(" type:", goType) switch typ := goType.(type) { + case *types.Array: + elemType, err := c.getLLVMType(typ.Elem()) + if err != nil { + return llvm.Type{}, err + } + return llvm.ArrayType(elemType, int(typ.Len())), nil case *types.Basic: switch typ.Kind() { case types.Bool: @@ -563,12 +580,28 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } + var buf llvm.Value if expr.Heap { // TODO: escape analysis - return c.builder.CreateMalloc(typ, expr.Comment), nil + buf = c.builder.CreateMalloc(typ, expr.Comment) } else { - return c.builder.CreateAlloca(typ, expr.Comment), nil + buf = c.builder.CreateAlloca(typ, expr.Comment) } + width := c.targetData.TypeAllocSize(typ) + if err != nil { + return llvm.Value{}, err + } + llvmWidth := llvm.ConstInt(llvm.Int32Type(), width, false) + bufBytes := c.builder.CreateBitCast(buf, llvm.PointerType(llvm.Int8Type(), 0), "") + c.builder.CreateCall( + c.memsetIntrinsic, + []llvm.Value{ + bufBytes, + llvm.ConstInt(llvm.Int8Type(), 0, false), // value to set (zero) + llvmWidth, // size to zero + llvm.ConstInt(llvm.Int1Type(), 0, false), // volatile + }, "") + return buf, nil case *ssa.BinOp: return c.parseBinOp(frame, expr) case *ssa.Call: @@ -603,6 +636,21 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, errors.New("global not found: " + expr.Name()) } return value, nil + case *ssa.IndexAddr: + val, err := c.parseExpr(frame, expr.X) + if err != nil { + return llvm.Value{}, err + } + index, err := c.parseExpr(frame, expr.Index) + if err != nil { + return llvm.Value{}, err + } + // TODO: bounds check + indices := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + index, + } + return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Lookup: if expr.CommaOk { return llvm.Value{}, errors.New("todo: lookup with comma-ok")