diff --git a/compiler/check.go b/compiler/check.go index 356d0296..97bb4459 100644 --- a/compiler/check.go +++ b/compiler/check.go @@ -10,14 +10,10 @@ import ( "tinygo.org/x/go-llvm" ) -func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { - if t.IsNil() { - panic(t) - } - +func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) error { // prevent infinite recursion for self-referential types if _, ok := checked[t]; ok { - return + return nil } checked[t] = struct{}{} @@ -27,10 +23,10 @@ func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specia // this is correct case t.Context() == llvm.GlobalContext(): // somewhere we accidentally used the global context instead of a real context - panic(fmt.Errorf("type %q uses global context", t.String())) + return fmt.Errorf("type %q uses global context", t.String()) default: // we used some other context by accident - panic(fmt.Errorf("type %q uses context %v instead of the main context %v", t.Context(), c.ctx)) + return fmt.Errorf("type %q uses context %v instead of the main context %v", t.Context(), c.ctx) } // if this is a composite type, check the components of the type @@ -40,7 +36,7 @@ func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specia if s, ok := specials[t.TypeKind()]; !ok { specials[t.TypeKind()] = t } else if s != t { - panic(fmt.Errorf("duplicate special type %q: %v and %v", t.TypeKind().String(), t, s)) + return fmt.Errorf("duplicate special type %q: %v and %v", t.TypeKind().String(), t, s) } case llvm.FloatTypeKind, llvm.DoubleTypeKind, llvm.X86_FP80TypeKind, llvm.FP128TypeKind, llvm.PPC_FP128TypeKind: // floating point numbers are primitives - nothing to recurse @@ -48,81 +44,130 @@ func (c *Compiler) checkType(t llvm.Type, checked map[llvm.Type]struct{}, specia // integers are primitives - nothing to recurse case llvm.FunctionTypeKind: // check arguments and return(s) - for _, v := range t.ParamTypes() { - c.checkType(v, checked, specials) + for i, v := range t.ParamTypes() { + if err := c.checkType(v, checked, specials); err != nil { + return fmt.Errorf("failed to verify argument %d of type %s: %s", i, t.String(), err.Error()) + } + } + if err := c.checkType(t.ReturnType(), checked, specials); err != nil { + return fmt.Errorf("failed to verify return type of type %s: %s", t.String(), err.Error()) } - c.checkType(t.ReturnType(), checked, specials) case llvm.StructTypeKind: // check all elements - for _, v := range t.StructElementTypes() { - c.checkType(v, checked, specials) + for i, v := range t.StructElementTypes() { + if err := c.checkType(v, checked, specials); err != nil { + return fmt.Errorf("failed to verify type of field %d of struct type %s: %s", i, t.String(), err.Error()) + } } case llvm.ArrayTypeKind: // check element type - c.checkType(t.ElementType(), checked, specials) + if err := c.checkType(t.ElementType(), checked, specials); err != nil { + return fmt.Errorf("failed to verify element type of array type %s: %s", t.String(), err.Error()) + } case llvm.PointerTypeKind: // check underlying type - c.checkType(t.ElementType(), checked, specials) + if err := c.checkType(t.ElementType(), checked, specials); err != nil { + return fmt.Errorf("failed to verify underlying type of pointer type %s: %s", t.String(), err.Error()) + } case llvm.VectorTypeKind: // check element type - c.checkType(t.ElementType(), checked, specials) + if err := c.checkType(t.ElementType(), checked, specials); err != nil { + return fmt.Errorf("failed to verify element type of vector type %s: %s", t.String(), err.Error()) + } + default: + return fmt.Errorf("unrecognized kind %q of type %s", t.TypeKind(), t.String()) } + + return nil } -func (c *Compiler) checkValue(v llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { +func (c *Compiler) checkValue(v llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) error { // check type - c.checkType(v.Type(), types, specials) + if err := c.checkType(v.Type(), types, specials); err != nil { + return fmt.Errorf("failed to verify type of value: %s", err.Error()) + } + + // check if this is an undefined void + if v.IsUndef() && v.Type().TypeKind() == llvm.VoidTypeKind { + return errors.New("encountered undefined void value") + } + + return nil } -func (c *Compiler) checkInstruction(inst llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { +func (c *Compiler) checkInstruction(inst llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) error { // check value properties - c.checkValue(inst, types, specials) + if err := c.checkValue(inst, types, specials); err != nil { + return fmt.Errorf("failed to validate value of instruction %q: %s", inst.Name(), err.Error()) + } // check operands for i := 0; i < inst.OperandsCount(); i++ { - c.checkValue(inst.Operand(i), types, specials) + if err := c.checkValue(inst.Operand(i), types, specials); err != nil { + return fmt.Errorf("failed to validate argument %d of instruction %q: %s", i, inst.Name(), err.Error()) + } } + + return nil } -func (c *Compiler) checkBasicBlock(bb llvm.BasicBlock, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { +func (c *Compiler) checkBasicBlock(bb llvm.BasicBlock, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) error { // check basic block value and type - c.checkValue(bb.AsValue(), types, specials) + if err := c.checkValue(bb.AsValue(), types, specials); err != nil { + return fmt.Errorf("failed to validate value of basic block %s: %s", bb.AsValue().Name(), err.Error()) + } // check instructions for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) { - c.checkInstruction(inst, types, specials) + if err := c.checkInstruction(inst, types, specials); err != nil { + return fmt.Errorf("failed to validate basic block %q: %s", bb.AsValue().Name(), err.Error()) + } } + + return nil } -func (c *Compiler) checkFunction(fn llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) { +func (c *Compiler) checkFunction(fn llvm.Value, types map[llvm.Type]struct{}, specials map[llvm.TypeKind]llvm.Type) error { // check function value and type - c.checkValue(fn, types, specials) + if err := c.checkValue(fn, types, specials); err != nil { + return fmt.Errorf("failed to validate value of function %s: %s", fn.Name(), err.Error()) + } // check basic blocks for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) { - c.checkBasicBlock(bb, types, specials) + if err := c.checkBasicBlock(bb, types, specials); err != nil { + return fmt.Errorf("failed to validate basic block of function %s: %s", fn.Name(), err.Error()) + } } + + return nil } -func (c *Compiler) checkModule() { +func (c *Compiler) checkModule() error { // check for any context mismatches switch { case c.mod.Context() == c.ctx: // this is correct case c.mod.Context() == llvm.GlobalContext(): // somewhere we accidentally used the global context instead of a real context - panic(errors.New("module uses global context")) + return errors.New("module uses global context") default: // we used some other context by accident - panic(fmt.Errorf("module uses context %v instead of the main context %v", c.mod.Context(), c.ctx)) + return fmt.Errorf("module uses context %v instead of the main context %v", c.mod.Context(), c.ctx) } types := map[llvm.Type]struct{}{} specials := map[llvm.TypeKind]llvm.Type{} for fn := c.mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) { - c.checkFunction(fn, types, specials) + if err := c.checkFunction(fn, types, specials); err != nil { + return err + } } for g := c.mod.FirstGlobal(); !g.IsNil(); g = llvm.NextGlobal(g) { - c.checkValue(g, types, specials) + if err := c.checkValue(g, types, specials); err != nil { + return fmt.Errorf("failed to verify global %s of module: %s", g.Name(), err.Error()) + } } + + return nil } diff --git a/compiler/optimizer.go b/compiler/optimizer.go index dee33447..480d3283 100644 --- a/compiler/optimizer.go +++ b/compiler/optimizer.go @@ -25,7 +25,10 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro // run a check of all of our code if c.VerifyIR { - c.checkModule() + err := c.checkModule() + if err != nil { + return err + } } // Run function passes for each function. @@ -93,6 +96,11 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) erro return err } } + if c.VerifyIR { + if err := c.checkModule(); err != nil { + return err + } + } if err := c.Verify(); err != nil { return errors.New("optimizations caused a verification failure") }