diff --git a/Gopkg.lock b/Gopkg.lock index e42c4f7f..2e8e47ce 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -3,11 +3,11 @@ [[projects]] branch = "master" - digest = "1:924b93e20c37a913b61d5cdb895e3c058d5e88e0baebdfa5813dcba4025679c9" + digest = "1:f250e2a6d7e4f9ebc5ba37e5e2ec91b46eb1399ee43f2fdaeb20cd4fd1aeee59" name = "github.com/aykevl/go-llvm" packages = ["."] pruneopts = "UT" - revision = "34571cdf380c5426708115e647b6fe0bb3bee3b5" + revision = "d8539684f173a591ea9474d6262ac47ef2277d64" [[projects]] branch = "master" diff --git a/Makefile b/Makefile index 01476b32..283a1414 100644 --- a/Makefile +++ b/Makefile @@ -62,7 +62,7 @@ clean: @rm -rf build fmt: - @go fmt . ./compiler ./ir ./src/device/arm ./src/examples/* ./src/machine ./src/runtime ./src/sync + @go fmt . ./compiler ./interp ./ir ./src/device/arm ./src/examples/* ./src/machine ./src/runtime ./src/sync @go fmt ./testdata/*.go test: diff --git a/compiler/compiler.go b/compiler/compiler.go index c030418c..efbdbccd 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -31,12 +31,13 @@ func init() { // Configure the compiler. type Config struct { - Triple string // LLVM target triple, e.g. x86_64-unknown-linux-gnu (empty string means default) - DumpSSA bool // dump Go SSA, for compiler debugging - Debug bool // add debug symbols for gdb - RootDir string // GOROOT for TinyGo - GOPATH string // GOPATH, like `go env GOPATH` - BuildTags []string // build tags for TinyGo (empty means {runtime.GOOS/runtime.GOARCH}) + Triple string // LLVM target triple, e.g. x86_64-unknown-linux-gnu (empty string means default) + DumpSSA bool // dump Go SSA, for compiler debugging + Debug bool // add debug symbols for gdb + RootDir string // GOROOT for TinyGo + GOPATH string // GOPATH, like `go env GOPATH` + BuildTags []string // build tags for TinyGo (empty means {runtime.GOOS/runtime.GOARCH}) + InitInterp bool // use new init interpretation, meaning the old one is disabled } type Compiler struct { @@ -164,6 +165,11 @@ func (c *Compiler) Module() llvm.Module { return c.mod } +// Return the LLVM target data object. Only valid after a successful compile. +func (c *Compiler) TargetData() llvm.TargetData { + return c.targetData +} + // Compile the given package path or .go file path. Return an error when this // fails (in any stage). func (c *Compiler) Compile(mainPath string) error { @@ -295,9 +301,11 @@ func (c *Compiler) Compile(mainPath string) error { // continues at runtime). // This should only happen when it hits a function call or the end // of the block, ideally. - err := c.ir.Interpret(frame.fn.Blocks[0], c.DumpSSA) - if err != nil { - return err + if !c.InitInterp { + err := c.ir.Interpret(frame.fn.Blocks[0], c.DumpSSA) + if err != nil { + return err + } } err = c.parseFunc(frame) if err != nil { diff --git a/interp/README.md b/interp/README.md new file mode 100644 index 00000000..7a2d4184 --- /dev/null +++ b/interp/README.md @@ -0,0 +1,64 @@ +# Partial evaluation of initialization code in Go + +For several reasons related to code size and memory consumption (see below), it +is best to try to evaluate as much initialization code at compile time as +possible and only run unknown expressions (e.g. external calls) at runtime. This +is in practice a partial evaluator of the `runtime.initAll` function, which +calls each package initializer. + +It works by directly interpreting LLVM IR: + + * Almost all operations work directly on constants, and are implemented using + the llvm.Const* set of functions that are evaluated directly. + * External function calls and some other operations (inline assembly, volatile + load, volatile store) are seen as having limited side effects. Limited in + the sense that it is known at compile time which globals it affects, which + then are marked 'dirty' (meaning, further operations on it must be done at + runtime). These operations are emitted directly in the `runtime.initAll` + function. Return values are also considered 'dirty'. + * Such 'dirty' objects and local values must be executed at runtime instead of + at compile time. This dirtyness propagates further through the IR, for + example storing a dirty local value to a global also makes the global dirty, + meaning that the global may not be read or written at compile time as it's + contents at that point during interpretation is unknown. + * There are some heuristics in place to avoid doing too much with dirty + values. For example, a branch based on a dirty local marks the whole + function itself as having side effect (as if it is an external function). + However, all globals it touches are still taken into account and when a call + is inserted in `runtime.initAll`, all globals it references are also marked + dirty. + * Heap allocation (`runtime.alloc`) is emulated by creating new objects. The + value in the allocation is the initializer of the global, the zero value is + the zero initializer. + * Stack allocation (`alloca`) is often emulated using a fake alloca object, + until the address of the alloca is taken in which case it is also created as + a real `alloca` in `runtime.initAll` and marked dirty. This may be necessary + when calling an external function with the given alloca as paramter. + +## Why is this necessary? + +A partial evaluator is hard to get right, so why go through all the trouble of +writing one? + +The main reason is that the previous attempt wasn't complete and wasn't sound. +It simply tried to evaluate Go SSA directly, which was good but more difficult +than necessary. An IR based interpreter needs to understand fewer instructions +as the LLVM IR simply has less (complex) instructions than Go SSA. Also, LLVM +provides some useful tools like easily getting all uses of a function or global, +which Go SSA does not provide. + +But why is it necessary at all? The answer is that globals with initializers are +much easier to optimize by LLVM than initialization code. Also, there are a few +other benefits: + + * Dead globals are trivial to optimize away. + * Constant globals are easier to detect. Remember that Go does not have global + constants in the same sense as that C has them. Constants are useful because + they can be propagated and provide some opportunities for other + optimizations (like dead code elimination when branching on the contents of + a global). + * Constants are much more efficent on microcontrollers, as they can be + allocated in flash instead of RAM. + +For more details, see [this section of the +documentation](https://tinygo.readthedocs.io/en/latest/internals.html#differences-from-go). diff --git a/interp/errors.go b/interp/errors.go new file mode 100644 index 00000000..d720bbf0 --- /dev/null +++ b/interp/errors.go @@ -0,0 +1,17 @@ +package interp + +// This file provides useful types for errors encountered during IR evaluation. + +import ( + "github.com/aykevl/go-llvm" +) + +type Unsupported struct { + Inst llvm.Value +} + +func (e Unsupported) Error() string { + // TODO: how to return the actual instruction string? + // It looks like LLVM provides no function for that... + return "interp: unsupported instruction" +} diff --git a/interp/frame.go b/interp/frame.go new file mode 100644 index 00000000..860ec02d --- /dev/null +++ b/interp/frame.go @@ -0,0 +1,440 @@ +package interp + +// This file implements the core interpretation routines, interpreting single +// functions. + +import ( + "errors" + "strings" + + "github.com/aykevl/go-llvm" +) + +type frame struct { + *Eval + fn llvm.Value + pkgName string + locals map[llvm.Value]Value +} + +// evalBasicBlock evaluates a single basic block, returning the return value (if +// ending with a ret instruction), a list of outgoing basic blocks (if not +// ending with a ret instruction), or an error on failure. +// Most of it works at compile time. Some calls get translated into calls to be +// executed at runtime: calls to functions with side effects, external calls, +// and operations on the result of such instructions. +func (fr *frame) evalBasicBlock(bb, incoming llvm.BasicBlock, indent string) (retval Value, outgoing []llvm.Value, err error) { + for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { + if fr.Debug { + print(indent) + inst.Dump() + println() + } + switch { + case !inst.IsABinaryOperator().IsNil(): + lhs := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying + rhs := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying + + switch inst.InstructionOpcode() { + // Standard binary operators + case llvm.Add: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstAdd(lhs, rhs)} + case llvm.FAdd: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFAdd(lhs, rhs)} + case llvm.Sub: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstSub(lhs, rhs)} + case llvm.FSub: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFSub(lhs, rhs)} + case llvm.Mul: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstMul(lhs, rhs)} + case llvm.FMul: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFMul(lhs, rhs)} + case llvm.UDiv: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstUDiv(lhs, rhs)} + case llvm.SDiv: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstSDiv(lhs, rhs)} + case llvm.FDiv: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFDiv(lhs, rhs)} + case llvm.URem: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstURem(lhs, rhs)} + case llvm.SRem: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstSRem(lhs, rhs)} + case llvm.FRem: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFRem(lhs, rhs)} + + // Logical operators + case llvm.Shl: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstShl(lhs, rhs)} + case llvm.LShr: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstLShr(lhs, rhs)} + case llvm.AShr: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstAShr(lhs, rhs)} + case llvm.And: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstAnd(lhs, rhs)} + case llvm.Or: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstOr(lhs, rhs)} + case llvm.Xor: + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstXor(lhs, rhs)} + + default: + return nil, nil, &Unsupported{inst} + } + + // Memory operators + case !inst.IsAAllocaInst().IsNil(): + fr.locals[inst] = &AllocaValue{ + Underlying: getZeroValue(inst.Type().ElementType()), + Dirty: false, + Eval: fr.Eval, + } + case !inst.IsALoadInst().IsNil(): + operand := fr.getLocal(inst.Operand(0)) + var value llvm.Value + if inst.IsVolatile() { + value = fr.builder.CreateLoad(operand.Value(), inst.Name()) + } else { + value = operand.Load() + } + if value.Type() != inst.Type() { + panic("interp: load: type does not match") + } + fr.locals[inst] = fr.getValue(value) + case !inst.IsAStoreInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + ptr := fr.getLocal(inst.Operand(1)) + if inst.IsVolatile() { + fr.builder.CreateStore(value.Value(), ptr.Value()) + } else { + ptr.Store(value.Value()) + } + case !inst.IsAGetElementPtrInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + llvmIndices := make([]llvm.Value, inst.OperandsCount()-1) + for i := range llvmIndices { + llvmIndices[i] = inst.Operand(i + 1) + } + indices := make([]uint32, len(llvmIndices)) + for i, operand := range llvmIndices { + if !operand.IsConstant() { + // not a constant operation, emit a low-level GEP + gep := fr.builder.CreateGEP(value.Value(), llvmIndices, inst.Name()) + fr.locals[inst] = &LocalValue{fr.Eval, gep} + continue + } + indices[i] = uint32(operand.ZExtValue()) + } + result := value.GetElementPtr(indices) + if result.Type() != inst.Type() { + println(" expected:", inst.Type().String()) + println(" actual: ", result.Type().String()) + panic("interp: gep: type does not match") + } + fr.locals[inst] = result + + // Cast operators + case !inst.IsATruncInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstTrunc(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAZExtInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstZExt(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsASExtInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstSExt(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAFPToUIInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFPToUI(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAFPToSIInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFPToSI(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAUIToFPInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstUIToFP(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsASIToFPInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstSIToFP(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAFPTruncInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFPTrunc(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsAFPExtInst().IsNil(): + value := fr.getLocal(inst.Operand(0)) + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFPExt(value.(*LocalValue).Value(), inst.Type())} + case !inst.IsABitCastInst().IsNil() && inst.Type().TypeKind() == llvm.PointerTypeKind: + operand := inst.Operand(0) + if !operand.IsACallInst().IsNil() { + fn := operand.CalledValue() + if !fn.IsAFunction().IsNil() && fn.Name() == "runtime.alloc" { + continue // special case: bitcast of alloc + } + } + value := fr.getLocal(operand) + if bc, ok := value.(*PointerCastValue); ok { + value = bc.Underlying // avoid double bitcasts + } + fr.locals[inst] = &PointerCastValue{Eval: fr.Eval, Underlying: value, CastType: inst.Type()} + + // Other operators + case !inst.IsAICmpInst().IsNil(): + lhs := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying + rhs := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying + predicate := inst.IntPredicate() + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstICmp(predicate, lhs, rhs)} + case !inst.IsAFCmpInst().IsNil(): + lhs := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying + rhs := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying + predicate := inst.FloatPredicate() + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstFCmp(predicate, lhs, rhs)} + case !inst.IsAPHINode().IsNil(): + for i := 0; i < inst.IncomingCount(); i++ { + if inst.IncomingBlock(i) == incoming { + fr.locals[inst] = fr.getLocal(inst.IncomingValue(i)) + } + } + case !inst.IsACallInst().IsNil(): + callee := inst.CalledValue() + switch { + case callee.Name() == "runtime.alloc": + // heap allocation + users := getUses(inst) + var resultInst = inst + if len(users) == 1 && !users[0].IsABitCastInst().IsNil() { + // happens when allocating something other than i8* + resultInst = users[0] + } + size := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying.ZExtValue() + allocType := resultInst.Type().ElementType() + typeSize := fr.TargetData.TypeAllocSize(allocType) + elementCount := 1 + if size != typeSize { + // allocate an array + if size%typeSize != 0 { + return nil, nil, &Unsupported{inst} + } + elementCount = int(size / typeSize) + allocType = llvm.ArrayType(allocType, elementCount) + } + alloc := llvm.AddGlobal(fr.Mod, allocType, fr.pkgName+"$alloc") + alloc.SetInitializer(getZeroValue(allocType)) + result := &GlobalValue{ + Underlying: alloc, + Eval: fr.Eval, + } + if elementCount == 1 { + fr.locals[resultInst] = result + } else { + fr.locals[resultInst] = result.GetElementPtr([]uint32{0, 0}) + } + case callee.Name() == "runtime.hashmapMake": + // create a map + keySize := inst.Operand(0).ZExtValue() + valueSize := inst.Operand(1).ZExtValue() + fr.locals[inst] = &MapValue{ + Eval: fr.Eval, + PkgName: fr.pkgName, + KeySize: int(keySize), + ValueSize: int(valueSize), + } + case callee.Name() == "runtime.hashmapStringSet": + // set a string key in the map + m := fr.getLocal(inst.Operand(0)).(*MapValue) + keyBuf := fr.getLocal(inst.Operand(1)) + keyLen := fr.getLocal(inst.Operand(2)) + valPtr := fr.getLocal(inst.Operand(3)) + m.PutString(keyBuf, keyLen, valPtr) + case callee.Name() == "runtime.hashmapBinarySet": + // set a binary (int etc.) key in the map + // TODO: unimplemented + case callee.Name() == "runtime.stringConcat": + // adding two strings together + buf1Ptr := fr.getLocal(inst.Operand(0)) + buf1Len := fr.getLocal(inst.Operand(1)) + buf2Ptr := fr.getLocal(inst.Operand(2)) + buf2Len := fr.getLocal(inst.Operand(3)) + buf1 := getStringBytes(buf1Ptr, buf1Len.Value()) + buf2 := getStringBytes(buf2Ptr, buf2Len.Value()) + result := []byte(string(buf1) + string(buf2)) + vals := make([]llvm.Value, len(result)) + for i := range vals { + vals[i] = llvm.ConstInt(fr.Mod.Context().Int8Type(), uint64(result[i]), false) + } + globalType := llvm.ArrayType(fr.Mod.Context().Int8Type(), len(result)) + globalValue := llvm.ConstArray(fr.Mod.Context().Int8Type(), vals) + global := llvm.AddGlobal(fr.Mod, globalType, fr.pkgName+"$stringconcat") + global.SetInitializer(globalValue) + global.SetLinkage(llvm.InternalLinkage) + global.SetGlobalConstant(true) + global.SetUnnamedAddr(true) + stringType := fr.Mod.GetTypeByName("runtime._string") + retPtr := llvm.ConstGEP(global, getLLVMIndices(fr.Mod.Context().Int32Type(), []uint32{0, 0})) + retLen := llvm.ConstInt(stringType.StructElementTypes()[1], uint64(len(result)), false) + ret := getZeroValue(stringType) + ret = llvm.ConstInsertValue(ret, retPtr, []uint32{0}) + ret = llvm.ConstInsertValue(ret, retLen, []uint32{1}) + fr.locals[inst] = &LocalValue{fr.Eval, ret} + case callee.Name() == "runtime.stringToBytes": + // convert a string to a []byte + bufPtr := fr.getLocal(inst.Operand(0)) + bufLen := fr.getLocal(inst.Operand(1)) + result := getStringBytes(bufPtr, bufLen.Value()) + vals := make([]llvm.Value, len(result)) + for i := range vals { + vals[i] = llvm.ConstInt(fr.Mod.Context().Int8Type(), uint64(result[i]), false) + } + globalType := llvm.ArrayType(fr.Mod.Context().Int8Type(), len(result)) + globalValue := llvm.ConstArray(fr.Mod.Context().Int8Type(), vals) + global := llvm.AddGlobal(fr.Mod, globalType, fr.pkgName+"$bytes") + global.SetInitializer(globalValue) + global.SetLinkage(llvm.InternalLinkage) + global.SetGlobalConstant(true) + global.SetUnnamedAddr(true) + sliceType := inst.Type() + retPtr := llvm.ConstGEP(global, getLLVMIndices(fr.Mod.Context().Int32Type(), []uint32{0, 0})) + retLen := llvm.ConstInt(sliceType.StructElementTypes()[1], uint64(len(result)), false) + ret := getZeroValue(sliceType) + ret = llvm.ConstInsertValue(ret, retPtr, []uint32{0}) // ptr + ret = llvm.ConstInsertValue(ret, retLen, []uint32{1}) // len + ret = llvm.ConstInsertValue(ret, retLen, []uint32{2}) // cap + fr.locals[inst] = &LocalValue{fr.Eval, ret} + case strings.HasPrefix(callee.Name(), "runtime.print") || callee.Name() == "runtime._panic": + // all print instructions, which necessarily have side + // effects but no results + var params []llvm.Value + for i := 0; i < inst.OperandsCount()-1; i++ { + operand := fr.getLocal(inst.Operand(i)).Value() + fr.markDirty(operand) + params = append(params, operand) + } + // TODO: accurate debug info, including call chain + fr.builder.CreateCall(callee, params, inst.Name()) + case !callee.IsAFunction().IsNil() && callee.IsDeclaration(): + // external functions + var params []llvm.Value + for i := 0; i < inst.OperandsCount()-1; i++ { + operand := fr.getLocal(inst.Operand(i)).Value() + fr.markDirty(operand) + params = append(params, operand) + } + // TODO: accurate debug info, including call chain + result := fr.builder.CreateCall(callee, params, inst.Name()) + if inst.Type().TypeKind() != llvm.VoidTypeKind { + fr.markDirty(result) + fr.locals[inst] = &LocalValue{fr.Eval, result} + } + case !callee.IsAFunction().IsNil(): + // regular function + var params []Value + for i := 0; i < inst.OperandsCount()-1; i++ { + params = append(params, fr.getLocal(inst.Operand(i))) + } + var ret Value + scanResult := fr.Eval.hasSideEffects(callee) + if scanResult.severity == sideEffectLimited { + // Side effect is bounded. This means the operation invokes + // side effects (like calling an external function) but it + // is known at compile time which side effects it invokes. + // This means the function can be called at runtime and the + // affected globals can be marked dirty at compile time. + llvmParams := make([]llvm.Value, len(params)) + for i, param := range params { + llvmParams[i] = param.Value() + } + result := fr.builder.CreateCall(callee, llvmParams, inst.Name()) + ret = &LocalValue{fr.Eval, result} + // mark all mentioned globals as dirty + for global := range scanResult.mentionsGlobals { + fr.markDirty(global) + } + } else { + // Side effect is one of: + // * None: no side effects, can be fully interpreted at + // compile time. + // * Unbounded: cannot call at runtime so we'll try to + // interpret anyway and hope for the best. + ret, err = fr.function(callee, params, fr.pkgName, indent+" ") + if err != nil { + return nil, nil, err + } + } + if inst.Type().TypeKind() != llvm.VoidTypeKind { + fr.locals[inst] = ret + } + default: + // function pointers, etc. + return nil, nil, &Unsupported{inst} + } + case !inst.IsAExtractValueInst().IsNil(): + agg := fr.getLocal(inst.Operand(0)).(*LocalValue) // must be constant + indices := inst.Indices() + if agg.Underlying.IsConstant() { + newValue := llvm.ConstExtractValue(agg.Underlying, indices) + fr.locals[inst] = fr.getValue(newValue) + } else { + if len(indices) != 1 { + return nil, nil, errors.New("cannot handle extractvalue with not exactly 1 index") + } + fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateExtractValue(agg.Underlying, int(indices[0]), inst.Name())} + } + case !inst.IsAInsertValueInst().IsNil(): + agg := fr.getLocal(inst.Operand(0)).(*LocalValue) // must be constant + val := fr.getLocal(inst.Operand(1)) + indices := inst.Indices() + if agg.IsConstant() && val.IsConstant() { + newValue := llvm.ConstInsertValue(agg.Underlying, val.Value(), indices) + fr.locals[inst] = &LocalValue{fr.Eval, newValue} + } else { + if len(indices) != 1 { + return nil, nil, errors.New("cannot handle insertvalue with not exactly 1 index") + } + fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateInsertValue(agg.Underlying, val.Value(), int(indices[0]), inst.Name())} + } + + case !inst.IsAReturnInst().IsNil() && inst.OperandsCount() == 0: + return nil, nil, nil // ret void + case !inst.IsAReturnInst().IsNil() && inst.OperandsCount() == 1: + return fr.getLocal(inst.Operand(0)), nil, nil + case !inst.IsABranchInst().IsNil() && inst.OperandsCount() == 3: + // conditional branch (if/then/else) + cond := fr.getLocal(inst.Operand(0)).Value() + if cond.Type() != fr.Mod.Context().Int1Type() { + panic("expected an i1 in a branch instruction") + } + thenBB := inst.Operand(1) + elseBB := inst.Operand(2) + if !cond.IsConstant() { + return nil, nil, errors.New("interp: branch on a non-constant") + } else { + switch cond.ZExtValue() { + case 0: // false + return nil, []llvm.Value{thenBB}, nil // then + case 1: // true + return nil, []llvm.Value{elseBB}, nil // else + default: + panic("branch was not true or false") + } + } + case !inst.IsABranchInst().IsNil() && inst.OperandsCount() == 1: + // unconditional branch (goto) + return nil, []llvm.Value{inst.Operand(0)}, nil + case !inst.IsAUnreachableInst().IsNil(): + // unreachable was reached (e.g. after a call to panic()) + // assume this is actually unreachable when running + return &LocalValue{fr.Eval, llvm.Undef(fr.fn.Type())}, nil, nil + + default: + return nil, nil, &Unsupported{inst} + } + } + + panic("interp: reached end of basic block without terminator") +} + +// Get the Value for an operand, which is a constant value of some sort. +func (fr *frame) getLocal(v llvm.Value) Value { + if ret, ok := fr.locals[v]; ok { + return ret + } else if value := fr.getValue(v); value != nil { + return value + } else { + panic("cannot find value") + } +} diff --git a/interp/interp.go b/interp/interp.go new file mode 100644 index 00000000..65a60855 --- /dev/null +++ b/interp/interp.go @@ -0,0 +1,145 @@ +// Package interp interprets Go package initializers as much as possible. This +// avoid running them at runtime, improving code size and making other +// optimizations possible. +package interp + +// This file provides the overarching Eval object with associated (utility) +// methods. + +import ( + "errors" + "strings" + + "github.com/aykevl/go-llvm" +) + +type Eval struct { + Mod llvm.Module + TargetData llvm.TargetData + Debug bool + builder llvm.Builder + dibuilder *llvm.DIBuilder + dirtyGlobals map[llvm.Value]struct{} + sideEffectFuncs map[llvm.Value]*sideEffectResult // cache of side effect scan results +} + +// Run evaluates the function with the given name and then eliminates all +// callers. +func Run(mod llvm.Module, targetData llvm.TargetData, debug bool) error { + if debug { + println("\ncompile-time evaluation:") + } + + name := "runtime.initAll" + e := &Eval{ + Mod: mod, + TargetData: targetData, + Debug: debug, + dirtyGlobals: map[llvm.Value]struct{}{}, + } + e.builder = mod.Context().NewBuilder() + e.dibuilder = llvm.NewDIBuilder(mod) + + initAll := mod.NamedFunction(name) + bb := initAll.EntryBasicBlock() + e.builder.SetInsertPointBefore(bb.LastInstruction()) + e.builder.SetInstDebugLocation(bb.FirstInstruction()) + var initCalls []llvm.Value + for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { + if !inst.IsAReturnInst().IsNil() { + break // ret void + } + if inst.IsACallInst().IsNil() || inst.CalledValue().IsAFunction().IsNil() { + return errors.New("expected all instructions in " + name + " to be direct calls") + } + initCalls = append(initCalls, inst) + } + + // Do this in a separate step to avoid corrupting the iterator above. + for _, call := range initCalls { + initName := call.CalledValue().Name() + if !strings.HasSuffix(initName, ".init") { + return errors.New("expected all instructions in " + name + " to be *.init() calls") + } + pkgName := initName[:len(initName)-5] + _, err := e.Function(call.CalledValue(), nil, pkgName) + if err != nil { + return err + } + call.EraseFromParentAsInstruction() + } + + return nil +} + +func (e *Eval) Function(fn llvm.Value, params []Value, pkgName string) (Value, error) { + return e.function(fn, params, pkgName, "") +} + +func (e *Eval) function(fn llvm.Value, params []Value, pkgName, indent string) (Value, error) { + fr := frame{ + Eval: e, + fn: fn, + pkgName: pkgName, + locals: make(map[llvm.Value]Value), + } + for i, param := range fn.Params() { + fr.locals[param] = params[i] + } + + bb := fn.EntryBasicBlock() + var lastBB llvm.BasicBlock + for { + retval, outgoing, err := fr.evalBasicBlock(bb, lastBB, indent) + if outgoing == nil { + // returned something (a value or void, or an error) + return retval, err + } + if len(outgoing) > 1 { + panic("unimplemented: multiple outgoing blocks") + } + next := outgoing[0] + if next.IsABasicBlock().IsNil() { + panic("did not switch to a basic block") + } + lastBB = bb + bb = next.AsBasicBlock() + } +} + +// getValue determines what kind of LLVM value it gets and returns the +// appropriate Value type. +func (e *Eval) getValue(v llvm.Value) Value { + if !v.IsAGlobalVariable().IsNil() { + return &GlobalValue{e, v} + } else { + return &LocalValue{e, v} + } +} + +// markDirty marks the passed-in LLVM value dirty, recursively. For example, +// when it encounters a constant GEP on a global, it marks the global dirty. +func (e *Eval) markDirty(v llvm.Value) { + if !v.IsAGlobalVariable().IsNil() { + if v.IsGlobalConstant() { + return + } + if _, ok := e.dirtyGlobals[v]; !ok { + e.dirtyGlobals[v] = struct{}{} + e.sideEffectFuncs = nil // re-calculate all side effects + } + } else if v.IsConstant() { + if v.OperandsCount() >= 2 && !v.Operand(0).IsAGlobalVariable().IsNil() { + // looks like a constant getelementptr of a global. + // TODO: find a way to make sure it really is: v.Opcode() returns 0. + e.markDirty(v.Operand(0)) + return + } + return // nothing to mark + } else if !v.IsAGetElementPtrInst().IsNil() { + panic("interp: todo: GEP") + } else { + // Not constant and not a global or GEP so doesn't have to be marked + // non-constant. + } +} diff --git a/interp/scan.go b/interp/scan.go new file mode 100644 index 00000000..42a17e59 --- /dev/null +++ b/interp/scan.go @@ -0,0 +1,176 @@ +package interp + +import ( + "github.com/aykevl/go-llvm" +) + +type sideEffectSeverity int + +const ( + sideEffectInProgress sideEffectSeverity = iota // computing side effects is in progress (for recursive functions) + sideEffectNone // no side effects at all (pure) + sideEffectLimited // has side effects, but the effects are known + sideEffectAll // has unknown side effects +) + +// sideEffectResult contains the scan results after scanning a function for side +// effects (recursively). +type sideEffectResult struct { + severity sideEffectSeverity + mentionsGlobals map[llvm.Value]struct{} +} + +// hasSideEffects scans this function and all descendants, recursively. It +// returns whether this function has side effects and if it does, which globals +// it mentions anywhere in this function or any called functions. +func (e *Eval) hasSideEffects(fn llvm.Value) *sideEffectResult { + if e.sideEffectFuncs == nil { + e.sideEffectFuncs = make(map[llvm.Value]*sideEffectResult) + } + if se, ok := e.sideEffectFuncs[fn]; ok { + return se + } + result := &sideEffectResult{ + severity: sideEffectInProgress, + mentionsGlobals: map[llvm.Value]struct{}{}, + } + e.sideEffectFuncs[fn] = result + dirtyLocals := map[llvm.Value]struct{}{} + for bb := fn.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { + for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { + if inst.IsAInstruction().IsNil() { + panic("not an instruction") + } + + // Check for any globals mentioned anywhere in the function. Assume + // any mentioned globals may be read from or written to when + // executed, thus must be marked dirty with a call. + for i := 0; i < inst.OperandsCount(); i++ { + operand := inst.Operand(i) + if !operand.IsAGlobalVariable().IsNil() { + result.mentionsGlobals[operand] = struct{}{} + } + } + + switch inst.InstructionOpcode() { + case llvm.IndirectBr, llvm.Invoke: + // Not emitted by the compiler. + panic("unknown instructions") + case llvm.Call: + child := inst.CalledValue() + if !child.IsAInlineAsm().IsNil() { + // Inline assembly. This most likely has side effects. + // Assume they're only limited side effects, similar to + // external function calls. + result.updateSeverity(sideEffectLimited) + continue + } + if child.IsAFunction().IsNil() { + // Indirect call? + // In any case, we can't know anything here about what it + // affects exactly so mark this function as invoking all + // possible side effects. + result.updateSeverity(sideEffectAll) + continue + } + if child.IsDeclaration() { + // External function call. Assume only limited side effects + // (no affected globals, etc.). + if result.hasLocalSideEffects(dirtyLocals, inst) { + result.updateSeverity(sideEffectLimited) + } + continue + } + childSideEffects := e.hasSideEffects(fn) + switch childSideEffects.severity { + case sideEffectInProgress, sideEffectNone: + // no side effects or recursive function - continue scanning + default: + result.update(childSideEffects) + } + case llvm.Load, llvm.Store: + if inst.IsVolatile() { + result.updateSeverity(sideEffectLimited) + } + default: + // Ignore most instructions. + // Check this list for completeness: + // https://godoc.org/github.com/llvm-mirror/llvm/bindings/go/llvm#Opcode + } + } + } + + if result.severity == sideEffectInProgress { + // No side effect was reported for this function. + result.severity = sideEffectNone + } + return result +} + +// hasLocalSideEffects checks whether the given instruction flows into a branch +// or return instruction, in which case the whole function must be marked as +// having side effects and be called at runtime. +func (r *sideEffectResult) hasLocalSideEffects(dirtyLocals map[llvm.Value]struct{}, inst llvm.Value) bool { + if _, ok := dirtyLocals[inst]; ok { + // It is already known that this local is dirty. + return true + } + + for use := inst.FirstUse(); !use.IsNil(); use = use.NextUse() { + user := use.User() + if user.IsAInstruction().IsNil() { + panic("user not an instruction") + } + switch user.InstructionOpcode() { + case llvm.Br, llvm.Switch: + // A branch on a dirty value makes this function dirty: it cannot be + // interpreted at compile time so has to be run at runtime. It is + // marked as having side effects for this reason. + return true + case llvm.Ret: + // This function returns a dirty value so it is itself marked as + // dirty to make sure it is called at runtime. + return true + case llvm.Store: + ptr := user.Operand(1) + if !ptr.IsAGlobalVariable().IsNil() { + // Store to a global variable. + // Already handled in (*Eval).hasSideEffects. + continue + } + // But a store might also store to an alloca, in which case all uses + // of the alloca (possibly indirect through a GEP, bitcast, etc.) + // must be marked dirty. + panic("todo: store") + default: + // All instructions that take 0 or more operands (1 or more if it + // was a use) and produce a result. + // For a list: + // https://godoc.org/github.com/llvm-mirror/llvm/bindings/go/llvm#Opcode + dirtyLocals[user] = struct{}{} + if r.hasLocalSideEffects(dirtyLocals, user) { + return true + } + } + } + + // No side effects found. + return false +} + +// updateSeverity sets r.severity to the max of r.severity and severity, +// conservatively assuming the worst severity. +func (r *sideEffectResult) updateSeverity(severity sideEffectSeverity) { + if severity > r.severity { + r.severity = severity + } +} + +// updateSeverity updates the severity with the severity of the child severity, +// like in a function call. This means it also copies the mentioned globals. +func (r *sideEffectResult) update(child *sideEffectResult) { + r.updateSeverity(child.severity) + for global := range child.mentionsGlobals { + r.mentionsGlobals[global] = struct{}{} + } +} diff --git a/interp/utils.go b/interp/utils.go new file mode 100644 index 00000000..669afa41 --- /dev/null +++ b/interp/utils.go @@ -0,0 +1,93 @@ +package interp + +import ( + "github.com/aykevl/go-llvm" +) + +// Return a list of values (actually, instructions) where this value is used as +// an operand. +func getUses(value llvm.Value) []llvm.Value { + var uses []llvm.Value + use := value.FirstUse() + for !use.IsNil() { + uses = append(uses, use.User()) + use = use.NextUse() + } + return uses +} + +// Return a zero LLVM value for any LLVM type. Setting this value as an +// initializer has the same effect as setting 'zeroinitializer' on a value. +// Sadly, I haven't found a way to do it directly with the Go API but this works +// just fine. +func getZeroValue(typ llvm.Type) llvm.Value { + switch typ.TypeKind() { + case llvm.ArrayTypeKind: + subTyp := typ.ElementType() + subVal := getZeroValue(subTyp) + vals := make([]llvm.Value, typ.ArrayLength()) + for i := range vals { + vals[i] = subVal + } + return llvm.ConstArray(subTyp, vals) + case llvm.FloatTypeKind, llvm.DoubleTypeKind: + return llvm.ConstFloat(typ, 0.0) + case llvm.IntegerTypeKind: + return llvm.ConstInt(typ, 0, false) + case llvm.PointerTypeKind: + return llvm.ConstPointerNull(typ) + case llvm.StructTypeKind: + types := typ.StructElementTypes() + vals := make([]llvm.Value, len(types)) + for i, subTyp := range types { + val := getZeroValue(subTyp) + vals[i] = val + } + if typ.StructName() != "" { + return llvm.ConstNamedStruct(typ, vals) + } else { + return typ.Context().ConstStruct(vals, false) + } + case llvm.VectorTypeKind: + zero := getZeroValue(typ.ElementType()) + vals := make([]llvm.Value, typ.VectorSize()) + for i := range vals { + vals[i] = zero + } + return llvm.ConstVector(vals, false) + default: + panic("interp: unknown LLVM type: " + typ.String()) + } +} + +// getStringBytes loads the byte slice of a Go string represented as a +// {ptr, len} pair. +func getStringBytes(strPtr Value, strLen llvm.Value) []byte { + buf := make([]byte, strLen.ZExtValue()) + for i := range buf { + c := strPtr.GetElementPtr([]uint32{uint32(i)}).Load() + buf[i] = byte(c.ZExtValue()) + } + return buf +} + +// getLLVMIndices converts an []uint32 into an []llvm.Value, for use in +// llvm.ConstGEP. +func getLLVMIndices(int32Type llvm.Type, indices []uint32) []llvm.Value { + llvmIndices := make([]llvm.Value, len(indices)) + for i, index := range indices { + llvmIndices[i] = llvm.ConstInt(int32Type, uint64(index), false) + } + return llvmIndices +} + +// Return true if this type is a scalar value (integer or floating point), false +// otherwise. +func isScalar(t llvm.Type) bool { + switch t.TypeKind() { + case llvm.IntegerTypeKind, llvm.FloatTypeKind, llvm.DoubleTypeKind: + return true + default: + return false + } +} diff --git a/interp/values.go b/interp/values.go new file mode 100644 index 00000000..ffadda30 --- /dev/null +++ b/interp/values.go @@ -0,0 +1,588 @@ +package interp + +// This file provides a litte bit of abstraction around LLVM values. + +import ( + "strconv" + + "github.com/aykevl/go-llvm" +) + +// A Value is a LLVM value with some extra methods attached for easier +// interpretation. +type Value interface { + Value() llvm.Value // returns a LLVM value + Type() llvm.Type // equal to Value().Type() + IsConstant() bool // returns true if this value is a constant value + Load() llvm.Value // dereference a pointer + Store(llvm.Value) // store to a pointer + GetElementPtr([]uint32) Value // returns an interior pointer + String() string // string representation, for debugging +} + +// A type that simply wraps a LLVM constant value. +type LocalValue struct { + Eval *Eval + Underlying llvm.Value +} + +// Value implements Value by returning the constant value itself. +func (v *LocalValue) Value() llvm.Value { + return v.Underlying +} + +func (v *LocalValue) Type() llvm.Type { + return v.Underlying.Type() +} + +func (v *LocalValue) IsConstant() bool { + return v.Underlying.IsConstant() +} + +// Load loads a constant value if this is a constant GEP, otherwise it panics. +func (v *LocalValue) Load() llvm.Value { + switch v.Underlying.Opcode() { + case llvm.GetElementPtr: + indices := v.getConstGEPIndices() + if indices[0] != 0 { + panic("invalid GEP") + } + global := v.Eval.getValue(v.Underlying.Operand(0)) + agg := global.Load() + return llvm.ConstExtractValue(agg, indices[1:]) + default: + panic("interp: load from a constant") + } +} + +// Store stores to the underlying value if the value type is a constant GEP, +// otherwise it panics. +func (v *LocalValue) Store(value llvm.Value) { + switch v.Underlying.Opcode() { + case llvm.GetElementPtr: + indices := v.getConstGEPIndices() + if indices[0] != 0 { + panic("invalid GEP") + } + global := &GlobalValue{v.Eval, v.Underlying.Operand(0)} + agg := global.Load() + agg = llvm.ConstInsertValue(agg, value, indices[1:]) + global.Store(agg) + return + default: + panic("interp: store on a constant") + } +} + +// GetElementPtr returns a constant GEP when the underlying value is also a +// constant GEP. It panics when the underlying value is not a constant GEP: +// getting the pointer to a constant is not possible. +func (v *LocalValue) GetElementPtr(indices []uint32) Value { + switch v.Underlying.Opcode() { + case llvm.GetElementPtr, llvm.IntToPtr: + int32Type := v.Underlying.Type().Context().Int32Type() + llvmIndices := getLLVMIndices(int32Type, indices) + return &LocalValue{v.Eval, llvm.ConstGEP(v.Underlying, llvmIndices)} + default: + panic("interp: GEP on a constant") + } +} + +func (v *LocalValue) String() string { + isConstant := "false" + if v.IsConstant() { + isConstant = "true" + } + return "&LocalValue{Type: " + v.Type().String() + ", IsConstant: " + isConstant + "}" +} + +// getConstGEPIndices returns indices of this constant GEP, if this is a GEP +// instruction. If it is not, the behavior is undefined. +func (v *LocalValue) getConstGEPIndices() []uint32 { + indices := make([]uint32, v.Underlying.OperandsCount()-1) + for i := range indices { + operand := v.Underlying.Operand(i + 1) + indices[i] = uint32(operand.ZExtValue()) + } + return indices +} + +// GlobalValue wraps a LLVM global variable. +type GlobalValue struct { + Eval *Eval + Underlying llvm.Value +} + +// Value returns the initializer for this global variable. +func (v *GlobalValue) Value() llvm.Value { + return v.Underlying +} + +// Type returns the type of this global variable, which is a pointer type. Use +// Type().ElementType() to get the actual global variable type. +func (v *GlobalValue) Type() llvm.Type { + return v.Underlying.Type() +} + +// IsConstant returns true if this global is not dirty, false otherwise. +func (v *GlobalValue) IsConstant() bool { + if _, ok := v.Eval.dirtyGlobals[v.Underlying]; ok { + return true + } + return false +} + +// Load returns the initializer of the global variable. +func (v *GlobalValue) Load() llvm.Value { + return v.Underlying.Initializer() +} + +// Store sets the initializer of the global variable. +func (v *GlobalValue) Store(value llvm.Value) { + if !value.IsConstant() { + v.MarkDirty() + v.Eval.builder.CreateStore(value, v.Underlying) + } else { + v.Underlying.SetInitializer(value) + } +} + +// GetElementPtr returns a constant GEP on this global, which can be used in +// load and store instructions. +func (v *GlobalValue) GetElementPtr(indices []uint32) Value { + int32Type := v.Underlying.Type().Context().Int32Type() + gep := llvm.ConstGEP(v.Underlying, getLLVMIndices(int32Type, indices)) + return &LocalValue{v.Eval, gep} +} + +func (v *GlobalValue) String() string { + return "&GlobalValue{" + v.Underlying.Name() + "}" +} + +// MarkDirty marks this global as dirty, meaning that every load from and store +// to this global (from now on) must be performed at runtime. +func (v *GlobalValue) MarkDirty() { + if !v.IsConstant() { + return // already dirty + } + v.Eval.dirtyGlobals[v.Underlying] = struct{}{} +} + +// An alloca represents a local alloca, which is a stack allocated variable. +// It is emulated by storing the constant of the alloca. +type AllocaValue struct { + Eval *Eval + Underlying llvm.Value // the constant value itself if not dirty, otherwise the alloca instruction + Dirty bool // this value must be evaluated at runtime +} + +// Value turns this alloca into a runtime alloca instead of a compile-time +// constant (if not already converted), and returns the alloca itself. +func (v *AllocaValue) Value() llvm.Value { + if !v.Dirty { + // Mark this alloca a dirty, meaning it is run at runtime instead of + // compile time. + alloca := v.Eval.builder.CreateAlloca(v.Underlying.Type(), "") + v.Eval.builder.CreateStore(v.Underlying, alloca) + v.Dirty = true + v.Underlying = alloca + } + return v.Underlying +} + +// Type returns the type of this alloca, which is always a pointer. +func (v *AllocaValue) Type() llvm.Type { + if v.Dirty { + return v.Underlying.Type() + } else { + return llvm.PointerType(v.Underlying.Type(), 0) + } +} + +func (v *AllocaValue) IsConstant() bool { + return !v.Dirty +} + +// Load returns the value this alloca contains, which may be evaluated at +// runtime. +func (v *AllocaValue) Load() llvm.Value { + if v.Dirty { + ret := v.Eval.builder.CreateLoad(v.Underlying, "") + if ret.IsNil() { + panic("alloca is nil") + } + return ret + } else { + if v.Underlying.IsNil() { + panic("alloca is nil") + } + return v.Underlying + } +} + +// Store updates the value of this alloca. +func (v *AllocaValue) Store(value llvm.Value) { + if v.Underlying.Type() != value.Type() { + panic("interp: trying to store to an alloca with a different type") + } + if v.Dirty || !value.IsConstant() { + v.Eval.builder.CreateStore(value, v.Value()) + } else { + v.Underlying = value + } +} + +// GetElementPtr returns a value (a *GetElementPtrValue) that keeps a reference +// to this alloca, so that Load() and Store() continue to work. +func (v *AllocaValue) GetElementPtr(indices []uint32) Value { + return &GetElementPtrValue{v, indices} +} + +func (v *AllocaValue) String() string { + return "&AllocaValue{Type: " + v.Type().String() + "}" +} + +// GetElementPtrValue wraps an alloca, keeping track of what the GEP points to +// so it can be used as a pointer value (with Load() and Store()). +type GetElementPtrValue struct { + Alloca *AllocaValue + Indices []uint32 +} + +// Type returns the type of this GEP, which is always of type pointer. +func (v *GetElementPtrValue) Type() llvm.Type { + if v.Alloca.Dirty { + return v.Value().Type() + } else { + return llvm.PointerType(v.Load().Type(), 0) + } +} + +func (v *GetElementPtrValue) IsConstant() bool { + return v.Alloca.IsConstant() +} + +// Value creates the LLVM GEP instruction of this GetElementPtrValue wrapper and +// returns it. +func (v *GetElementPtrValue) Value() llvm.Value { + if v.Alloca.Dirty { + alloca := v.Alloca.Value() + int32Type := v.Alloca.Type().Context().Int32Type() + llvmIndices := getLLVMIndices(int32Type, v.Indices) + return v.Alloca.Eval.builder.CreateGEP(alloca, llvmIndices, "") + } else { + panic("interp: todo: pointer to alloca gep") + } +} + +// Load deferences the pointer this GEP points to. For a constant GEP, it +// extracts the value from the underlying alloca. +func (v *GetElementPtrValue) Load() llvm.Value { + if v.Alloca.Dirty { + gep := v.Value() + return v.Alloca.Eval.builder.CreateLoad(gep, "") + } else { + underlying := v.Alloca.Load() + indices := v.Indices + if indices[0] != 0 { + panic("invalid GEP") + } + return llvm.ConstExtractValue(underlying, indices[1:]) + } +} + +// Store stores to the pointer this GEP points to. For a constant GEP, it +// updates the underlying allloca. +func (v *GetElementPtrValue) Store(value llvm.Value) { + if v.Alloca.Dirty || !value.IsConstant() { + alloca := v.Alloca.Value() + int32Type := v.Alloca.Type().Context().Int32Type() + llvmIndices := getLLVMIndices(int32Type, v.Indices) + gep := v.Alloca.Eval.builder.CreateGEP(alloca, llvmIndices, "") + v.Alloca.Eval.builder.CreateStore(value, gep) + } else { + underlying := v.Alloca.Load() + indices := v.Indices + if indices[0] != 0 { + panic("invalid GEP") + } + underlying = llvm.ConstInsertValue(underlying, value, indices[1:]) + v.Alloca.Store(underlying) + } +} + +func (v *GetElementPtrValue) GetElementPtr(indices []uint32) Value { + if v.Alloca.Dirty { + panic("interp: todo: gep on a dirty gep") + } else { + combined := append([]uint32{}, v.Indices...) + combined[len(combined)-1] += indices[0] + combined = append(combined, indices[1:]...) + return &GetElementPtrValue{v.Alloca, combined} + } +} + +func (v *GetElementPtrValue) String() string { + indices := "" + for _, n := range v.Indices { + if indices != "" { + indices += ", " + } + indices += strconv.Itoa(int(n)) + } + return "&GetElementPtrValue{Alloca: " + v.Alloca.String() + ", Indices: [" + indices + "]}" +} + +// PointerCastValue represents a bitcast operation on a pointer. +type PointerCastValue struct { + Eval *Eval + Underlying Value + CastType llvm.Type +} + +// Value returns a constant bitcast value. +func (v *PointerCastValue) Value() llvm.Value { + from := v.Underlying.Value() + return llvm.ConstBitCast(from, v.CastType) +} + +// Type returns the type this pointer has been cast to. +func (v *PointerCastValue) Type() llvm.Type { + return v.CastType +} + +func (v *PointerCastValue) IsConstant() bool { + return v.Underlying.IsConstant() +} + +// Load tries to load and bitcast the given value. If this value cannot be +// bitcasted, Load panics. +func (v *PointerCastValue) Load() llvm.Value { + if v.Underlying.IsConstant() { + typeFrom := v.Underlying.Type().ElementType() + typeTo := v.CastType.ElementType() + if isScalar(typeFrom) && isScalar(typeTo) && v.Eval.TargetData.TypeAllocSize(typeFrom) == v.Eval.TargetData.TypeAllocSize(typeTo) { + return llvm.ConstBitCast(v.Underlying.Load(), v.CastType.ElementType()) + } + } + + panic("interp: load from a pointer bitcast: " + v.String()) +} + +// Store panics: it is not (yet) possible to store directly to a bitcast. +func (v *PointerCastValue) Store(value llvm.Value) { + panic("interp: store on a pointer bitcast") +} + +// GetElementPtr panics: it is not (yet) possible to do a GEP operation on a +// bitcast. +func (v *PointerCastValue) GetElementPtr(indices []uint32) Value { + panic("interp: GEP on a pointer bitcast") +} + +func (v *PointerCastValue) String() string { + return "&PointerCastValue{Value: " + v.Underlying.String() + ", CastType: " + v.CastType.String() + "}" +} + +// MapValue implements a Go map which is created at compile time and stored as a +// global variable. +type MapValue struct { + Eval *Eval + PkgName string + Underlying llvm.Value + Keys []Value + Values []Value + KeySize int + ValueSize int + KeyType llvm.Type + ValueType llvm.Type +} + +func (v *MapValue) newBucket() llvm.Value { + ctx := v.Eval.Mod.Context() + i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) + bucketType := ctx.StructType([]llvm.Type{ + llvm.ArrayType(ctx.Int8Type(), 8), // tophash + i8ptrType, // next bucket + llvm.ArrayType(v.KeyType, 8), // key type + llvm.ArrayType(v.ValueType, 8), // value type + }, false) + bucketValue := getZeroValue(bucketType) + bucket := llvm.AddGlobal(v.Eval.Mod, bucketType, v.PkgName+"$mapbucket") + bucket.SetInitializer(bucketValue) + bucket.SetLinkage(llvm.InternalLinkage) + bucket.SetUnnamedAddr(true) + return bucket +} + +// Value returns a global variable which is a pointer to the actual hashmap. +func (v *MapValue) Value() llvm.Value { + if !v.Underlying.IsNil() { + return v.Underlying + } + + ctx := v.Eval.Mod.Context() + i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) + + var firstBucketGlobal llvm.Value + if len(v.Keys) == 0 { + // there are no buckets + firstBucketGlobal = llvm.ConstPointerNull(i8ptrType) + } else { + // create initial bucket + firstBucketGlobal = v.newBucket() + } + + // Insert each key/value pair in the hashmap. + bucketGlobal := firstBucketGlobal + for i, key := range v.Keys { + var keyBuf []byte + llvmKey := key.Value() + llvmValue := v.Values[i].Value() + if key.Type().TypeKind() == llvm.StructTypeKind && key.Type().StructName() == "runtime._string" { + keyPtr := llvm.ConstExtractValue(llvmKey, []uint32{0}) + keyLen := llvm.ConstExtractValue(llvmKey, []uint32{1}) + keyPtrVal := v.Eval.getValue(keyPtr) + keyBuf = getStringBytes(keyPtrVal, keyLen) + } else if key.Type().TypeKind() == llvm.IntegerTypeKind { + keyBuf = make([]byte, v.Eval.TargetData.TypeAllocSize(key.Type())) + n := key.Value().ZExtValue() + for i := range keyBuf { + keyBuf[i] = byte(n) + n >>= 8 + } + } else { + panic("interp: map key type not implemented: " + key.Type().String()) + } + hash := v.hash(keyBuf) + + if i%8 == 0 && i != 0 { + // Bucket is full, create a new one. + newBucketGlobal := v.newBucket() + zero := llvm.ConstInt(ctx.Int32Type(), 0, false) + newBucketPtr := llvm.ConstInBoundsGEP(newBucketGlobal, []llvm.Value{zero}) + newBucketPtrCast := llvm.ConstBitCast(newBucketPtr, i8ptrType) + // insert pointer into old bucket + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, newBucketPtrCast, []uint32{1}) + bucketGlobal.SetInitializer(bucket) + // switch to next bucket + bucketGlobal = newBucketGlobal + } + + tophashValue := llvm.ConstInt(ctx.Int8Type(), uint64(v.topHash(hash)), false) + bucket := bucketGlobal.Initializer() + bucket = llvm.ConstInsertValue(bucket, tophashValue, []uint32{0, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmKey, []uint32{2, uint32(i % 8)}) + bucket = llvm.ConstInsertValue(bucket, llvmValue, []uint32{3, uint32(i % 8)}) + bucketGlobal.SetInitializer(bucket) + } + + // Create the hashmap itself. + zero := llvm.ConstInt(ctx.Int32Type(), 0, false) + bucketPtr := llvm.ConstInBoundsGEP(firstBucketGlobal, []llvm.Value{zero}) + hashmapType := v.Type() + hashmap := llvm.ConstNamedStruct(hashmapType, []llvm.Value{ + llvm.ConstPointerNull(llvm.PointerType(hashmapType, 0)), // next + llvm.ConstBitCast(bucketPtr, i8ptrType), // buckets + llvm.ConstInt(hashmapType.StructElementTypes()[2], uint64(len(v.Keys)), false), // count + llvm.ConstInt(ctx.Int8Type(), uint64(v.KeySize), false), // keySize + llvm.ConstInt(ctx.Int8Type(), uint64(v.ValueSize), false), // valueSize + llvm.ConstInt(ctx.Int8Type(), 0, false), // bucketBits + }) + + // Create a pointer to this hashmap. + hashmapPtr := llvm.AddGlobal(v.Eval.Mod, hashmap.Type(), v.PkgName+"$map") + hashmapPtr.SetInitializer(hashmap) + hashmapPtr.SetLinkage(llvm.InternalLinkage) + hashmapPtr.SetUnnamedAddr(true) + v.Underlying = llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}) + return v.Underlying +} + +// Type returns type runtime.hashmap, which is the actual hashmap type. +func (v *MapValue) Type() llvm.Type { + return v.Eval.Mod.GetTypeByName("runtime.hashmap") +} + +func (v *MapValue) IsConstant() bool { + return true // TODO: dirty maps +} + +// Load panics: maps are of reference type so cannot be dereferenced. +func (v *MapValue) Load() llvm.Value { + panic("interp: load from a map") +} + +// Store panics: maps are of reference type so cannot be stored to. +func (v *MapValue) Store(value llvm.Value) { + panic("interp: store on a map") +} + +// GetElementPtr panics: maps are of reference type so their (interior) +// addresses cannot be calculated. +func (v *MapValue) GetElementPtr(indices []uint32) Value { + panic("interp: GEP on a map") +} + +// PutString does a map assign operation, assuming that the map is of type +// map[string]T. +func (v *MapValue) PutString(keyBuf, keyLen, valPtr Value) { + if !v.Underlying.IsNil() { + panic("map already created") + } + + var value llvm.Value + switch valPtr := valPtr.(type) { + case *PointerCastValue: + value = valPtr.Underlying.Load() + if v.ValueType.IsNil() { + v.ValueType = value.Type() + if int(v.Eval.TargetData.TypeAllocSize(v.ValueType)) != v.ValueSize { + panic("interp: map store value type has the wrong size") + } + } else { + if value.Type() != v.ValueType { + panic("interp: map store value type is inconsistent") + } + } + default: + panic("interp: todo: handle map value pointer") + } + + keyType := v.Eval.Mod.GetTypeByName("runtime._string") + v.KeyType = keyType + key := getZeroValue(keyType) + key = llvm.ConstInsertValue(key, keyBuf.Value(), []uint32{0}) + key = llvm.ConstInsertValue(key, keyLen.Value(), []uint32{1}) + + // TODO: avoid duplicate keys + v.Keys = append(v.Keys, &LocalValue{v.Eval, key}) + v.Values = append(v.Values, &LocalValue{v.Eval, value}) +} + +// Get FNV-1a hash of this string. +// +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function#FNV-1a_hash +func (v *MapValue) hash(data []byte) uint32 { + var result uint32 = 2166136261 // FNV offset basis + for _, c := range data { + result ^= uint32(c) + result *= 16777619 // FNV prime + } + return result +} + +// Get the topmost 8 bits of the hash, without using a special value (like 0). +func (v *MapValue) topHash(hash uint32) uint8 { + tophash := uint8(hash >> 24) + if tophash < 1 { + // 0 means empty slot, so make it bigger. + tophash += 1 + } + return tophash +} + +func (v *MapValue) String() string { + return "&MapValue{KeySize: " + strconv.Itoa(v.KeySize) + ", ValueSize: " + strconv.Itoa(v.ValueSize) + "}" +} diff --git a/main.go b/main.go index e4259160..f92f2ae9 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "github.com/aykevl/go-llvm" "github.com/aykevl/tinygo/compiler" + "github.com/aykevl/tinygo/interp" ) var commands = map[string]string{ @@ -28,17 +29,19 @@ type BuildConfig struct { dumpSSA bool debug bool printSizes string + initInterp bool } // Helper function for Compiler object. func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, action func(string) error) error { compilerConfig := compiler.Config{ - Triple: spec.Triple, - Debug: config.debug, - DumpSSA: config.dumpSSA, - RootDir: sourceDir(), - GOPATH: getGopath(), - BuildTags: append(spec.BuildTags, "tinygo"), + Triple: spec.Triple, + Debug: config.debug, + DumpSSA: config.dumpSSA, + RootDir: sourceDir(), + GOPATH: getGopath(), + BuildTags: append(spec.BuildTags, "tinygo"), + InitInterp: config.initInterp, } c, err := compiler.NewCompiler(pkgName, compilerConfig) if err != nil { @@ -58,6 +61,16 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act return err } + if config.initInterp { + err = interp.Run(c.Module(), c.TargetData(), config.dumpSSA) + if err != nil { + return err + } + if err := c.Verify(); err != nil { + return err + } + } + c.ApplyFunctionSections() // -ffunction-sections if err := c.Verify(); err != nil { return err @@ -399,6 +412,20 @@ func usage() { flag.PrintDefaults() } +func handleCompilerError(err error) { + if err != nil { + if errUnsupported, ok := err.(*interp.Unsupported); ok { + // hit an unknown/unsupported instruction + fmt.Fprintln(os.Stderr, "unsupported instruction during init evaluation:") + errUnsupported.Inst.Dump() + fmt.Fprintln(os.Stderr) + } else { + fmt.Fprintln(os.Stderr, "error:", err) + } + os.Exit(1) + } +} + func main() { outpath := flag.String("o", "", "output filename") opt := flag.String("opt", "z", "optimization level: 0, 1, 2, s, z") @@ -408,6 +435,7 @@ func main() { printSize := flag.String("size", "", "print sizes (none, short, full)") nodebug := flag.Bool("no-debug", false, "disable DWARF debug symbol generation") ocdOutput := flag.Bool("ocd-output", false, "print OCD daemon output during debug") + initInterp := flag.Bool("interp", false, "enable experimental partial evaluator of generated IR") port := flag.String("port", "/dev/ttyACM0", "flash port") if len(os.Args) < 2 { @@ -424,6 +452,7 @@ func main() { dumpSSA: *dumpSSA, debug: !*nodebug, printSizes: *printSize, + initInterp: *initInterp, } os.Setenv("CC", "clang -target="+*target) @@ -445,30 +474,24 @@ func main() { target = "wasm" } err := Build(flag.Arg(0), *outpath, target, config) - if err != nil { - fmt.Fprintln(os.Stderr, "error:", err) - os.Exit(1) - } + handleCompilerError(err) case "flash", "gdb": if *outpath != "" { fmt.Fprintln(os.Stderr, "Output cannot be specified with the flash command.") usage() os.Exit(1) } - var err error if command == "flash" { - err = Flash(flag.Arg(0), *target, *port, config) + err := Flash(flag.Arg(0), *target, *port, config) + handleCompilerError(err) } else { if !config.debug { fmt.Fprintln(os.Stderr, "Debug disabled while running gdb?") usage() os.Exit(1) } - err = FlashGDB(flag.Arg(0), *target, *port, *ocdOutput, config) - } - if err != nil { - fmt.Fprintln(os.Stderr, "error:", err) - os.Exit(1) + err := FlashGDB(flag.Arg(0), *target, *port, *ocdOutput, config) + handleCompilerError(err) } case "run": if flag.NArg() != 1 { @@ -476,15 +499,12 @@ func main() { usage() os.Exit(1) } - var err error if *target == "" { - err = Run(flag.Arg(0)) + err := Run(flag.Arg(0)) + handleCompilerError(err) } else { - err = Emulate(flag.Arg(0), *target, config) - } - if err != nil { - fmt.Fprintln(os.Stderr, "error:", err) - os.Exit(1) + err := Emulate(flag.Arg(0), *target, config) + handleCompilerError(err) } case "clean": // remove cache directory diff --git a/main_test.go b/main_test.go index afebb9a8..c02881c6 100644 --- a/main_test.go +++ b/main_test.go @@ -65,6 +65,7 @@ func runTest(path, tmpdir string, target string, t *testing.T) { dumpSSA: false, debug: false, printSizes: "", + initInterp: false, } binary := filepath.Join(tmpdir, "test") err = Build(path, binary, target, config) diff --git a/testdata/init.go b/testdata/init.go index f90c1192..3233acf2 100644 --- a/testdata/init.go +++ b/testdata/init.go @@ -12,6 +12,7 @@ func main() { println("v4:", len(v4), v4 == nil) println("v5:", len(v5), v5 == nil) println("v6:", v6) + println("v7:", cap(v7), string(v7)) } type ( @@ -28,4 +29,5 @@ var ( v4 map[string]int v5 = map[string]int{} v6 = float64(v1) < 2.6 + v7 = []byte("foo") ) diff --git a/testdata/init.txt b/testdata/init.txt index 18d99b82..7ac27b98 100644 --- a/testdata/init.txt +++ b/testdata/init.txt @@ -6,3 +6,4 @@ v3: 4 4 2 7 v4: 0 true v5: 0 false v6: false +v7: 3 foo