compiler: add dereferenceable_or_null attribute where possible
This gives a hint to the compiler that such parameters are either NULL or point to a valid object that can be dereferenced. This is not directly very useful, but is very useful when combined with https://reviews.llvm.org/D60047 to remove the runtime.isnil hack without regressing escape analysis.
Этот коммит содержится в:
		
							родитель
							
								
									980068543a
								
							
						
					
					
						коммит
						85854cd58b
					
				
					 4 изменённых файлов: 98 добавлений и 17 удалений
				
			
		| 
						 | 
					@ -1,6 +1,8 @@
 | 
				
			||||||
package compiler
 | 
					package compiler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"go/types"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"tinygo.org/x/go-llvm"
 | 
						"tinygo.org/x/go-llvm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,6 +13,16 @@ import (
 | 
				
			||||||
// a struct contains more fields, it is passed as a struct without expanding.
 | 
					// a struct contains more fields, it is passed as a struct without expanding.
 | 
				
			||||||
const MaxFieldsPerParam = 3
 | 
					const MaxFieldsPerParam = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// paramFlags identifies parameter attributes for flags. Most importantly, it
 | 
				
			||||||
 | 
					// determines which parameters are dereferenceable_or_null and which aren't.
 | 
				
			||||||
 | 
					type paramFlags uint8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						// Parameter may have the deferenceable_or_null attribute. This attribute
 | 
				
			||||||
 | 
						// cannot be applied to unsafe.Pointer and to the data pointer of slices.
 | 
				
			||||||
 | 
						paramIsDeferenceableOrNull = 1 << iota
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// createCall creates a new call to runtime.<fnName> with the given arguments.
 | 
					// createCall creates a new call to runtime.<fnName> with the given arguments.
 | 
				
			||||||
func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
 | 
					func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
 | 
				
			||||||
	fullName := "runtime." + fnName
 | 
						fullName := "runtime." + fnName
 | 
				
			||||||
| 
						 | 
					@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Expand an argument type to a list that can be used in a function call
 | 
					// Expand an argument type to a list that can be used in a function call
 | 
				
			||||||
// parameter list.
 | 
					// parameter list.
 | 
				
			||||||
func expandFormalParamType(t llvm.Type) []llvm.Type {
 | 
					func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
 | 
				
			||||||
	switch t.TypeKind() {
 | 
						switch t.TypeKind() {
 | 
				
			||||||
	case llvm.StructTypeKind:
 | 
						case llvm.StructTypeKind:
 | 
				
			||||||
		fields := flattenAggregateType(t)
 | 
							fields, fieldFlags := flattenAggregateType(t, goType)
 | 
				
			||||||
		if len(fields) <= MaxFieldsPerParam {
 | 
							if len(fields) <= MaxFieldsPerParam {
 | 
				
			||||||
			return fields
 | 
								return fields, fieldFlags
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			// failed to lower
 | 
								// failed to lower
 | 
				
			||||||
			return []llvm.Type{t}
 | 
								return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		// TODO: split small arrays
 | 
							// TODO: split small arrays
 | 
				
			||||||
		return []llvm.Type{t}
 | 
							return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 {
 | 
				
			||||||
func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
 | 
					func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
 | 
				
			||||||
	switch v.Type().TypeKind() {
 | 
						switch v.Type().TypeKind() {
 | 
				
			||||||
	case llvm.StructTypeKind:
 | 
						case llvm.StructTypeKind:
 | 
				
			||||||
		fieldTypes := flattenAggregateType(v.Type())
 | 
							fieldTypes, _ := flattenAggregateType(v.Type(), nil)
 | 
				
			||||||
		if len(fieldTypes) <= MaxFieldsPerParam {
 | 
							if len(fieldTypes) <= MaxFieldsPerParam {
 | 
				
			||||||
			fields := b.flattenAggregate(v)
 | 
								fields := b.flattenAggregate(v)
 | 
				
			||||||
			if len(fields) != len(fieldTypes) {
 | 
								if len(fields) != len(fieldTypes) {
 | 
				
			||||||
| 
						 | 
					@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Try to flatten a struct type to a list of types. Returns a 1-element slice
 | 
					// Try to flatten a struct type to a list of types. Returns a 1-element slice
 | 
				
			||||||
// with the passed in type if this is not possible.
 | 
					// with the passed in type if this is not possible.
 | 
				
			||||||
func flattenAggregateType(t llvm.Type) []llvm.Type {
 | 
					func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
 | 
				
			||||||
 | 
						typeFlags := getTypeFlags(goType)
 | 
				
			||||||
	switch t.TypeKind() {
 | 
						switch t.TypeKind() {
 | 
				
			||||||
	case llvm.StructTypeKind:
 | 
						case llvm.StructTypeKind:
 | 
				
			||||||
		fields := make([]llvm.Type, 0, t.StructElementTypesCount())
 | 
							fields := make([]llvm.Type, 0, t.StructElementTypesCount())
 | 
				
			||||||
		for _, subfield := range t.StructElementTypes() {
 | 
							fieldFlags := make([]paramFlags, 0, cap(fields))
 | 
				
			||||||
			subfields := flattenAggregateType(subfield)
 | 
							for i, subfield := range t.StructElementTypes() {
 | 
				
			||||||
			fields = append(fields, subfields...)
 | 
								subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i))
 | 
				
			||||||
 | 
								for i := range subfieldFlags {
 | 
				
			||||||
 | 
									subfieldFlags[i] |= typeFlags
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		return fields
 | 
								fields = append(fields, subfields...)
 | 
				
			||||||
 | 
								fieldFlags = append(fieldFlags, subfieldFlags...)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return fields, fieldFlags
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return []llvm.Type{t}
 | 
							return []llvm.Type{t}, []paramFlags{typeFlags}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getTypeFlags returns the type flags for a given type. It will not recurse
 | 
				
			||||||
 | 
					// into sub-types (such as in structs).
 | 
				
			||||||
 | 
					func getTypeFlags(t types.Type) paramFlags {
 | 
				
			||||||
 | 
						if t == nil {
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						switch t.Underlying().(type) {
 | 
				
			||||||
 | 
						case *types.Pointer:
 | 
				
			||||||
 | 
							// Pointers in Go must either point to an object or be nil.
 | 
				
			||||||
 | 
							return paramIsDeferenceableOrNull
 | 
				
			||||||
 | 
						case *types.Chan, *types.Map:
 | 
				
			||||||
 | 
							// Channels and maps are implemented as pointers pointing to some
 | 
				
			||||||
 | 
							// object, and follow the same rules as *types.Pointer.
 | 
				
			||||||
 | 
							return paramIsDeferenceableOrNull
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// extractSubfield extracts a field from a struct, or returns null if this is
 | 
				
			||||||
 | 
					// not a struct and thus no subfield can be obtained.
 | 
				
			||||||
 | 
					func extractSubfield(t types.Type, field int) types.Type {
 | 
				
			||||||
 | 
						if t == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						switch t := t.Underlying().(type) {
 | 
				
			||||||
 | 
						case *types.Struct:
 | 
				
			||||||
 | 
							return t.Field(field).Type()
 | 
				
			||||||
 | 
						case *types.Interface, *types.Slice, *types.Basic, *types.Signature:
 | 
				
			||||||
 | 
							// These Go types are (sometimes) implemented as LLVM structs but can't
 | 
				
			||||||
 | 
							// really be split further up in Go (with the possible exception of
 | 
				
			||||||
 | 
							// complex numbers).
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							// This should be unreachable.
 | 
				
			||||||
 | 
							panic("cannot split subfield: " + t.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val
 | 
				
			||||||
func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
 | 
					func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
 | 
				
			||||||
	switch t.TypeKind() {
 | 
						switch t.TypeKind() {
 | 
				
			||||||
	case llvm.StructTypeKind:
 | 
						case llvm.StructTypeKind:
 | 
				
			||||||
		if len(flattenAggregateType(t)) <= MaxFieldsPerParam {
 | 
							flattened, _ := flattenAggregateType(t, nil)
 | 
				
			||||||
 | 
							if len(flattened) <= MaxFieldsPerParam {
 | 
				
			||||||
			value := llvm.ConstNull(t)
 | 
								value := llvm.ConstNull(t)
 | 
				
			||||||
			for i, subtyp := range t.StructElementTypes() {
 | 
								for i, subtyp := range t.StructElementTypes() {
 | 
				
			||||||
				structField, remaining := b.collapseFormalParamInternal(subtyp, fields)
 | 
									structField, remaining := b.collapseFormalParamInternal(subtyp, fields)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -750,10 +750,12 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var paramTypes []llvm.Type
 | 
						var paramTypes []llvm.Type
 | 
				
			||||||
 | 
						var paramTypeVariants []paramFlags
 | 
				
			||||||
	for _, param := range f.Params {
 | 
						for _, param := range f.Params {
 | 
				
			||||||
		paramType := c.getLLVMType(param.Type())
 | 
							paramType := c.getLLVMType(param.Type())
 | 
				
			||||||
		paramTypeFragments := expandFormalParamType(paramType)
 | 
							paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type())
 | 
				
			||||||
		paramTypes = append(paramTypes, paramTypeFragments...)
 | 
							paramTypes = append(paramTypes, paramTypeFragments...)
 | 
				
			||||||
 | 
							paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Add an extra parameter as the function context. This context is used in
 | 
						// Add an extra parameter as the function context. This context is used in
 | 
				
			||||||
| 
						 | 
					@ -761,6 +763,7 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
 | 
				
			||||||
	if !f.IsExported() {
 | 
						if !f.IsExported() {
 | 
				
			||||||
		paramTypes = append(paramTypes, c.i8ptrType) // context
 | 
							paramTypes = append(paramTypes, c.i8ptrType) // context
 | 
				
			||||||
		paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
 | 
							paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
 | 
				
			||||||
 | 
							paramTypeVariants = append(paramTypeVariants, 0, 0)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fnType := llvm.FunctionType(retType, paramTypes, false)
 | 
						fnType := llvm.FunctionType(retType, paramTypes, false)
 | 
				
			||||||
| 
						 | 
					@ -771,6 +774,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
 | 
				
			||||||
		f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
 | 
							f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null")
 | 
				
			||||||
 | 
						for i, typ := range paramTypes {
 | 
				
			||||||
 | 
							if paramTypeVariants[i]¶mIsDeferenceableOrNull == 0 {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if typ.TypeKind() == llvm.PointerTypeKind {
 | 
				
			||||||
 | 
								el := typ.ElementType()
 | 
				
			||||||
 | 
								size := c.targetData.TypeAllocSize(el)
 | 
				
			||||||
 | 
								if size == 0 {
 | 
				
			||||||
 | 
									// dereferenceable_or_null(0) appears to be illegal in LLVM.
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size)
 | 
				
			||||||
 | 
								f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// External/exported functions may not retain pointer values.
 | 
						// External/exported functions may not retain pointer values.
 | 
				
			||||||
	// https://golang.org/cmd/cgo/#hdr-Passing_pointers
 | 
						// https://golang.org/cmd/cgo/#hdr-Passing_pointers
 | 
				
			||||||
	if f.IsExported() {
 | 
						if f.IsExported() {
 | 
				
			||||||
| 
						 | 
					@ -901,7 +921,8 @@ func (b *builder) createFunctionDefinition() {
 | 
				
			||||||
	for _, param := range b.fn.Params {
 | 
						for _, param := range b.fn.Params {
 | 
				
			||||||
		llvmType := b.getLLVMType(param.Type())
 | 
							llvmType := b.getLLVMType(param.Type())
 | 
				
			||||||
		fields := make([]llvm.Value, 0, 1)
 | 
							fields := make([]llvm.Value, 0, 1)
 | 
				
			||||||
		for range expandFormalParamType(llvmType) {
 | 
							fieldFragments, _ := expandFormalParamType(llvmType, nil)
 | 
				
			||||||
 | 
							for range fieldFragments {
 | 
				
			||||||
			fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
 | 
								fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
 | 
				
			||||||
			llvmParamIndex++
 | 
								llvmParamIndex++
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type {
 | 
				
			||||||
			// The receiver is not an interface, but a i8* type.
 | 
								// The receiver is not an interface, but a i8* type.
 | 
				
			||||||
			recv = c.i8ptrType
 | 
								recv = c.i8ptrType
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		paramTypes = append(paramTypes, expandFormalParamType(recv)...)
 | 
							recvFragments, _ := expandFormalParamType(recv, nil)
 | 
				
			||||||
 | 
							paramTypes = append(paramTypes, recvFragments...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for i := 0; i < typ.Params().Len(); i++ {
 | 
						for i := 0; i < typ.Params().Len(); i++ {
 | 
				
			||||||
		subType := c.getLLVMType(typ.Params().At(i).Type())
 | 
							subType := c.getLLVMType(typ.Params().At(i).Type())
 | 
				
			||||||
		paramTypes = append(paramTypes, expandFormalParamType(subType)...)
 | 
							paramTypeFragments, _ := expandFormalParamType(subType, nil)
 | 
				
			||||||
 | 
							paramTypes = append(paramTypes, paramTypeFragments...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// All functions take these parameters at the end.
 | 
						// All functions take these parameters at the end.
 | 
				
			||||||
	paramTypes = append(paramTypes, c.i8ptrType) // context
 | 
						paramTypes = append(paramTypes, c.i8ptrType) // context
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the expanded receiver type.
 | 
						// Get the expanded receiver type.
 | 
				
			||||||
	receiverType := c.getLLVMType(f.Params[0].Type())
 | 
						receiverType := c.getLLVMType(f.Params[0].Type())
 | 
				
			||||||
	expandedReceiverType := expandFormalParamType(receiverType)
 | 
						expandedReceiverType, _ := expandFormalParamType(receiverType, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Does this method even need any wrapping?
 | 
						// Does this method even need any wrapping?
 | 
				
			||||||
	if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {
 | 
						if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Загрузка…
	
	Создание таблицы
		
		Сослаться в новой задаче