diff --git a/transform/func-lowering.go b/transform/func-lowering.go index ed385b5f..caef3842 100644 --- a/transform/func-lowering.go +++ b/transform/func-lowering.go @@ -6,6 +6,7 @@ package transform import ( "sort" "strconv" + "strings" "github.com/tinygo-org/tinygo/compiler/llvmutil" "tinygo.org/x/go-llvm" @@ -55,17 +56,30 @@ func LowerFuncValues(mod llvm.Module) { funcValueWithSignaturePtr := llvm.PointerType(mod.GetTypeByName("runtime.funcValueWithSignature"), 0) signatures := map[string]*funcSignatureInfo{} for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { - if global.Type() != funcValueWithSignaturePtr { + var sig, funcVal llvm.Value + switch { + case global.Type() == funcValueWithSignaturePtr: + sig = llvm.ConstExtractValue(global.Initializer(), []uint32{1}) + funcVal = global + case strings.HasPrefix(global.Name(), "reflect/types.type:func:{"): + sig = global + default: continue } - sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1}) + name := sig.Name() + var funcValueWithSignatures []llvm.Value + if funcVal.IsNil() { + funcValueWithSignatures = []llvm.Value{} + } else { + funcValueWithSignatures = []llvm.Value{funcVal} + } if info, ok := signatures[name]; ok { - info.funcValueWithSignatures = append(info.funcValueWithSignatures, global) + info.funcValueWithSignatures = append(info.funcValueWithSignatures, funcValueWithSignatures...) } else { signatures[name] = &funcSignatureInfo{ sig: sig, - funcValueWithSignatures: []llvm.Value{global}, + funcValueWithSignatures: funcValueWithSignatures, } } } @@ -123,95 +137,64 @@ func LowerFuncValues(mod llvm.Module) { panic("expected all call uses to be runtime.getFuncPtr") } funcID := getFuncPtrCall.Operand(1) - switch len(functions) { - case 0: - // There are no functions used in a func value that implement - // this signature. The only possible value is a nil value. - for _, inttoptr := range getUses(getFuncPtrCall) { - if inttoptr.IsAIntToPtrInst().IsNil() { - panic("expected inttoptr") - } - nilptr := llvm.ConstPointerNull(inttoptr.Type()) - inttoptr.ReplaceAllUsesWith(nilptr) - inttoptr.EraseFromParentAsInstruction() - } - getFuncPtrCall.EraseFromParentAsInstruction() - case 1: - // There is exactly one function with this signature that is - // used in a func value. The func value itself can be either nil - // or this one function. - builder.SetInsertPointBefore(getFuncPtrCall) - zero := llvm.ConstInt(uintptrType, 0, false) - isnil := builder.CreateICmp(llvm.IntEQ, funcID, zero, "") - funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type()) - funcPtr := builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "") - for _, inttoptr := range getUses(getFuncPtrCall) { - if inttoptr.IsAIntToPtrInst().IsNil() { - panic("expected inttoptr") - } - inttoptr.ReplaceAllUsesWith(funcPtr) - inttoptr.EraseFromParentAsInstruction() - } - getFuncPtrCall.EraseFromParentAsInstruction() - default: - // There are multiple functions used in a func value that - // implement this signature. - // What we'll do is transform the following: - // rawPtr := runtime.getFuncPtr(func.ptr) - // if rawPtr == nil { - // runtime.nilPanic() - // } - // result := rawPtr(...args, func.context) - // into this: - // if false { - // runtime.nilPanic() - // } - // var result // Phi - // switch fn.id { - // case 0: - // runtime.nilPanic() - // case 1: - // result = call first implementation... - // case 2: - // result = call second implementation... - // default: - // unreachable - // } - // Remove some casts, checks, and the old call which we're going - // to replace. - for _, callIntPtr := range getUses(getFuncPtrCall) { - if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" { - // Special case for goroutine starts. - addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { - i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) - calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "") - start := mod.NamedFunction("internal/task.start") - builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") - return llvm.Value{} // void so no return value - }, functions) - callIntPtr.EraseFromParentAsInstruction() - continue - } - if callIntPtr.IsAIntToPtrInst().IsNil() { - panic("expected inttoptr") - } - for _, ptrUse := range getUses(callIntPtr) { - if !ptrUse.IsAICmpInst().IsNil() { - ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false)) - } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr { - addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { - return builder.CreateCall(funcPtr, params, "") - }, functions) - } else { - panic("unexpected getFuncPtrCall") - } - ptrUse.EraseFromParentAsInstruction() - } + // There are functions used in a func value that + // implement this signature. + // What we'll do is transform the following: + // rawPtr := runtime.getFuncPtr(func.ptr) + // if rawPtr == nil { + // runtime.nilPanic() + // } + // result := rawPtr(...args, func.context) + // into this: + // if false { + // runtime.nilPanic() + // } + // var result // Phi + // switch fn.id { + // case 0: + // runtime.nilPanic() + // case 1: + // result = call first implementation... + // case 2: + // result = call second implementation... + // default: + // unreachable + // } + + // Remove some casts, checks, and the old call which we're going + // to replace. + for _, callIntPtr := range getUses(getFuncPtrCall) { + if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" { + // Special case for goroutine starts. + addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) + calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "") + start := mod.NamedFunction("internal/task.start") + builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") + return llvm.Value{} // void so no return value + }, functions) callIntPtr.EraseFromParentAsInstruction() + continue } - getFuncPtrCall.EraseFromParentAsInstruction() + if callIntPtr.IsAIntToPtrInst().IsNil() { + panic("expected inttoptr") + } + for _, ptrUse := range getUses(callIntPtr) { + if !ptrUse.IsAICmpInst().IsNil() { + ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false)) + } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr { + addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { + return builder.CreateCall(funcPtr, params, "") + }, functions) + } else { + panic("unexpected getFuncPtrCall") + } + ptrUse.EraseFromParentAsInstruction() + } + callIntPtr.EraseFromParentAsInstruction() } + getFuncPtrCall.EraseFromParentAsInstruction() } } } @@ -270,13 +253,18 @@ func addFuncLoweringSwitch(mod llvm.Module, builder llvm.Builder, funcID, call l phiBlocks[i] = bb phiValues[i] = result } - // Create the PHI node so that the call result flows into the - // next block (after the split). This is only necessary when the - // call produced a value. if call.Type().TypeKind() != llvm.VoidTypeKind { - builder.SetInsertPointBefore(nextBlock.FirstInstruction()) - phi := builder.CreatePHI(call.Type(), "") - phi.AddIncoming(phiValues, phiBlocks) - call.ReplaceAllUsesWith(phi) + if len(functions) > 0 { + // Create the PHI node so that the call result flows into the + // next block (after the split). This is only necessary when the + // call produced a value. + builder.SetInsertPointBefore(nextBlock.FirstInstruction()) + phi := builder.CreatePHI(call.Type(), "") + phi.AddIncoming(phiValues, phiBlocks) + call.ReplaceAllUsesWith(phi) + } else { + // This is always a nil panic, so replace the call result with undef. + call.ReplaceAllUsesWith(llvm.Undef(call.Type())) + } } } diff --git a/transform/testdata/func-lowering.ll b/transform/testdata/func-lowering.ll index b5692c70..04241dc3 100644 --- a/transform/testdata/func-lowering.ll +++ b/transform/testdata/func-lowering.ll @@ -4,10 +4,9 @@ target triple = "wasm32-unknown-unknown-wasm" %runtime.typecodeID = type { %runtime.typecodeID*, i32 } %runtime.funcValueWithSignature = type { i32, %runtime.typecodeID* } -@"reflect/types.type:func:{basic:int8}{}" = external constant %runtime.typecodeID @"reflect/types.type:func:{basic:uint8}{}" = external constant %runtime.typecodeID @"reflect/types.type:func:{basic:int}{}" = external constant %runtime.typecodeID -@"funcInt8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @funcInt8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}" } +@"reflect/types.type:func:{}{basic:uint32}" = external constant %runtime.typecodeID @"func1Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func1Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" } @"func2Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func2Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" } @"main$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int}{}" } @@ -23,29 +22,26 @@ declare void @"main$1"(i32, i8*, i8*) declare void @"main$2"(i32, i8*, i8*) -declare void @funcInt8(i8, i8*, i8*) - declare void @func1Uint8(i8, i8*, i8*) declare void @func2Uint8(i8, i8*, i8*) -; Call a function of which only one function with this signature is used as a -; function value. This means that lowering it to IR is trivial: simply check -; whether the func value is nil, and if not, call that one function directly. -define void @runFunc1(i8*, i32, i8, i8* %context, i8* %parentHandle) { +; There are no functions with this signature used in a func value. +; This means that this should unconditionally nil panic. +define i32 @runFuncNone(i8*, i32, i8* %context, i8* %parentHandle) { entry: - %3 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}", i8* undef, i8* null) - %4 = inttoptr i32 %3 to void (i8, i8*, i8*)* - %5 = icmp eq void (i8, i8*, i8*)* %4, null - br i1 %5, label %fpcall.nil, label %fpcall.next + %2 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{}{basic:uint32}", i8* undef, i8* null) + %3 = inttoptr i32 %2 to i32 (i8*, i8*)* + %4 = icmp eq i32 (i8*, i8*)* %3, null + br i1 %4, label %fpcall.nil, label %fpcall.next fpcall.nil: call void @runtime.nilPanic(i8* undef, i8* null) unreachable fpcall.next: - call void %4(i8 %2, i8* %0, i8* undef) - ret void + %5 = call i32 %3(i8* %0, i8* undef) + ret i32 %5 } ; There are two functions with this signature used in a func value. That means diff --git a/transform/testdata/func-lowering.out.ll b/transform/testdata/func-lowering.out.ll index 50558fc8..2f46baad 100644 --- a/transform/testdata/func-lowering.out.ll +++ b/transform/testdata/func-lowering.out.ll @@ -4,10 +4,9 @@ target triple = "wasm32-unknown-unknown-wasm" %runtime.typecodeID = type { %runtime.typecodeID*, i32 } %runtime.funcValueWithSignature = type { i32, %runtime.typecodeID* } -@"reflect/types.type:func:{basic:int8}{}" = external constant %runtime.typecodeID @"reflect/types.type:func:{basic:uint8}{}" = external constant %runtime.typecodeID @"reflect/types.type:func:{basic:int}{}" = external constant %runtime.typecodeID -@"funcInt8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @funcInt8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}" } +@"reflect/types.type:func:{}{basic:uint32}" = external constant %runtime.typecodeID @"func1Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func1Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" } @"func2Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func2Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" } @"main$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int}{}" } @@ -23,26 +22,32 @@ declare void @"main$1"(i32, i8*, i8*) declare void @"main$2"(i32, i8*, i8*) -declare void @funcInt8(i8, i8*, i8*) - declare void @func1Uint8(i8, i8*, i8*) declare void @func2Uint8(i8, i8*, i8*) -define void @runFunc1(i8* %0, i32 %1, i8 %2, i8* %context, i8* %parentHandle) { +define i32 @runFuncNone(i8* %0, i32 %1, i8* %context, i8* %parentHandle) { entry: - %3 = icmp eq i32 %1, 0 - %4 = select i1 %3, void (i8, i8*, i8*)* null, void (i8, i8*, i8*)* @funcInt8 - %5 = icmp eq void (i8, i8*, i8*)* %4, null - br i1 %5, label %fpcall.nil, label %fpcall.next + br i1 false, label %fpcall.nil, label %fpcall.next fpcall.nil: ; preds = %entry call void @runtime.nilPanic(i8* undef, i8* null) unreachable fpcall.next: ; preds = %entry - call void %4(i8 %2, i8* %0, i8* undef) - ret void + switch i32 %1, label %func.default [ + i32 0, label %func.nil + ] + +func.nil: ; preds = %fpcall.next + call void @runtime.nilPanic(i8* undef, i8* null) + unreachable + +func.next: ; No predecessors! + ret i32 undef + +func.default: ; preds = %fpcall.next + unreachable } define void @runFunc2(i8* %0, i32 %1, i8 %2, i8* %context, i8* %parentHandle) {