From abca3132a98159960a21e167a3aa2e606cff83f4 Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Thu, 12 Sep 2019 16:33:14 -0400 Subject: [PATCH] fix bugs found by LLVM assertions --- compiler/check.go | 128 +++++++++++++++++++++++++++++++++ compiler/compiler.go | 3 +- compiler/defer.go | 10 +-- compiler/func-lowering.go | 6 +- compiler/gc.go | 4 +- compiler/goroutine.go | 6 +- compiler/interface-lowering.go | 12 ++-- compiler/interface.go | 4 +- compiler/optimizer.go | 5 ++ compiler/syscall.go | 4 +- main.go | 4 ++ main_test.go | 1 + 12 files changed, 163 insertions(+), 24 deletions(-) create mode 100644 compiler/check.go diff --git a/compiler/check.go b/compiler/check.go new file mode 100644 index 00000000..356d0296 --- /dev/null +++ b/compiler/check.go @@ -0,0 +1,128 @@ +package compiler + +// This file implements a set of sanity checks for the IR that is generated. +// It can catch some mistakes that LLVM's verifier cannot. + +import ( + "errors" + "fmt" + + "tinygo.org/x/go-llvm" +) + +func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { + if t.IsNil() { + panic(t) + } + + // prevent infinite recursion for self-referential types + if _, ok := checked[t]; ok { + return + } + checked[t] = struct{}{} + + // check for any context mismatches + switch { + case t.Context() == c.ctx: + // this is correct + case t.Context() == llvm.GlobalContext(): + // somewhere we accidentally used the global context instead of a real context + panic(fmt.Errorf("type %q uses global context", t.String())) + default: + // we used some other context by accident + panic(fmt.Errorf("type %q uses context %v instead of the main context %v", t.Context(), c.ctx)) + } + + // if this is a composite type, check the components of the type + switch t.TypeKind() { + case llvm.VoidTypeKind, llvm.LabelTypeKind, llvm.TokenTypeKind, llvm.MetadataTypeKind: + // there should only be one of any of these + if s, ok := specials[t.TypeKind()]; !ok { + specials[t.TypeKind()] = t + } else if s != t { + panic(fmt.Errorf("duplicate special type %q: %v and %v", t.TypeKind().String(), t, s)) + } + case llvm.FloatTypeKind, llvm.DoubleTypeKind, llvm.X86_FP80TypeKind, llvm.FP128TypeKind, llvm.PPC_FP128TypeKind: + // floating point numbers are primitives - nothing to recurse + case llvm.IntegerTypeKind: + // integers are primitives - nothing to recurse + case llvm.FunctionTypeKind: + // check arguments and return(s) + for _, v := range t.ParamTypes() { + c.checkType(v, checked, specials) + } + c.checkType(t.ReturnType(), checked, specials) + case llvm.StructTypeKind: + // check all elements + for _, v := range t.StructElementTypes() { + c.checkType(v, checked, specials) + } + case llvm.ArrayTypeKind: + // check element type + c.checkType(t.ElementType(), checked, specials) + case llvm.PointerTypeKind: + // check underlying type + c.checkType(t.ElementType(), checked, specials) + case llvm.VectorTypeKind: + // check element type + c.checkType(t.ElementType(), checked, specials) + } +} + +func (c *Compiler) checkValue(v llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { + // check type + c.checkType(v.Type(), types, specials) +} + +func (c *Compiler) checkInstruction(inst llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { + // check value properties + c.checkValue(inst, types, specials) + + // check operands + for i := 0; i < inst.OperandsCount(); i++ { + c.checkValue(inst.Operand(i), types, specials) + } +} + +func (c *Compiler) checkBasicBlock(bb llvm.BasicBlock, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { + // check basic block value and type + c.checkValue(bb.AsValue(), types, specials) + + // check instructions + for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { + c.checkInstruction(inst, types, specials) + } +} + +func (c *Compiler) checkFunction(fn llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { + // check function value and type + c.checkValue(fn, types, specials) + + // check basic blocks + for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { + c.checkBasicBlock(bb, types, specials) + } +} + +func (c *Compiler) checkModule() { + // check for any context mismatches + switch { + case c.mod.Context() == c.ctx: + // this is correct + case c.mod.Context() == llvm.GlobalContext(): + // somewhere we accidentally used the global context instead of a real context + panic(errors.New("module uses global context")) + default: + // we used some other context by accident + panic(fmt.Errorf("module uses context %v instead of the main context %v", c.mod.Context(), c.ctx)) + } + + types := map[llvm.Type]struct{}{} + specials := map[llvm.TypeKind]llvm.Type{} + for fn := c.mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) { + c.checkFunction(fn, types, specials) + } + for g := c.mod.FirstGlobal(); !g.IsNil(); g = llvm.NextGlobal(g) { + c.checkValue(g, types, specials) + } +} diff --git a/compiler/compiler.go b/compiler/compiler.go index 050e7107..de0193de 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -59,6 +59,7 @@ type Config struct { LDFlags []string // ldflags to pass to cgo ClangHeaders string // Clang built-in header include path DumpSSA bool // dump Go SSA, for compiler debugging + VerifyIR bool // run extra checks on the IR Debug bool // add debug symbols for gdb GOROOT string // GOROOT TINYGOROOT string // GOROOT for TinyGo @@ -2712,7 +2713,7 @@ func (c *Compiler) ExternalInt64AsPtr() error { // correct calling convention. fn.SetLinkage(llvm.InternalLinkage) fn.SetUnnamedAddr(true) - entryBlock := llvm.AddBasicBlock(externalFn, "entry") + entryBlock := c.ctx.AddBasicBlock(externalFn, "entry") c.builder.SetInsertPointAtEnd(entryBlock) var callParams []llvm.Value if fnType.ReturnType() == int64Type { diff --git a/compiler/defer.go b/compiler/defer.go index c408b6e7..00664cfe 100644 --- a/compiler/defer.go +++ b/compiler/defer.go @@ -157,10 +157,10 @@ func (c *Compiler) emitRunDefers(frame *Frame) { // } // Create loop. - loophead := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loophead") - loop := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loop") - unreachable := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.default") - end := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.end") + loophead := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loophead") + loop := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "rundefers.loop") + unreachable := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "rundefers.default") + end := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "rundefers.end") c.builder.CreateBr(loophead) // Create loop head: @@ -192,7 +192,7 @@ func (c *Compiler) emitRunDefers(frame *Frame) { // Create switch case, for example: // case 0: // // run first deferred call - block := llvm.AddBasicBlock(frame.fn.LLVMFn, "rundefers.callback") + block := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "rundefers.callback") sw.AddCase(llvm.ConstInt(c.uintptrType, uint64(i), false), block) c.builder.SetInsertPointAtEnd(block) switch callback := callback.(type) { diff --git a/compiler/func-lowering.go b/compiler/func-lowering.go index 9abb2963..f1af3ab3 100644 --- a/compiler/func-lowering.go +++ b/compiler/func-lowering.go @@ -234,7 +234,7 @@ func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall fun // The block that cannot be reached with correct funcValues (to help the // optimizer). c.builder.SetInsertPointBefore(call) - defaultBlock := llvm.AddBasicBlock(call.InstructionParent().Parent(), "func.default") + defaultBlock := c.ctx.AddBasicBlock(call.InstructionParent().Parent(), "func.default") c.builder.SetInsertPointAtEnd(defaultBlock) c.builder.CreateUnreachable() @@ -247,7 +247,7 @@ func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall fun nextBlock := c.splitBasicBlock(sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") // The 0 case, which is actually a nil check. - nilBlock := llvm.InsertBasicBlock(nextBlock, "func.nil") + nilBlock := c.ctx.InsertBasicBlock(nextBlock, "func.nil") c.builder.SetInsertPointAtEnd(nilBlock) c.createRuntimeCall("nilPanic", nil, "") c.builder.CreateUnreachable() @@ -265,7 +265,7 @@ func (c *Compiler) addFuncLoweringSwitch(funcID, call llvm.Value, createCall fun phiValues := make([]llvm.Value, len(functions)) for i, fn := range functions { // Insert a switch case. - bb := llvm.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) + bb := c.ctx.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) c.builder.SetInsertPointAtEnd(bb) result := createCall(fn.funcPtr, callParams) c.builder.CreateBr(nextBlock) diff --git a/compiler/gc.go b/compiler/gc.go index 87fed6b1..763902ff 100644 --- a/compiler/gc.go +++ b/compiler/gc.go @@ -398,10 +398,10 @@ func (c *Compiler) addGlobalsBitmap() bool { for i, b := range bitmapBytes { bitmapValues[len(bitmapBytes)-i-1] = llvm.ConstInt(c.ctx.Int8Type(), uint64(b), false) } - bitmapArray := llvm.ConstArray(llvm.ArrayType(c.ctx.Int8Type(), len(bitmapBytes)), bitmapValues) + bitmapArray := llvm.ConstArray(c.ctx.Int8Type(), bitmapValues) bitmapNew := llvm.AddGlobal(c.mod, bitmapArray.Type(), "runtime.trackedGlobalsBitmap.tmp") bitmapOld := c.mod.NamedGlobal("runtime.trackedGlobalsBitmap") - bitmapOld.ReplaceAllUsesWith(bitmapNew) + bitmapOld.ReplaceAllUsesWith(llvm.ConstBitCast(bitmapNew, bitmapOld.Type())) bitmapNew.SetInitializer(bitmapArray) bitmapNew.SetName("runtime.trackedGlobalsBitmap") diff --git a/compiler/goroutine.go b/compiler/goroutine.go index 9c8376c5..59c5a383 100644 --- a/compiler/goroutine.go +++ b/compiler/goroutine.go @@ -57,7 +57,7 @@ func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { name := fn.Name() wrapper = c.mod.NamedFunction(name + "$gowrapper") if !wrapper.IsNil() { - return c.builder.CreateIntToPtr(wrapper, c.uintptrType, "") + return c.builder.CreatePtrToInt(wrapper, c.uintptrType, "") } // Save the current position in the IR builder. @@ -69,7 +69,7 @@ func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { wrapper = llvm.AddFunction(c.mod, name+"$gowrapper", wrapperType) wrapper.SetLinkage(llvm.PrivateLinkage) wrapper.SetUnnamedAddr(true) - entry := llvm.AddBasicBlock(wrapper, "entry") + entry := c.ctx.AddBasicBlock(wrapper, "entry") c.builder.SetInsertPointAtEnd(entry) // Create the list of params for the call. @@ -107,7 +107,7 @@ func (c *Compiler) createGoroutineStartWrapper(fn llvm.Value) llvm.Value { wrapper = llvm.AddFunction(c.mod, ".gowrapper", wrapperType) wrapper.SetLinkage(llvm.InternalLinkage) wrapper.SetUnnamedAddr(true) - entry := llvm.AddBasicBlock(wrapper, "entry") + entry := c.ctx.AddBasicBlock(wrapper, "entry") c.builder.SetInsertPointAtEnd(entry) // Get the list of parameters, with the extra parameters at the end. diff --git a/compiler/interface-lowering.go b/compiler/interface-lowering.go index 6a756410..b8e57278 100644 --- a/compiler/interface-lowering.go +++ b/compiler/interface-lowering.go @@ -603,9 +603,9 @@ func (p *lowerInterfacesPass) createInterfaceImplementsFunc(itf *interfaceInfo) // TODO: debug info // Create all used basic blocks. - entry := llvm.AddBasicBlock(fn, "entry") - thenBlock := llvm.AddBasicBlock(fn, "then") - elseBlock := llvm.AddBasicBlock(fn, "else") + entry := p.ctx.AddBasicBlock(fn, "entry") + thenBlock := p.ctx.AddBasicBlock(fn, "then") + elseBlock := p.ctx.AddBasicBlock(fn, "else") // Add all possible types as cases. p.builder.SetInsertPointAtEnd(entry) @@ -661,11 +661,11 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign // TODO: debug info // Create entry block. - entry := llvm.AddBasicBlock(fn, "entry") + entry := p.ctx.AddBasicBlock(fn, "entry") // Create default block and make it unreachable (which it is, because all // possible types are checked). - defaultBlock := llvm.AddBasicBlock(fn, "default") + defaultBlock := p.ctx.AddBasicBlock(fn, "default") p.builder.SetInsertPointAtEnd(defaultBlock) p.builder.CreateUnreachable() @@ -684,7 +684,7 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign // Define all possible functions that can be called. for _, typ := range itf.types { - bb := llvm.AddBasicBlock(fn, typ.name) + bb := p.ctx.AddBasicBlock(fn, typ.name) sw.AddCase(llvm.ConstInt(p.uintptrType, typ.num, false), bb) // The function we will redirect to when the interface has this type. diff --git a/compiler/interface.go b/compiler/interface.go index 672f7a17..c7a257ef 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -110,12 +110,12 @@ func (c *Compiler) makeStructTypeFields(typ *types.Struct) llvm.Value { fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, fieldName, []uint32{1}) if typ.Tag(i) != "" { fieldTag := c.makeGlobalArray([]byte(typ.Tag(i)), "reflect/types.structFieldTag", c.ctx.Int8Type()) + fieldTag.SetLinkage(llvm.PrivateLinkage) + fieldTag.SetUnnamedAddr(true) fieldTag = llvm.ConstGEP(fieldTag, []llvm.Value{ llvm.ConstInt(llvm.Int32Type(), 0, false), llvm.ConstInt(llvm.Int32Type(), 0, false), }) - fieldTag.SetLinkage(llvm.PrivateLinkage) - fieldTag.SetUnnamedAddr(true) fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, fieldTag, []uint32{2}) } if typ.Field(i).Embedded() { diff --git a/compiler/optimizer.go b/compiler/optimizer.go index 5d3f2ecd..669e089c 100644 --- a/compiler/optimizer.go +++ b/compiler/optimizer.go @@ -23,6 +23,11 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro c.replacePanicsWithTrap() // -panic=trap } + // run a check of all of our code + if c.VerifyIR { + c.checkModule() + } + // Run function passes for each function. funcPasses := llvm.NewFunctionPassManagerForModule(c.mod) defer funcPasses.Dispose() diff --git a/compiler/syscall.go b/compiler/syscall.go index 293480e6..9b99fcb6 100644 --- a/compiler/syscall.go +++ b/compiler/syscall.go @@ -165,7 +165,7 @@ func (c *Compiler) emitSyscall(frame *Frame, call *ssa.CallCommon) (llvm.Value, inrange2 := c.builder.CreateICmp(llvm.IntSGT, syscallResult, llvm.ConstInt(c.uintptrType, 0xfffffffffffff000, true), "") // -4096 hasError := c.builder.CreateAnd(inrange1, inrange2, "") errResult := c.builder.CreateSelect(hasError, c.builder.CreateSub(zero, syscallResult, ""), zero, "syscallError") - retval := llvm.Undef(llvm.StructType([]llvm.Type{c.uintptrType, c.uintptrType, c.uintptrType}, false)) + retval := llvm.Undef(c.ctx.StructType([]llvm.Type{c.uintptrType, c.uintptrType, c.uintptrType}, false)) retval = c.builder.CreateInsertValue(retval, syscallResult, 0, "") retval = c.builder.CreateInsertValue(retval, zero, 1, "") retval = c.builder.CreateInsertValue(retval, errResult, 2, "") @@ -181,7 +181,7 @@ func (c *Compiler) emitSyscall(frame *Frame, call *ssa.CallCommon) (llvm.Value, zero := llvm.ConstInt(c.uintptrType, 0, false) hasError := c.builder.CreateICmp(llvm.IntNE, syscallResult, llvm.ConstInt(c.uintptrType, 0, false), "") errResult := c.builder.CreateSelect(hasError, syscallResult, zero, "syscallError") - retval := llvm.Undef(llvm.StructType([]llvm.Type{c.uintptrType, c.uintptrType, c.uintptrType}, false)) + retval := llvm.Undef(c.ctx.StructType([]llvm.Type{c.uintptrType, c.uintptrType, c.uintptrType}, false)) retval = c.builder.CreateInsertValue(retval, syscallResult, 0, "") retval = c.builder.CreateInsertValue(retval, zero, 1, "") retval = c.builder.CreateInsertValue(retval, errResult, 2, "") diff --git a/main.go b/main.go index 6d29328b..64dfcbaf 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,7 @@ type BuildConfig struct { scheduler string printIR bool dumpSSA bool + verifyIR bool debug bool printSizes string cFlags []string @@ -116,6 +117,7 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act ClangHeaders: getClangHeaderPath(root), Debug: config.debug, DumpSSA: config.dumpSSA, + VerifyIR: config.verifyIR, TINYGOROOT: root, GOROOT: goroot, GOPATH: getGopath(), @@ -621,6 +623,7 @@ func main() { scheduler := flag.String("scheduler", "", "which scheduler to use (coroutines, tasks)") printIR := flag.Bool("printir", false, "print LLVM IR") dumpSSA := flag.Bool("dumpssa", false, "dump internal Go SSA") + verifyIR := flag.Bool("verifyir", false, "run extra verification steps on LLVM IR") tags := flag.String("tags", "", "a space-separated list of extra build tags") target := flag.String("target", "", "LLVM target | .json file with TargetSpec") printSize := flag.String("size", "", "print sizes (none, short, full)") @@ -647,6 +650,7 @@ func main() { scheduler: *scheduler, printIR: *printIR, dumpSSA: *dumpSSA, + verifyIR: *verifyIR, debug: !*nodebug, printSizes: *printSize, tags: *tags, diff --git a/main_test.go b/main_test.go index fa85a82c..9b406d2b 100644 --- a/main_test.go +++ b/main_test.go @@ -116,6 +116,7 @@ func runTest(path, tmpdir string, target string, t *testing.T) { opt: "z", printIR: false, dumpSSA: false, + verifyIR: true, debug: false, printSizes: "", wasmAbi: "js",