diff --git a/compiler/asserts.go b/compiler/asserts.go index 92083bc6..acd605fa 100644 --- a/compiler/asserts.go +++ b/compiler/asserts.go @@ -15,22 +15,16 @@ import ( // createLookupBoundsCheck emits a bounds check before doing a lookup into a // slice. This is required by the Go language spec: an index out of bounds must // cause a panic. -func (b *builder) createLookupBoundsCheck(arrayLen, index llvm.Value, indexType types.Type) { +// The caller should make sure that index is at least as big as arrayLen. +func (b *builder) createLookupBoundsCheck(arrayLen, index llvm.Value) { if b.info.nobounds { // The //go:nobounds pragma was added to the function to avoid bounds // checking. return } - if index.Type().IntTypeWidth() < arrayLen.Type().IntTypeWidth() { - // Sometimes, the index can be e.g. an uint8 or int8, and we have to - // correctly extend that type. - if indexType.Underlying().(*types.Basic).Info()&types.IsUnsigned == 0 { - index = b.CreateZExt(index, arrayLen.Type(), "") - } else { - index = b.CreateSExt(index, arrayLen.Type(), "") - } - } else if index.Type().IntTypeWidth() > arrayLen.Type().IntTypeWidth() { + // Extend arrayLen if it's too small. + if index.Type().IntTypeWidth() > arrayLen.Type().IntTypeWidth() { // The index is bigger than the array length type, so extend it. arrayLen = b.CreateZExt(arrayLen, index.Type(), "") } @@ -70,27 +64,9 @@ func (b *builder) createSliceBoundsCheck(capacity, low, high, max llvm.Value, lo } // Extend low and high to be the same size as capacity. - if low.Type().IntTypeWidth() < capacityType.IntTypeWidth() { - if lowType.Info()&types.IsUnsigned != 0 { - low = b.CreateZExt(low, capacityType, "") - } else { - low = b.CreateSExt(low, capacityType, "") - } - } - if high.Type().IntTypeWidth() < capacityType.IntTypeWidth() { - if highType.Info()&types.IsUnsigned != 0 { - high = b.CreateZExt(high, capacityType, "") - } else { - high = b.CreateSExt(high, capacityType, "") - } - } - if max.Type().IntTypeWidth() < capacityType.IntTypeWidth() { - if maxType.Info()&types.IsUnsigned != 0 { - max = b.CreateZExt(max, capacityType, "") - } else { - max = b.CreateSExt(max, capacityType, "") - } - } + low = b.extendInteger(low, lowType, capacityType) + high = b.extendInteger(high, highType, capacityType) + max = b.extendInteger(max, maxType, capacityType) // Now do the bounds check: low > high || high > capacity outOfBounds1 := b.CreateICmp(llvm.IntUGT, low, high, "slice.lowhigh") @@ -125,13 +101,7 @@ func (b *builder) createUnsafeSliceCheck(ptr, len llvm.Value, lenType *types.Bas // using an unsiged greater than. // Make sure the len value is at least as big as a uintptr. - if len.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { - if lenType.Info()&types.IsUnsigned != 0 { - len = b.CreateZExt(len, b.uintptrType, "") - } else { - len = b.CreateSExt(len, b.uintptrType, "") - } - } + len = b.extendInteger(len, lenType, b.uintptrType) // Determine the maximum slice size, and therefore the maximum value of the // len parameter. @@ -159,17 +129,8 @@ func (b *builder) createChanBoundsCheck(elementSize uint64, bufSize llvm.Value, return } - // Check whether the bufSize parameter must be cast to a wider integer for - // comparison. - if bufSize.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { - if bufSizeType.Info()&types.IsUnsigned != 0 { - // Unsigned, so zero-extend to uint type. - bufSize = b.CreateZExt(bufSize, b.intType, "") - } else { - // Signed, so sign-extend to int type. - bufSize = b.CreateSExt(bufSize, b.intType, "") - } - } + // Make sure bufSize is at least as big as maxBufSize (an uintptr). + bufSize = b.extendInteger(bufSize, bufSizeType, b.uintptrType) // Calculate (^uintptr(0)) >> 1, which is the max value that fits in an // uintptr if uintptrs were signed. @@ -294,3 +255,19 @@ func (b *builder) createRuntimeAssert(assert llvm.Value, blockPrefix, assertFunc // Ok: assert didn't trigger so continue normally. b.SetInsertPointAtEnd(nextBlock) } + +// extendInteger extends the value to at least targetType using a zero or sign +// extend. The resulting value is not truncated: it may still be bigger than +// targetType. +func (b *builder) extendInteger(value llvm.Value, valueType types.Type, targetType llvm.Type) llvm.Value { + if value.Type().IntTypeWidth() < targetType.IntTypeWidth() { + if valueType.Underlying().(*types.Basic).Info()&types.IsUnsigned != 0 { + // Unsigned, so zero-extend to the target type. + value = b.CreateZExt(value, targetType, "") + } else { + // Signed, so sign-extend to the target type. + value = b.CreateSExt(value, targetType, "") + } + } + return value +} diff --git a/compiler/compiler.go b/compiler/compiler.go index c279800c..ae1ba07b 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1622,10 +1622,14 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { array := b.getValue(expr.X) index := b.getValue(expr.Index) + // Extend index to at least uintptr size, because getelementptr assumes + // index is a signed integer. + index = b.extendInteger(index, expr.Index.Type(), b.uintptrType) + // Check bounds. arrayLen := expr.X.Type().Underlying().(*types.Array).Len() arrayLenLLVM := llvm.ConstInt(b.uintptrType, uint64(arrayLen), false) - b.createLookupBoundsCheck(arrayLenLLVM, index, expr.Index.Type()) + b.createLookupBoundsCheck(arrayLenLLVM, index) // Can't load directly from array (as index is non-constant), so have to // do it using an alloca+gep+load. @@ -1666,8 +1670,12 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, b.makeError(expr.Pos(), "todo: indexaddr: "+ptrTyp.String()) } + // Make sure index is at least the size of uintptr becuase getelementptr + // assumes index is a signed integer. + index = b.extendInteger(index, expr.Index.Type(), b.uintptrType) + // Bounds check. - b.createLookupBoundsCheck(buflen, index, expr.Index.Type()) + b.createLookupBoundsCheck(buflen, index) switch expr.X.Type().Underlying().(type) { case *types.Pointer: @@ -1691,9 +1699,17 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { panic("lookup on non-string?") } + // Sometimes, the index can be e.g. an uint8 or int8, and we have to + // correctly extend that type for two reasons: + // 1. The lookup bounds check expects an index of at least uintptr + // size. + // 2. getelementptr has signed operands, and therefore s[uint8(x)] + // can be lowered as s[int8(x)]. That would be a bug. + index = b.extendInteger(index, expr.Index.Type(), b.uintptrType) + // Bounds check. length := b.CreateExtractValue(value, 1, "len") - b.createLookupBoundsCheck(length, index, expr.Index.Type()) + b.createLookupBoundsCheck(length, index) // Lookup byte buf := b.CreateExtractValue(value, 0, "") @@ -1819,13 +1835,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { if expr.Low != nil { lowType = expr.Low.Type().Underlying().(*types.Basic) low = b.getValue(expr.Low) - if low.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { - if lowType.Info()&types.IsUnsigned != 0 { - low = b.CreateZExt(low, b.uintptrType, "") - } else { - low = b.CreateSExt(low, b.uintptrType, "") - } - } + low = b.extendInteger(low, lowType, b.uintptrType) } else { lowType = types.Typ[types.Uintptr] low = llvm.ConstInt(b.uintptrType, 0, false) @@ -1834,13 +1844,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { if expr.High != nil { highType = expr.High.Type().Underlying().(*types.Basic) high = b.getValue(expr.High) - if high.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { - if highType.Info()&types.IsUnsigned != 0 { - high = b.CreateZExt(high, b.uintptrType, "") - } else { - high = b.CreateSExt(high, b.uintptrType, "") - } - } + high = b.extendInteger(high, highType, b.uintptrType) } else { highType = types.Typ[types.Uintptr] } @@ -1848,13 +1852,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { if expr.Max != nil { maxType = expr.Max.Type().Underlying().(*types.Basic) max = b.getValue(expr.Max) - if max.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { - if maxType.Info()&types.IsUnsigned != 0 { - max = b.CreateZExt(max, b.uintptrType, "") - } else { - max = b.CreateSExt(max, b.uintptrType, "") - } - } + max = b.extendInteger(max, maxType, b.uintptrType) } else { maxType = types.Typ[types.Uintptr] } diff --git a/compiler/testdata/string.go b/compiler/testdata/string.go index 2c37df45..56ad1170 100644 --- a/compiler/testdata/string.go +++ b/compiler/testdata/string.go @@ -27,3 +27,8 @@ func stringCompareUnequal(s1, s2 string) bool { func stringCompareLarger(s1, s2 string) bool { return s1 > s2 } + +func stringLookup(s string, x uint8) byte { + // Test that x is correctly extended to an uint before comparison. + return s[x] +} diff --git a/compiler/testdata/string.ll b/compiler/testdata/string.ll index 75812865..5bfe0255 100644 --- a/compiler/testdata/string.ll +++ b/compiler/testdata/string.ll @@ -78,4 +78,21 @@ entry: declare i1 @runtime.stringLess(i8*, i32, i8*, i32, i8*, i8*) +; Function Attrs: nounwind +define hidden i8 @main.stringLookup(i8* %s.data, i32 %s.len, i8 %x, i8* %context, i8* %parentHandle) unnamed_addr #0 { +entry: + %0 = zext i8 %x to i32 + %.not = icmp ult i32 %0, %s.len + br i1 %.not, label %lookup.next, label %lookup.throw + +lookup.throw: ; preds = %entry + call void @runtime.lookupPanic(i8* undef, i8* null) #0 + unreachable + +lookup.next: ; preds = %entry + %1 = getelementptr inbounds i8, i8* %s.data, i32 %0 + %2 = load i8, i8* %1, align 1 + ret i8 %2 +} + attributes #0 = { nounwind }