compiler: implement most math/bits functions
These functions can be implemented more efficiently using LLVM intrinsics. That makes them the Go equivalent of functions like __builtin_clz which are also implemented using these LLVM intrinsics. I believe the Go compiler does something very similar: IIRC it converts calls to these functions into optimal instructions for the given architecture. I tested these by running `tinygo test math/bits` after uncommenting the tests that would always fail (the *PanicZero and *PanicOverflow tests).
Этот коммит содержится в:
родитель
568c2a4363
коммит
464ebc4fe1
3 изменённых файлов: 133 добавлений и 0 удалений
|
@ -845,6 +845,11 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package
|
||||||
b.defineMathOp()
|
b.defineMathOp()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if ok := b.defineMathBitsIntrinsic(); ok {
|
||||||
|
// Like a math intrinsic, the body of this function was replaced
|
||||||
|
// with a LLVM intrinsic.
|
||||||
|
continue
|
||||||
|
}
|
||||||
if member.Blocks == nil {
|
if member.Blocks == nil {
|
||||||
// Try to define this as an intrinsic function.
|
// Try to define this as an intrinsic function.
|
||||||
b.defineIntrinsicFunction()
|
b.defineIntrinsicFunction()
|
||||||
|
|
|
@ -154,3 +154,118 @@ func (b *builder) defineMathOp() {
|
||||||
result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "")
|
result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "")
|
||||||
b.CreateRet(result)
|
b.CreateRet(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Implement most math/bits functions.
|
||||||
|
//
|
||||||
|
// This implements all the functions that operate on bits. It does not yet
|
||||||
|
// implement the arithmetic functions (like bits.Add), which also have LLVM
|
||||||
|
// intrinsics.
|
||||||
|
func (b *builder) defineMathBitsIntrinsic() bool {
|
||||||
|
if b.fn.Pkg.Pkg.Path() != "math/bits" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
name := b.fn.Name()
|
||||||
|
switch name {
|
||||||
|
case "LeadingZeros", "LeadingZeros8", "LeadingZeros16", "LeadingZeros32", "LeadingZeros64",
|
||||||
|
"TrailingZeros", "TrailingZeros8", "TrailingZeros16", "TrailingZeros32", "TrailingZeros64":
|
||||||
|
b.createFunctionStart(true)
|
||||||
|
param := b.getValue(b.fn.Params[0], b.fn.Pos())
|
||||||
|
valueType := param.Type()
|
||||||
|
var intrinsicName string
|
||||||
|
if strings.HasPrefix(name, "Leading") { // LeadingZeros
|
||||||
|
intrinsicName = "llvm.ctlz.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
} else { // TrailingZeros
|
||||||
|
intrinsicName = "llvm.cttz.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
}
|
||||||
|
llvmFn := b.mod.NamedFunction(intrinsicName)
|
||||||
|
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
|
||||||
|
if llvmFn.IsNil() {
|
||||||
|
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
|
||||||
|
}
|
||||||
|
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
|
||||||
|
param,
|
||||||
|
llvm.ConstInt(b.ctx.Int1Type(), 0, false),
|
||||||
|
}, "")
|
||||||
|
result = b.createZExtOrTrunc(result, b.intType)
|
||||||
|
b.CreateRet(result)
|
||||||
|
return true
|
||||||
|
case "Len", "Len8", "Len16", "Len32", "Len64":
|
||||||
|
// bits.Len can be implemented as:
|
||||||
|
// (unsafe.Sizeof(v) * 8) - bits.LeadingZeros(n)
|
||||||
|
// Not sure why this isn't already done in the standard library, as it
|
||||||
|
// is much simpler than a lookup table.
|
||||||
|
b.createFunctionStart(true)
|
||||||
|
param := b.getValue(b.fn.Params[0], b.fn.Pos())
|
||||||
|
valueType := param.Type()
|
||||||
|
valueBits := valueType.IntTypeWidth()
|
||||||
|
intrinsicName := "llvm.ctlz.i" + strconv.Itoa(valueBits)
|
||||||
|
llvmFn := b.mod.NamedFunction(intrinsicName)
|
||||||
|
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
|
||||||
|
if llvmFn.IsNil() {
|
||||||
|
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
|
||||||
|
}
|
||||||
|
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
|
||||||
|
param,
|
||||||
|
llvm.ConstInt(b.ctx.Int1Type(), 0, false),
|
||||||
|
}, "")
|
||||||
|
result = b.createZExtOrTrunc(result, b.intType)
|
||||||
|
maxLen := llvm.ConstInt(b.intType, uint64(valueBits), false) // number of bits in the value
|
||||||
|
result = b.CreateSub(maxLen, result, "")
|
||||||
|
b.CreateRet(result)
|
||||||
|
return true
|
||||||
|
case "OnesCount", "OnesCount8", "OnesCount16", "OnesCount32", "OnesCount64":
|
||||||
|
b.createFunctionStart(true)
|
||||||
|
param := b.getValue(b.fn.Params[0], b.fn.Pos())
|
||||||
|
valueType := param.Type()
|
||||||
|
intrinsicName := "llvm.ctpop.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
llvmFn := b.mod.NamedFunction(intrinsicName)
|
||||||
|
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
|
||||||
|
if llvmFn.IsNil() {
|
||||||
|
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
|
||||||
|
}
|
||||||
|
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
|
||||||
|
result = b.createZExtOrTrunc(result, b.intType)
|
||||||
|
b.CreateRet(result)
|
||||||
|
return true
|
||||||
|
case "Reverse", "Reverse8", "Reverse16", "Reverse32", "Reverse64",
|
||||||
|
"ReverseBytes", "ReverseBytes16", "ReverseBytes32", "ReverseBytes64":
|
||||||
|
b.createFunctionStart(true)
|
||||||
|
param := b.getValue(b.fn.Params[0], b.fn.Pos())
|
||||||
|
valueType := param.Type()
|
||||||
|
var intrinsicName string
|
||||||
|
if strings.HasPrefix(name, "ReverseBytes") {
|
||||||
|
intrinsicName = "llvm.bswap.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
} else { // Reverse
|
||||||
|
intrinsicName = "llvm.bitreverse.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
}
|
||||||
|
llvmFn := b.mod.NamedFunction(intrinsicName)
|
||||||
|
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
|
||||||
|
if llvmFn.IsNil() {
|
||||||
|
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
|
||||||
|
}
|
||||||
|
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
|
||||||
|
b.CreateRet(result)
|
||||||
|
return true
|
||||||
|
case "RotateLeft", "RotateLeft8", "RotateLeft16", "RotateLeft32", "RotateLeft64":
|
||||||
|
// Warning: the documentation says these functions must be constant time.
|
||||||
|
// I do not think LLVM guarantees this, but there's a good chance LLVM
|
||||||
|
// already recognized the rotate instruction so it probably won't get
|
||||||
|
// any _worse_ by implementing these rotate functions.
|
||||||
|
b.createFunctionStart(true)
|
||||||
|
x := b.getValue(b.fn.Params[0], b.fn.Pos())
|
||||||
|
k := b.getValue(b.fn.Params[1], b.fn.Pos())
|
||||||
|
valueType := x.Type()
|
||||||
|
intrinsicName := "llvm.fshl.i" + strconv.Itoa(valueType.IntTypeWidth())
|
||||||
|
llvmFn := b.mod.NamedFunction(intrinsicName)
|
||||||
|
llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, valueType, valueType}, false)
|
||||||
|
if llvmFn.IsNil() {
|
||||||
|
llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
|
||||||
|
}
|
||||||
|
k = b.createZExtOrTrunc(k, valueType)
|
||||||
|
result := b.createCall(llvmFnType, llvmFn, []llvm.Value{x, x, k}, "")
|
||||||
|
b.CreateRet(result)
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -464,6 +464,19 @@ func (b *builder) readStackPointer() llvm.Value {
|
||||||
return b.CreateCall(stacksave.GlobalValueType(), stacksave, nil, "")
|
return b.CreateCall(stacksave.GlobalValueType(), stacksave, nil, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createZExtOrTrunc lets the input value fit in the output type bits, by zero
|
||||||
|
// extending or truncating the integer.
|
||||||
|
func (b *builder) createZExtOrTrunc(value llvm.Value, t llvm.Type) llvm.Value {
|
||||||
|
valueBits := value.Type().IntTypeWidth()
|
||||||
|
resultBits := t.IntTypeWidth()
|
||||||
|
if valueBits > resultBits {
|
||||||
|
value = b.CreateTrunc(value, t, "")
|
||||||
|
} else if valueBits < resultBits {
|
||||||
|
value = b.CreateZExt(value, t, "")
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
// Reverse a slice of bytes. From the wiki:
|
// Reverse a slice of bytes. From the wiki:
|
||||||
// https://github.com/golang/go/wiki/SliceTricks#reversing
|
// https://github.com/golang/go/wiki/SliceTricks#reversing
|
||||||
func reverseBytes(buf []byte) {
|
func reverseBytes(buf []byte) {
|
||||||
|
|
Загрузка…
Создание таблицы
Сослаться в новой задаче