diff --git a/compiler.go b/compiler.go index 389f0c02..fdc764cd 100644 --- a/compiler.go +++ b/compiler.go @@ -654,6 +654,8 @@ func (c *Compiler) parseInitFunc(frame *Frame) error { llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "entry") c.builder.SetInsertPointAtEnd(llvmBlock) + allocs := map[ssa.Value]llvm.Value{} + for _, block := range frame.fn.fn.DomPreorder() { if c.dumpSSA { fmt.Printf("%s:\n", block.Comment) @@ -668,6 +670,16 @@ func (c *Compiler) parseInitFunc(frame *Frame) error { } var err error switch instr := instr.(type) { + case *ssa.Alloc: + llvmType, err := c.getLLVMType(instr.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return err + } + val, err := getZeroValue(llvmType) + if err != nil { + return err + } + allocs[instr] = val case *ssa.Call, *ssa.Return: err = c.parseInstr(frame, instr) case *ssa.Convert: @@ -677,64 +689,25 @@ func (c *Compiler) parseInitFunc(frame *Frame) error { case *ssa.Store: switch addr := instr.Addr.(type) { case *ssa.Global: - // Regular store, like a global int variable. - if strings.HasPrefix(addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(addr.Name(), "_cgo_") { + if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { // Ignore CGo global variables which we don't use. continue } - val, err := c.parseExpr(frame, instr.Val) + val, err := c.initParseValue(instr.Val, allocs) if err != nil { return err } llvmAddr := c.ir.GetGlobal(addr).llvmGlobal llvmAddr.SetInitializer(val) - case *ssa.FieldAddr: - // Initialize field of a global struct. - // LLVM does not allow setting an initializer on part of a - // global variable. So we take the current initializer, add - // the field, and replace the initializer with the new - // initializer. - val, err := c.parseExpr(frame, instr.Val) - if err != nil { - return err - } - global := addr.X.(*ssa.Global) - llvmAddr := c.ir.GetGlobal(global).llvmGlobal - llvmValue := llvmAddr.Initializer() - if llvmValue.IsNil() { - llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) - if err != nil { - return err - } - } - llvmValue = c.builder.CreateInsertValue(llvmValue, val, addr.Field, "") - llvmAddr.SetInitializer(llvmValue) - case *ssa.IndexAddr: - val, err := c.parseExpr(frame, instr.Val) - if err != nil { - return err - } - constIndex := addr.Index.(*ssa.Const) - index, exact := constant.Int64Val(constIndex.Value) - if !exact { - return errors.New("could not get store index: " + constIndex.Value.ExactString()) - } - fieldAddr := addr.X.(*ssa.FieldAddr) - global := fieldAddr.X.(*ssa.Global) - llvmAddr := c.ir.GetGlobal(global).llvmGlobal - llvmValue := llvmAddr.Initializer() - if llvmValue.IsNil() { - llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) - if err != nil { - return err - } - } - llvmFieldValue := c.builder.CreateExtractValue(llvmValue, fieldAddr.Field, "") - llvmFieldValue = c.builder.CreateInsertValue(llvmFieldValue, val, int(index), "") - llvmValue = c.builder.CreateInsertValue(llvmValue, llvmFieldValue, fieldAddr.Field, "") - llvmAddr.SetInitializer(llvmValue) default: - return errors.New("unknown init store: " + addr.String()) + val, err := c.initParseValue(instr.Val, allocs) + if err != nil { + return err + } + err = c.initStore(instr.Addr, val, allocs) + if err != nil { + return err + } } default: return errors.New("unknown init instruction: " + instr.String()) @@ -747,6 +720,92 @@ func (c *Compiler) parseInitFunc(frame *Frame) error { return nil } +func (c *Compiler) initParseValue(val ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { + if cnst, ok := val.(*ssa.Const); ok { + return c.parseConst(cnst) + } else if v, ok := val.(*ssa.Convert); ok { + // hopefully the same type under the hood + val, err := c.initParseValue(v.X, allocs) + if err != nil { + return llvm.Value{}, err + } + return c.parseConvert(v.X.Type(), v.Type(), val) + } else if v, ok := val.(*ssa.Global); ok { + global := c.ir.GetGlobal(v) + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + globalPtr := c.builder.CreateInBoundsGEP(global.llvmGlobal, []llvm.Value{zero}, "") + return globalPtr, nil + } else if v, ok := allocs[val]; ok { + return v, nil + } else { + return llvm.Value{}, errors.New("todo: init value for store: " + val.String()) + } +} + +func (c *Compiler) initStore(addr ssa.Value, val llvm.Value, allocs map[ssa.Value]llvm.Value) error { + switch addr := addr.(type) { + case *ssa.FieldAddr: + return c.initStoreSet(addr.X, val, addr.Field, allocs) + case *ssa.IndexAddr: + if cnst, ok := addr.Index.(*ssa.Const); ok { + index, _ := constant.Int64Val(cnst.Value) + return c.initStoreSet(addr.X, val, int(index), allocs) + } else { + return errors.New("todo: init IndexAddr index: " + addr.Index.String()) + } + default: + return errors.New("todo: init addr: " + addr.String()) + } +} + +func (c *Compiler) initStoreSet(x ssa.Value, val llvm.Value, index int, allocs map[ssa.Value]llvm.Value) error { + if agg, ok := allocs[x]; ok { + agg = c.builder.CreateInsertValue(agg, val, index, "") + allocs[x] = agg + return nil + } else if g, ok := x.(*ssa.Global); ok { + global := c.ir.GetGlobal(g) + agg := global.llvmGlobal.Initializer() + agg = c.builder.CreateInsertValue(agg, val, index, "") + global.llvmGlobal.SetInitializer(agg) + return nil + } else { + agg, err := c.initStoreGet(x, allocs) + if err != nil { + return err + } + agg = c.builder.CreateInsertValue(agg, val, index, "") + return c.initStore(x, agg, allocs) + } +} + +func (c *Compiler) initStoreGet(x ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { + if val, ok := allocs[x]; ok { + return val, nil + } else if g, ok := x.(*ssa.Global); ok { + return c.ir.GetGlobal(g).llvmGlobal.Initializer(), nil + } else if fa, ok := x.(*ssa.FieldAddr); ok { + val, err := c.initStoreGet(fa.X, allocs) + if err != nil { + return llvm.Value{}, err + } + return c.builder.CreateExtractValue(val, fa.Field, ""), nil + } else if ia, ok := x.(*ssa.IndexAddr); ok { + val, err := c.initStoreGet(ia.X, allocs) + if err != nil { + return llvm.Value{}, err + } + if cnst, ok := ia.Index.(*ssa.Const); ok { + index, _ := constant.Int64Val(cnst.Value) + return c.builder.CreateExtractValue(val, int(index), ""), nil + } else { + return llvm.Value{}, errors.New("initStoreGet: unknown FieldAddr index: " + ia.Index.String()) + } + } else { + return llvm.Value{}, errors.New("initStoreGet: unknown value: " + x.String()) + } +} + func (c *Compiler) parseFunc(frame *Frame) error { if c.dumpSSA { fmt.Printf("\nfunc %s:\n", frame.fn.fn)