compiler: zero struct padding during map operations

Fixes #3358
Этот коммит содержится в:
Damian Gryski 2023-02-25 13:40:08 -08:00 коммит произвёл GitHub
родитель 7b44fcd865
коммит 476621736c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 286 добавлений и 1 удалений

Просмотреть файл

@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) {
{"goroutine.go", "cortex-m-qemu", "tasks"},
{"channel.go", "", ""},
{"gc.go", "", ""},
{"zeromap.go", "", ""},
}
if goMinor >= 20 {
tests = append(tests, testCase{"go1.20.go", "", ""})

Просмотреть файл

@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
// growth.
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, mapKeyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
// Fetch the value from the hashmap.
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
// key can be compared with runtime.memequal
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr, valuePtr}
b.createRuntimeCall("hashmapBinarySet", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
} else if hashmapIsBinaryKey(keyType) {
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr}
b.createRuntimeCall("hashmapBinaryDelete", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
}
// Returns true if this key type does not contain strings, interfaces etc., so
// can be compared with runtime.memequal.
// can be compared with runtime.memequal. Note that padding bytes are undef
// and can alter two "equal" structs being equal when compared with memequal.
func hashmapIsBinaryKey(keyType types.Type) bool {
switch keyType := keyType.(type) {
case *types.Basic:
@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
return false
}
}
func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
// we handle nested types. Next, we determine if there are any padding bytes before the next
// element and zero those as well.
zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
switch llvmType.TypeKind() {
case llvm.IntegerTypeKind:
// no padding bytes
return nil
case llvm.PointerTypeKind:
// mo padding bytes
return nil
case llvm.ArrayTypeKind:
llvmArrayType := llvmType
llvmElemType := llvmType.ElementType()
for i := 0; i < llvmArrayType.ArrayLength(); i++ {
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "")
// zero any padding bytes in this element
b.zeroUndefBytes(llvmElemType, elemPtr)
}
case llvm.StructTypeKind:
llvmStructType := llvmType
numFields := llvmStructType.StructElementTypesCount()
llvmElementTypes := llvmStructType.StructElementTypes()
for i := 0; i < numFields; i++ {
idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)
elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "")
// zero any padding bytes in this field
llvmElemType := llvmElementTypes[i]
b.zeroUndefBytes(llvmElemType, elemPtr)
// zero any padding bytes before the next field, if any
offset := b.targetData.ElementOffset(llvmStructType, i)
storeSize := b.targetData.TypeStoreSize(llvmElemType)
fieldEndOffset := offset + storeSize
var nextOffset uint64
if i < numFields-1 {
nextOffset = b.targetData.ElementOffset(llvmStructType, i+1)
} else {
// Last field? Next offset is the total size of the allcoate struct.
nextOffset = b.targetData.TypeAllocSize(llvmStructType)
}
if fieldEndOffset != nextOffset {
n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false)
llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false)
gepPtr := elemPtr
if gepPtr.Type() != b.i8ptrType {
gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14
}
paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "")
if paddingStart.Type() != b.i8ptrType {
paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14
}
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
}
}
}
return nil
}

37
compiler/testdata/zeromap.go предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1,37 @@
package main
type hasPadding struct {
b1 bool
i int
b2 bool
}
type nestedPadding struct {
b bool
hasPadding
i int
}
//go:noinline
func testZeroGet(m map[hasPadding]int, s hasPadding) int {
return m[s]
}
//go:noinline
func testZeroSet(m map[hasPadding]int, s hasPadding) {
m[s] = 5
}
//go:noinline
func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int {
return m[s]
}
//go:noinline
func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) {
m[s] = 5
}
func main() {
}

170
compiler/testdata/zeromap.ll предоставленный Обычный файл
Просмотреть файл

@ -0,0 +1,170 @@
; ModuleID = 'zeromap.go'
source_filename = "zeromap.go"
target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
target triple = "wasm32-unknown-wasi"
%main.hasPadding = type { i1, i32, i1 }
declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0
declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0
; Function Attrs: nounwind
define hidden void @main.init(ptr %context) unnamed_addr #1 {
entry:
ret void
}
; Function Attrs: noinline nounwind
define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca %main.hasPadding, align 8
%hashmap.value = alloca i32, align 4
%s = alloca %main.hasPadding, align 8
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s, align 8
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
store %main.hasPadding %2, ptr %s, align 8
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
store %main.hasPadding %2, ptr %hashmap.key, align 8
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
%5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
%6 = load i32, ptr %hashmap.value, align 4
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret i32 %6
}
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3
declare void @runtime.memzero(ptr, i32, ptr) #0
declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3
; Function Attrs: noinline nounwind
define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca %main.hasPadding, align 8
%hashmap.value = alloca i32, align 4
%s = alloca %main.hasPadding, align 8
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s, align 8
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
store %main.hasPadding %2, ptr %s, align 8
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
store i32 5, ptr %hashmap.value, align 4
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
store %main.hasPadding %2, ptr %hashmap.key, align 8
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret void
}
declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0
; Function Attrs: noinline nounwind
define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca [2 x %main.hasPadding], align 8
%hashmap.value = alloca i32, align 4
%s1 = alloca [2 x %main.hasPadding], align 8
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s1, align 8
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt, ptr %s1, align 8
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
%5 = load i32, ptr %hashmap.value, align 4
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret i32 %5
}
; Function Attrs: noinline nounwind
define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca [2 x %main.hasPadding], align 8
%hashmap.value = alloca i32, align 4
%s1 = alloca [2 x %main.hasPadding], align 8
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s1, align 8
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt, ptr %s1, align 8
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
store i32 5, ptr %hashmap.value, align 4
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret void
}
; Function Attrs: nounwind
define hidden void @main.main(ptr %context) unnamed_addr #1 {
entry:
ret void
}
attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn }
attributes #4 = { nounwind }