From 88b6b2e7f5a3f513d75efdebf0df872b10b37333 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sun, 2 Sep 2018 16:24:50 +0200 Subject: [PATCH] Optimize/eliminate bounds checking TODO: do better at it by tracking min/max values of integers. The following straightforward code doesn't have its bounds checks removed: for _, n := range slice { println(n) } --- compiler.go | 46 ++++++++++++++++++++++++++-------------- ir.go | 13 ++++++++++-- src/runtime/hashmap.go | 2 ++ src/runtime/interface.go | 1 + src/runtime/print.go | 2 ++ src/runtime/runtime.go | 2 +- src/runtime/string.go | 1 + 7 files changed, 48 insertions(+), 19 deletions(-) diff --git a/compiler.go b/compiler.go index 31e5de10..dd73dfe4 100644 --- a/compiler.go +++ b/compiler.go @@ -1610,6 +1610,26 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } } +func (c *Compiler) emitBoundsCheck(frame *Frame, arrayLen, index llvm.Value) { + if frame.fn.nobounds { + // The //go:nobounds pragma was added to the function to avoid bounds + // checking. + return + } + // Optimize away trivial cases. + // LLVM would do this anyway with interprocedural optimizations, but it + // helps to see cases where bounds checking would really help. + if index.IsConstant() && arrayLen.IsConstant() { + index := index.SExtValue() + arrayLen := arrayLen.SExtValue() + if index >= 0 && index < arrayLen { + return + } + } + lookupBoundsCheck := c.mod.NamedFunction("runtime.lookupBoundsCheck") + c.builder.CreateCall(lookupBoundsCheck, []llvm.Value{arrayLen, index}, "") +} + func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if value, ok := frame.locals[expr]; ok { // Value is a local variable that has already been computed. @@ -1730,8 +1750,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Check bounds. arrayLen := expr.X.Type().(*types.Array).Len() arrayLenLLVM := llvm.ConstInt(llvm.Int32Type(), uint64(arrayLen), false) - lookupBoundsCheck := c.mod.NamedFunction("runtime.lookupBoundsCheck") - c.builder.CreateCall(lookupBoundsCheck, []llvm.Value{arrayLenLLVM, index}, "") + c.emitBoundsCheck(frame, arrayLenLLVM, index) // Can't load directly from array (as index is non-constant), so have to // do it using an alloca+gep+load. @@ -1771,12 +1790,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Bounds check. // LLVM optimizes this away in most cases. - // TODO: runtime.lookupBoundsCheck is undefined in packages imported by - // package runtime, so we have to remove it. This should be fixed. - lookupBoundsCheck := c.mod.NamedFunction("runtime.lookupBoundsCheck") - if !lookupBoundsCheck.IsNil() && frame.fn.llvmFn.Name() != "runtime.interfaceMethod" { - c.builder.CreateCall(lookupBoundsCheck, []llvm.Value{buflen, index}, "") - } + c.emitBoundsCheck(frame, buflen, index) switch expr.X.Type().(type) { case *types.Pointer: @@ -1811,13 +1825,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Bounds check. // LLVM optimizes this away in most cases. - if frame.fn.llvmFn.Name() != "runtime.lookupBoundsCheck" { - length, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") - if err != nil { - return llvm.Value{}, err // shouldn't happen - } - c.builder.CreateCall(c.mod.NamedFunction("runtime.lookupBoundsCheck"), []llvm.Value{length, index}, "") + length, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") + if err != nil { + return llvm.Value{}, err // shouldn't happen } + c.emitBoundsCheck(frame, length, index) // Lookup byte buf := c.builder.CreateExtractValue(value, 1, "") @@ -1932,8 +1944,10 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { } // This check is optimized away in most cases. - sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") - c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{llvmLen, low, high}, "") + if !frame.fn.nobounds { + sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") + c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{llvmLen, low, high}, "") + } slice := llvm.ConstNamedStruct(sliceTyp, []llvm.Value{ llvm.Undef(slicePtr.Type()), diff --git a/ir.go b/ir.go index 69737be0..7f5212bf 100644 --- a/ir.go +++ b/ir.go @@ -35,8 +35,9 @@ type Program struct { type Function struct { fn *ssa.Function llvmFn llvm.Value - linkName string - blocking bool + linkName string // go:linkname pragma + nobounds bool // go:nobounds pragma + blocking bool // calculated by AnalyseBlockingRecursive flag bool // used by dead code elimination addressTaken bool // used as function pointer, calculated by AnalyseFunctionPointers parents []*Function // calculated by AnalyseCallgraph @@ -169,6 +170,14 @@ func (f *Function) parsePragmas() { if hasUnsafeImport(f.fn.Pkg.Pkg) { f.linkName = parts[2] } + case "//go:nobounds": + // Skip bounds checking in this function. Useful for some + // runtime functions. + // This is somewhat dangerous and thus only imported in packages + // that import unsafe. + if hasUnsafeImport(f.fn.Pkg.Pkg) { + f.nobounds = true + } } } } diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go index 3d6869ad..30cc781e 100644 --- a/src/runtime/hashmap.go +++ b/src/runtime/hashmap.go @@ -66,6 +66,7 @@ func hashmapMake(keySize, valueSize uint8) *hashmap { } // Set a specified key to a given value. Grow the map if necessary. +//go:nobounds func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint32, keyEqual func(x, y unsafe.Pointer, n uintptr) bool) { numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) @@ -114,6 +115,7 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 } // Get the value of a specified key, or zero the value if not found. +//go:nobounds func hashmapGet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint32, keyEqual func(x, y unsafe.Pointer, n uintptr) bool) { numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) diff --git a/src/runtime/interface.go b/src/runtime/interface.go index b4d5c231..98d8f954 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -44,6 +44,7 @@ var ( // Get the function pointer for the method on the interface. // This is a compiler intrinsic. +//go:nobounds func interfaceMethod(itf _interface, method uint16) *uint8 { // This function doesn't do bounds checking as the supplied method must be // in the list of signatures. The compiler will only emit diff --git a/src/runtime/print.go b/src/runtime/print.go index 54d98afd..2a42f99c 100644 --- a/src/runtime/print.go +++ b/src/runtime/print.go @@ -4,6 +4,7 @@ import ( "unsafe" ) +//go:nobounds func printstring(s string) { for i := 0; i < len(s); i++ { putchar(s[i]) @@ -42,6 +43,7 @@ func printint16(n uint16) { printint32(int32(n)) } +//go:nobounds func printuint32(n uint32) { digits := [10]byte{} // enough to hold (2^32)-1 // Fill in all 10 digits. diff --git a/src/runtime/runtime.go b/src/runtime/runtime.go index 07cf049c..94b1f270 100644 --- a/src/runtime/runtime.go +++ b/src/runtime/runtime.go @@ -56,7 +56,7 @@ func _panic(message interface{}) { abort() } -// Check for bounds in *ssa.IndexAddr and *ssa.Lookup. +// Check for bounds in *ssa.Index, *ssa.IndexAddr and *ssa.Lookup. func lookupBoundsCheck(length, index int) { if index < 0 || index >= length { // printstring() here is safe as this function is excluded from bounds diff --git a/src/runtime/string.go b/src/runtime/string.go index 6b56aeb1..64a60b3a 100644 --- a/src/runtime/string.go +++ b/src/runtime/string.go @@ -13,6 +13,7 @@ type _string struct { } // Return true iff the strings match. +//go:nobounds func stringEqual(x, y string) bool { if len(x) != len(y) { return false