diff --git a/compiler.go b/compiler.go index d76a14d0..f7eed444 100644 --- a/compiler.go +++ b/compiler.go @@ -67,6 +67,8 @@ type Phi struct { llvm llvm.Value } +var cgoWrapperError = errors.New("tinygo internal: cgo wrapper") + func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { c := &Compiler{ dumpSSA: dumpSSA, @@ -283,6 +285,41 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { frames = append(frames, frame) } + // Find and interpret package initializers. + for _, frame := range frames { + if frame.fn.fn.Synthetic == "package initializer" { + c.initFuncs = append(c.initFuncs, frame.fn.llvmFn) + if len(frame.fn.fn.Blocks) != 1 { + panic("unexpected number of basic blocks in package initializer") + } + // Try to interpret as much as possible of the init() function. + // Whenever it hits an instruction that it doesn't understand, it + // bails out and leaves the rest to the compiler (so initialization + // continues at runtime). + // This should only happen when it hits a function call or the end + // of the block, ideally. + err := c.ir.Interpret(frame.fn.fn.Blocks[0]) + if err != nil { + return err + } + err = c.parseFunc(frame) + if err != nil { + return err + } + } + } + + // Set values for globals (after package initializer has been interpreted). + for _, g := range c.ir.Globals { + if g.initializer == nil { + continue + } + err := c.parseGlobalInitializer(g) + if err != nil { + return err + } + } + // Add definitions to declarations. for _, frame := range frames { if frame.fn.CName() != "" { @@ -293,8 +330,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } var err error if frame.fn.fn.Synthetic == "package initializer" { - c.initFuncs = append(c.initFuncs, frame.fn.llvmFn) - err = c.parseInitFunc(frame) + continue // already done } else { err = c.parseFunc(frame) } @@ -647,280 +683,6 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) { return frame, nil } -// Special function parser for generated package initializers (which also -// initializes global variables). -// -// What we're doing here is two things: -// * Initialize global variables. The SSA compiler generates alloc/gep/store -// instructions to initialize variables at runtime. But that's inefficient -// in code size and RAM (for const variables) so we're interpreting these -// instructions and store global variables instead. When they're not -// modified, LLVM will automatically make them const. -// In some cases this might even help constant propagation and thus -// performance / code size elsewhere. -// * Call the actual init() functions (init#0(), init#1 etc.) for the package. -func (c *Compiler) parseInitFunc(frame *Frame) error { - if c.dumpSSA { - fmt.Printf("\nfunc %s:\n", frame.fn.fn) - } - frame.fn.llvmFn.SetLinkage(llvm.PrivateLinkage) - llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "entry") - c.builder.SetInsertPointAtEnd(llvmBlock) - - allocs := map[ssa.Value]llvm.Value{} - - for _, block := range frame.fn.fn.DomPreorder() { - if c.dumpSSA { - fmt.Printf("%s:\n", block.Comment) - } - for _, instr := range block.Instrs { - if c.dumpSSA { - if val, ok := instr.(ssa.Value); ok && val.Name() != "" { - fmt.Printf("\t%s = %s\n", val.Name(), val.String()) - } else { - fmt.Printf("\t%s\n", instr.String()) - } - } - var err error - switch instr := instr.(type) { - case *ssa.Alloc: - llvmType, err := c.getLLVMType(instr.Type().Underlying().(*types.Pointer).Elem()) - if err != nil { - return err - } - val, err := getZeroValue(llvmType) - if err != nil { - return err - } - allocs[instr] = val - case *ssa.Call, *ssa.Return: - err = c.parseInstr(frame, instr) - case *ssa.Convert: - // Ignore: CGo pointer conversion. - case *ssa.FieldAddr, *ssa.IndexAddr: - // Ignore: handled below with *ssa.Store. - case *ssa.MakeMap: - // TODO: use the initial map size - mapType := instr.Type().Underlying().(*types.Map) - bucket, keySize, valueSize, err := c.initMapNewBucket(mapType) - if err != nil { - return err - } - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - bucketPtr := llvm.ConstInBoundsGEP(bucket, []llvm.Value{zero}) - hashmapType := c.mod.GetTypeByName("runtime.hashmap") - hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{ - llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next - llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets - llvm.ConstInt(c.lenType, 0, false), // count - llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize - llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize - llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits - }) - allocs[instr] = hashmap - case *ssa.MapUpdate: - // Note: we're assuming here the Go SSA compiler knows what it's - // doing and doesn't insert a key twice. - - // Update hashmap.count. - hashmap := allocs[instr.Map] - count := llvm.ConstExtractValue(hashmap, []uint32{2}).ZExtValue() - count++ - countValue := llvm.ConstInt(c.lenType, count, false) - hashmap = llvm.ConstInsertValue(hashmap, countValue, []uint32{2}) - allocs[instr.Map] = hashmap - - // Select the bucket (chain). - bucketPtr := llvm.ConstExtractValue(hashmap, []uint32{1}) - bucketGlobal := bucketPtr.Operand(0) - - llvmKey, err := c.initParseValue(instr.Key, allocs) - if err != nil { - return err - } - llvmValue, err := c.initParseValue(instr.Value, allocs) - if err != nil { - return err - } - - // Hash for the hashtable. This must be equal to what is - // calculated in the runtime implementation. - key := constant.StringVal(instr.Key.(*ssa.Const).Value) - hash := stringhash(&key) - - // Find an empty spot in the bucket. - done := false - for !done { - bucket := bucketGlobal.Initializer() - for i := uint32(0); i < 8; i++ { - if llvm.ConstExtractValue(bucket, []uint32{0, i}).ZExtValue() != 0 { - // already taken - continue - } - tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false) - bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, i}) - bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, i}) - bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, i}) - bucketGlobal.SetInitializer(bucket) - done = true - break - } - if !done { - nextBucket := llvm.ConstExtractValue(bucket, []uint32{1}) - if nextBucket.IsNull() { - // create new bucket - newBucketGlobal, _, _, err := c.initMapNewBucket(instr.Map.Type().Underlying().(*types.Map)) - if err != nil { - return err - } - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero}) - newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType) - newBucket := newBucketGlobal.Initializer() - // insert pointer into old bucket - bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1}) - bucketGlobal.SetInitializer(bucket) - // switch to next bucket - bucketGlobal = newBucketGlobal - bucket = newBucket - } else { - bucketGlobal = nextBucket.Operand(0) - bucket = bucketGlobal.Initializer() - } - } - } - case *ssa.Slice: - // Turn a just-allocated array into a slice. - if instr.Low != nil || instr.High != nil || instr.Max != nil { - return errors.New("init: slice expression with bounds") - } - var val llvm.Value - if v, ok := allocs[instr.X]; ok { - val = v - } else { - return errors.New("init: slice operation didn't find source data") - } - switch typ := instr.X.Type().Underlying().(type) { - case *types.Pointer: // pointer to array - // make slice from array - length := typ.Elem().(*types.Array).Len() - llvmLen := llvm.ConstInt(c.lenType, uint64(length), false) - global := llvm.AddGlobal(c.mod, val.Type(), ".array") - global.SetInitializer(val) - global.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - globalPtr := c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "") - sliceTyp, err := c.getLLVMType(instr.Type()) - if err != nil { - return err - } - - slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ - llvm.Undef(val.Type()), - llvm.Undef(c.lenType), - llvm.Undef(c.lenType), - }) - slice = c.builder.CreateInsertValue(slice, globalPtr, 0, "") - slice = c.builder.CreateInsertValue(slice, llvmLen, 1, "") - slice = c.builder.CreateInsertValue(slice, llvmLen, 2, "") - allocs[instr] = slice - default: - return errors.New("init: unknown slice type: " + typ.String()) - } - case *ssa.Store: - switch addr := instr.Addr.(type) { - case *ssa.Global: - if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { - // Ignore CGo global variables which we don't use. - continue - } - val, err := c.initParseValue(instr.Val, allocs) - if err != nil { - return err - } - switch valType := instr.Val.Type().Underlying().(type) { - case *types.Basic: - case *types.Map: - // Store pointer to map instead of the map itself. It is - // a reference type, so every access goes through a - // pointer to the real value. Note that this has to be a - // pointer in some cases as the global value can be - // replaced with a different map. - hashmap := llvm.AddGlobal(c.mod, val.Type(), ".hashmap") - hashmap.SetInitializer(val) - hashmap.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - val = llvm.ConstInBoundsGEP(hashmap, []llvm.Value{zero}) - case *types.Pointer: - // Turn this into a pointer to a global object if it - // isn't already. - if val.IsConstant() && val.IsAGlobalVariable().IsNil() { - obj := llvm.AddGlobal(c.mod, val.Type(), ".obj") - obj.SetInitializer(val) - obj.SetLinkage(llvm.PrivateLinkage) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - val = llvm.ConstInBoundsGEP(obj, []llvm.Value{zero}) - } - case *types.Slice: - case *types.Struct: - default: - return errors.New("init: unknown store type: " + valType.String()) - } - llvmAddr := c.ir.GetGlobal(addr).llvmGlobal - llvmAddr.SetInitializer(val) - default: - val, err := c.initParseValue(instr.Val, allocs) - if err != nil { - return err - } - err = c.initStore(instr.Addr, val, allocs) - if err != nil { - return err - } - } - case *ssa.UnOp: - if instr.Op != token.MUL || instr.CommaOk { - return errors.New("init: unknown unop: " + instr.String()) - } - valPtr, err := c.initParseValue(instr.X, allocs) - if err != nil { - return err - } - // Assume it's a GEP instruction... - val := valPtr.Operand(0) - allocs[instr] = val - default: - return errors.New("unknown init instruction: " + instr.String()) - } - if err != nil { - return err - } - } - } - return nil -} - -func (c *Compiler) initParseValue(val ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { - if cnst, ok := val.(*ssa.Const); ok { - return c.parseConst(cnst) - } else if v, ok := val.(*ssa.Convert); ok { - val, err := c.initParseValue(v.X, allocs) - if err != nil { - return llvm.Value{}, err - } - return c.parseConvert(v.X.Type(), v.Type(), val) - } else if v, ok := val.(*ssa.Global); ok { - global := c.ir.GetGlobal(v) - zero := llvm.ConstInt(llvm.Int32Type(), 0, false) - globalPtr := c.builder.CreateInBoundsGEP(global.llvmGlobal, []llvm.Value{zero}, "") - return globalPtr, nil - } else if v, ok := allocs[val]; ok { - return v, nil - } else { - return llvm.Value{}, errors.New("todo: init value for store: " + val.String()) - } -} - // Create a new global hashmap bucket, for map initialization. func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uint64, error) { llvmKeyType, err := c.getLLVMType(mapType.Key().Underlying()) @@ -949,67 +711,213 @@ func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uin return bucket, keySize, valueSize, nil } -func (c *Compiler) initStore(addr ssa.Value, val llvm.Value, allocs map[ssa.Value]llvm.Value) error { - switch addr := addr.(type) { - case *ssa.FieldAddr: - return c.initStoreSet(addr.X, val, addr.Field, allocs) - case *ssa.IndexAddr: - if cnst, ok := addr.Index.(*ssa.Const); ok { - index, _ := constant.Int64Val(cnst.Value) - return c.initStoreSet(addr.X, val, int(index), allocs) - } else { - return errors.New("todo: init IndexAddr index: " + addr.Index.String()) +func (c *Compiler) parseGlobalInitializer(g *Global) error { + llvmValue, err := c.getInterpretedValue(g.initializer) + if err != nil { + return err + } + g.llvmGlobal.SetInitializer(llvmValue) + return nil +} + +// Turn a computed Value type (ConstValue, ArrayValue, etc.) into a LLVM value. +// This is used to set the initializer of globals after they have been +// calculated by the package initializer interpreter. +func (c *Compiler) getInterpretedValue(value Value) (llvm.Value, error) { + switch value := value.(type) { + case *ArrayValue: + vals := make([]llvm.Value, len(value.Elems)) + for i, elem := range value.Elems { + val, err := c.getInterpretedValue(elem) + if err != nil { + return llvm.Value{}, err + } + vals[i] = val } + subTyp, err := c.getLLVMType(value.ElemType) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstArray(subTyp, vals), nil + + case *ConstValue: + return c.parseConst(value.Expr) + + case *GlobalValue: + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + ptr := llvm.ConstInBoundsGEP(value.Global.llvmGlobal, []llvm.Value{zero}) + return ptr, nil + + case *MapValue: + // Create initial bucket. + firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(value.Type) + if err != nil { + return llvm.Value{}, err + } + + // Insert each key/value pair in the hashmap. + bucketGlobal := firstBucketGlobal + for i, key := range value.Keys { + llvmKey, err := c.getInterpretedValue(key) + if err != nil { + return llvm.Value{}, nil + } + llvmValue, err := c.getInterpretedValue(value.Values[i]) + if err != nil { + return llvm.Value{}, nil + } + + keyString := constant.StringVal(key.(*ConstValue).Expr.Value) + hash := stringhash(&keyString) + + if i%8 == 0 && i != 0 { + // Bucket is full, create a new one. + newBucketGlobal, _, _, err := c.initMapNewBucket(value.Type) + if err != nil { + return llvm.Value{}, err + } + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero}) + newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType) + // insert pointer into old bucket + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1}) + bucketGlobal.SetInitializer(bucket) + // switch to next bucket + bucketGlobal = newBucketGlobal + } + + tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false) + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, uint32(i % 8)}) + bucketGlobal.SetInitializer(bucket) + } + + // Create the hashmap itself. + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + bucketPtr := llvm.ConstInBoundsGEP(bucketGlobal, []llvm.Value{zero}) + hashmapType := c.mod.GetTypeByName("runtime.hashmap") + hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{ + llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next + llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets + llvm.ConstInt(c.lenType, uint64(len(value.Keys)), false), // count + llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize + llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize + llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits + }) + + // Create a pointer to this hashmap. + hashmapPtr := llvm.AddGlobal(c.mod, hashmap.Type(), ".hashmap") + hashmapPtr.SetInitializer(hashmap) + hashmapPtr.SetLinkage(llvm.PrivateLinkage) + return llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}), nil + + case *PointerBitCastValue: + elem, err := c.getInterpretedValue(value.Elem) + if err != nil { + return llvm.Value{}, err + } + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstBitCast(elem, llvmType), nil + + case *PointerToUintptrValue: + elem, err := c.getInterpretedValue(value.Elem) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstPtrToInt(elem, c.uintptrType), nil + + case *PointerValue: + elem, err := c.getInterpretedValue(*value.Elem) + if err != nil { + return llvm.Value{}, err + } + + obj := llvm.AddGlobal(c.mod, elem.Type(), ".obj") + obj.SetInitializer(elem) + obj.SetLinkage(llvm.PrivateLinkage) + elem = obj + + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + ptr := llvm.ConstInBoundsGEP(elem, []llvm.Value{zero}) + return ptr, nil + + case *SliceValue: + var globalPtr llvm.Value + var arrayLength uint64 + if value.Array == nil { + arrayType, err := c.getLLVMType(value.Type.Elem()) + if err != nil { + return llvm.Value{}, err + } + globalPtr = llvm.ConstPointerNull(llvm.PointerType(arrayType, 0)) + } else { + // make array + array, err := c.getInterpretedValue(value.Array) + if err != nil { + return llvm.Value{}, err + } + // make global from array + global := llvm.AddGlobal(c.mod, array.Type(), ".array") + global.SetInitializer(array) + global.SetLinkage(llvm.PrivateLinkage) + + // get pointer to global + zero := llvm.ConstInt(llvm.Int32Type(), 0, false) + globalPtr = c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "") + + arrayLength = uint64(len(value.Array.Elems)) + } + + // make slice + sliceTyp, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + llvmLen := llvm.ConstInt(c.lenType, arrayLength, false) + slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ + globalPtr, // ptr + llvmLen, // len + llvmLen, // cap + }) + return slice, nil + + case *StructValue: + fields := make([]llvm.Value, len(value.Fields)) + for i, elem := range value.Fields { + field, err := c.getInterpretedValue(elem) + if err != nil { + return llvm.Value{}, err + } + fields[i] = field + } + switch value.Type.(type) { + case *types.Named: + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return llvm.ConstNamedStruct(llvmType, fields), nil + case *types.Struct: + return llvm.ConstStruct(fields, false), nil + default: + return llvm.Value{}, errors.New("init: unknown struct type: " + value.Type.String()) + } + + case *ZeroBasicValue: + llvmType, err := c.getLLVMType(value.Type) + if err != nil { + return llvm.Value{}, err + } + return getZeroValue(llvmType) + default: - return errors.New("todo: init addr: " + addr.String()) - } -} - -func (c *Compiler) initStoreSet(x ssa.Value, val llvm.Value, index int, allocs map[ssa.Value]llvm.Value) error { - if agg, ok := allocs[x]; ok { - agg = c.builder.CreateInsertValue(agg, val, index, "") - allocs[x] = agg - return nil - } else if g, ok := x.(*ssa.Global); ok { - global := c.ir.GetGlobal(g) - agg := global.llvmGlobal.Initializer() - agg = c.builder.CreateInsertValue(agg, val, index, "") - global.llvmGlobal.SetInitializer(agg) - return nil - } else { - agg, err := c.initStoreGet(x, allocs) - if err != nil { - return err - } - agg = c.builder.CreateInsertValue(agg, val, index, "") - return c.initStore(x, agg, allocs) - } -} - -func (c *Compiler) initStoreGet(x ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) { - if val, ok := allocs[x]; ok { - return val, nil - } else if g, ok := x.(*ssa.Global); ok { - return c.ir.GetGlobal(g).llvmGlobal.Initializer(), nil - } else if fa, ok := x.(*ssa.FieldAddr); ok { - val, err := c.initStoreGet(fa.X, allocs) - if err != nil { - return llvm.Value{}, err - } - return c.builder.CreateExtractValue(val, fa.Field, ""), nil - } else if ia, ok := x.(*ssa.IndexAddr); ok { - val, err := c.initStoreGet(ia.X, allocs) - if err != nil { - return llvm.Value{}, err - } - if cnst, ok := ia.Index.(*ssa.Const); ok { - index, _ := constant.Int64Val(cnst.Value) - return c.builder.CreateExtractValue(val, int(index), ""), nil - } else { - return llvm.Value{}, errors.New("initStoreGet: unknown FieldAddr index: " + ia.Index.String()) - } - } else { - return llvm.Value{}, errors.New("initStoreGet: unknown value: " + x.String()) + return llvm.Value{}, errors.New("init: unknown initializer type: " + fmt.Sprintf("%#v", value)) } } @@ -1110,6 +1018,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { switch instr := instr.(type) { case ssa.Value: value, err := c.parseExpr(frame, instr) + if err == cgoWrapperError { + // Ignore CGo global variables which we don't use. + return nil + } frame.locals[instr] = value return err case *ssa.Go: @@ -1215,6 +1127,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { } case *ssa.Store: llvmAddr, err := c.parseExpr(frame, instr.Addr) + if err == cgoWrapperError { + // Ignore CGo global variables which we don't use. + return nil + } if err != nil { return err } @@ -1544,6 +1460,10 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { case *ssa.Function: return c.mod.NamedFunction(c.ir.GetFunction(expr).LinkName(false)), nil case *ssa.Global: + if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + return llvm.Value{}, cgoWrapperError + } value := c.ir.GetGlobal(expr).llvmGlobal if value.IsNil() { return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).LinkName()) diff --git a/interpreter.go b/interpreter.go new file mode 100644 index 00000000..f7c618f0 --- /dev/null +++ b/interpreter.go @@ -0,0 +1,294 @@ +package main + +// This file provides functionality to interpret very basic Go SSA, for +// compile-time initialization of globals. + +import ( + "errors" + "go/constant" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/ssa" +) + +// Interpret instructions as far as possible, and drop those instructions from +// the basic block. +func (p *Program) Interpret(block *ssa.BasicBlock) error { + for { + i, err := p.interpret(block.Instrs) + if err == cgoWrapperError { + // skip this instruction + block.Instrs = block.Instrs[i+1:] + continue + } + block.Instrs = block.Instrs[i:] + return err + } +} + +// Interpret instructions as far as possible, and return the index of the first +// unknown instruction. +func (p *Program) interpret(instrs []ssa.Instruction) (int, error) { + locals := map[ssa.Value]Value{} + for i, instr := range instrs { + switch instr := instr.(type) { + case *ssa.Alloc: + alloc, err := p.getZeroValue(instr.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return i, err + } + locals[instr] = &PointerValue{&alloc} + case *ssa.Convert: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch typ := instr.Type().Underlying().(type) { + case *types.Basic: + if _, ok := instr.X.Type().Underlying().(*types.Pointer); ok && typ.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typ, x} + } else if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer && typ.Kind() == types.Uintptr { + locals[instr] = &PointerToUintptrValue{x} + } else { + return i, errors.New("todo: init: unknown basic convert: " + instr.String()) + } + case *types.Pointer: + if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typ, x} + } else { + panic("expected unsafe pointer conversion") + } + default: + return i, errors.New("todo: init: unknown convert: " + instr.String()) + } + case *ssa.FieldAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + var structVal *StructValue + switch x := x.(type) { + case *GlobalValue: + structVal = x.Global.initializer.(*StructValue) + case *PointerValue: + structVal = (*x.Elem).(*StructValue) + default: + panic("expected a pointer") + } + locals[instr] = &PointerValue{&structVal.Fields[instr.Field]} + case *ssa.IndexAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + if cnst, ok := instr.Index.(*ssa.Const); ok { + index, _ := constant.Int64Val(cnst.Value) + switch xPtr := x.(type) { + case *GlobalValue: + x = xPtr.Global.initializer + case *PointerValue: + x = *xPtr.Elem + default: + panic("expected a pointer") + } + switch x := x.(type) { + case *ArrayValue: + locals[instr] = &PointerValue{&x.Elems[index]} + default: + return i, errors.New("todo: init IndexAddr not on an array or struct") + } + } else { + return i, errors.New("todo: init IndexAddr index: " + instr.Index.String()) + } + case *ssa.UnOp: + if instr.Op != token.MUL || instr.CommaOk { + return i, errors.New("init: unknown unop: " + instr.String()) + } + valPtr, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch valPtr := valPtr.(type) { + case *GlobalValue: + locals[instr] = valPtr.Global.initializer + case *PointerValue: + locals[instr] = *valPtr.Elem + default: + panic("expected a pointer") + } + case *ssa.MakeMap: + locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil} + case *ssa.MapUpdate: + // Assume no duplicate keys exist. This is most likely true for + // autogenerated code, but may not be true when trying to interpret + // user code. + key, err := p.getValue(instr.Key, locals) + if err != nil { + return i, err + } + value, err := p.getValue(instr.Value, locals) + if err != nil { + return i, err + } + x := locals[instr.Map].(*MapValue) + x.Keys = append(x.Keys, key) + x.Values = append(x.Values, value) + case *ssa.Slice: + // Turn a just-allocated array into a slice. + if instr.Low != nil || instr.High != nil || instr.Max != nil { + return i, errors.New("init: slice expression with bounds") + } + source, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch source := source.(type) { + case *PointerValue: // pointer to array + array := (*source.Elem).(*ArrayValue) + locals[instr] = &SliceValue{instr.Type().Underlying().(*types.Slice), array} + default: + return i, errors.New("init: unknown slice type") + } + case *ssa.Store: + if addr, ok := instr.Addr.(*ssa.Global); ok { + if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + continue + } + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + p.GetGlobal(addr).initializer = value + } else if addr, ok := locals[instr.Addr]; ok { + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + if addr, ok := addr.(*PointerValue); ok { + *(addr.Elem) = value + } else { + panic("store to non-pointer") + } + } else { + return i, errors.New("todo: init Store: " + instr.String()) + } + default: + return i, nil + } + } + return len(instrs), nil // normally unreachable +} + +func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value, error) { + switch value := value.(type) { + case *ssa.Const: + return &ConstValue{value}, nil + case *ssa.Global: + if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + return nil, cgoWrapperError + } + g := p.GetGlobal(value) + if g.initializer == nil { + value, err := p.getZeroValue(value.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return nil, err + } + g.initializer = value + } + return &GlobalValue{g}, nil + //return &PointerValue{&g.initializer}, nil + default: + if local, ok := locals[value]; ok { + return local, nil + } else { + return nil, errors.New("todo: init: unknown value: " + value.String()) + } + } +} + +func (p *Program) getZeroValue(t types.Type) (Value, error) { + switch typ := t.Underlying().(type) { + case *types.Array: + elems := make([]Value, typ.Len()) + for i := range elems { + elem, err := p.getZeroValue(typ.Elem()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &ArrayValue{typ.Elem(), elems}, nil + case *types.Basic: + return &ZeroBasicValue{typ}, nil + case *types.Pointer: + return &PointerValue{nil}, nil + case *types.Struct: + elems := make([]Value, typ.NumFields()) + for i := range elems { + elem, err := p.getZeroValue(typ.Field(i).Type().Underlying()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &StructValue{t, elems}, nil + case *types.Slice: + return &SliceValue{typ, nil}, nil + default: + return nil, errors.New("todo: init: unknown global type: " + typ.String()) + } +} + +// Boxed value for interpreter. +type Value interface { +} + +type ConstValue struct { + Expr *ssa.Const +} + +type ZeroBasicValue struct { + Type *types.Basic +} + +type PointerValue struct { + Elem *Value +} + +type PointerBitCastValue struct { + Type types.Type + Elem Value +} + +type PointerToUintptrValue struct { + Elem Value +} + +type GlobalValue struct { + Global *Global +} + +type ArrayValue struct { + ElemType types.Type + Elems []Value +} + +type StructValue struct { + Type types.Type // types.Struct or types.Named + Fields []Value +} + +type SliceValue struct { + Type *types.Slice + Array *ArrayValue +} + +type MapValue struct { + Type *types.Map + Keys []Value + Values []Value +} diff --git a/ir.go b/ir.go index 4d628e93..9a2abd82 100644 --- a/ir.go +++ b/ir.go @@ -38,9 +38,10 @@ type Function struct { // Global variable, possibly constant. type Global struct { - g *ssa.Global - llvmGlobal llvm.Value - flag bool // used by dead code elimination + g *ssa.Global + llvmGlobal llvm.Value + flag bool // used by dead code elimination + initializer Value } // Type with a name and possibly methods.