From 85854cd58b09ee026cecaa590a23013d43ba74e5 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Thu, 19 Mar 2020 20:15:03 +0100 Subject: [PATCH] compiler: add dereferenceable_or_null attribute where possible This gives a hint to the compiler that such parameters are either NULL or point to a valid object that can be dereferenced. This is not directly very useful, but is very useful when combined with https://reviews.llvm.org/D60047 to remove the runtime.isnil hack without regressing escape analysis. --- compiler/calls.go | 82 ++++++++++++++++++++++++++++++++++++------- compiler/compiler.go | 25 +++++++++++-- compiler/func.go | 6 ++-- compiler/interface.go | 2 +- 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/compiler/calls.go b/compiler/calls.go index 1de28b24..e5e1f44c 100644 --- a/compiler/calls.go +++ b/compiler/calls.go @@ -1,6 +1,8 @@ package compiler import ( + "go/types" + "tinygo.org/x/go-llvm" ) @@ -11,6 +13,16 @@ import ( // a struct contains more fields, it is passed as a struct without expanding. const MaxFieldsPerParam = 3 +// paramFlags identifies parameter attributes for flags. Most importantly, it +// determines which parameters are dereferenceable_or_null and which aren't. +type paramFlags uint8 + +const ( + // Parameter may have the deferenceable_or_null attribute. This attribute + // cannot be applied to unsafe.Pointer and to the data pointer of slices. + paramIsDeferenceableOrNull = 1 << iota +) + // createCall creates a new call to runtime. with the given arguments. func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value { fullName := "runtime." + fnName @@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm // Expand an argument type to a list that can be used in a function call // parameter list. -func expandFormalParamType(t llvm.Type) []llvm.Type { +func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) { switch t.TypeKind() { case llvm.StructTypeKind: - fields := flattenAggregateType(t) + fields, fieldFlags := flattenAggregateType(t, goType) if len(fields) <= MaxFieldsPerParam { - return fields + return fields, fieldFlags } else { // failed to lower - return []llvm.Type{t} + return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)} } default: // TODO: split small arrays - return []llvm.Type{t} + return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)} } } @@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 { func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value { switch v.Type().TypeKind() { case llvm.StructTypeKind: - fieldTypes := flattenAggregateType(v.Type()) + fieldTypes, _ := flattenAggregateType(v.Type(), nil) if len(fieldTypes) <= MaxFieldsPerParam { fields := b.flattenAggregate(v) if len(fields) != len(fieldTypes) { @@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value { // Try to flatten a struct type to a list of types. Returns a 1-element slice // with the passed in type if this is not possible. -func flattenAggregateType(t llvm.Type) []llvm.Type { +func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) { + typeFlags := getTypeFlags(goType) switch t.TypeKind() { case llvm.StructTypeKind: fields := make([]llvm.Type, 0, t.StructElementTypesCount()) - for _, subfield := range t.StructElementTypes() { - subfields := flattenAggregateType(subfield) + fieldFlags := make([]paramFlags, 0, cap(fields)) + for i, subfield := range t.StructElementTypes() { + subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i)) + for i := range subfieldFlags { + subfieldFlags[i] |= typeFlags + } fields = append(fields, subfields...) + fieldFlags = append(fieldFlags, subfieldFlags...) } - return fields + return fields, fieldFlags default: - return []llvm.Type{t} + return []llvm.Type{t}, []paramFlags{typeFlags} + } +} + +// getTypeFlags returns the type flags for a given type. It will not recurse +// into sub-types (such as in structs). +func getTypeFlags(t types.Type) paramFlags { + if t == nil { + return 0 + } + switch t.Underlying().(type) { + case *types.Pointer: + // Pointers in Go must either point to an object or be nil. + return paramIsDeferenceableOrNull + case *types.Chan, *types.Map: + // Channels and maps are implemented as pointers pointing to some + // object, and follow the same rules as *types.Pointer. + return paramIsDeferenceableOrNull + default: + return 0 + } +} + +// extractSubfield extracts a field from a struct, or returns null if this is +// not a struct and thus no subfield can be obtained. +func extractSubfield(t types.Type, field int) types.Type { + if t == nil { + return nil + } + switch t := t.Underlying().(type) { + case *types.Struct: + return t.Field(field).Type() + case *types.Interface, *types.Slice, *types.Basic, *types.Signature: + // These Go types are (sometimes) implemented as LLVM structs but can't + // really be split further up in Go (with the possible exception of + // complex numbers). + return nil + default: + // This should be unreachable. + panic("cannot split subfield: " + t.String()) } } @@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) { switch t.TypeKind() { case llvm.StructTypeKind: - if len(flattenAggregateType(t)) <= MaxFieldsPerParam { + flattened, _ := flattenAggregateType(t, nil) + if len(flattened) <= MaxFieldsPerParam { value := llvm.ConstNull(t) for i, subtyp := range t.StructElementTypes() { structField, remaining := b.collapseFormalParamInternal(subtyp, fields) diff --git a/compiler/compiler.go b/compiler/compiler.go index 966b0f62..6a17da79 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -750,10 +750,12 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { } var paramTypes []llvm.Type + var paramTypeVariants []paramFlags for _, param := range f.Params { paramType := c.getLLVMType(param.Type()) - paramTypeFragments := expandFormalParamType(paramType) + paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type()) paramTypes = append(paramTypes, paramTypeFragments...) + paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...) } // Add an extra parameter as the function context. This context is used in @@ -761,6 +763,7 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { if !f.IsExported() { paramTypes = append(paramTypes, c.i8ptrType) // context paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine + paramTypeVariants = append(paramTypeVariants, 0, 0) } fnType := llvm.FunctionType(retType, paramTypes, false) @@ -771,6 +774,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) { f.LLVMFn = llvm.AddFunction(c.mod, name, fnType) } + dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null") + for i, typ := range paramTypes { + if paramTypeVariants[i]¶mIsDeferenceableOrNull == 0 { + continue + } + if typ.TypeKind() == llvm.PointerTypeKind { + el := typ.ElementType() + size := c.targetData.TypeAllocSize(el) + if size == 0 { + // dereferenceable_or_null(0) appears to be illegal in LLVM. + continue + } + dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size) + f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull) + } + } + // External/exported functions may not retain pointer values. // https://golang.org/cmd/cgo/#hdr-Passing_pointers if f.IsExported() { @@ -901,7 +921,8 @@ func (b *builder) createFunctionDefinition() { for _, param := range b.fn.Params { llvmType := b.getLLVMType(param.Type()) fields := make([]llvm.Value, 0, 1) - for range expandFormalParamType(llvmType) { + fieldFragments, _ := expandFormalParamType(llvmType, nil) + for range fieldFragments { fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex)) llvmParamIndex++ } diff --git a/compiler/func.go b/compiler/func.go index 544f3e5f..2d14d47a 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type { // The receiver is not an interface, but a i8* type. recv = c.i8ptrType } - paramTypes = append(paramTypes, expandFormalParamType(recv)...) + recvFragments, _ := expandFormalParamType(recv, nil) + paramTypes = append(paramTypes, recvFragments...) } for i := 0; i < typ.Params().Len(); i++ { subType := c.getLLVMType(typ.Params().At(i).Type()) - paramTypes = append(paramTypes, expandFormalParamType(subType)...) + paramTypeFragments, _ := expandFormalParamType(subType, nil) + paramTypes = append(paramTypes, paramTypeFragments...) } // All functions take these parameters at the end. paramTypes = append(paramTypes, c.i8ptrType) // context diff --git a/compiler/interface.go b/compiler/interface.go index c6a8a61a..4613e2e3 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value { // Get the expanded receiver type. receiverType := c.getLLVMType(f.Params[0].Type()) - expandedReceiverType := expandFormalParamType(receiverType) + expandedReceiverType, _ := expandFormalParamType(receiverType, nil) // Does this method even need any wrapping? if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {