From 464ebc4fe1122085a60f0a376189281b2e391eba Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Wed, 29 Mar 2023 18:33:24 +0200 Subject: [PATCH] compiler: implement most math/bits functions These functions can be implemented more efficiently using LLVM intrinsics. That makes them the Go equivalent of functions like __builtin_clz which are also implemented using these LLVM intrinsics. I believe the Go compiler does something very similar: IIRC it converts calls to these functions into optimal instructions for the given architecture. I tested these by running `tinygo test math/bits` after uncommenting the tests that would always fail (the *PanicZero and *PanicOverflow tests). --- compiler/compiler.go | 5 ++ compiler/intrinsics.go | 115 +++++++++++++++++++++++++++++++++++++++++ compiler/llvm.go | 13 +++++ 3 files changed, 133 insertions(+) diff --git a/compiler/compiler.go b/compiler/compiler.go index 53576fc3..e90d9dca 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -845,6 +845,11 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package b.defineMathOp() continue } + if ok := b.defineMathBitsIntrinsic(); ok { + // Like a math intrinsic, the body of this function was replaced + // with a LLVM intrinsic. + continue + } if member.Blocks == nil { // Try to define this as an intrinsic function. b.defineIntrinsicFunction() diff --git a/compiler/intrinsics.go b/compiler/intrinsics.go index c196b60d..5761a438 100644 --- a/compiler/intrinsics.go +++ b/compiler/intrinsics.go @@ -154,3 +154,118 @@ func (b *builder) defineMathOp() { result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "") b.CreateRet(result) } + +// Implement most math/bits functions. +// +// This implements all the functions that operate on bits. It does not yet +// implement the arithmetic functions (like bits.Add), which also have LLVM +// intrinsics. +func (b *builder) defineMathBitsIntrinsic() bool { + if b.fn.Pkg.Pkg.Path() != "math/bits" { + return false + } + name := b.fn.Name() + switch name { + case "LeadingZeros", "LeadingZeros8", "LeadingZeros16", "LeadingZeros32", "LeadingZeros64", + "TrailingZeros", "TrailingZeros8", "TrailingZeros16", "TrailingZeros32", "TrailingZeros64": + b.createFunctionStart(true) + param := b.getValue(b.fn.Params[0], b.fn.Pos()) + valueType := param.Type() + var intrinsicName string + if strings.HasPrefix(name, "Leading") { // LeadingZeros + intrinsicName = "llvm.ctlz.i" + strconv.Itoa(valueType.IntTypeWidth()) + } else { // TrailingZeros + intrinsicName = "llvm.cttz.i" + strconv.Itoa(valueType.IntTypeWidth()) + } + llvmFn := b.mod.NamedFunction(intrinsicName) + llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false) + if llvmFn.IsNil() { + llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType) + } + result := b.createCall(llvmFnType, llvmFn, []llvm.Value{ + param, + llvm.ConstInt(b.ctx.Int1Type(), 0, false), + }, "") + result = b.createZExtOrTrunc(result, b.intType) + b.CreateRet(result) + return true + case "Len", "Len8", "Len16", "Len32", "Len64": + // bits.Len can be implemented as: + // (unsafe.Sizeof(v) * 8) - bits.LeadingZeros(n) + // Not sure why this isn't already done in the standard library, as it + // is much simpler than a lookup table. + b.createFunctionStart(true) + param := b.getValue(b.fn.Params[0], b.fn.Pos()) + valueType := param.Type() + valueBits := valueType.IntTypeWidth() + intrinsicName := "llvm.ctlz.i" + strconv.Itoa(valueBits) + llvmFn := b.mod.NamedFunction(intrinsicName) + llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false) + if llvmFn.IsNil() { + llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType) + } + result := b.createCall(llvmFnType, llvmFn, []llvm.Value{ + param, + llvm.ConstInt(b.ctx.Int1Type(), 0, false), + }, "") + result = b.createZExtOrTrunc(result, b.intType) + maxLen := llvm.ConstInt(b.intType, uint64(valueBits), false) // number of bits in the value + result = b.CreateSub(maxLen, result, "") + b.CreateRet(result) + return true + case "OnesCount", "OnesCount8", "OnesCount16", "OnesCount32", "OnesCount64": + b.createFunctionStart(true) + param := b.getValue(b.fn.Params[0], b.fn.Pos()) + valueType := param.Type() + intrinsicName := "llvm.ctpop.i" + strconv.Itoa(valueType.IntTypeWidth()) + llvmFn := b.mod.NamedFunction(intrinsicName) + llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false) + if llvmFn.IsNil() { + llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType) + } + result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "") + result = b.createZExtOrTrunc(result, b.intType) + b.CreateRet(result) + return true + case "Reverse", "Reverse8", "Reverse16", "Reverse32", "Reverse64", + "ReverseBytes", "ReverseBytes16", "ReverseBytes32", "ReverseBytes64": + b.createFunctionStart(true) + param := b.getValue(b.fn.Params[0], b.fn.Pos()) + valueType := param.Type() + var intrinsicName string + if strings.HasPrefix(name, "ReverseBytes") { + intrinsicName = "llvm.bswap.i" + strconv.Itoa(valueType.IntTypeWidth()) + } else { // Reverse + intrinsicName = "llvm.bitreverse.i" + strconv.Itoa(valueType.IntTypeWidth()) + } + llvmFn := b.mod.NamedFunction(intrinsicName) + llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false) + if llvmFn.IsNil() { + llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType) + } + result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "") + b.CreateRet(result) + return true + case "RotateLeft", "RotateLeft8", "RotateLeft16", "RotateLeft32", "RotateLeft64": + // Warning: the documentation says these functions must be constant time. + // I do not think LLVM guarantees this, but there's a good chance LLVM + // already recognized the rotate instruction so it probably won't get + // any _worse_ by implementing these rotate functions. + b.createFunctionStart(true) + x := b.getValue(b.fn.Params[0], b.fn.Pos()) + k := b.getValue(b.fn.Params[1], b.fn.Pos()) + valueType := x.Type() + intrinsicName := "llvm.fshl.i" + strconv.Itoa(valueType.IntTypeWidth()) + llvmFn := b.mod.NamedFunction(intrinsicName) + llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, valueType, valueType}, false) + if llvmFn.IsNil() { + llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType) + } + k = b.createZExtOrTrunc(k, valueType) + result := b.createCall(llvmFnType, llvmFn, []llvm.Value{x, x, k}, "") + b.CreateRet(result) + return true + default: + return false + } +} diff --git a/compiler/llvm.go b/compiler/llvm.go index 33c6603e..0d33ab56 100644 --- a/compiler/llvm.go +++ b/compiler/llvm.go @@ -464,6 +464,19 @@ func (b *builder) readStackPointer() llvm.Value { return b.CreateCall(stacksave.GlobalValueType(), stacksave, nil, "") } +// createZExtOrTrunc lets the input value fit in the output type bits, by zero +// extending or truncating the integer. +func (b *builder) createZExtOrTrunc(value llvm.Value, t llvm.Type) llvm.Value { + valueBits := value.Type().IntTypeWidth() + resultBits := t.IntTypeWidth() + if valueBits > resultBits { + value = b.CreateTrunc(value, t, "") + } else if valueBits < resultBits { + value = b.CreateZExt(value, t, "") + } + return value +} + // Reverse a slice of bytes. From the wiki: // https://github.com/golang/go/wiki/SliceTricks#reversing func reverseBytes(buf []byte) {