compiler: add dereferenceable_or_null attribute where possible
This gives a hint to the compiler that such parameters are either NULL or point to a valid object that can be dereferenced. This is not directly very useful, but is very useful when combined with https://reviews.llvm.org/D60047 to remove the runtime.isnil hack without regressing escape analysis.
Этот коммит содержится в:
родитель
980068543a
коммит
85854cd58b
4 изменённых файлов: 98 добавлений и 17 удалений
|
@ -1,6 +1,8 @@
|
|||
package compiler
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
|
||||
"tinygo.org/x/go-llvm"
|
||||
)
|
||||
|
||||
|
@ -11,6 +13,16 @@ import (
|
|||
// a struct contains more fields, it is passed as a struct without expanding.
|
||||
const MaxFieldsPerParam = 3
|
||||
|
||||
// paramFlags identifies parameter attributes for flags. Most importantly, it
|
||||
// determines which parameters are dereferenceable_or_null and which aren't.
|
||||
type paramFlags uint8
|
||||
|
||||
const (
|
||||
// Parameter may have the deferenceable_or_null attribute. This attribute
|
||||
// cannot be applied to unsafe.Pointer and to the data pointer of slices.
|
||||
paramIsDeferenceableOrNull = 1 << iota
|
||||
)
|
||||
|
||||
// createCall creates a new call to runtime.<fnName> with the given arguments.
|
||||
func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
|
||||
fullName := "runtime." + fnName
|
||||
|
@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm
|
|||
|
||||
// Expand an argument type to a list that can be used in a function call
|
||||
// parameter list.
|
||||
func expandFormalParamType(t llvm.Type) []llvm.Type {
|
||||
func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
|
||||
switch t.TypeKind() {
|
||||
case llvm.StructTypeKind:
|
||||
fields := flattenAggregateType(t)
|
||||
fields, fieldFlags := flattenAggregateType(t, goType)
|
||||
if len(fields) <= MaxFieldsPerParam {
|
||||
return fields
|
||||
return fields, fieldFlags
|
||||
} else {
|
||||
// failed to lower
|
||||
return []llvm.Type{t}
|
||||
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
|
||||
}
|
||||
default:
|
||||
// TODO: split small arrays
|
||||
return []llvm.Type{t}
|
||||
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 {
|
|||
func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
|
||||
switch v.Type().TypeKind() {
|
||||
case llvm.StructTypeKind:
|
||||
fieldTypes := flattenAggregateType(v.Type())
|
||||
fieldTypes, _ := flattenAggregateType(v.Type(), nil)
|
||||
if len(fieldTypes) <= MaxFieldsPerParam {
|
||||
fields := b.flattenAggregate(v)
|
||||
if len(fields) != len(fieldTypes) {
|
||||
|
@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
|
|||
|
||||
// Try to flatten a struct type to a list of types. Returns a 1-element slice
|
||||
// with the passed in type if this is not possible.
|
||||
func flattenAggregateType(t llvm.Type) []llvm.Type {
|
||||
func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
|
||||
typeFlags := getTypeFlags(goType)
|
||||
switch t.TypeKind() {
|
||||
case llvm.StructTypeKind:
|
||||
fields := make([]llvm.Type, 0, t.StructElementTypesCount())
|
||||
for _, subfield := range t.StructElementTypes() {
|
||||
subfields := flattenAggregateType(subfield)
|
||||
fieldFlags := make([]paramFlags, 0, cap(fields))
|
||||
for i, subfield := range t.StructElementTypes() {
|
||||
subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i))
|
||||
for i := range subfieldFlags {
|
||||
subfieldFlags[i] |= typeFlags
|
||||
}
|
||||
fields = append(fields, subfields...)
|
||||
fieldFlags = append(fieldFlags, subfieldFlags...)
|
||||
}
|
||||
return fields
|
||||
return fields, fieldFlags
|
||||
default:
|
||||
return []llvm.Type{t}
|
||||
return []llvm.Type{t}, []paramFlags{typeFlags}
|
||||
}
|
||||
}
|
||||
|
||||
// getTypeFlags returns the type flags for a given type. It will not recurse
|
||||
// into sub-types (such as in structs).
|
||||
func getTypeFlags(t types.Type) paramFlags {
|
||||
if t == nil {
|
||||
return 0
|
||||
}
|
||||
switch t.Underlying().(type) {
|
||||
case *types.Pointer:
|
||||
// Pointers in Go must either point to an object or be nil.
|
||||
return paramIsDeferenceableOrNull
|
||||
case *types.Chan, *types.Map:
|
||||
// Channels and maps are implemented as pointers pointing to some
|
||||
// object, and follow the same rules as *types.Pointer.
|
||||
return paramIsDeferenceableOrNull
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// extractSubfield extracts a field from a struct, or returns null if this is
|
||||
// not a struct and thus no subfield can be obtained.
|
||||
func extractSubfield(t types.Type, field int) types.Type {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
switch t := t.Underlying().(type) {
|
||||
case *types.Struct:
|
||||
return t.Field(field).Type()
|
||||
case *types.Interface, *types.Slice, *types.Basic, *types.Signature:
|
||||
// These Go types are (sometimes) implemented as LLVM structs but can't
|
||||
// really be split further up in Go (with the possible exception of
|
||||
// complex numbers).
|
||||
return nil
|
||||
default:
|
||||
// This should be unreachable.
|
||||
panic("cannot split subfield: " + t.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val
|
|||
func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
|
||||
switch t.TypeKind() {
|
||||
case llvm.StructTypeKind:
|
||||
if len(flattenAggregateType(t)) <= MaxFieldsPerParam {
|
||||
flattened, _ := flattenAggregateType(t, nil)
|
||||
if len(flattened) <= MaxFieldsPerParam {
|
||||
value := llvm.ConstNull(t)
|
||||
for i, subtyp := range t.StructElementTypes() {
|
||||
structField, remaining := b.collapseFormalParamInternal(subtyp, fields)
|
||||
|
|
|
@ -750,10 +750,12 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
|
|||
}
|
||||
|
||||
var paramTypes []llvm.Type
|
||||
var paramTypeVariants []paramFlags
|
||||
for _, param := range f.Params {
|
||||
paramType := c.getLLVMType(param.Type())
|
||||
paramTypeFragments := expandFormalParamType(paramType)
|
||||
paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type())
|
||||
paramTypes = append(paramTypes, paramTypeFragments...)
|
||||
paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...)
|
||||
}
|
||||
|
||||
// Add an extra parameter as the function context. This context is used in
|
||||
|
@ -761,6 +763,7 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
|
|||
if !f.IsExported() {
|
||||
paramTypes = append(paramTypes, c.i8ptrType) // context
|
||||
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
|
||||
paramTypeVariants = append(paramTypeVariants, 0, 0)
|
||||
}
|
||||
|
||||
fnType := llvm.FunctionType(retType, paramTypes, false)
|
||||
|
@ -771,6 +774,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
|
|||
f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
|
||||
}
|
||||
|
||||
dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null")
|
||||
for i, typ := range paramTypes {
|
||||
if paramTypeVariants[i]¶mIsDeferenceableOrNull == 0 {
|
||||
continue
|
||||
}
|
||||
if typ.TypeKind() == llvm.PointerTypeKind {
|
||||
el := typ.ElementType()
|
||||
size := c.targetData.TypeAllocSize(el)
|
||||
if size == 0 {
|
||||
// dereferenceable_or_null(0) appears to be illegal in LLVM.
|
||||
continue
|
||||
}
|
||||
dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size)
|
||||
f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull)
|
||||
}
|
||||
}
|
||||
|
||||
// External/exported functions may not retain pointer values.
|
||||
// https://golang.org/cmd/cgo/#hdr-Passing_pointers
|
||||
if f.IsExported() {
|
||||
|
@ -901,7 +921,8 @@ func (b *builder) createFunctionDefinition() {
|
|||
for _, param := range b.fn.Params {
|
||||
llvmType := b.getLLVMType(param.Type())
|
||||
fields := make([]llvm.Value, 0, 1)
|
||||
for range expandFormalParamType(llvmType) {
|
||||
fieldFragments, _ := expandFormalParamType(llvmType, nil)
|
||||
for range fieldFragments {
|
||||
fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
|
||||
llvmParamIndex++
|
||||
}
|
||||
|
|
|
@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type {
|
|||
// The receiver is not an interface, but a i8* type.
|
||||
recv = c.i8ptrType
|
||||
}
|
||||
paramTypes = append(paramTypes, expandFormalParamType(recv)...)
|
||||
recvFragments, _ := expandFormalParamType(recv, nil)
|
||||
paramTypes = append(paramTypes, recvFragments...)
|
||||
}
|
||||
for i := 0; i < typ.Params().Len(); i++ {
|
||||
subType := c.getLLVMType(typ.Params().At(i).Type())
|
||||
paramTypes = append(paramTypes, expandFormalParamType(subType)...)
|
||||
paramTypeFragments, _ := expandFormalParamType(subType, nil)
|
||||
paramTypes = append(paramTypes, paramTypeFragments...)
|
||||
}
|
||||
// All functions take these parameters at the end.
|
||||
paramTypes = append(paramTypes, c.i8ptrType) // context
|
||||
|
|
|
@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value {
|
|||
|
||||
// Get the expanded receiver type.
|
||||
receiverType := c.getLLVMType(f.Params[0].Type())
|
||||
expandedReceiverType := expandFormalParamType(receiverType)
|
||||
expandedReceiverType, _ := expandFormalParamType(receiverType, nil)
|
||||
|
||||
// Does this method even need any wrapping?
|
||||
if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {
|
||||
|
|
Загрузка…
Создание таблицы
Сослаться в новой задаче