diff --git a/compiler/asserts.go b/compiler/asserts.go index f1ae259e..e8b53f7d 100644 --- a/compiler/asserts.go +++ b/compiler/asserts.go @@ -186,6 +186,19 @@ func (b *builder) createNilCheck(inst ssa.Value, ptr llvm.Value, blockPrefix str b.createRuntimeAssert(isnil, blockPrefix, "nilPanic") } +// createNegativeShiftCheck creates an assertion that panics if the given shift value is negative. +// This function assumes that the shift value is signed. +func (b *builder) createNegativeShiftCheck(shift llvm.Value) { + if b.fn.IsNoBounds() { + // Function disabled bounds checking - skip shift check. + return + } + + // isNegative = shift < 0 + isNegative := b.CreateICmp(llvm.IntSLT, shift, llvm.ConstInt(shift.Type(), 0, false), "") + b.createRuntimeAssert(isNegative, "shift", "negativeShiftPanic") +} + // createRuntimeAssert is a common function to create a new branch on an assert // bool, calling an assert func if the assert value is true (1). func (b *builder) createRuntimeAssert(assert llvm.Value, blockPrefix, assertFunc string) { diff --git a/compiler/compiler.go b/compiler/compiler.go index d7609d13..14b45a43 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1476,7 +1476,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.BinOp: x := b.getValue(expr.X) y := b.getValue(expr.Y) - return b.createBinOp(expr.Op, expr.X.Type(), x, y, expr.Pos()) + return b.createBinOp(expr.Op, expr.X.Type(), expr.Y.Type(), x, y, expr.Pos()) case *ssa.Call: return b.createFunctionCall(expr.Common()) case *ssa.ChangeInterface: @@ -1925,7 +1925,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // same type, even for bitshifts. Also, signedness in Go is encoded in the type // and is encoded in the operation in LLVM IR: this is important for some // operations such as divide. -func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) { +func (b *builder) createBinOp(op token.Token, typ, ytyp types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) { switch typ := typ.Underlying().(type) { case *types.Basic: if typ.Info()&types.IsInteger != 0 { @@ -1957,32 +1957,49 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p case token.XOR: // ^ return b.CreateXor(x, y, ""), nil case token.SHL, token.SHR: + if ytyp.Underlying().(*types.Basic).Info()&types.IsUnsigned == 0 { + // Ensure that y is not negative. + b.createNegativeShiftCheck(y) + } + sizeX := b.targetData.TypeAllocSize(x.Type()) sizeY := b.targetData.TypeAllocSize(y.Type()) - if sizeX > sizeY { - // x and y must have equal sizes, make Y bigger in this case. - // y is unsigned, this has been checked by the Go type checker. + + // Check if the shift is bigger than the bit-width of the shifted value. + // This is UB in LLVM, so it needs to be handled seperately. + // The Go spec indirectly defines the result as 0. + // Negative shifts are handled earlier, so we can treat y as unsigned. + overshifted := b.CreateICmp(llvm.IntUGE, y, llvm.ConstInt(y.Type(), 8*sizeX, false), "shift.overflow") + + // Adjust the size of y to match x. + switch { + case sizeX > sizeY: y = b.CreateZExt(y, x.Type(), "") - } else if sizeX < sizeY { - // What about shifting more than the integer width? - // I'm not entirely sure what the Go spec is on that, but as - // Intel CPUs have undefined behavior when shifting more - // than the integer width I'm assuming it is also undefined - // in Go. + case sizeX < sizeY: + // If it gets truncated, overshifted will be true and it will not matter. y = b.CreateTrunc(y, x.Type(), "") } + + // Create a shift operation. + var val llvm.Value switch op { case token.SHL: // << - return b.CreateShl(x, y, ""), nil + val = b.CreateShl(x, y, "") case token.SHR: // >> if signed { + // Arithmetic right shifts work differently, since shifting a negative number right yields -1. + // Cap the shift input rather than selecting the output. + y = b.CreateSelect(overshifted, llvm.ConstInt(y.Type(), 8*sizeX-1, false), y, "shift.offset") return b.CreateAShr(x, y, ""), nil } else { - return b.CreateLShr(x, y, ""), nil + val = b.CreateLShr(x, y, "") } default: panic("unreachable") } + + // Select between the shift result and zero depending on whether there was an overshift. + return b.CreateSelect(overshifted, llvm.ConstInt(val.Type(), 0, false), val, "shift.result"), nil case token.EQL: // == return b.CreateICmp(llvm.IntEQ, x, y, ""), nil case token.NEQ: // != @@ -2218,7 +2235,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p for i := 0; i < int(typ.Len()); i++ { xField := b.CreateExtractValue(x, i, "") yField := b.CreateExtractValue(y, i, "") - fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), xField, yField, pos) + fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), typ.Elem(), xField, yField, pos) if err != nil { return llvm.Value{}, err } @@ -2246,7 +2263,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p fieldType := typ.Field(i).Type() xField := b.CreateExtractValue(x, i, "") yField := b.CreateExtractValue(y, i, "") - fieldEqual, err := b.createBinOp(token.EQL, fieldType, xField, yField, pos) + fieldEqual, err := b.createBinOp(token.EQL, fieldType, fieldType, xField, yField, pos) if err != nil { return llvm.Value{}, err } diff --git a/interp/frame.go b/interp/frame.go index 97b438ce..299a1778 100644 --- a/interp/frame.go +++ b/interp/frame.go @@ -603,6 +603,18 @@ func (fr *frame) evalBasicBlock(bb, incoming llvm.BasicBlock, indent string) (re } fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateInsertValue(agg.Underlying, val.Value(), int(indices[0]), inst.Name())} } + case !inst.IsASelectInst().IsNil(): + // var result T + // if cond { + // result = x + // } else { + // result = y + // } + // return result + cond := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying + x := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying + y := fr.getLocal(inst.Operand(2)).(*LocalValue).Underlying + fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateSelect(cond, x, y, "")} case !inst.IsAReturnInst().IsNil() && inst.OperandsCount() == 0: return nil, nil, nil // ret void diff --git a/src/runtime/panic.go b/src/runtime/panic.go index 2efeed20..a0036493 100644 --- a/src/runtime/panic.go +++ b/src/runtime/panic.go @@ -47,6 +47,11 @@ func chanMakePanic() { runtimePanic("new channel is too big") } +// Panic when a shift value is negative. +func negativeShiftPanic() { + runtimePanic("negative shift") +} + func blockingPanic() { runtimePanic("trying to do blocking operation in exported function") } diff --git a/testdata/binop.go b/testdata/binop.go index f58a94c1..537fb5cd 100644 --- a/testdata/binop.go +++ b/testdata/binop.go @@ -61,6 +61,15 @@ func main() { println(c128 != 3+2i) println(c128 != 4+2i) println(c128 != 3+3i) + + println("shifts") + println(shlSimple == 4) + println(shlOverflow == 0) + println(shrSimple == 1) + println(shrOverflow == 0) + println(ashrNeg == -1) + println(ashrOverflow == 0) + println(ashrNegOverflow == -1) } var x = true @@ -87,3 +96,23 @@ type Struct2 struct { _ float64 i int } + +func shl(x uint, y uint) uint { + return x << y +} + +func shr(x uint, y uint) uint { + return x >> y +} + +func ashr(x int, y uint) int { + return x >> y +} + +var shlSimple = shl(2, 1) +var shlOverflow = shl(2, 1000) +var shrSimple = shr(2, 1) +var shrOverflow = shr(2, 1000000) +var ashrNeg = ashr(-1, 1) +var ashrOverflow = ashr(1, 1000000) +var ashrNegOverflow = ashr(-1, 1000000) diff --git a/testdata/binop.txt b/testdata/binop.txt index 8d1f37be..a8370b19 100644 --- a/testdata/binop.txt +++ b/testdata/binop.txt @@ -54,3 +54,11 @@ false true true true +shifts +true +true +true +true +true +true +true