Rewrite init() interpretation to a real interpreter

Instead of mostly heuristics, actually execute the init() instruction in
an interpreter to calculate initializers for globals. This is far more
flexible and extensible, and gives the option of extending the
interpreter to other code and make it a partial evaluator.
Этот коммит содержится в:
Ayke van Laethem 2018-08-25 01:14:33 +02:00
родитель 9b4ac0459b
коммит c25b448758
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: E97FF5335DFDFDED
3 изменённых файлов: 553 добавлений и 338 удалений

Просмотреть файл

@ -67,6 +67,8 @@ type Phi struct {
llvm llvm.Value
}
var cgoWrapperError = errors.New("tinygo internal: cgo wrapper")
func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) {
c := &Compiler{
dumpSSA: dumpSSA,
@ -283,6 +285,41 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
frames = append(frames, frame)
}
// Find and interpret package initializers.
for _, frame := range frames {
if frame.fn.fn.Synthetic == "package initializer" {
c.initFuncs = append(c.initFuncs, frame.fn.llvmFn)
if len(frame.fn.fn.Blocks) != 1 {
panic("unexpected number of basic blocks in package initializer")
}
// Try to interpret as much as possible of the init() function.
// Whenever it hits an instruction that it doesn't understand, it
// bails out and leaves the rest to the compiler (so initialization
// continues at runtime).
// This should only happen when it hits a function call or the end
// of the block, ideally.
err := c.ir.Interpret(frame.fn.fn.Blocks[0])
if err != nil {
return err
}
err = c.parseFunc(frame)
if err != nil {
return err
}
}
}
// Set values for globals (after package initializer has been interpreted).
for _, g := range c.ir.Globals {
if g.initializer == nil {
continue
}
err := c.parseGlobalInitializer(g)
if err != nil {
return err
}
}
// Add definitions to declarations.
for _, frame := range frames {
if frame.fn.CName() != "" {
@ -293,8 +330,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
}
var err error
if frame.fn.fn.Synthetic == "package initializer" {
c.initFuncs = append(c.initFuncs, frame.fn.llvmFn)
err = c.parseInitFunc(frame)
continue // already done
} else {
err = c.parseFunc(frame)
}
@ -647,280 +683,6 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
return frame, nil
}
// Special function parser for generated package initializers (which also
// initializes global variables).
//
// What we're doing here is two things:
// * Initialize global variables. The SSA compiler generates alloc/gep/store
// instructions to initialize variables at runtime. But that's inefficient
// in code size and RAM (for const variables) so we're interpreting these
// instructions and store global variables instead. When they're not
// modified, LLVM will automatically make them const.
// In some cases this might even help constant propagation and thus
// performance / code size elsewhere.
// * Call the actual init() functions (init#0(), init#1 etc.) for the package.
func (c *Compiler) parseInitFunc(frame *Frame) error {
if c.dumpSSA {
fmt.Printf("\nfunc %s:\n", frame.fn.fn)
}
frame.fn.llvmFn.SetLinkage(llvm.PrivateLinkage)
llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "entry")
c.builder.SetInsertPointAtEnd(llvmBlock)
allocs := map[ssa.Value]llvm.Value{}
for _, block := range frame.fn.fn.DomPreorder() {
if c.dumpSSA {
fmt.Printf("%s:\n", block.Comment)
}
for _, instr := range block.Instrs {
if c.dumpSSA {
if val, ok := instr.(ssa.Value); ok && val.Name() != "" {
fmt.Printf("\t%s = %s\n", val.Name(), val.String())
} else {
fmt.Printf("\t%s\n", instr.String())
}
}
var err error
switch instr := instr.(type) {
case *ssa.Alloc:
llvmType, err := c.getLLVMType(instr.Type().Underlying().(*types.Pointer).Elem())
if err != nil {
return err
}
val, err := getZeroValue(llvmType)
if err != nil {
return err
}
allocs[instr] = val
case *ssa.Call, *ssa.Return:
err = c.parseInstr(frame, instr)
case *ssa.Convert:
// Ignore: CGo pointer conversion.
case *ssa.FieldAddr, *ssa.IndexAddr:
// Ignore: handled below with *ssa.Store.
case *ssa.MakeMap:
// TODO: use the initial map size
mapType := instr.Type().Underlying().(*types.Map)
bucket, keySize, valueSize, err := c.initMapNewBucket(mapType)
if err != nil {
return err
}
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
bucketPtr := llvm.ConstInBoundsGEP(bucket, []llvm.Value{zero})
hashmapType := c.mod.GetTypeByName("runtime.hashmap")
hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{
llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next
llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets
llvm.ConstInt(c.lenType, 0, false), // count
llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize
llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize
llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits
})
allocs[instr] = hashmap
case *ssa.MapUpdate:
// Note: we're assuming here the Go SSA compiler knows what it's
// doing and doesn't insert a key twice.
// Update hashmap.count.
hashmap := allocs[instr.Map]
count := llvm.ConstExtractValue(hashmap, []uint32{2}).ZExtValue()
count++
countValue := llvm.ConstInt(c.lenType, count, false)
hashmap = llvm.ConstInsertValue(hashmap, countValue, []uint32{2})
allocs[instr.Map] = hashmap
// Select the bucket (chain).
bucketPtr := llvm.ConstExtractValue(hashmap, []uint32{1})
bucketGlobal := bucketPtr.Operand(0)
llvmKey, err := c.initParseValue(instr.Key, allocs)
if err != nil {
return err
}
llvmValue, err := c.initParseValue(instr.Value, allocs)
if err != nil {
return err
}
// Hash for the hashtable. This must be equal to what is
// calculated in the runtime implementation.
key := constant.StringVal(instr.Key.(*ssa.Const).Value)
hash := stringhash(&key)
// Find an empty spot in the bucket.
done := false
for !done {
bucket := bucketGlobal.Initializer()
for i := uint32(0); i < 8; i++ {
if llvm.ConstExtractValue(bucket, []uint32{0, i}).ZExtValue() != 0 {
// already taken
continue
}
tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false)
bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, i})
bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, i})
bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, i})
bucketGlobal.SetInitializer(bucket)
done = true
break
}
if !done {
nextBucket := llvm.ConstExtractValue(bucket, []uint32{1})
if nextBucket.IsNull() {
// create new bucket
newBucketGlobal, _, _, err := c.initMapNewBucket(instr.Map.Type().Underlying().(*types.Map))
if err != nil {
return err
}
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero})
newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType)
newBucket := newBucketGlobal.Initializer()
// insert pointer into old bucket
bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1})
bucketGlobal.SetInitializer(bucket)
// switch to next bucket
bucketGlobal = newBucketGlobal
bucket = newBucket
} else {
bucketGlobal = nextBucket.Operand(0)
bucket = bucketGlobal.Initializer()
}
}
}
case *ssa.Slice:
// Turn a just-allocated array into a slice.
if instr.Low != nil || instr.High != nil || instr.Max != nil {
return errors.New("init: slice expression with bounds")
}
var val llvm.Value
if v, ok := allocs[instr.X]; ok {
val = v
} else {
return errors.New("init: slice operation didn't find source data")
}
switch typ := instr.X.Type().Underlying().(type) {
case *types.Pointer: // pointer to array
// make slice from array
length := typ.Elem().(*types.Array).Len()
llvmLen := llvm.ConstInt(c.lenType, uint64(length), false)
global := llvm.AddGlobal(c.mod, val.Type(), ".array")
global.SetInitializer(val)
global.SetLinkage(llvm.PrivateLinkage)
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
globalPtr := c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "")
sliceTyp, err := c.getLLVMType(instr.Type())
if err != nil {
return err
}
slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{
llvm.Undef(val.Type()),
llvm.Undef(c.lenType),
llvm.Undef(c.lenType),
})
slice = c.builder.CreateInsertValue(slice, globalPtr, 0, "")
slice = c.builder.CreateInsertValue(slice, llvmLen, 1, "")
slice = c.builder.CreateInsertValue(slice, llvmLen, 2, "")
allocs[instr] = slice
default:
return errors.New("init: unknown slice type: " + typ.String())
}
case *ssa.Store:
switch addr := instr.Addr.(type) {
case *ssa.Global:
if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use.
continue
}
val, err := c.initParseValue(instr.Val, allocs)
if err != nil {
return err
}
switch valType := instr.Val.Type().Underlying().(type) {
case *types.Basic:
case *types.Map:
// Store pointer to map instead of the map itself. It is
// a reference type, so every access goes through a
// pointer to the real value. Note that this has to be a
// pointer in some cases as the global value can be
// replaced with a different map.
hashmap := llvm.AddGlobal(c.mod, val.Type(), ".hashmap")
hashmap.SetInitializer(val)
hashmap.SetLinkage(llvm.PrivateLinkage)
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
val = llvm.ConstInBoundsGEP(hashmap, []llvm.Value{zero})
case *types.Pointer:
// Turn this into a pointer to a global object if it
// isn't already.
if val.IsConstant() && val.IsAGlobalVariable().IsNil() {
obj := llvm.AddGlobal(c.mod, val.Type(), ".obj")
obj.SetInitializer(val)
obj.SetLinkage(llvm.PrivateLinkage)
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
val = llvm.ConstInBoundsGEP(obj, []llvm.Value{zero})
}
case *types.Slice:
case *types.Struct:
default:
return errors.New("init: unknown store type: " + valType.String())
}
llvmAddr := c.ir.GetGlobal(addr).llvmGlobal
llvmAddr.SetInitializer(val)
default:
val, err := c.initParseValue(instr.Val, allocs)
if err != nil {
return err
}
err = c.initStore(instr.Addr, val, allocs)
if err != nil {
return err
}
}
case *ssa.UnOp:
if instr.Op != token.MUL || instr.CommaOk {
return errors.New("init: unknown unop: " + instr.String())
}
valPtr, err := c.initParseValue(instr.X, allocs)
if err != nil {
return err
}
// Assume it's a GEP instruction...
val := valPtr.Operand(0)
allocs[instr] = val
default:
return errors.New("unknown init instruction: " + instr.String())
}
if err != nil {
return err
}
}
}
return nil
}
func (c *Compiler) initParseValue(val ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) {
if cnst, ok := val.(*ssa.Const); ok {
return c.parseConst(cnst)
} else if v, ok := val.(*ssa.Convert); ok {
val, err := c.initParseValue(v.X, allocs)
if err != nil {
return llvm.Value{}, err
}
return c.parseConvert(v.X.Type(), v.Type(), val)
} else if v, ok := val.(*ssa.Global); ok {
global := c.ir.GetGlobal(v)
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
globalPtr := c.builder.CreateInBoundsGEP(global.llvmGlobal, []llvm.Value{zero}, "")
return globalPtr, nil
} else if v, ok := allocs[val]; ok {
return v, nil
} else {
return llvm.Value{}, errors.New("todo: init value for store: " + val.String())
}
}
// Create a new global hashmap bucket, for map initialization.
func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uint64, error) {
llvmKeyType, err := c.getLLVMType(mapType.Key().Underlying())
@ -949,67 +711,213 @@ func (c *Compiler) initMapNewBucket(mapType *types.Map) (llvm.Value, uint64, uin
return bucket, keySize, valueSize, nil
}
func (c *Compiler) initStore(addr ssa.Value, val llvm.Value, allocs map[ssa.Value]llvm.Value) error {
switch addr := addr.(type) {
case *ssa.FieldAddr:
return c.initStoreSet(addr.X, val, addr.Field, allocs)
case *ssa.IndexAddr:
if cnst, ok := addr.Index.(*ssa.Const); ok {
index, _ := constant.Int64Val(cnst.Value)
return c.initStoreSet(addr.X, val, int(index), allocs)
} else {
return errors.New("todo: init IndexAddr index: " + addr.Index.String())
}
default:
return errors.New("todo: init addr: " + addr.String())
}
}
func (c *Compiler) initStoreSet(x ssa.Value, val llvm.Value, index int, allocs map[ssa.Value]llvm.Value) error {
if agg, ok := allocs[x]; ok {
agg = c.builder.CreateInsertValue(agg, val, index, "")
allocs[x] = agg
return nil
} else if g, ok := x.(*ssa.Global); ok {
global := c.ir.GetGlobal(g)
agg := global.llvmGlobal.Initializer()
agg = c.builder.CreateInsertValue(agg, val, index, "")
global.llvmGlobal.SetInitializer(agg)
return nil
} else {
agg, err := c.initStoreGet(x, allocs)
func (c *Compiler) parseGlobalInitializer(g *Global) error {
llvmValue, err := c.getInterpretedValue(g.initializer)
if err != nil {
return err
}
agg = c.builder.CreateInsertValue(agg, val, index, "")
return c.initStore(x, agg, allocs)
}
g.llvmGlobal.SetInitializer(llvmValue)
return nil
}
func (c *Compiler) initStoreGet(x ssa.Value, allocs map[ssa.Value]llvm.Value) (llvm.Value, error) {
if val, ok := allocs[x]; ok {
return val, nil
} else if g, ok := x.(*ssa.Global); ok {
return c.ir.GetGlobal(g).llvmGlobal.Initializer(), nil
} else if fa, ok := x.(*ssa.FieldAddr); ok {
val, err := c.initStoreGet(fa.X, allocs)
// Turn a computed Value type (ConstValue, ArrayValue, etc.) into a LLVM value.
// This is used to set the initializer of globals after they have been
// calculated by the package initializer interpreter.
func (c *Compiler) getInterpretedValue(value Value) (llvm.Value, error) {
switch value := value.(type) {
case *ArrayValue:
vals := make([]llvm.Value, len(value.Elems))
for i, elem := range value.Elems {
val, err := c.getInterpretedValue(elem)
if err != nil {
return llvm.Value{}, err
}
return c.builder.CreateExtractValue(val, fa.Field, ""), nil
} else if ia, ok := x.(*ssa.IndexAddr); ok {
val, err := c.initStoreGet(ia.X, allocs)
vals[i] = val
}
subTyp, err := c.getLLVMType(value.ElemType)
if err != nil {
return llvm.Value{}, err
}
if cnst, ok := ia.Index.(*ssa.Const); ok {
index, _ := constant.Int64Val(cnst.Value)
return c.builder.CreateExtractValue(val, int(index), ""), nil
} else {
return llvm.Value{}, errors.New("initStoreGet: unknown FieldAddr index: " + ia.Index.String())
return llvm.ConstArray(subTyp, vals), nil
case *ConstValue:
return c.parseConst(value.Expr)
case *GlobalValue:
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
ptr := llvm.ConstInBoundsGEP(value.Global.llvmGlobal, []llvm.Value{zero})
return ptr, nil
case *MapValue:
// Create initial bucket.
firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(value.Type)
if err != nil {
return llvm.Value{}, err
}
// Insert each key/value pair in the hashmap.
bucketGlobal := firstBucketGlobal
for i, key := range value.Keys {
llvmKey, err := c.getInterpretedValue(key)
if err != nil {
return llvm.Value{}, nil
}
llvmValue, err := c.getInterpretedValue(value.Values[i])
if err != nil {
return llvm.Value{}, nil
}
keyString := constant.StringVal(key.(*ConstValue).Expr.Value)
hash := stringhash(&keyString)
if i%8 == 0 && i != 0 {
// Bucket is full, create a new one.
newBucketGlobal, _, _, err := c.initMapNewBucket(value.Type)
if err != nil {
return llvm.Value{}, err
}
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero})
newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, c.i8ptrType)
// insert pointer into old bucket
bucket := bucketGlobal.Initializer()
bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1})
bucketGlobal.SetInitializer(bucket)
// switch to next bucket
bucketGlobal = newBucketGlobal
}
tophashValue := llvm.ConstInt(llvm.Int8Type(), uint64(hashmapTopHash(hash)), false)
bucket := bucketGlobal.Initializer()
bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, uint32(i % 8)})
bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, uint32(i % 8)})
bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, uint32(i % 8)})
bucketGlobal.SetInitializer(bucket)
}
// Create the hashmap itself.
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
bucketPtr := llvm.ConstInBoundsGEP(bucketGlobal, []llvm.Value{zero})
hashmapType := c.mod.GetTypeByName("runtime.hashmap")
hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{
llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next
llvm.ConstBitCast(bucketPtr, c.i8ptrType), // buckets
llvm.ConstInt(c.lenType, uint64(len(value.Keys)), false), // count
llvm.ConstInt(llvm.Int8Type(), keySize, false), // keySize
llvm.ConstInt(llvm.Int8Type(), valueSize, false), // valueSize
llvm.ConstInt(llvm.Int8Type(), 0, false), // bucketBits
})
// Create a pointer to this hashmap.
hashmapPtr := llvm.AddGlobal(c.mod, hashmap.Type(), ".hashmap")
hashmapPtr.SetInitializer(hashmap)
hashmapPtr.SetLinkage(llvm.PrivateLinkage)
return llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}), nil
case *PointerBitCastValue:
elem, err := c.getInterpretedValue(value.Elem)
if err != nil {
return llvm.Value{}, err
}
llvmType, err := c.getLLVMType(value.Type)
if err != nil {
return llvm.Value{}, err
}
return llvm.ConstBitCast(elem, llvmType), nil
case *PointerToUintptrValue:
elem, err := c.getInterpretedValue(value.Elem)
if err != nil {
return llvm.Value{}, err
}
return llvm.ConstPtrToInt(elem, c.uintptrType), nil
case *PointerValue:
elem, err := c.getInterpretedValue(*value.Elem)
if err != nil {
return llvm.Value{}, err
}
obj := llvm.AddGlobal(c.mod, elem.Type(), ".obj")
obj.SetInitializer(elem)
obj.SetLinkage(llvm.PrivateLinkage)
elem = obj
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
ptr := llvm.ConstInBoundsGEP(elem, []llvm.Value{zero})
return ptr, nil
case *SliceValue:
var globalPtr llvm.Value
var arrayLength uint64
if value.Array == nil {
arrayType, err := c.getLLVMType(value.Type.Elem())
if err != nil {
return llvm.Value{}, err
}
globalPtr = llvm.ConstPointerNull(llvm.PointerType(arrayType, 0))
} else {
return llvm.Value{}, errors.New("initStoreGet: unknown value: " + x.String())
// make array
array, err := c.getInterpretedValue(value.Array)
if err != nil {
return llvm.Value{}, err
}
// make global from array
global := llvm.AddGlobal(c.mod, array.Type(), ".array")
global.SetInitializer(array)
global.SetLinkage(llvm.PrivateLinkage)
// get pointer to global
zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
globalPtr = c.builder.CreateInBoundsGEP(global, []llvm.Value{zero, zero}, "")
arrayLength = uint64(len(value.Array.Elems))
}
// make slice
sliceTyp, err := c.getLLVMType(value.Type)
if err != nil {
return llvm.Value{}, err
}
llvmLen := llvm.ConstInt(c.lenType, arrayLength, false)
slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{
globalPtr, // ptr
llvmLen, // len
llvmLen, // cap
})
return slice, nil
case *StructValue:
fields := make([]llvm.Value, len(value.Fields))
for i, elem := range value.Fields {
field, err := c.getInterpretedValue(elem)
if err != nil {
return llvm.Value{}, err
}
fields[i] = field
}
switch value.Type.(type) {
case *types.Named:
llvmType, err := c.getLLVMType(value.Type)
if err != nil {
return llvm.Value{}, err
}
return llvm.ConstNamedStruct(llvmType, fields), nil
case *types.Struct:
return llvm.ConstStruct(fields, false), nil
default:
return llvm.Value{}, errors.New("init: unknown struct type: " + value.Type.String())
}
case *ZeroBasicValue:
llvmType, err := c.getLLVMType(value.Type)
if err != nil {
return llvm.Value{}, err
}
return getZeroValue(llvmType)
default:
return llvm.Value{}, errors.New("init: unknown initializer type: " + fmt.Sprintf("%#v", value))
}
}
@ -1110,6 +1018,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
switch instr := instr.(type) {
case ssa.Value:
value, err := c.parseExpr(frame, instr)
if err == cgoWrapperError {
// Ignore CGo global variables which we don't use.
return nil
}
frame.locals[instr] = value
return err
case *ssa.Go:
@ -1215,6 +1127,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
}
case *ssa.Store:
llvmAddr, err := c.parseExpr(frame, instr.Addr)
if err == cgoWrapperError {
// Ignore CGo global variables which we don't use.
return nil
}
if err != nil {
return err
}
@ -1544,6 +1460,10 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
case *ssa.Function:
return c.mod.NamedFunction(c.ir.GetFunction(expr).LinkName(false)), nil
case *ssa.Global:
if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use.
return llvm.Value{}, cgoWrapperError
}
value := c.ir.GetGlobal(expr).llvmGlobal
if value.IsNil() {
return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).LinkName())

294
interpreter.go Обычный файл
Просмотреть файл

@ -0,0 +1,294 @@
package main
// This file provides functionality to interpret very basic Go SSA, for
// compile-time initialization of globals.
import (
"errors"
"go/constant"
"go/token"
"go/types"
"strings"
"golang.org/x/tools/go/ssa"
)
// Interpret instructions as far as possible, and drop those instructions from
// the basic block.
func (p *Program) Interpret(block *ssa.BasicBlock) error {
for {
i, err := p.interpret(block.Instrs)
if err == cgoWrapperError {
// skip this instruction
block.Instrs = block.Instrs[i+1:]
continue
}
block.Instrs = block.Instrs[i:]
return err
}
}
// Interpret instructions as far as possible, and return the index of the first
// unknown instruction.
func (p *Program) interpret(instrs []ssa.Instruction) (int, error) {
locals := map[ssa.Value]Value{}
for i, instr := range instrs {
switch instr := instr.(type) {
case *ssa.Alloc:
alloc, err := p.getZeroValue(instr.Type().Underlying().(*types.Pointer).Elem())
if err != nil {
return i, err
}
locals[instr] = &PointerValue{&alloc}
case *ssa.Convert:
x, err := p.getValue(instr.X, locals)
if err != nil {
return i, err
}
switch typ := instr.Type().Underlying().(type) {
case *types.Basic:
if _, ok := instr.X.Type().Underlying().(*types.Pointer); ok && typ.Kind() == types.UnsafePointer {
locals[instr] = &PointerBitCastValue{typ, x}
} else if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer && typ.Kind() == types.Uintptr {
locals[instr] = &PointerToUintptrValue{x}
} else {
return i, errors.New("todo: init: unknown basic convert: " + instr.String())
}
case *types.Pointer:
if xtyp, ok := instr.X.Type().Underlying().(*types.Basic); ok && xtyp.Kind() == types.UnsafePointer {
locals[instr] = &PointerBitCastValue{typ, x}
} else {
panic("expected unsafe pointer conversion")
}
default:
return i, errors.New("todo: init: unknown convert: " + instr.String())
}
case *ssa.FieldAddr:
x, err := p.getValue(instr.X, locals)
if err != nil {
return i, err
}
var structVal *StructValue
switch x := x.(type) {
case *GlobalValue:
structVal = x.Global.initializer.(*StructValue)
case *PointerValue:
structVal = (*x.Elem).(*StructValue)
default:
panic("expected a pointer")
}
locals[instr] = &PointerValue{&structVal.Fields[instr.Field]}
case *ssa.IndexAddr:
x, err := p.getValue(instr.X, locals)
if err != nil {
return i, err
}
if cnst, ok := instr.Index.(*ssa.Const); ok {
index, _ := constant.Int64Val(cnst.Value)
switch xPtr := x.(type) {
case *GlobalValue:
x = xPtr.Global.initializer
case *PointerValue:
x = *xPtr.Elem
default:
panic("expected a pointer")
}
switch x := x.(type) {
case *ArrayValue:
locals[instr] = &PointerValue{&x.Elems[index]}
default:
return i, errors.New("todo: init IndexAddr not on an array or struct")
}
} else {
return i, errors.New("todo: init IndexAddr index: " + instr.Index.String())
}
case *ssa.UnOp:
if instr.Op != token.MUL || instr.CommaOk {
return i, errors.New("init: unknown unop: " + instr.String())
}
valPtr, err := p.getValue(instr.X, locals)
if err != nil {
return i, err
}
switch valPtr := valPtr.(type) {
case *GlobalValue:
locals[instr] = valPtr.Global.initializer
case *PointerValue:
locals[instr] = *valPtr.Elem
default:
panic("expected a pointer")
}
case *ssa.MakeMap:
locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil}
case *ssa.MapUpdate:
// Assume no duplicate keys exist. This is most likely true for
// autogenerated code, but may not be true when trying to interpret
// user code.
key, err := p.getValue(instr.Key, locals)
if err != nil {
return i, err
}
value, err := p.getValue(instr.Value, locals)
if err != nil {
return i, err
}
x := locals[instr.Map].(*MapValue)
x.Keys = append(x.Keys, key)
x.Values = append(x.Values, value)
case *ssa.Slice:
// Turn a just-allocated array into a slice.
if instr.Low != nil || instr.High != nil || instr.Max != nil {
return i, errors.New("init: slice expression with bounds")
}
source, err := p.getValue(instr.X, locals)
if err != nil {
return i, err
}
switch source := source.(type) {
case *PointerValue: // pointer to array
array := (*source.Elem).(*ArrayValue)
locals[instr] = &SliceValue{instr.Type().Underlying().(*types.Slice), array}
default:
return i, errors.New("init: unknown slice type")
}
case *ssa.Store:
if addr, ok := instr.Addr.(*ssa.Global); ok {
if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use.
continue
}
value, err := p.getValue(instr.Val, locals)
if err != nil {
return i, err
}
p.GetGlobal(addr).initializer = value
} else if addr, ok := locals[instr.Addr]; ok {
value, err := p.getValue(instr.Val, locals)
if err != nil {
return i, err
}
if addr, ok := addr.(*PointerValue); ok {
*(addr.Elem) = value
} else {
panic("store to non-pointer")
}
} else {
return i, errors.New("todo: init Store: " + instr.String())
}
default:
return i, nil
}
}
return len(instrs), nil // normally unreachable
}
func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value, error) {
switch value := value.(type) {
case *ssa.Const:
return &ConstValue{value}, nil
case *ssa.Global:
if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use.
return nil, cgoWrapperError
}
g := p.GetGlobal(value)
if g.initializer == nil {
value, err := p.getZeroValue(value.Type().Underlying().(*types.Pointer).Elem())
if err != nil {
return nil, err
}
g.initializer = value
}
return &GlobalValue{g}, nil
//return &PointerValue{&g.initializer}, nil
default:
if local, ok := locals[value]; ok {
return local, nil
} else {
return nil, errors.New("todo: init: unknown value: " + value.String())
}
}
}
func (p *Program) getZeroValue(t types.Type) (Value, error) {
switch typ := t.Underlying().(type) {
case *types.Array:
elems := make([]Value, typ.Len())
for i := range elems {
elem, err := p.getZeroValue(typ.Elem())
if err != nil {
return nil, err
}
elems[i] = elem
}
return &ArrayValue{typ.Elem(), elems}, nil
case *types.Basic:
return &ZeroBasicValue{typ}, nil
case *types.Pointer:
return &PointerValue{nil}, nil
case *types.Struct:
elems := make([]Value, typ.NumFields())
for i := range elems {
elem, err := p.getZeroValue(typ.Field(i).Type().Underlying())
if err != nil {
return nil, err
}
elems[i] = elem
}
return &StructValue{t, elems}, nil
case *types.Slice:
return &SliceValue{typ, nil}, nil
default:
return nil, errors.New("todo: init: unknown global type: " + typ.String())
}
}
// Boxed value for interpreter.
type Value interface {
}
type ConstValue struct {
Expr *ssa.Const
}
type ZeroBasicValue struct {
Type *types.Basic
}
type PointerValue struct {
Elem *Value
}
type PointerBitCastValue struct {
Type types.Type
Elem Value
}
type PointerToUintptrValue struct {
Elem Value
}
type GlobalValue struct {
Global *Global
}
type ArrayValue struct {
ElemType types.Type
Elems []Value
}
type StructValue struct {
Type types.Type // types.Struct or types.Named
Fields []Value
}
type SliceValue struct {
Type *types.Slice
Array *ArrayValue
}
type MapValue struct {
Type *types.Map
Keys []Value
Values []Value
}

1
ir.go
Просмотреть файл

@ -41,6 +41,7 @@ type Global struct {
g *ssa.Global
llvmGlobal llvm.Value
flag bool // used by dead code elimination
initializer Value
}
// Type with a name and possibly methods.