diff --git a/compiler.go b/compiler.go index 3028494e..51516dad 100644 --- a/compiler.go +++ b/compiler.go @@ -2122,6 +2122,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, err } if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok { + // TODO: check whether the type implements the interface. return llvm.Value{}, errors.New("todo: assert on interface") } assertedType, err := c.getLLVMType(expr.AssertedType) @@ -2137,19 +2138,46 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer") } actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") - valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value") - var value llvm.Value + + commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false), actualTypeNum, "") + + // Add 2 new basic blocks (that should get optimized away): one for the + // 'ok' case and one for all instructions following this type assert. + // This is necessary because we need to insert the casted value or the + // nil value based on whether the assert was successful. Casting before + // this check tells LLVM that it can use this value and may + // speculatively dereference pointers before the check. This can lead to + // a miscompilation resulting in a segfault at runtime. + // Additionally, this is even required by the Go spec: a failed + // typeassert should return a zero value, not an incorrectly casted + // value. + + valueNil, err := getZeroValue(assertedType) + if err != nil { + return llvm.Value{}, err + } + + prevBlock := c.builder.GetInsertBlock() + okBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.ok") + nextBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.next") + c.builder.CreateCondBr(commaOk, okBlock, nextBlock) + + // Retrieve the value from the interface if the type assert was + // successful. + c.builder.SetInsertPointAtEnd(okBlock) + valuePtr := c.builder.CreateExtractValue(itf, 1, "typeassert.value.ptr") + var valueOk llvm.Value if c.targetData.TypeAllocSize(assertedType) > c.targetData.TypeAllocSize(c.i8ptrType) { // Value was stored in an allocated buffer, load it from there. valuePtrCast := c.builder.CreateBitCast(valuePtr, llvm.PointerType(assertedType, 0), "") - value = c.builder.CreateLoad(valuePtrCast, "") + valueOk = c.builder.CreateLoad(valuePtrCast, "typeassert.value.ok") } else { // Value was stored directly in the interface. switch assertedType.TypeKind() { case llvm.IntegerTypeKind: - value = c.builder.CreatePtrToInt(valuePtr, assertedType, "") + valueOk = c.builder.CreatePtrToInt(valuePtr, assertedType, "typeassert.value.ok") case llvm.PointerTypeKind: - value = c.builder.CreateBitCast(valuePtr, assertedType, "") + valueOk = c.builder.CreateBitCast(valuePtr, assertedType, "typeassert.value.ok") case llvm.StructTypeKind: // A bitcast would be useful here, but bitcast doesn't allow // aggregate types. So we'll bitcast it using an alloca. @@ -2157,16 +2185,20 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { mem := c.builder.CreateAlloca(c.i8ptrType, "") c.builder.CreateStore(valuePtr, mem) memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") - value = c.builder.CreateLoad(memStructPtr, "") + valueOk = c.builder.CreateLoad(memStructPtr, "typeassert.value.ok") default: return llvm.Value{}, errors.New("todo: typeassert: bitcast small types") } } - // TODO: for interfaces, check whether the type implements the - // interface. - commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false), actualTypeNum, "") + c.builder.CreateBr(nextBlock) + + // Continue after the if statement. + c.builder.SetInsertPointAtEnd(nextBlock) + phi := c.builder.CreatePHI(assertedType, "typeassert.value") + phi.AddIncoming([]llvm.Value{valueNil, valueOk}, []llvm.BasicBlock{prevBlock, okBlock}) + tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple - tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value + tuple = c.builder.CreateInsertValue(tuple, phi, 0, "") // insert value tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean return tuple, nil case *ssa.UnOp: