From 9bddaae04a1593789a8f3f0bf877298e8d504be8 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Wed, 14 Nov 2018 14:41:40 +0100 Subject: [PATCH] compiler: support any int type in slice indexes Make sure the compiler will correctly compile indexes of type uint64, for example. --- compiler/compiler.go | 74 ++++++++++++++++++++++++++++++++------------ src/runtime/panic.go | 6 ++-- testdata/slice.go | 46 +++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index 675da94a..b2ef1b0a 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -2104,21 +2104,31 @@ func (c *Compiler) emitBoundsCheck(frame *Frame, arrayLen, index llvm.Value, ind } } -func (c *Compiler) emitSliceBoundsCheck(frame *Frame, capacity, low, high llvm.Value) { +func (c *Compiler) emitSliceBoundsCheck(frame *Frame, capacity, low, high llvm.Value, lowType, highType *types.Basic) { if frame.fn.IsNoBounds() { // The //go:nobounds pragma was added to the function to avoid bounds // checking. return } - if low.Type().IntTypeWidth() > 32 || high.Type().IntTypeWidth() > 32 { + uintptrWidth := c.uintptrType.IntTypeWidth() + if low.Type().IntTypeWidth() > uintptrWidth || high.Type().IntTypeWidth() > uintptrWidth { if low.Type().IntTypeWidth() < 64 { - low = c.builder.CreateSExt(low, c.ctx.Int64Type(), "") + if lowType.Info()&types.IsUnsigned != 0 { + low = c.builder.CreateZExt(low, c.ctx.Int64Type(), "") + } else { + low = c.builder.CreateSExt(low, c.ctx.Int64Type(), "") + } } if high.Type().IntTypeWidth() < 64 { - high = c.builder.CreateSExt(high, c.ctx.Int64Type(), "") + if highType.Info()&types.IsUnsigned != 0 { + high = c.builder.CreateZExt(high, c.ctx.Int64Type(), "") + } else { + high = c.builder.CreateSExt(high, c.ctx.Int64Type(), "") + } } - c.createRuntimeCall("sliceBoundsCheckLong", []llvm.Value{capacity, low, high}, "") + // TODO: 32-bit or even 16-bit slice bounds checks for 8-bit platforms + c.createRuntimeCall("sliceBoundsCheck64", []llvm.Value{capacity, low, high}, "") } else { c.createRuntimeCall("sliceBoundsCheck", []llvm.Value{capacity, low, high}, "") } @@ -2494,45 +2504,71 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } + + var lowType, highType *types.Basic var low, high llvm.Value - if expr.Low == nil { - low = llvm.ConstInt(c.intType, 0, false) - } else { + + if expr.Low != nil { + lowType = expr.Low.Type().(*types.Basic) low, err = c.parseExpr(frame, expr.Low) if err != nil { return llvm.Value{}, nil } + if low.Type().IntTypeWidth() < c.uintptrType.IntTypeWidth() { + if lowType.Info()&types.IsUnsigned != 0 { + low = c.builder.CreateZExt(low, c.uintptrType, "") + } else { + low = c.builder.CreateSExt(low, c.uintptrType, "") + } + } + } else { + lowType = types.Typ[types.Int] + low = llvm.ConstInt(c.intType, 0, false) } + if expr.High != nil { + highType = expr.High.Type().(*types.Basic) high, err = c.parseExpr(frame, expr.High) if err != nil { return llvm.Value{}, nil } + if high.Type().IntTypeWidth() < c.uintptrType.IntTypeWidth() { + if highType.Info()&types.IsUnsigned != 0 { + high = c.builder.CreateZExt(high, c.uintptrType, "") + } else { + high = c.builder.CreateSExt(high, c.uintptrType, "") + } + } + } else { + highType = types.Typ[types.Uintptr] } + switch typ := expr.X.Type().Underlying().(type) { case *types.Pointer: // pointer to array // slice an array length := typ.Elem().(*types.Array).Len() llvmLen := llvm.ConstInt(c.uintptrType, uint64(length), false) - llvmLenInt := llvm.ConstInt(c.intType, uint64(length), false) if high.IsNil() { - high = llvmLenInt + high = llvmLen } indices := []llvm.Value{ llvm.ConstInt(c.ctx.Int32Type(), 0, false), low, } - slicePtr := c.builder.CreateGEP(value, indices, "slice.ptr") - sliceLen := c.builder.CreateSub(high, low, "slice.len") - sliceCap := c.builder.CreateSub(llvmLenInt, low, "slice.cap") // This check is optimized away in most cases. - c.emitSliceBoundsCheck(frame, llvmLen, low, high) + c.emitSliceBoundsCheck(frame, llvmLen, low, high, lowType, highType) - if c.targetData.TypeAllocSize(sliceLen.Type()) > c.targetData.TypeAllocSize(c.uintptrType) { - sliceLen = c.builder.CreateTrunc(sliceLen, c.uintptrType, "") - sliceCap = c.builder.CreateTrunc(sliceCap, c.uintptrType, "") + if c.targetData.TypeAllocSize(high.Type()) > c.targetData.TypeAllocSize(c.uintptrType) { + high = c.builder.CreateTrunc(high, c.uintptrType, "") } + if c.targetData.TypeAllocSize(low.Type()) > c.targetData.TypeAllocSize(c.uintptrType) { + low = c.builder.CreateTrunc(low, c.uintptrType, "") + } + + sliceLen := c.builder.CreateSub(high, low, "slice.len") + slicePtr := c.builder.CreateGEP(value, indices, "slice.ptr") + sliceCap := c.builder.CreateSub(llvmLen, low, "slice.cap") slice := c.ctx.ConstStruct([]llvm.Value{ llvm.Undef(slicePtr.Type()), @@ -2553,7 +2589,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { high = oldLen } - c.emitSliceBoundsCheck(frame, oldCap, low, high) + c.emitSliceBoundsCheck(frame, oldCap, low, high, lowType, highType) if c.targetData.TypeAllocSize(low.Type()) > c.targetData.TypeAllocSize(c.uintptrType) { low = c.builder.CreateTrunc(low, c.uintptrType, "") @@ -2586,7 +2622,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { high = oldLen } - c.emitSliceBoundsCheck(frame, oldLen, low, high) + c.emitSliceBoundsCheck(frame, oldLen, low, high, lowType, highType) newPtr := c.builder.CreateGEP(oldPtr, []llvm.Value{low}, "") newLen := c.builder.CreateSub(high, low, "") diff --git a/src/runtime/panic.go b/src/runtime/panic.go index fbd5d3c5..a71fe141 100644 --- a/src/runtime/panic.go +++ b/src/runtime/panic.go @@ -43,14 +43,14 @@ func lookupBoundsCheckLong(length uintptr, index int64) { } // Check for bounds in *ssa.Slice. -func sliceBoundsCheck(capacity uintptr, low, high uint) { - if !(0 <= low && low <= high && high <= uint(capacity)) { +func sliceBoundsCheck(capacity, low, high uintptr) { + if !(0 <= low && low <= high && high <= capacity) { runtimePanic("slice out of range") } } // Check for bounds in *ssa.Slice. Supports 64-bit indexes. -func sliceBoundsCheckLong(capacity uintptr, low, high uint64) { +func sliceBoundsCheck64(capacity uintptr, low, high uint64) { if !(0 <= low && low <= high && high <= uint64(capacity)) { runtimePanic("slice out of range") } diff --git a/testdata/slice.go b/testdata/slice.go index 8fcbffff..6237a015 100644 --- a/testdata/slice.go +++ b/testdata/slice.go @@ -10,6 +10,46 @@ func main() { printslice("foo[1:2]", foo[1:2]) println("sum foo:", sum(foo)) + // indexing into a slice with uncommon index types + assert(foo[int(2)] == 4) + assert(foo[int8(2)] == 4) + assert(foo[int16(2)] == 4) + assert(foo[int32(2)] == 4) + assert(foo[int64(2)] == 4) + assert(foo[uint(2)] == 4) + assert(foo[uint8(2)] == 4) + assert(foo[uint16(2)] == 4) + assert(foo[uint32(2)] == 4) + assert(foo[uint64(2)] == 4) + assert(foo[uintptr(2)] == 4) + + // slicing with uncommon low, high types + assert(len(foo[int(1):int(3)]) == 2) + assert(len(foo[int8(1):int8(3)]) == 2) + assert(len(foo[int16(1):int16(3)]) == 2) + assert(len(foo[int32(1):int32(3)]) == 2) + assert(len(foo[int64(1):int64(3)]) == 2) + assert(len(foo[uint(1):uint(3)]) == 2) + assert(len(foo[uint8(1):uint8(3)]) == 2) + assert(len(foo[uint16(1):uint16(3)]) == 2) + assert(len(foo[uint32(1):uint32(3)]) == 2) + assert(len(foo[uint64(1):uint64(3)]) == 2) + assert(len(foo[uintptr(1):uintptr(3)]) == 2) + + // slicing an array with uncommon low, high types + arr := [4]int{1, 2, 4, 5} + assert(len(arr[int(1):int(3)]) == 2) + assert(len(arr[int8(1):int8(3)]) == 2) + assert(len(arr[int16(1):int16(3)]) == 2) + assert(len(arr[int32(1):int32(3)]) == 2) + assert(len(arr[int64(1):int64(3)]) == 2) + assert(len(arr[uint(1):uint(3)]) == 2) + assert(len(arr[uint8(1):uint8(3)]) == 2) + assert(len(arr[uint16(1):uint16(3)]) == 2) + assert(len(arr[uint32(1):uint32(3)]) == 2) + assert(len(arr[uint64(1):uint64(3)]) == 2) + assert(len(arr[uintptr(1):uintptr(3)]) == 2) + // copy println("copy foo -> bar:", copy(bar, foo)) printslice("bar", bar) @@ -53,3 +93,9 @@ func sum(l []int) int { } return sum } + +func assert(ok bool) { + if !ok { + panic("assert failed") + } +}