From 7cea40bcb5ebff1ee87a489316280ee9a2a46bca Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sun, 23 Sep 2018 03:01:10 +0200 Subject: [PATCH] compiler: small cleanup in call handling code --- compiler/calls.go | 39 +++++++++++++++++---------------------- compiler/compiler.go | 15 ++++++++------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/compiler/calls.go b/compiler/calls.go index dcfc4375..c64f83e1 100644 --- a/compiler/calls.go +++ b/compiler/calls.go @@ -1,10 +1,7 @@ package compiler import ( - "go/types" - "github.com/aykevl/llvm/bindings/go/llvm" - "github.com/aykevl/tinygo/ir" "golang.org/x/tools/go/ssa" ) @@ -21,13 +18,16 @@ import ( // Note that all native Go data types that don't exist in LLVM (string, // slice, interface, fat function pointer) can be expanded this way, making // the work of LLVM optimizers easier. -// * Closures have an extra paramter appended at the end of the argument list, -// which is a pointer to a struct containing free variables. +// * Closures have an extra context paramter appended at the end of the +// argument list. // * Blocking functions have a coroutine pointer prepended to the argument // list, see src/runtime/scheduler.go for details. +// The maximum number of arguments that can be expanded from a single struct. If +// a struct contains more fields, it is passed as value. const MaxFieldsPerParam = 3 +// Shortcut: create a call to runtime. with the given arguments. func (c *Compiler) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value { runtimePkg := c.ir.Program.ImportedPackage("runtime") member := runtimePkg.Members[fnName] @@ -35,14 +35,11 @@ func (c *Compiler) createRuntimeCall(fnName string, args []llvm.Value, name stri panic("trying to call runtime." + fnName) } fn := c.ir.GetFunction(member.(*ssa.Function)) - return c.createCall(fn, args, name) + return c.createCall(fn.LLVMFn, args, name) } -func (c *Compiler) createCall(fn *ir.Function, args []llvm.Value, name string) llvm.Value { - return c.createIndirectCall(fn.Signature, fn.LLVMFn, args, name) -} - -func (c *Compiler) createIndirectCall(sig *types.Signature, fn llvm.Value, args []llvm.Value, name string) llvm.Value { +// Create a call to the given function with the arguments possibly expanded. +func (c *Compiler) createCall(fn llvm.Value, args []llvm.Value, name string) llvm.Value { expanded := make([]llvm.Value, 0, len(args)) for _, arg := range args { fragments := c.expandFormalParam(arg) @@ -51,14 +48,8 @@ func (c *Compiler) createIndirectCall(sig *types.Signature, fn llvm.Value, args return c.builder.CreateCall(fn, expanded, name) } -func (c *Compiler) getLLVMParamTypes(t types.Type) ([]llvm.Type, error) { - llvmType, err := c.getLLVMType(t) - if err != nil { - return nil, err - } - return c.expandFormalParamType(llvmType), nil -} - +// Expand an argument type to a list that can be used in a function call +// paramter list. func (c *Compiler) expandFormalParamType(t llvm.Type) []llvm.Type { switch t.TypeKind() { case llvm.StructTypeKind: @@ -75,7 +66,7 @@ func (c *Compiler) expandFormalParamType(t llvm.Type) []llvm.Type { } } -// Convert an argument to one that can be passed in a parameter. +// Equivalent of expandFormalParamType for parameter values. func (c *Compiler) expandFormalParam(v llvm.Value) []llvm.Value { switch v.Type().TypeKind() { case llvm.StructTypeKind: @@ -96,6 +87,8 @@ func (c *Compiler) expandFormalParam(v llvm.Value) []llvm.Value { } } +// Try to flatten a struct type to a list of types. Returns a 1-element slice +// with the passed in type if this is not possible. func (c *Compiler) flattenAggregateType(t llvm.Type) []llvm.Type { switch t.TypeKind() { case llvm.StructTypeKind: @@ -110,7 +103,8 @@ func (c *Compiler) flattenAggregateType(t llvm.Type) []llvm.Type { } } -// Break down a struct into its elementary types for argument passing. +// Break down a struct into its elementary types for argument passing. The value +// equivalent of flattenAggregateType func (c *Compiler) flattenAggregate(v llvm.Value) []llvm.Value { switch v.Type().TypeKind() { case llvm.StructTypeKind: @@ -126,6 +120,7 @@ func (c *Compiler) flattenAggregate(v llvm.Value) []llvm.Value { } } +// Collapse a list of fields into its original value. func (c *Compiler) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Value { param, remaining := c.collapseFormalParamInternal(t, fields) if len(remaining) != 0 { @@ -134,7 +129,7 @@ func (c *Compiler) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Va return param } -// Returns (value, remainingFields). +// Returns (value, remainingFields). Used by collapseFormalParam. func (c *Compiler) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) { switch t.TypeKind() { case llvm.StructTypeKind: diff --git a/compiler/compiler.go b/compiler/compiler.go index 48b0f874..efe3b2fb 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -378,7 +378,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } // Call real function (of which this is a wrapper). - c.createCall(fn, forwardParams, "") + c.createCall(fn.LLVMFn, forwardParams, "") c.builder.CreateRetVoid() } @@ -813,10 +813,11 @@ func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) { paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine } for _, param := range f.Params { - paramTypeFragments, err := c.getLLVMParamTypes(param.Type()) + paramType, err := c.getLLVMType(param.Type()) if err != nil { return nil, err } + paramTypeFragments := c.expandFormalParamType(paramType) paramTypes = append(paramTypes, paramTypeFragments...) } @@ -1690,7 +1691,7 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) } } -func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, fnType *types.Signature, llvmFn, context llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) { +func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn, context llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) { var params []llvm.Value if blocking { if parentHandle.IsNil() { @@ -1733,7 +1734,7 @@ func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, fnType *typ return llvm.Value{}, nil } - result := c.createIndirectCall(fnType, llvmFn, params, "") + result := c.createCall(llvmFn, params, "") if blocking && !parentHandle.IsNil() { // Calling a blocking function as a regular function call. // This is done by passing the current coroutine as a parameter to the @@ -1807,7 +1808,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } // TODO: blocking methods (needs analysis) - return c.createIndirectCall(instr.Method.Type().(*types.Signature), fnCast, args, ""), nil + return c.createCall(fnCast, args, ""), nil } // Try to call the function directly for trivially static calls. @@ -1845,7 +1846,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } } } - return c.parseFunctionCall(frame, instr.Args, targetFunc.Signature, targetFunc.LLVMFn, context, c.ir.IsBlocking(targetFunc), parentHandle) + return c.parseFunctionCall(frame, instr.Args, targetFunc.LLVMFn, context, c.ir.IsBlocking(targetFunc), parentHandle) } // Builtin or function pointer. @@ -1866,7 +1867,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l context = c.builder.CreateExtractValue(value, 0, "") value = c.builder.CreateExtractValue(value, 1, "") } - return c.parseFunctionCall(frame, instr.Args, instr.Value.Type().(*types.Signature), value, context, false, parentHandle) + return c.parseFunctionCall(frame, instr.Args, value, context, false, parentHandle) } }