diff --git a/transform/gc.go b/transform/gc.go index ab080980..eb3520aa 100644 --- a/transform/gc.go +++ b/transform/gc.go @@ -271,37 +271,17 @@ func MakeGCStackSlots(mod llvm.Module) bool { // Make sure this stack object is popped from the linked list of stack // objects at return. for _, ret := range returns { - inst := ret - // Try to do the popping of the stack object earlier, by inserting - // it not right before the return instruction but moving the insert - // position up. - // This is necessary so that the GC stack slot pass doesn't - // interfere with tail calls (in particular, musttail calls). - for { - prevInst := llvm.PrevInstruction(inst) - if prevInst == parent { - break - } - if _, ok := pointerStores[prevInst]; ok { - // Pop the stack object after the last store instruction. - // This can probably be made more efficient: storing to the - // stack chain object and then immediately popping isn't - // useful. - break - } - if prevInst.IsNil() { - // Start of basic block. Pop the stack object here. - break - } - if !prevInst.IsAPHINode().IsNil() { - // Do not insert before a PHI node. PHI nodes must be - // grouped at the beginning of a basic block before any - // other instruction. - break - } - inst = prevInst + // Check for any tail calls at this return. + prev := llvm.PrevInstruction(ret) + if !prev.IsNil() && !prev.IsABitCastInst().IsNil() { + // A bitcast can appear before a tail call, so skip backwards more. + prev = llvm.PrevInstruction(prev) } - builder.SetInsertPointBefore(inst) + if !prev.IsNil() && !prev.IsACallInst().IsNil() { + // This is no longer a tail call. + prev.SetTailCall(false) + } + builder.SetInsertPointBefore(ret) builder.CreateStore(parent, stackChainStart) } } diff --git a/transform/testdata/gc-stackslots.ll b/transform/testdata/gc-stackslots.ll index 130b0c99..c217fb9a 100644 --- a/transform/testdata/gc-stackslots.ll +++ b/transform/testdata/gc-stackslots.ll @@ -5,6 +5,7 @@ target triple = "wasm32-unknown-unknown-wasm" @runtime.stackChainStart = external global %runtime.stackChainObject* @someGlobal = global i8 3 +@ptrGlobal = global i8** null declare void @runtime.trackPointer(i8* nocapture readonly) @@ -20,8 +21,6 @@ define i8* @needsStackSlots() { ; so tracking it is not really necessary. %ptr = call i8* @runtime.alloc(i32 4, i8* null) call void @runtime.trackPointer(i8* %ptr) - ; Restoring the stack pointer can happen at this position, before the return. - ; This avoids issues with tail calls. call void @someArbitraryFunction() %val = load i8, i8* @someGlobal ret i8* %ptr @@ -103,3 +102,20 @@ define void @testGEPBitcast() { define void @someArbitraryFunction() { ret void } + +define void @earlyPopRegression() { + %x.alloc = call i8* @runtime.alloc(i32 4, i8* null) + call void @runtime.trackPointer(i8* %x.alloc) + %x = bitcast i8* %x.alloc to i8** + ; At this point the pass used to pop the stack chain, resulting in a potential use-after-free during allocAndSave. + musttail call void @allocAndSave(i8** %x) + ret void +} + +define void @allocAndSave(i8** %x) { + %y = call i8* @runtime.alloc(i32 4, i8* null) + call void @runtime.trackPointer(i8* %y) + store i8* %y, i8** %x + store i8** %x, i8*** @ptrGlobal + ret void +} \ No newline at end of file diff --git a/transform/testdata/gc-stackslots.out.ll b/transform/testdata/gc-stackslots.out.ll index 15fab17e..83d1c841 100644 --- a/transform/testdata/gc-stackslots.out.ll +++ b/transform/testdata/gc-stackslots.out.ll @@ -5,6 +5,7 @@ target triple = "wasm32-unknown-unknown-wasm" @runtime.stackChainStart = internal global %runtime.stackChainObject* null @someGlobal = global i8 3 +@ptrGlobal = global i8** null declare void @runtime.trackPointer(i8* nocapture readonly) @@ -25,9 +26,9 @@ define i8* @needsStackSlots() { %ptr = call i8* @runtime.alloc(i32 4, i8* null) %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 store i8* %ptr, i8** %4, align 4 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 call void @someArbitraryFunction() %val = load i8, i8* @someGlobal, align 1 + store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 ret i8* %ptr } @@ -75,8 +76,8 @@ define i8* @fibNext(i8* %x, i8* %y) { %out.alloc = call i8* @runtime.alloc(i32 1, i8* null) %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 store i8* %out.alloc, i8** %4, align 4 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 store i8 %out.val, i8* %out.alloc, align 1 + store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 ret i8* %out.alloc } @@ -141,3 +142,37 @@ define void @testGEPBitcast() { define void @someArbitraryFunction() { ret void } + +define void @earlyPopRegression() { + %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 + store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 + %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 + store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 + %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* + store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %x.alloc = call i8* @runtime.alloc(i32 4, i8* null) + %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 + store i8* %x.alloc, i8** %4, align 4 + %x = bitcast i8* %x.alloc to i8** + call void @allocAndSave(i8** %x) + store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + ret void +} + +define void @allocAndSave(i8** %x) { + %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 + store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 + %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 + store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 + %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* + store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %y = call i8* @runtime.alloc(i32 4, i8* null) + %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 + store i8* %y, i8** %4, align 4 + store i8* %y, i8** %x, align 4 + store i8** %x, i8*** @ptrGlobal, align 4 + store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + ret void +}