From 5cc130bb6ea92882cce391c4251159dc2effd95f Mon Sep 17 00:00:00 2001 From: Jaden Weiss Date: Sat, 28 Mar 2020 12:35:19 -0400 Subject: [PATCH] compiler: implement spec-compliant shifts Previously, the compiler used LLVM's shift instructions directly, which have UB whenever the shifts are large or negative. This commit adds runtime checks for negative shifts, and handles oversized shifts. --- compiler/asserts.go | 13 ++++++++++++ compiler/compiler.go | 47 ++++++++++++++++++++++++++++++-------------- interp/frame.go | 12 +++++++++++ src/runtime/panic.go | 5 +++++ testdata/binop.go | 29 +++++++++++++++++++++++++++ testdata/binop.txt | 8 ++++++++ 6 files changed, 99 insertions(+), 15 deletions(-) diff --git a/compiler/asserts.go b/compiler/asserts.go index f1ae259e..e8b53f7d 100644 --- a/compiler/asserts.go +++ b/compiler/asserts.go @@ -186,6 +186,19 @@ func (b *builder) createNilCheck(inst ssa.Value, ptr llvm.Value, blockPrefix str b.createRuntimeAssert(isnil, blockPrefix, "nilPanic") } +// createNegativeShiftCheck creates an assertion that panics if the given shift value is negative. +// This function assumes that the shift value is signed. +func (b *builder) createNegativeShiftCheck(shift llvm.Value) { + if b.fn.IsNoBounds() { + // Function disabled bounds checking - skip shift check. + return + } + + // isNegative = shift < 0 + isNegative := b.CreateICmp(llvm.IntSLT, shift, llvm.ConstInt(shift.Type(), 0, false), "") + b.createRuntimeAssert(isNegative, "shift", "negativeShiftPanic") +} + // createRuntimeAssert is a common function to create a new branch on an assert // bool, calling an assert func if the assert value is true (1). func (b *builder) createRuntimeAssert(assert llvm.Value, blockPrefix, assertFunc string) { diff --git a/compiler/compiler.go b/compiler/compiler.go index d7609d13..14b45a43 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1476,7 +1476,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.BinOp: x := b.getValue(expr.X) y := b.getValue(expr.Y) - return b.createBinOp(expr.Op, expr.X.Type(), x, y, expr.Pos()) + return b.createBinOp(expr.Op, expr.X.Type(), expr.Y.Type(), x, y, expr.Pos()) case *ssa.Call: return b.createFunctionCall(expr.Common()) case *ssa.ChangeInterface: @@ -1925,7 +1925,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // same type, even for bitshifts. Also, signedness in Go is encoded in the type // and is encoded in the operation in LLVM IR: this is important for some // operations such as divide. -func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) { +func (b *builder) createBinOp(op token.Token, typ, ytyp types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) { switch typ := typ.Underlying().(type) { case *types.Basic: if typ.Info()&types.IsInteger != 0 { @@ -1957,32 +1957,49 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p case token.XOR: // ^ return b.CreateXor(x, y, ""), nil case token.SHL, token.SHR: + if ytyp.Underlying().(*types.Basic).Info()&types.IsUnsigned == 0 { + // Ensure that y is not negative. + b.createNegativeShiftCheck(y) + } + sizeX := b.targetData.TypeAllocSize(x.Type()) sizeY := b.targetData.TypeAllocSize(y.Type()) - if sizeX > sizeY { - // x and y must have equal sizes, make Y bigger in this case. - // y is unsigned, this has been checked by the Go type checker. + + // Check if the shift is bigger than the bit-width of the shifted value. + // This is UB in LLVM, so it needs to be handled seperately. + // The Go spec indirectly defines the result as 0. + // Negative shifts are handled earlier, so we can treat y as unsigned. + overshifted := b.CreateICmp(llvm.IntUGE, y, llvm.ConstInt(y.Type(), 8*sizeX, false), "shift.overflow") + + // Adjust the size of y to match x. + switch { + case sizeX > sizeY: y = b.CreateZExt(y, x.Type(), "") - } else if sizeX < sizeY { - // What about shifting more than the integer width? - // I'm not entirely sure what the Go spec is on that, but as - // Intel CPUs have undefined behavior when shifting more - // than the integer width I'm assuming it is also undefined - // in Go. + case sizeX < sizeY: + // If it gets truncated, overshifted will be true and it will not matter. y = b.CreateTrunc(y, x.Type(), "") } + + // Create a shift operation. + var val llvm.Value switch op { case token.SHL: // << - return b.CreateShl(x, y, ""), nil + val = b.CreateShl(x, y, "") case token.SHR: // >> if signed { + // Arithmetic right shifts work differently, since shifting a negative number right yields -1. + // Cap the shift input rather than selecting the output. + y = b.CreateSelect(overshifted, llvm.ConstInt(y.Type(), 8*sizeX-1, false), y, "shift.offset") return b.CreateAShr(x, y, ""), nil } else { - return b.CreateLShr(x, y, ""), nil + val = b.CreateLShr(x, y, "") } default: panic("unreachable") } + + // Select between the shift result and zero depending on whether there was an overshift. + return b.CreateSelect(overshifted, llvm.ConstInt(val.Type(), 0, false), val, "shift.result"), nil case token.EQL: // == return b.CreateICmp(llvm.IntEQ, x, y, ""), nil case token.NEQ: // != @@ -2218,7 +2235,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p for i := 0; i < int(typ.Len()); i++ { xField := b.CreateExtractValue(x, i, "") yField := b.CreateExtractValue(y, i, "") - fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), xField, yField, pos) + fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), typ.Elem(), xField, yField, pos) if err != nil { return llvm.Value{}, err } @@ -2246,7 +2263,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p fieldType := typ.Field(i).Type() xField := b.CreateExtractValue(x, i, "") yField := b.CreateExtractValue(y, i, "") - fieldEqual, err := b.createBinOp(token.EQL, fieldType, xField, yField, pos) + fieldEqual, err := b.createBinOp(token.EQL, fieldType, fieldType, xField, yField, pos) if err != nil { return llvm.Value{}, err } diff --git a/interp/frame.go b/interp/frame.go index 97b438ce..299a1778 100644 --- a/interp/frame.go +++ b/interp/frame.go @@ -603,6 +603,18 @@ func (fr *frame) evalBasicBlock(bb, incoming llvm.BasicBlock, indent string) (re } fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateInsertValue(agg.Underlying, val.Value(), int(indices[0]), inst.Name())} } + case !inst.IsASelectInst().IsNil(): + // var result T + // if cond { + // result = x + // } else { + // result = y + // } + // return result + cond := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying + x := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying + y := fr.getLocal(inst.Operand(2)).(*LocalValue).Underlying + fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateSelect(cond, x, y, "")} case !inst.IsAReturnInst().IsNil() && inst.OperandsCount() == 0: return nil, nil, nil // ret void diff --git a/src/runtime/panic.go b/src/runtime/panic.go index 2efeed20..a0036493 100644 --- a/src/runtime/panic.go +++ b/src/runtime/panic.go @@ -47,6 +47,11 @@ func chanMakePanic() { runtimePanic("new channel is too big") } +// Panic when a shift value is negative. +func negativeShiftPanic() { + runtimePanic("negative shift") +} + func blockingPanic() { runtimePanic("trying to do blocking operation in exported function") } diff --git a/testdata/binop.go b/testdata/binop.go index f58a94c1..537fb5cd 100644 --- a/testdata/binop.go +++ b/testdata/binop.go @@ -61,6 +61,15 @@ func main() { println(c128 != 3+2i) println(c128 != 4+2i) println(c128 != 3+3i) + + println("shifts") + println(shlSimple == 4) + println(shlOverflow == 0) + println(shrSimple == 1) + println(shrOverflow == 0) + println(ashrNeg == -1) + println(ashrOverflow == 0) + println(ashrNegOverflow == -1) } var x = true @@ -87,3 +96,23 @@ type Struct2 struct { _ float64 i int } + +func shl(x uint, y uint) uint { + return x << y +} + +func shr(x uint, y uint) uint { + return x >> y +} + +func ashr(x int, y uint) int { + return x >> y +} + +var shlSimple = shl(2, 1) +var shlOverflow = shl(2, 1000) +var shrSimple = shr(2, 1) +var shrOverflow = shr(2, 1000000) +var ashrNeg = ashr(-1, 1) +var ashrOverflow = ashr(1, 1000000) +var ashrNegOverflow = ashr(-1, 1000000) diff --git a/testdata/binop.txt b/testdata/binop.txt index 8d1f37be..a8370b19 100644 --- a/testdata/binop.txt +++ b/testdata/binop.txt @@ -54,3 +54,11 @@ false true true true +shifts +true +true +true +true +true +true +true