diff --git a/compiler/compiler.go b/compiler/compiler.go index 6cae54a5..8a699224 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -23,7 +23,7 @@ import ( // Version of the compiler pacakge. Must be incremented each time the compiler // package changes in a way that affects the generated LLVM module. // This version is independent of the TinyGo version number. -const Version = 3 // last change: remove runtime.typeInInterface +const Version = 4 // last change: runtime.typeAssert signature func init() { llvm.InitializeAllTargets() diff --git a/compiler/interface.go b/compiler/interface.go index b6ca8492..535e7650 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -338,10 +338,16 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value { commaOk = b.createRuntimeCall("interfaceImplements", []llvm.Value{actualTypeNum, methodSet}, "") } else { + globalName := "reflect/types.type:" + getTypeCodeName(expr.AssertedType) + "$id" + assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName) + if assertedTypeCodeGlobal.IsNil() { + // Create a new typecode global. + assertedTypeCodeGlobal = llvm.AddGlobal(b.mod, b.ctx.Int8Type(), globalName) + assertedTypeCodeGlobal.SetGlobalConstant(true) + } // Type assert on concrete type. // Call runtime.typeAssert, which will be lowered to a simple icmp or // const false in the interface lowering pass. - assertedTypeCodeGlobal := b.getTypeCode(expr.AssertedType) commaOk = b.createRuntimeCall("typeAssert", []llvm.Value{actualTypeNum, assertedTypeCodeGlobal}, "typecode") } diff --git a/interp/interpreter.go b/interp/interpreter.go index 44e5a900..afe0f246 100644 --- a/interp/interpreter.go +++ b/interp/interpreter.go @@ -286,16 +286,10 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent if r.debug { fmt.Fprintln(os.Stderr, indent+"typeassert:", operands[1:]) } - actualType, err := operands[1].asPointer(r) - if err != nil { - return nil, mem, r.errorAt(inst, err) - } - assertedType, err := operands[2].asPointer(r) - if err != nil { - return nil, mem, r.errorAt(inst, err) - } - result := assertedType.asRawValue(r).equal(actualType.asRawValue(r)) - if result { + assertedType := operands[2].toLLVMValue(inst.llvmInst.Operand(1).Type(), &mem) + actualTypePtrToInt := operands[1].toLLVMValue(inst.llvmInst.Operand(0).Type(), &mem) + actualType := actualTypePtrToInt.Operand(0) + if actualType.Name()+"$id" == assertedType.Name() { locals[inst.localIndex] = literalValue{uint8(1)} } else { locals[inst.localIndex] = literalValue{uint8(0)} diff --git a/interp/testdata/interface.ll b/interp/testdata/interface.ll index 5b7798ba..8031632f 100644 --- a/interp/testdata/interface.ll +++ b/interp/testdata/interface.ll @@ -6,10 +6,11 @@ target triple = "x86_64--linux" @main.v1 = global i1 0 @"reflect/types.type:named:main.foo" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i64 0, %runtime.interfaceMethodInfo* null } +@"reflect/types.type:named:main.foo$id" = external constant i8 @"reflect/types.type:basic:int" = external constant %runtime.typecodeID -declare i1 @runtime.typeAssert(i64, %runtime.typecodeID*, i8*, i8*) +declare i1 @runtime.typeAssert(i64, i8*, i8*, i8*) define void @runtime.initAll() unnamed_addr { entry: @@ -20,7 +21,7 @@ entry: define internal void @main.init() unnamed_addr { entry: ; Test type asserts. - %typecode = call i1 @runtime.typeAssert(i64 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:main.foo" to i64), %runtime.typecodeID* @"reflect/types.type:named:main.foo", i8* undef, i8* null) + %typecode = call i1 @runtime.typeAssert(i64 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:main.foo" to i64), i8* @"reflect/types.type:named:main.foo$id", i8* undef, i8* null) store i1 %typecode, i1* @main.v1 ret void } diff --git a/src/runtime/interface.go b/src/runtime/interface.go index 0d38389b..c39d7769 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -124,7 +124,7 @@ type structField struct { // lowering, to assign the lowest type numbers to the types with the most type // asserts. Also, it is replaced with const false if this type assert can never // happen. -func typeAssert(actualType uintptr, assertedType *typecodeID) bool +func typeAssert(actualType uintptr, assertedType *uint8) bool // Pseudo function call that returns whether a given type implements all methods // of the given interface. diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go index 34aa6915..ee384f83 100644 --- a/transform/interface-lowering.go +++ b/transform/interface-lowering.go @@ -194,8 +194,11 @@ func (p *lowerInterfacesPass) run() error { typeAssertUses := getUses(typeAssert) for _, use := range typeAssertUses { typecode := use.Operand(1) - name := typecode.Name() - p.types[name].countTypeAsserts++ + name := typecode.Name() // name with $id suffix + name = name[:len(name)-len("$id")] // remove $id suffix + if t, ok := p.types[name]; ok { + t.countTypeAsserts++ + } } // Find all interface method calls. @@ -371,13 +374,27 @@ func (p *lowerInterfacesPass) run() error { } } - // Replace each type assert with an actual type comparison. + // Replace each type assert with an actual type comparison or (if the type + // assert is impossible) the constant false. + llvmFalse := llvm.ConstInt(p.ctx.Int1Type(), 0, false) for _, use := range typeAssertUses { actualType := use.Operand(0) - assertedTypeGlobal := use.Operand(1) - p.builder.SetInsertPointBefore(use) - commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(assertedTypeGlobal, p.uintptrType), actualType, "typeassert.ok") - use.ReplaceAllUsesWith(commaOk) + name := use.Operand(1).Name() // name with $id suffix + name = name[:len(name)-len("$id")] // remove $id suffix + if t, ok := p.types[name]; ok { + // The type exists in the program, so lower to a regular integer + // comparison. + p.builder.SetInsertPointBefore(use) + commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(t.typecode, p.uintptrType), actualType, "typeassert.ok") + use.ReplaceAllUsesWith(commaOk) + } else { + // The type does not exist in the program, so lower to a constant + // false. This is trivially further optimized. + // TODO: eventually it'll be necessary to handle reflect.PtrTo and + // reflect.New calls which create new types not present in the + // original program. + use.ReplaceAllUsesWith(llvmFalse) + } use.EraseFromParentAsInstruction() } diff --git a/transform/testdata/interface.ll b/transform/testdata/interface.ll index c58e8a23..487ecbb2 100644 --- a/transform/testdata/interface.ll +++ b/transform/testdata/interface.ll @@ -5,6 +5,8 @@ target triple = "armv7m-none-eabi" %runtime.interfaceMethodInfo = type { i8*, i32 } @"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID +@"reflect/types.type:basic:uint8$id" = external constant i8 +@"reflect/types.type:basic:int16$id" = external constant i8 @"reflect/types.type:basic:int" = external constant %runtime.typecodeID @"func NeverImplementedMethod()" = external constant i8 @"Unmatched$interface" = private constant [1 x i8*] [i8* @"func NeverImplementedMethod()"] @@ -14,9 +16,10 @@ target triple = "armv7m-none-eabi" @"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([1 x %runtime.interfaceMethodInfo], [1 x %runtime.interfaceMethodInfo]* @"Number$methodset", i32 0, i32 0) } declare i1 @runtime.interfaceImplements(i32, i8**) -declare i1 @runtime.typeAssert(i32, %runtime.typecodeID*) +declare i1 @runtime.typeAssert(i32, i8*) declare i32 @runtime.interfaceMethod(i32, i8**, i8*) declare void @runtime.printuint8(i8) +declare void @runtime.printint16(i16) declare void @runtime.printint32(i32) declare void @runtime.printptr(i32) declare void @runtime.printnl() @@ -52,7 +55,7 @@ typeswitch.Doubler: ret void typeswitch.notDoubler: - %isByte = call i1 @runtime.typeAssert(i32 %typecode, %runtime.typecodeID* nonnull @"reflect/types.type:basic:uint8") + %isByte = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.type:basic:uint8$id") br i1 %isByte, label %typeswitch.byte, label %typeswitch.notByte typeswitch.byte: @@ -62,6 +65,17 @@ typeswitch.byte: ret void typeswitch.notByte: + ; this is a type assert that always fails + %isInt16 = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.type:basic:int16$id") + br i1 %isInt16, label %typeswitch.int16, label %typeswitch.notInt16 + +typeswitch.int16: + %int16 = ptrtoint i8* %value to i16 + call void @runtime.printint16(i16 %int16) + call void @runtime.printnl() + ret void + +typeswitch.notInt16: ret void } diff --git a/transform/testdata/interface.out.ll b/transform/testdata/interface.out.ll index 12a4c6a4..faff1254 100644 --- a/transform/testdata/interface.out.ll +++ b/transform/testdata/interface.out.ll @@ -5,6 +5,8 @@ target triple = "armv7m-none-eabi" %runtime.interfaceMethodInfo = type { i8*, i32 } @"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID +@"reflect/types.type:basic:uint8$id" = external constant i8 +@"reflect/types.type:basic:int16$id" = external constant i8 @"reflect/types.type:basic:int" = external constant %runtime.typecodeID @"func NeverImplementedMethod()" = external constant i8 @"func Double() int" = external constant i8 @@ -12,12 +14,14 @@ target triple = "armv7m-none-eabi" declare i1 @runtime.interfaceImplements(i32, i8**) -declare i1 @runtime.typeAssert(i32, %runtime.typecodeID*) +declare i1 @runtime.typeAssert(i32, i8*) declare i32 @runtime.interfaceMethod(i32, i8**, i8*) declare void @runtime.printuint8(i8) +declare void @runtime.printint16(i16) + declare void @runtime.printint32(i32) declare void @runtime.printptr(i32) @@ -63,6 +67,15 @@ typeswitch.byte: ; preds = %typeswitch.notDoubl ret void typeswitch.notByte: ; preds = %typeswitch.notDoubler + br i1 false, label %typeswitch.int16, label %typeswitch.notInt16 + +typeswitch.int16: ; preds = %typeswitch.notByte + %int16 = ptrtoint i8* %value to i16 + call void @runtime.printint16(i16 %int16) + call void @runtime.printnl() + ret void + +typeswitch.notInt16: ; preds = %typeswitch.notByte ret void }