From c25b44875801ccba3d3eb8a3223fa3de3b8d10b4 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sat, 25 Aug 2018 01:14:33 +0200 Subject: [PATCH] Rewrite init() interpretation to a real interpreter Instead of mostly heuristics, actually execute the init() instruction in an interpreter to calculate initializers for globals. This is far more flexible and extensible, and gives the option of extending the interpreter to other code and make it a partial evaluator. --- compiler.go | 590 +++++++++++++++++++++---------------------------- interpreter.go | 294 ++++++++++++++++++++++++ ir.go | 7 +- 3 files changed, 553 insertions(+), 338 deletions(-) create mode 100644 interpreter.go diff --git a/compiler.go b/compiler.go index d76a14d0..f7eed444 100644 --- a/compiler.go +++ b/compiler.go @@ -67,6 +67,8 @@ type Phi struct { llvm llvm.Value } +var cgoWrapperError = errors.New("tinygo internal: cgo wrapper") + func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { c := &Compiler{ dumpSSA: dumpSSA, @@ -283,6 +285,41 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { frames = append(frames, frame) } + // Find and interpret package initializers. + for _, frame := range frames { + if frame.fn.fn.Synthetic == "package initializer" { + c.initFuncs = append(c.initFuncs, frame.fn.llvmFn) + if len(frame.fn.fn.Blocks) != 1 { + panic("unexpected number of basic blocks in package initializer") + } + // Try to interpret as much as possible of the init() function. + // Whenever it hits an instruction that it doesn't understand, it + // bails out and leaves the rest to the compiler (so initialization + // continues at runtime). + // This should only happen when it hits a function call or the end + // of the block, ideally. + err := c.ir.Interpret(frame.fn.fn.Blocks[0]) + if err != nil { + return err + } + err = c.parseFunc(frame) + if err != nil { + return err + } + } + } + + // Set values for globals (after package initializer has been interpreted). + for _, g := range c.ir.Globals { + if g.initializer == nil { + continue + } + err := c.parseGlobalInitializer(g) + if err != nil { + return err + } + } + // Add definitions to declarations. for _, frame := range frames { if frame.fn.CName() != "" { @@ -293,8 +330,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } var err error if frame.fn.fn.Synthetic == "package initializer" { - c.initFuncs = append(c.initFuncs, frame.fn.llvmFn) - err = c.parseInitFunc(frame) + continue // already done } else { err = c.parseFunc(frame) } @@ -647,280 +683,6 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) { return frame, nil } -// Special function parser for generated package initializers (which also -// initializes global variables). -// -// What we're doing here is two things: -// * Initialize global variables. The SSA compiler generates alloc/gep/store -// instructions to initialize variables at runtime. But that's inefficient -// in code size and RAM (for const variables) so we're interpreting these -// instructions and store global variables instead. When they're not -// modified, LLVM will automatically make them const. -// In some cases this might even help constant propagation and thus -// performance / code size elsewhere. -// * Call the actual init() functions (init#0(), init#1 etc.) for the package. -func (c *Compiler) parseInitFunc(frame *Frame) error { - if c.dumpSSA { - fmt.Printf("\nfunc %s:\n", frame.fn.fn) - } - frame.fn.llvmFn.SetLinkage(llvm.PrivateLinkage) - llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "entry") - c.builder.SetInsertPointAtEnd(llvmBlock) - - allocs := map[ssa.Value]llvm.Value{} - - for _, block := range frame.fn.fn.DomPreorder() { - if c.dumpSSA { - fmt.Printf("%s:\n", block.Comment) - } - for _, instr := range block.Instrs { - if c.dumpSSA { - if val, ok := instr.(ssa.Value); ok && val.Name() != "" { - fmt.Printf("\t%s = %s\n", val.Name(), val.String()) - } else { - fmt.Printf("\t%s\n", instr.String()) - } - } - var err error - switch instr := instr.(type) { - case *ssa.Alloc: - llvmType, err := c.getLLVMType(instr.Type().Underlying().(*types.Pointer).Elem()) - if err != nil { - return err - } - val, err := getZeroValue(llvmType) - if err != nil { - return err - } - allocs[instr] = val - case *ssa.Call, *ssa.Return: - err = c.parseInstr(frame, instr) - case *ssa.Convert: - // Ignore: CGo pointer conversion. - case *ssa.FieldAddr, *ssa.IndexAddr: - // Ignore: handled below with *ssa.Store. - case *ssa.MakeMap: - // TODO: use the initial map size - mapType := instr.Type().Underlying().(*types.Map) - bucket, keySize, valueSize, err := c.initMapNewBucket(mapType) - if err != nil { - return err - } - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - bucketPtr := llvm.ConstInBoundsGEP(bucket, []llvm.Value{zero}) - hashmapType := c.mod.GetTypeByName("runtime.hashmap") - hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{ - llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next - llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets - llvm.ConstInt(c.lenType, 0, false), // count - llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize - llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize - llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits - }) - allocs[instr] = hashmap - case *ssa.MapUpdate: - // Note: we're assuming here the Go SSA compiler knows what it's - // doing and doesn't insert a key twice. - - // Update hashmap.count. - hashmap := allocs[instr.Map] - count := llvm.ConstExtractValue(hashmap, []uint32{2}).ZExtValue() - count++ - countValue := llvm.ConstInt(c.lenType, count, false) - hashmap = llvm.ConstInsertValue(hashmap, countValue, []uint32{2}) - allocs[instr.Map] = hashmap - - // Select the bucket (chain). - bucketPtr := llvm.ConstExtractValue(hashmap, []uint32{1}) - bucketGlobal := bucketPtr.Operand(0) - - llvmKey, err := c.initParseValue(instr.Key, allocs) - if err != nil { - return err - } - llvmValue, err := c.initParseValue(instr.Value, allocs) - if err != nil { - return err - } - - // Hash for the hashtable. This must be equal to what is - // calculated in the runtime implementation. - key := constant.StringVal(instr.Key.(*ssa.Const).Value) - hash := stringhash(&key) - - // Find an empty spot in the bucket. - done := false - for !done { - bucket := bucketGlobal.Initializer() - for i := uint32(0); i < 8; i++ { - if llvm.ConstExtractValue(bucket, []uint32{0, i}).ZExtValue() != 0 { - // already taken - continue - } - tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false) - bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, i}) - bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, i}) - bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, i}) - bucketGlobal.SetInitializer(bucket) - done = true - break - } - if !done { - nextBucket := llvm.ConstExtractValue(bucket, []uint32{1}) - if nextBucket.IsNull() { - // create new bucket - newBucketGlobal, _, _, err := c.initMapNewBucket(instr.Map.Type().Underlying().(*types.Map)) - if err != nil { - return err - } - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero}) - newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType) - newBucket := newBucketGlobal.Initializer() - // insert pointer into old bucket - bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1}) - bucketGlobal.SetInitializer(bucket) - // switch to next bucket - bucketGlobal = newBucketGlobal - bucket = newBucket - } else { - bucketGlobal = nextBucket.Operand(0) - bucket = bucketGlobal.Initializer() - } - } - } - case *ssa.Slice: - // Turn a just-allocated array into a slice. - if instr.Low != nil || instr.High != nil || instr.Max != nil { - return errors.New("init: slice expression with bounds") - } - var val llvm.Value - if v, ok := allocs[instr.X]; ok { - val = v - } else { - return errors.New("init: slice operation didn't find source data") - } - switch typ := instr.X.Type().Underlying().(type) { - case *types.Pointer: // pointer to array - // make slice from array - length := typ.Elem().(*types.Array).Len() - llvmLen := llvm.ConstInt(c.lenType, uint64(length), false) - global := llvm.AddGlobal(c.mod, val.Type(), ".array") - global.SetInitializer(val) - global.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - globalPtr := c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "") - sliceTyp, err := c.getLLVMType(instr.Type()) - if err != nil { - return err - } - - slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ - llvm.Undef(val.Type()), - llvm.Undef(c.lenType), - llvm.Undef(c.lenType), - }) - slice = c.builder.CreateInsertValue(slice, globalPtr, 0, "") - slice = c.builder.CreateInsertValue(slice, llvmLen, 1, "") - slice = c.builder.CreateInsertValue(slice, llvmLen, 2, "") - allocs[instr] = slice - default: - return errors.New("init: unknown slice type: " + typ.String()) - } - case *ssa.Store: - switch addr := instr.Addr.(type) { - case *ssa.Global: - if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { - // Ignore CGo global variables which we don't use. - continue - } - val, err := c.initParseValue(instr.Val, allocs) - if err != nil { - return err - } - switch valType := instr.Val.Type().Underlying().(type) { - case *types.Basic: - case *types.Map: - // Store pointer to map instead of the map itself. It is - // a reference type, so every access goes through a - // pointer to the real value. Note that this has to be a - // pointer in some cases as the global value can be - // replaced with a different map. - hashmap := llvm.AddGlobal(c.mod, val.Type(), ".hashmap") - hashmap.SetInitializer(val) - hashmap.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - val = llvm.ConstInBoundsGEP(hashmap, []llvm.Value{zero}) - case *types.Pointer: - // Turn this into a pointer to a global object if it - // isn't already. - if val.IsConstant() && val.IsAGlobalVariable().IsNil() { - obj := llvm.AddGlobal(c.mod, val.Type(), ".obj") - obj.SetInitializer(val) - obj.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - val = llvm.ConstInBoundsGEP(obj, []llvm.Value{zero}) - } - case *types.Slice: - case *types.Struct: - default: - return errors.New("init: unknown store type: " + valType.String()) - } - llvmAddr := c.ir.GetGlobal(addr).llvmGlobal - llvmAddr.SetInitializer(val) - default: - val, err := c.initParseValue(instr.Val, allocs) - if err != nil { - return err - } - err = c.initStore(instr.Addr, val, allocs) - if err != nil { - return err - } - } - case *ssa.UnOp: - if instr.Op != token.MUL || instr.CommaOk { - return errors.New("init: unknown unop: " + instr.String()) - } - valPtr, err := c.initParseValue(instr.X, allocs) - if err != nil { - return err - } - // Assume it's a GEP instruction... - val := valPtr.Operand(0) - allocs[instr] = val - default: - return errors.New("unknown init instruction: " + instr.String()) - } - if err != nil { - return err - } - } - } - return nil -} - -func (c *Compiler) initParseValue(val ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { - if cnst, ok := val.(*ssa.Const); ok { - return c.parseConst(cnst) - } else if v, ok := val.(*ssa.Convert); ok { - val, err := c.initParseValue(v.X, allocs) - if err != nil { - return llvm.Value{}, err - } - return c.parseConvert(v.X.Type(), v.Type(), val) - } else if v, ok := val.(*ssa.Global); ok { - global := c.ir.GetGlobal(v) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - globalPtr := c.builder.CreateInBoundsGEP(global.llvmGlobal, []llvm.Value{zero}, "") - return globalPtr, nil - } else if v, ok := allocs[val]; ok { - return v, nil - } else { - return llvm.Value{}, errors.New("todo: init value for store: " + val.String()) - } -} - // Create a new global hashmap bucket, for map initialization. func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uint64, error) { llvmKeyType, err := c.getLLVMType(mapType.Key().Underlying()) @@ -949,67 +711,213 @@ func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uin return bucket, keySize, valueSize, nil } -func (c *Compiler) initStore(addr ssa.Value, val llvm.Value, allocs map[ssa.Value]llvm.Value) error { - switch addr := addr.(type) { - case *ssa.FieldAddr: - return c.initStoreSet(addr.X, val, addr.Field, allocs) - case *ssa.IndexAddr: - if cnst, ok := addr.Index.(*ssa.Const); ok { - index, _ := constant.Int64Val(cnst.Value) - return c.initStoreSet(addr.X, val, int(index), allocs) - } else { - return errors.New("todo: init IndexAddr index: " + addr.Index.String()) +func (c *Compiler) parseGlobalInitializer(g *Global) error { + llvmValue, err := c.getInterpretedValue(g.initializer) + if err != nil { + return err + } + g.llvmGlobal.SetInitializer(llvmValue) + return nil +} + +// Turn a computed Value type (ConstValue, ArrayValue, etc.) into a LLVM value. +// This is used to set the initializer of globals after they have been +// calculated by the package initializer interpreter. +func (c *Compiler) getInterpretedValue(value Value) (llvm.Value, error) { + switch value := value.(type) { + case *ArrayValue: + vals := make([]llvm.Value, len(value.Elems)) + for i, elem := range value.Elems { + val, err := c.getInterpretedValue(elem) + if err != nil { + return llvm.Value{}, err + } + vals[i] = val } + subTyp, err := c.getLLVMType(value.ElemType) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstArray(subTyp, vals), nil + + case *ConstValue: + return c.parseConst(value.Expr) + + case *GlobalValue: + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + ptr := llvm.ConstInBoundsGEP(value.Global.llvmGlobal, []llvm.Value{zero}) + return ptr, nil + + case *MapValue: + // Create initial bucket. + firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(value.Type) + if err != nil { + return llvm.Value{}, err + } + + // Insert each key/value pair in the hashmap. + bucketGlobal := firstBucketGlobal + for i, key := range value.Keys { + llvmKey, err := c.getInterpretedValue(key) + if err != nil { + return llvm.Value{}, nil + } + llvmValue, err := c.getInterpretedValue(value.Values[i]) + if err != nil { + return llvm.Value{}, nil + } + + keyString := constant.StringVal(key.(*ConstValue).Expr.Value) + hash := stringhash(&keyString) + + if i%8 == 0 && i != 0 { + // Bucket is full, create a new one. + newBucketGlobal, _, _, err := c.initMapNewBucket(value.Type) + if err != nil { + return llvm.Value{}, err + } + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero}) + newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType) + // insert pointer into old bucket + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1}) + bucketGlobal.SetInitializer(bucket) + // switch to next bucket + bucketGlobal = newBucketGlobal + } + + tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false) + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, uint32(i % 8)}) + bucketGlobal.SetInitializer(bucket) + } + + // Create the hashmap itself. + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + bucketPtr := llvm.ConstInBoundsGEP(bucketGlobal, []llvm.Value{zero}) + hashmapType := c.mod.GetTypeByName("runtime.hashmap") + hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{ + llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next + llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets + llvm.ConstInt(c.lenType, uint64(len(value.Keys)), false), // count + llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize + llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize + llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits + }) + + // Create a pointer to this hashmap. + hashmapPtr := llvm.AddGlobal(c.mod, hashmap.Type(), ".hashmap") + hashmapPtr.SetInitializer(hashmap) + hashmapPtr.SetLinkage(llvm.PrivateLinkage) + return llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}), nil + + case *PointerBitCastValue: + elem, err := c.getInterpretedValue(value.Elem) + if err != nil { + return llvm.Value{}, err + } + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstBitCast(elem, llvmType), nil + + case *PointerToUintptrValue: + elem, err := c.getInterpretedValue(value.Elem) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstPtrToInt(elem, c.uintptrType), nil + + case *PointerValue: + elem, err := c.getInterpretedValue(*value.Elem) + if err != nil { + return llvm.Value{}, err + } + + obj := llvm.AddGlobal(c.mod, elem.Type(), ".obj") + obj.SetInitializer(elem) + obj.SetLinkage(llvm.PrivateLinkage) + elem = obj + + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + ptr := llvm.ConstInBoundsGEP(elem, []llvm.Value{zero}) + return ptr, nil + + case *SliceValue: + var globalPtr llvm.Value + var arrayLength uint64 + if value.Array == nil { + arrayType, err := c.getLLVMType(value.Type.Elem()) + if err != nil { + return llvm.Value{}, err + } + globalPtr = llvm.ConstPointerNull(llvm.PointerType(arrayType, 0)) + } else { + // make array + array, err := c.getInterpretedValue(value.Array) + if err != nil { + return llvm.Value{}, err + } + // make global from array + global := llvm.AddGlobal(c.mod, array.Type(), ".array") + global.SetInitializer(array) + global.SetLinkage(llvm.PrivateLinkage) + + // get pointer to global + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + globalPtr = c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "") + + arrayLength = uint64(len(value.Array.Elems)) + } + + // make slice + sliceTyp, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + llvmLen := llvm.ConstInt(c.lenType, arrayLength, false) + slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ + globalPtr, // ptr + llvmLen, // len + llvmLen, // cap + }) + return slice, nil + + case *StructValue: + fields := make([]llvm.Value, len(value.Fields)) + for i, elem := range value.Fields { + field, err := c.getInterpretedValue(elem) + if err != nil { + return llvm.Value{}, err + } + fields[i] = field + } + switch value.Type.(type) { + case *types.Named: + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstNamedStruct(llvmType, fields), nil + case *types.Struct: + return llvm.ConstStruct(fields, false), nil + default: + return llvm.Value{}, errors.New("init: unknown struct type: " + value.Type.String()) + } + + case *ZeroBasicValue: + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return getZeroValue(llvmType) + default: - return errors.New("todo: init addr: " + addr.String()) - } -} - -func (c *Compiler) initStoreSet(x ssa.Value, val llvm.Value, index int, allocs map[ssa.Value]llvm.Value) error { - if agg, ok := allocs[x]; ok { - agg = c.builder.CreateInsertValue(agg, val, index, "") - allocs[x] = agg - return nil - } else if g, ok := x.(*ssa.Global); ok { - global := c.ir.GetGlobal(g) - agg := global.llvmGlobal.Initializer() - agg = c.builder.CreateInsertValue(agg, val, index, "") - global.llvmGlobal.SetInitializer(agg) - return nil - } else { - agg, err := c.initStoreGet(x, allocs) - if err != nil { - return err - } - agg = c.builder.CreateInsertValue(agg, val, index, "") - return c.initStore(x, agg, allocs) - } -} - -func (c *Compiler) initStoreGet(x ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { - if val, ok := allocs[x]; ok { - return val, nil - } else if g, ok := x.(*ssa.Global); ok { - return c.ir.GetGlobal(g).llvmGlobal.Initializer(), nil - } else if fa, ok := x.(*ssa.FieldAddr); ok { - val, err := c.initStoreGet(fa.X, allocs) - if err != nil { - return llvm.Value{}, err - } - return c.builder.CreateExtractValue(val, fa.Field, ""), nil - } else if ia, ok := x.(*ssa.IndexAddr); ok { - val, err := c.initStoreGet(ia.X, allocs) - if err != nil { - return llvm.Value{}, err - } - if cnst, ok := ia.Index.(*ssa.Const); ok { - index, _ := constant.Int64Val(cnst.Value) - return c.builder.CreateExtractValue(val, int(index), ""), nil - } else { - return llvm.Value{}, errors.New("initStoreGet: unknown FieldAddr index: " + ia.Index.String()) - } - } else { - return llvm.Value{}, errors.New("initStoreGet: unknown value: " + x.String()) + return llvm.Value{}, errors.New("init: unknown initializer type: " + fmt.Sprintf("%#v", value)) } } @@ -1110,6 +1018,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { switch instr := instr.(type) { case ssa.Value: value, err := c.parseExpr(frame, instr) + if err == cgoWrapperError { + // Ignore CGo global variables which we don't use. + return nil + } frame.locals[instr] = value return err case *ssa.Go: @@ -1215,6 +1127,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { } case *ssa.Store: llvmAddr, err := c.parseExpr(frame, instr.Addr) + if err == cgoWrapperError { + // Ignore CGo global variables which we don't use. + return nil + } if err != nil { return err } @@ -1544,6 +1460,10 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { case *ssa.Function: return c.mod.NamedFunction(c.ir.GetFunction(expr).LinkName(false)), nil case *ssa.Global: + if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + return llvm.Value{}, cgoWrapperError + } value := c.ir.GetGlobal(expr).llvmGlobal if value.IsNil() { return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).LinkName()) diff --git a/interpreter.go b/interpreter.go new file mode 100644 index 00000000..f7c618f0 --- /dev/null +++ b/interpreter.go @@ -0,0 +1,294 @@ +package main + +// This file provides functionality to interpret very basic Go SSA, for +// compile-time initialization of globals. + +import ( + "errors" + "go/constant" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/ssa" +) + +// Interpret instructions as far as possible, and drop those instructions from +// the basic block. +func (p *Program) Interpret(block *ssa.BasicBlock) error { + for { + i, err := p.interpret(block.Instrs) + if err == cgoWrapperError { + // skip this instruction + block.Instrs = block.Instrs[i+1:] + continue + } + block.Instrs = block.Instrs[i:] + return err + } +} + +// Interpret instructions as far as possible, and return the index of the first +// unknown instruction. +func (p *Program) interpret(instrs []ssa.Instruction) (int, error) { + locals := map[ssa.Value]Value{} + for i, instr := range instrs { + switch instr := instr.(type) { + case *ssa.Alloc: + alloc, err := p.getZeroValue(instr.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return i, err + } + locals[instr] = &PointerValue{&alloc} + case *ssa.Convert: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch typ := instr.Type().Underlying().(type) { + case *types.Basic: + if _, ok := instr.X.Type().Underlying().(*types.Pointer); ok && typ.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typ, x} + } else if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer && typ.Kind() == types.Uintptr { + locals[instr] = &PointerToUintptrValue{x} + } else { + return i, errors.New("todo: init: unknown basic convert: " + instr.String()) + } + case *types.Pointer: + if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typ, x} + } else { + panic("expected unsafe pointer conversion") + } + default: + return i, errors.New("todo: init: unknown convert: " + instr.String()) + } + case *ssa.FieldAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + var structVal *StructValue + switch x := x.(type) { + case *GlobalValue: + structVal = x.Global.initializer.(*StructValue) + case *PointerValue: + structVal = (*x.Elem).(*StructValue) + default: + panic("expected a pointer") + } + locals[instr] = &PointerValue{&structVal.Fields[instr.Field]} + case *ssa.IndexAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + if cnst, ok := instr.Index.(*ssa.Const); ok { + index, _ := constant.Int64Val(cnst.Value) + switch xPtr := x.(type) { + case *GlobalValue: + x = xPtr.Global.initializer + case *PointerValue: + x = *xPtr.Elem + default: + panic("expected a pointer") + } + switch x := x.(type) { + case *ArrayValue: + locals[instr] = &PointerValue{&x.Elems[index]} + default: + return i, errors.New("todo: init IndexAddr not on an array or struct") + } + } else { + return i, errors.New("todo: init IndexAddr index: " + instr.Index.String()) + } + case *ssa.UnOp: + if instr.Op != token.MUL || instr.CommaOk { + return i, errors.New("init: unknown unop: " + instr.String()) + } + valPtr, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch valPtr := valPtr.(type) { + case *GlobalValue: + locals[instr] = valPtr.Global.initializer + case *PointerValue: + locals[instr] = *valPtr.Elem + default: + panic("expected a pointer") + } + case *ssa.MakeMap: + locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil} + case *ssa.MapUpdate: + // Assume no duplicate keys exist. This is most likely true for + // autogenerated code, but may not be true when trying to interpret + // user code. + key, err := p.getValue(instr.Key, locals) + if err != nil { + return i, err + } + value, err := p.getValue(instr.Value, locals) + if err != nil { + return i, err + } + x := locals[instr.Map].(*MapValue) + x.Keys = append(x.Keys, key) + x.Values = append(x.Values, value) + case *ssa.Slice: + // Turn a just-allocated array into a slice. + if instr.Low != nil || instr.High != nil || instr.Max != nil { + return i, errors.New("init: slice expression with bounds") + } + source, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch source := source.(type) { + case *PointerValue: // pointer to array + array := (*source.Elem).(*ArrayValue) + locals[instr] = &SliceValue{instr.Type().Underlying().(*types.Slice), array} + default: + return i, errors.New("init: unknown slice type") + } + case *ssa.Store: + if addr, ok := instr.Addr.(*ssa.Global); ok { + if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + continue + } + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + p.GetGlobal(addr).initializer = value + } else if addr, ok := locals[instr.Addr]; ok { + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + if addr, ok := addr.(*PointerValue); ok { + *(addr.Elem) = value + } else { + panic("store to non-pointer") + } + } else { + return i, errors.New("todo: init Store: " + instr.String()) + } + default: + return i, nil + } + } + return len(instrs), nil // normally unreachable +} + +func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value, error) { + switch value := value.(type) { + case *ssa.Const: + return &ConstValue{value}, nil + case *ssa.Global: + if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + return nil, cgoWrapperError + } + g := p.GetGlobal(value) + if g.initializer == nil { + value, err := p.getZeroValue(value.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return nil, err + } + g.initializer = value + } + return &GlobalValue{g}, nil + //return &PointerValue{&g.initializer}, nil + default: + if local, ok := locals[value]; ok { + return local, nil + } else { + return nil, errors.New("todo: init: unknown value: " + value.String()) + } + } +} + +func (p *Program) getZeroValue(t types.Type) (Value, error) { + switch typ := t.Underlying().(type) { + case *types.Array: + elems := make([]Value, typ.Len()) + for i := range elems { + elem, err := p.getZeroValue(typ.Elem()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &ArrayValue{typ.Elem(), elems}, nil + case *types.Basic: + return &ZeroBasicValue{typ}, nil + case *types.Pointer: + return &PointerValue{nil}, nil + case *types.Struct: + elems := make([]Value, typ.NumFields()) + for i := range elems { + elem, err := p.getZeroValue(typ.Field(i).Type().Underlying()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &StructValue{t, elems}, nil + case *types.Slice: + return &SliceValue{typ, nil}, nil + default: + return nil, errors.New("todo: init: unknown global type: " + typ.String()) + } +} + +// Boxed value for interpreter. +type Value interface { +} + +type ConstValue struct { + Expr *ssa.Const +} + +type ZeroBasicValue struct { + Type *types.Basic +} + +type PointerValue struct { + Elem *Value +} + +type PointerBitCastValue struct { + Type types.Type + Elem Value +} + +type PointerToUintptrValue struct { + Elem Value +} + +type GlobalValue struct { + Global *Global +} + +type ArrayValue struct { + ElemType types.Type + Elems []Value +} + +type StructValue struct { + Type types.Type // types.Struct or types.Named + Fields []Value +} + +type SliceValue struct { + Type *types.Slice + Array *ArrayValue +} + +type MapValue struct { + Type *types.Map + Keys []Value + Values []Value +} diff --git a/ir.go b/ir.go index 4d628e93..9a2abd82 100644 --- a/ir.go +++ b/ir.go @@ -38,9 +38,10 @@ type Function struct { // Global variable, possibly constant. type Global struct { - g *ssa.Global - llvmGlobal llvm.Value - flag bool // used by dead code elimination + g *ssa.Global + llvmGlobal llvm.Value + flag bool // used by dead code elimination + initializer Value } // Type with a name and possibly methods.