compiler: add dereferenceable_or_null attribute where possible

This gives a hint to the compiler that such parameters are either NULL
or point to a valid object that can be dereferenced. This is not
directly very useful, but is very useful when combined with
https://reviews.llvm.org/D60047 to remove the runtime.isnil hack without
regressing escape analysis.
Этот коммит содержится в:
Ayke van Laethem 2020-03-19 20:15:03 +01:00 коммит произвёл Ron Evans
родитель 980068543a
коммит 85854cd58b
4 изменённых файлов: 98 добавлений и 17 удалений

Просмотреть файл

@ -1,6 +1,8 @@
package compiler
import (
"go/types"
"tinygo.org/x/go-llvm"
)
@ -11,6 +13,16 @@ import (
// a struct contains more fields, it is passed as a struct without expanding.
const MaxFieldsPerParam = 3
// paramFlags identifies parameter attributes for flags. Most importantly, it
// determines which parameters are dereferenceable_or_null and which aren't.
type paramFlags uint8
const (
// Parameter may have the deferenceable_or_null attribute. This attribute
// cannot be applied to unsafe.Pointer and to the data pointer of slices.
paramIsDeferenceableOrNull = 1 << iota
)
// createCall creates a new call to runtime.<fnName> with the given arguments.
func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
fullName := "runtime." + fnName
@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm
// Expand an argument type to a list that can be used in a function call
// parameter list.
func expandFormalParamType(t llvm.Type) []llvm.Type {
func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
switch t.TypeKind() {
case llvm.StructTypeKind:
fields := flattenAggregateType(t)
fields, fieldFlags := flattenAggregateType(t, goType)
if len(fields) <= MaxFieldsPerParam {
return fields
return fields, fieldFlags
} else {
// failed to lower
return []llvm.Type{t}
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
}
default:
// TODO: split small arrays
return []llvm.Type{t}
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
}
}
@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 {
func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
switch v.Type().TypeKind() {
case llvm.StructTypeKind:
fieldTypes := flattenAggregateType(v.Type())
fieldTypes, _ := flattenAggregateType(v.Type(), nil)
if len(fieldTypes) <= MaxFieldsPerParam {
fields := b.flattenAggregate(v)
if len(fields) != len(fieldTypes) {
@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
// Try to flatten a struct type to a list of types. Returns a 1-element slice
// with the passed in type if this is not possible.
func flattenAggregateType(t llvm.Type) []llvm.Type {
func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
typeFlags := getTypeFlags(goType)
switch t.TypeKind() {
case llvm.StructTypeKind:
fields := make([]llvm.Type, 0, t.StructElementTypesCount())
for _, subfield := range t.StructElementTypes() {
subfields := flattenAggregateType(subfield)
fieldFlags := make([]paramFlags, 0, cap(fields))
for i, subfield := range t.StructElementTypes() {
subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i))
for i := range subfieldFlags {
subfieldFlags[i] |= typeFlags
}
fields = append(fields, subfields...)
fieldFlags = append(fieldFlags, subfieldFlags...)
}
return fields
return fields, fieldFlags
default:
return []llvm.Type{t}
return []llvm.Type{t}, []paramFlags{typeFlags}
}
}
// getTypeFlags returns the type flags for a given type. It will not recurse
// into sub-types (such as in structs).
func getTypeFlags(t types.Type) paramFlags {
if t == nil {
return 0
}
switch t.Underlying().(type) {
case *types.Pointer:
// Pointers in Go must either point to an object or be nil.
return paramIsDeferenceableOrNull
case *types.Chan, *types.Map:
// Channels and maps are implemented as pointers pointing to some
// object, and follow the same rules as *types.Pointer.
return paramIsDeferenceableOrNull
default:
return 0
}
}
// extractSubfield extracts a field from a struct, or returns null if this is
// not a struct and thus no subfield can be obtained.
func extractSubfield(t types.Type, field int) types.Type {
if t == nil {
return nil
}
switch t := t.Underlying().(type) {
case *types.Struct:
return t.Field(field).Type()
case *types.Interface, *types.Slice, *types.Basic, *types.Signature:
// These Go types are (sometimes) implemented as LLVM structs but can't
// really be split further up in Go (with the possible exception of
// complex numbers).
return nil
default:
// This should be unreachable.
panic("cannot split subfield: " + t.String())
}
}
@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val
func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
switch t.TypeKind() {
case llvm.StructTypeKind:
if len(flattenAggregateType(t)) <= MaxFieldsPerParam {
flattened, _ := flattenAggregateType(t, nil)
if len(flattened) <= MaxFieldsPerParam {
value := llvm.ConstNull(t)
for i, subtyp := range t.StructElementTypes() {
structField, remaining := b.collapseFormalParamInternal(subtyp, fields)

Просмотреть файл

@ -750,10 +750,12 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
}
var paramTypes []llvm.Type
var paramTypeVariants []paramFlags
for _, param := range f.Params {
paramType := c.getLLVMType(param.Type())
paramTypeFragments := expandFormalParamType(paramType)
paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type())
paramTypes = append(paramTypes, paramTypeFragments...)
paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...)
}
// Add an extra parameter as the function context. This context is used in
@ -761,6 +763,7 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
if !f.IsExported() {
paramTypes = append(paramTypes, c.i8ptrType) // context
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
paramTypeVariants = append(paramTypeVariants, 0, 0)
}
fnType := llvm.FunctionType(retType, paramTypes, false)
@ -771,6 +774,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
}
dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null")
for i, typ := range paramTypes {
if paramTypeVariants[i]&paramIsDeferenceableOrNull == 0 {
continue
}
if typ.TypeKind() == llvm.PointerTypeKind {
el := typ.ElementType()
size := c.targetData.TypeAllocSize(el)
if size == 0 {
// dereferenceable_or_null(0) appears to be illegal in LLVM.
continue
}
dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size)
f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull)
}
}
// External/exported functions may not retain pointer values.
// https://golang.org/cmd/cgo/#hdr-Passing_pointers
if f.IsExported() {
@ -901,7 +921,8 @@ func (b *builder) createFunctionDefinition() {
for _, param := range b.fn.Params {
llvmType := b.getLLVMType(param.Type())
fields := make([]llvm.Value, 0, 1)
for range expandFormalParamType(llvmType) {
fieldFragments, _ := expandFormalParamType(llvmType, nil)
for range fieldFragments {
fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
llvmParamIndex++
}

Просмотреть файл

@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type {
// The receiver is not an interface, but a i8* type.
recv = c.i8ptrType
}
paramTypes = append(paramTypes, expandFormalParamType(recv)...)
recvFragments, _ := expandFormalParamType(recv, nil)
paramTypes = append(paramTypes, recvFragments...)
}
for i := 0; i < typ.Params().Len(); i++ {
subType := c.getLLVMType(typ.Params().At(i).Type())
paramTypes = append(paramTypes, expandFormalParamType(subType)...)
paramTypeFragments, _ := expandFormalParamType(subType, nil)
paramTypes = append(paramTypes, paramTypeFragments...)
}
// All functions take these parameters at the end.
paramTypes = append(paramTypes, c.i8ptrType) // context

Просмотреть файл

@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value {
// Get the expanded receiver type.
receiverType := c.getLLVMType(f.Params[0].Type())
expandedReceiverType := expandFormalParamType(receiverType)
expandedReceiverType, _ := expandFormalParamType(receiverType, nil)
// Does this method even need any wrapping?
if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {