From a867b56e5f9e6778f48a33d75b6543b0ca92dcc1 Mon Sep 17 00:00:00 2001 From: Nia Weiss Date: Fri, 15 Jan 2021 10:48:07 -0500 Subject: [PATCH] compiler: saturate float-to-int conversions This works around some UB in LLVM, where an out-of-bounds conversion would produce a poison value. The selected behavior is saturating, except that NaN is mapped to the minimum value. --- compiler/compiler.go | 65 ++++++++++++++++++++++++++++++++++++++++++-- testdata/float.go | 7 ++++- testdata/float.txt | 2 +- 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index c5b7be65..35d42eef 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -8,6 +8,7 @@ import ( "go/constant" "go/token" "go/types" + "math/bits" "path/filepath" "sort" "strconv" @@ -2552,10 +2553,70 @@ func (b *builder) createConvert(typeFrom, typeTo types.Type, value llvm.Value, p if typeFrom.Info()&types.IsFloat != 0 && typeTo.Info()&types.IsInteger != 0 { // Conversion from float to int. + // Passing an out-of-bounds float to LLVM would cause UB, so that UB is trapped by select instructions. + // The Go specification says that this should be implementation-defined behavior. + // This implements saturating behavior, except that NaN is mapped to the minimum value. + var significandBits int + switch typeFrom.Kind() { + case types.Float32: + significandBits = 23 + case types.Float64: + significandBits = 52 + } if typeTo.Info()&types.IsUnsigned != 0 { // if unsigned - return b.CreateFPToUI(value, llvmTypeTo, ""), nil + // Select the maximum value for this unsigned integer type. + max := ^(^uint64(0) << uint(llvmTypeTo.IntTypeWidth())) + maxFloat := float64(max) + if bits.Len64(max) > significandBits { + // Round the max down to fit within the significand. + maxFloat = float64(max & ^uint64(0) << uint(bits.Len64(max)-significandBits)) + } + + // Check if the value is in-bounds (0 <= value <= max). + positive := b.CreateFCmp(llvm.FloatOLE, llvm.ConstNull(llvmTypeFrom), value, "positive") + withinMax := b.CreateFCmp(llvm.FloatOLE, value, llvm.ConstFloat(llvmTypeFrom, maxFloat), "withinmax") + inBounds := b.CreateAnd(positive, withinMax, "inbounds") + + // Assuming that the value is out-of-bounds, select a saturated value. + saturated := b.CreateSelect(positive, + llvm.ConstInt(llvmTypeTo, max, false), // value > max + llvm.ConstNull(llvmTypeTo), // value < 0 (or NaN) + "saturated", + ) + + // Do a normal conversion. + normal := b.CreateFPToUI(value, llvmTypeTo, "normal") + + return b.CreateSelect(inBounds, normal, saturated, ""), nil } else { // if signed - return b.CreateFPToSI(value, llvmTypeTo, ""), nil + // Select the minimum value for this signed integer type. + min := uint64(1) << uint(llvmTypeTo.IntTypeWidth()-1) + minFloat := -float64(min) + + // Select the maximum value for this signed integer type. + max := ^(^uint64(0) << uint(llvmTypeTo.IntTypeWidth()-1)) + maxFloat := float64(max) + if bits.Len64(max) > significandBits { + // Round the max down to fit within the significand. + maxFloat = float64(max & ^uint64(0) << uint(bits.Len64(max)-significandBits)) + } + + // Check if the value is in-bounds (min <= value <= max). + aboveMin := b.CreateFCmp(llvm.FloatOLE, llvm.ConstFloat(llvmTypeFrom, minFloat), value, "abovemin") + belowMax := b.CreateFCmp(llvm.FloatOLE, value, llvm.ConstFloat(llvmTypeFrom, maxFloat), "belowmax") + inBounds := b.CreateAnd(aboveMin, belowMax, "inbounds") + + // Assuming that the value is out-of-bounds, select a saturated value. + saturated := b.CreateSelect(aboveMin, + llvm.ConstInt(llvmTypeTo, max, false), // value > max + llvm.ConstInt(llvmTypeTo, min, false), // value < min (or NaN) + "saturated", + ) + + // Do a normal conversion. + normal := b.CreateFPToSI(value, llvmTypeTo, "normal") + + return b.CreateSelect(inBounds, normal, saturated, ""), nil } } diff --git a/testdata/float.go b/testdata/float.go index 8c33cf11..6d2d1bca 100644 --- a/testdata/float.go +++ b/testdata/float.go @@ -29,7 +29,12 @@ func main() { var f2 float32 = 5.7 var f3 float32 = -2.3 var f4 float32 = -11.8 - println(int32(f1), int32(f2), int32(f3), int32(f4)) + var f5 float32 = -1 + var f6 float32 = 256 + var f7 float32 = -129 + var f8 float32 = 0 + f8 /= 0 + println(int32(f1), int32(f2), int32(f3), int32(f4), uint8(f5), uint8(f6), int8(f7), int8(f6), uint8(f8), int8(f8)) // int -> float var i1 int32 = 53 diff --git a/testdata/float.txt b/testdata/float.txt index dbc36fb8..12c48598 100644 --- a/testdata/float.txt +++ b/testdata/float.txt @@ -11,7 +11,7 @@ +3.333333e-001 +6.666667e-001 +6.666667e-001 -3 5 -2 -11 +3 5 -2 -11 0 255 -128 127 0 -128 +5.300000e+001 -8.000000e+000 +2.000000e+001 (+6.666667e-001+1.200000e+000i) +6.666667e-001