compiler: implement most math/bits functions

These functions can be implemented more efficiently using LLVM
intrinsics. That makes them the Go equivalent of functions like
__builtin_clz which are also implemented using these LLVM intrinsics.

I believe the Go compiler does something very similar: IIRC it converts
calls to these functions into optimal instructions for the given
architecture.

I tested these by running `tinygo test math/bits` after uncommenting the
tests that would always fail (the *PanicZero and *PanicOverflow tests).
Этот коммит содержится в:
Ayke van Laethem 2023-03-29 18:33:24 +02:00 коммит произвёл Ron Evans
родитель 568c2a4363
коммит 464ebc4fe1
3 изменённых файлов: 133 добавлений и 0 удалений

Просмотреть файл

@ -845,6 +845,11 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package
b.defineMathOp()
continue
}
if ok := b.defineMathBitsIntrinsic(); ok {
// Like a math intrinsic, the body of this function was replaced
// with a LLVM intrinsic.
continue
}
if member.Blocks == nil {
// Try to define this as an intrinsic function.
b.defineIntrinsicFunction()

Просмотреть файл

@ -154,3 +154,118 @@ func (b *builder) defineMathOp() {
result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "")
b.CreateRet(result)
}
// Implement most math/bits functions.
//
// This implements all the functions that operate on bits. It does not yet
// implement the arithmetic functions (like bits.Add), which also have LLVM
// intrinsics.
func (b *builder) defineMathBitsIntrinsic() bool {
if b.fn.Pkg.Pkg.Path() != "math/bits" {
return false
}
name := b.fn.Name()
switch name {
case "LeadingZeros", "LeadingZeros8", "LeadingZeros16", "LeadingZeros32", "LeadingZeros64",
"TrailingZeros", "TrailingZeros8", "TrailingZeros16", "TrailingZeros32", "TrailingZeros64":
b.createFunctionStart(true)
param := b.getValue(b.fn.Params[0], b.fn.Pos())
valueType := param.Type()
var intrinsicName string
if strings.HasPrefix(name, "Leading") { // LeadingZeros
intrinsicName = "llvm.ctlz.i" + strconv.Itoa(valueType.IntTypeWidth())
} else { // TrailingZeros
intrinsicName = "llvm.cttz.i" + strconv.Itoa(valueType.IntTypeWidth())
}
llvmFn := b.mod.NamedFunction(intrinsicName)
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
if llvmFn.IsNil() {
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
}
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
param,
llvm.ConstInt(b.ctx.Int1Type(), 0, false),
}, "")
result = b.createZExtOrTrunc(result, b.intType)
b.CreateRet(result)
return true
case "Len", "Len8", "Len16", "Len32", "Len64":
// bits.Len can be implemented as:
// (unsafe.Sizeof(v) * 8) - bits.LeadingZeros(n)
// Not sure why this isn't already done in the standard library, as it
// is much simpler than a lookup table.
b.createFunctionStart(true)
param := b.getValue(b.fn.Params[0], b.fn.Pos())
valueType := param.Type()
valueBits := valueType.IntTypeWidth()
intrinsicName := "llvm.ctlz.i" + strconv.Itoa(valueBits)
llvmFn := b.mod.NamedFunction(intrinsicName)
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
if llvmFn.IsNil() {
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
}
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
param,
llvm.ConstInt(b.ctx.Int1Type(), 0, false),
}, "")
result = b.createZExtOrTrunc(result, b.intType)
maxLen := llvm.ConstInt(b.intType, uint64(valueBits), false) // number of bits in the value
result = b.CreateSub(maxLen, result, "")
b.CreateRet(result)
return true
case "OnesCount", "OnesCount8", "OnesCount16", "OnesCount32", "OnesCount64":
b.createFunctionStart(true)
param := b.getValue(b.fn.Params[0], b.fn.Pos())
valueType := param.Type()
intrinsicName := "llvm.ctpop.i" + strconv.Itoa(valueType.IntTypeWidth())
llvmFn := b.mod.NamedFunction(intrinsicName)
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
if llvmFn.IsNil() {
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
}
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
result = b.createZExtOrTrunc(result, b.intType)
b.CreateRet(result)
return true
case "Reverse", "Reverse8", "Reverse16", "Reverse32", "Reverse64",
"ReverseBytes", "ReverseBytes16", "ReverseBytes32", "ReverseBytes64":
b.createFunctionStart(true)
param := b.getValue(b.fn.Params[0], b.fn.Pos())
valueType := param.Type()
var intrinsicName string
if strings.HasPrefix(name, "ReverseBytes") {
intrinsicName = "llvm.bswap.i" + strconv.Itoa(valueType.IntTypeWidth())
} else { // Reverse
intrinsicName = "llvm.bitreverse.i" + strconv.Itoa(valueType.IntTypeWidth())
}
llvmFn := b.mod.NamedFunction(intrinsicName)
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
if llvmFn.IsNil() {
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
}
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
b.CreateRet(result)
return true
case "RotateLeft", "RotateLeft8", "RotateLeft16", "RotateLeft32", "RotateLeft64":
// Warning: the documentation says these functions must be constant time.
// I do not think LLVM guarantees this, but there's a good chance LLVM
// already recognized the rotate instruction so it probably won't get
// any _worse_ by implementing these rotate functions.
b.createFunctionStart(true)
x := b.getValue(b.fn.Params[0], b.fn.Pos())
k := b.getValue(b.fn.Params[1], b.fn.Pos())
valueType := x.Type()
intrinsicName := "llvm.fshl.i" + strconv.Itoa(valueType.IntTypeWidth())
llvmFn := b.mod.NamedFunction(intrinsicName)
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, valueType, valueType}, false)
if llvmFn.IsNil() {
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
}
k = b.createZExtOrTrunc(k, valueType)
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{x, x, k}, "")
b.CreateRet(result)
return true
default:
return false
}
}

Просмотреть файл

@ -464,6 +464,19 @@ func (b *builder) readStackPointer() llvm.Value {
return b.CreateCall(stacksave.GlobalValueType(), stacksave, nil, "")
}
// createZExtOrTrunc lets the input value fit in the output type bits, by zero
// extending or truncating the integer.
func (b *builder) createZExtOrTrunc(value llvm.Value, t llvm.Type) llvm.Value {
valueBits := value.Type().IntTypeWidth()
resultBits := t.IntTypeWidth()
if valueBits > resultBits {
value = b.CreateTrunc(value, t, "")
} else if valueBits < resultBits {
value = b.CreateZExt(value, t, "")
}
return value
}
// Reverse a slice of bytes. From the wiki:
// https://github.com/golang/go/wiki/SliceTricks#reversing
func reverseBytes(buf []byte) {