compiler: do not check for impossible type asserts

Previously there was code to avoid impossible type asserts but it wasn't
great and in fact was too aggressive when combined with reflection.

This commit improves this by checking all types that exist in the
program that may appear in an interface (even struct fields and the
like) but without creating runtime.typecodeID objects with the type
assert. This has two advantages:

  * As mentioned, it optimizes impossible type asserts away.
  * It allows methods on types that were only asserted on (in
    runtime.typeAssert) but never used in an interface to be optimized
    away using GlobalDCE. This may have a cascading effect so that other
    parts of the code can be further optimized.

This sometimes massively improves code size and mostly negates the code
size regression of the previous commit.
Этот коммит содержится в:
Ayke van Laethem 2021-03-18 02:56:03 +01:00 коммит произвёл Ron Evans
родитель bbb2909283
коммит 19dec048b0
8 изменённых файлов: 70 добавлений и 25 удалений

Просмотреть файл

@ -23,7 +23,7 @@ import (
// Version of the compiler pacakge. Must be incremented each time the compiler // Version of the compiler pacakge. Must be incremented each time the compiler
// package changes in a way that affects the generated LLVM module. // package changes in a way that affects the generated LLVM module.
// This version is independent of the TinyGo version number. // This version is independent of the TinyGo version number.
const Version = 3 // last change: remove runtime.typeInInterface const Version = 4 // last change: runtime.typeAssert signature
func init() { func init() {
llvm.InitializeAllTargets() llvm.InitializeAllTargets()

Просмотреть файл

@ -338,10 +338,16 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
commaOk = b.createRuntimeCall("interfaceImplements", []llvm.Value{actualTypeNum, methodSet}, "") commaOk = b.createRuntimeCall("interfaceImplements", []llvm.Value{actualTypeNum, methodSet}, "")
} else { } else {
globalName := "reflect/types.type:" + getTypeCodeName(expr.AssertedType) + "$id"
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
if assertedTypeCodeGlobal.IsNil() {
// Create a new typecode global.
assertedTypeCodeGlobal = llvm.AddGlobal(b.mod, b.ctx.Int8Type(), globalName)
assertedTypeCodeGlobal.SetGlobalConstant(true)
}
// Type assert on concrete type. // Type assert on concrete type.
// Call runtime.typeAssert, which will be lowered to a simple icmp or // Call runtime.typeAssert, which will be lowered to a simple icmp or
// const false in the interface lowering pass. // const false in the interface lowering pass.
assertedTypeCodeGlobal := b.getTypeCode(expr.AssertedType)
commaOk = b.createRuntimeCall("typeAssert", []llvm.Value{actualTypeNum, assertedTypeCodeGlobal}, "typecode") commaOk = b.createRuntimeCall("typeAssert", []llvm.Value{actualTypeNum, assertedTypeCodeGlobal}, "typecode")
} }

Просмотреть файл

@ -286,16 +286,10 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent
if r.debug { if r.debug {
fmt.Fprintln(os.Stderr, indent+"typeassert:", operands[1:]) fmt.Fprintln(os.Stderr, indent+"typeassert:", operands[1:])
} }
actualType, err := operands[1].asPointer(r) assertedType := operands[2].toLLVMValue(inst.llvmInst.Operand(1).Type(), &mem)
if err != nil { actualTypePtrToInt := operands[1].toLLVMValue(inst.llvmInst.Operand(0).Type(), &mem)
return nil, mem, r.errorAt(inst, err) actualType := actualTypePtrToInt.Operand(0)
} if actualType.Name()+"$id" == assertedType.Name() {
assertedType, err := operands[2].asPointer(r)
if err != nil {
return nil, mem, r.errorAt(inst, err)
}
result := assertedType.asRawValue(r).equal(actualType.asRawValue(r))
if result {
locals[inst.localIndex] = literalValue{uint8(1)} locals[inst.localIndex] = literalValue{uint8(1)}
} else { } else {
locals[inst.localIndex] = literalValue{uint8(0)} locals[inst.localIndex] = literalValue{uint8(0)}

5
interp/testdata/interface.ll предоставленный
Просмотреть файл

@ -6,10 +6,11 @@ target triple = "x86_64--linux"
@main.v1 = global i1 0 @main.v1 = global i1 0
@"reflect/types.type:named:main.foo" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i64 0, %runtime.interfaceMethodInfo* null } @"reflect/types.type:named:main.foo" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i64 0, %runtime.interfaceMethodInfo* null }
@"reflect/types.type:named:main.foo$id" = external constant i8
@"reflect/types.type:basic:int" = external constant %runtime.typecodeID @"reflect/types.type:basic:int" = external constant %runtime.typecodeID
declare i1 @runtime.typeAssert(i64, %runtime.typecodeID*, i8*, i8*) declare i1 @runtime.typeAssert(i64, i8*, i8*, i8*)
define void @runtime.initAll() unnamed_addr { define void @runtime.initAll() unnamed_addr {
entry: entry:
@ -20,7 +21,7 @@ entry:
define internal void @main.init() unnamed_addr { define internal void @main.init() unnamed_addr {
entry: entry:
; Test type asserts. ; Test type asserts.
%typecode = call i1 @runtime.typeAssert(i64 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:main.foo" to i64), %runtime.typecodeID* @"reflect/types.type:named:main.foo", i8* undef, i8* null) %typecode = call i1 @runtime.typeAssert(i64 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:main.foo" to i64), i8* @"reflect/types.type:named:main.foo$id", i8* undef, i8* null)
store i1 %typecode, i1* @main.v1 store i1 %typecode, i1* @main.v1
ret void ret void
} }

Просмотреть файл

@ -124,7 +124,7 @@ type structField struct {
// lowering, to assign the lowest type numbers to the types with the most type // lowering, to assign the lowest type numbers to the types with the most type
// asserts. Also, it is replaced with const false if this type assert can never // asserts. Also, it is replaced with const false if this type assert can never
// happen. // happen.
func typeAssert(actualType uintptr, assertedType *typecodeID) bool func typeAssert(actualType uintptr, assertedType *uint8) bool
// Pseudo function call that returns whether a given type implements all methods // Pseudo function call that returns whether a given type implements all methods
// of the given interface. // of the given interface.

Просмотреть файл

@ -194,8 +194,11 @@ func (p *lowerInterfacesPass) run() error {
typeAssertUses := getUses(typeAssert) typeAssertUses := getUses(typeAssert)
for _, use := range typeAssertUses { for _, use := range typeAssertUses {
typecode := use.Operand(1) typecode := use.Operand(1)
name := typecode.Name() name := typecode.Name() // name with $id suffix
p.types[name].countTypeAsserts++ name = name[:len(name)-len("$id")] // remove $id suffix
if t, ok := p.types[name]; ok {
t.countTypeAsserts++
}
} }
// Find all interface method calls. // Find all interface method calls.
@ -371,13 +374,27 @@ func (p *lowerInterfacesPass) run() error {
} }
} }
// Replace each type assert with an actual type comparison. // Replace each type assert with an actual type comparison or (if the type
// assert is impossible) the constant false.
llvmFalse := llvm.ConstInt(p.ctx.Int1Type(), 0, false)
for _, use := range typeAssertUses { for _, use := range typeAssertUses {
actualType := use.Operand(0) actualType := use.Operand(0)
assertedTypeGlobal := use.Operand(1) name := use.Operand(1).Name() // name with $id suffix
p.builder.SetInsertPointBefore(use) name = name[:len(name)-len("$id")] // remove $id suffix
commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(assertedTypeGlobal, p.uintptrType), actualType, "typeassert.ok") if t, ok := p.types[name]; ok {
use.ReplaceAllUsesWith(commaOk) // The type exists in the program, so lower to a regular integer
// comparison.
p.builder.SetInsertPointBefore(use)
commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(t.typecode, p.uintptrType), actualType, "typeassert.ok")
use.ReplaceAllUsesWith(commaOk)
} else {
// The type does not exist in the program, so lower to a constant
// false. This is trivially further optimized.
// TODO: eventually it'll be necessary to handle reflect.PtrTo and
// reflect.New calls which create new types not present in the
// original program.
use.ReplaceAllUsesWith(llvmFalse)
}
use.EraseFromParentAsInstruction() use.EraseFromParentAsInstruction()
} }

18
transform/testdata/interface.ll предоставленный
Просмотреть файл

@ -5,6 +5,8 @@ target triple = "armv7m-none-eabi"
%runtime.interfaceMethodInfo = type { i8*, i32 } %runtime.interfaceMethodInfo = type { i8*, i32 }
@"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID @"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID
@"reflect/types.type:basic:uint8$id" = external constant i8
@"reflect/types.type:basic:int16$id" = external constant i8
@"reflect/types.type:basic:int" = external constant %runtime.typecodeID @"reflect/types.type:basic:int" = external constant %runtime.typecodeID
@"func NeverImplementedMethod()" = external constant i8 @"func NeverImplementedMethod()" = external constant i8
@"Unmatched$interface" = private constant [1 x i8*] [i8* @"func NeverImplementedMethod()"] @"Unmatched$interface" = private constant [1 x i8*] [i8* @"func NeverImplementedMethod()"]
@ -14,9 +16,10 @@ target triple = "armv7m-none-eabi"
@"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([1 x %runtime.interfaceMethodInfo], [1 x %runtime.interfaceMethodInfo]* @"Number$methodset", i32 0, i32 0) } @"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([1 x %runtime.interfaceMethodInfo], [1 x %runtime.interfaceMethodInfo]* @"Number$methodset", i32 0, i32 0) }
declare i1 @runtime.interfaceImplements(i32, i8**) declare i1 @runtime.interfaceImplements(i32, i8**)
declare i1 @runtime.typeAssert(i32, %runtime.typecodeID*) declare i1 @runtime.typeAssert(i32, i8*)
declare i32 @runtime.interfaceMethod(i32, i8**, i8*) declare i32 @runtime.interfaceMethod(i32, i8**, i8*)
declare void @runtime.printuint8(i8) declare void @runtime.printuint8(i8)
declare void @runtime.printint16(i16)
declare void @runtime.printint32(i32) declare void @runtime.printint32(i32)
declare void @runtime.printptr(i32) declare void @runtime.printptr(i32)
declare void @runtime.printnl() declare void @runtime.printnl()
@ -52,7 +55,7 @@ typeswitch.Doubler:
ret void ret void
typeswitch.notDoubler: typeswitch.notDoubler:
%isByte = call i1 @runtime.typeAssert(i32 %typecode, %runtime.typecodeID* nonnull @"reflect/types.type:basic:uint8") %isByte = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.type:basic:uint8$id")
br i1 %isByte, label %typeswitch.byte, label %typeswitch.notByte br i1 %isByte, label %typeswitch.byte, label %typeswitch.notByte
typeswitch.byte: typeswitch.byte:
@ -62,6 +65,17 @@ typeswitch.byte:
ret void ret void
typeswitch.notByte: typeswitch.notByte:
; this is a type assert that always fails
%isInt16 = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.type:basic:int16$id")
br i1 %isInt16, label %typeswitch.int16, label %typeswitch.notInt16
typeswitch.int16:
%int16 = ptrtoint i8* %value to i16
call void @runtime.printint16(i16 %int16)
call void @runtime.printnl()
ret void
typeswitch.notInt16:
ret void ret void
} }

15
transform/testdata/interface.out.ll предоставленный
Просмотреть файл

@ -5,6 +5,8 @@ target triple = "armv7m-none-eabi"
%runtime.interfaceMethodInfo = type { i8*, i32 } %runtime.interfaceMethodInfo = type { i8*, i32 }
@"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID @"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID
@"reflect/types.type:basic:uint8$id" = external constant i8
@"reflect/types.type:basic:int16$id" = external constant i8
@"reflect/types.type:basic:int" = external constant %runtime.typecodeID @"reflect/types.type:basic:int" = external constant %runtime.typecodeID
@"func NeverImplementedMethod()" = external constant i8 @"func NeverImplementedMethod()" = external constant i8
@"func Double() int" = external constant i8 @"func Double() int" = external constant i8
@ -12,12 +14,14 @@ target triple = "armv7m-none-eabi"
declare i1 @runtime.interfaceImplements(i32, i8**) declare i1 @runtime.interfaceImplements(i32, i8**)
declare i1 @runtime.typeAssert(i32, %runtime.typecodeID*) declare i1 @runtime.typeAssert(i32, i8*)
declare i32 @runtime.interfaceMethod(i32, i8**, i8*) declare i32 @runtime.interfaceMethod(i32, i8**, i8*)
declare void @runtime.printuint8(i8) declare void @runtime.printuint8(i8)
declare void @runtime.printint16(i16)
declare void @runtime.printint32(i32) declare void @runtime.printint32(i32)
declare void @runtime.printptr(i32) declare void @runtime.printptr(i32)
@ -63,6 +67,15 @@ typeswitch.byte: ; preds = %typeswitch.notDoubl
ret void ret void
typeswitch.notByte: ; preds = %typeswitch.notDoubler typeswitch.notByte: ; preds = %typeswitch.notDoubler
br i1 false, label %typeswitch.int16, label %typeswitch.notInt16
typeswitch.int16: ; preds = %typeswitch.notByte
%int16 = ptrtoint i8* %value to i16
call void @runtime.printint16(i16 %int16)
call void @runtime.printnl()
ret void
typeswitch.notInt16: ; preds = %typeswitch.notByte
ret void ret void
} }