diff --git a/compiler/compiler.go b/compiler/compiler.go index 719861c6..8b23e971 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -473,58 +473,8 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { return llvm.Type{}, err } return llvm.PointerType(ptrTo, 0), nil - case *types.Signature: // function pointer - // return value - var err error - var returnType llvm.Type - if typ.Results().Len() == 0 { - returnType = c.ctx.VoidType() - } else if typ.Results().Len() == 1 { - returnType, err = c.getLLVMType(typ.Results().At(0).Type()) - if err != nil { - return llvm.Type{}, err - } - } else { - // Multiple return values. Put them together in a struct. - members := make([]llvm.Type, typ.Results().Len()) - for i := 0; i < typ.Results().Len(); i++ { - returnType, err := c.getLLVMType(typ.Results().At(i).Type()) - if err != nil { - return llvm.Type{}, err - } - members[i] = returnType - } - returnType = c.ctx.StructType(members, false) - } - // param values - var paramTypes []llvm.Type - if typ.Recv() != nil { - recv, err := c.getLLVMType(typ.Recv().Type()) - if err != nil { - return llvm.Type{}, err - } - if recv.StructName() == "runtime._interface" { - // This is a call on an interface, not a concrete type. - // The receiver is not an interface, but a i8* type. - recv = c.i8ptrType - } - paramTypes = append(paramTypes, c.expandFormalParamType(recv)...) - } - params := typ.Params() - for i := 0; i < params.Len(); i++ { - subType, err := c.getLLVMType(params.At(i).Type()) - if err != nil { - return llvm.Type{}, err - } - paramTypes = append(paramTypes, c.expandFormalParamType(subType)...) - } - // make a closure type (with a function pointer type inside): - // {context, funcptr} - paramTypes = append(paramTypes, c.i8ptrType) // context - paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine - ptr := llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), c.funcPtrAddrSpace) - ptr = c.ctx.StructType([]llvm.Type{c.i8ptrType, ptr}, false) - return ptr, nil + case *types.Signature: // function value + return c.getFuncType(typ) case *types.Slice: elemType, err := c.getLLVMType(typ.Elem()) if err != nil { @@ -1359,20 +1309,20 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon) (llvm.Value, e return llvm.Value{}, c.makeError(instr.Pos(), "undefined function: "+targetFunc.LinkName()) } var context llvm.Value - // This function call is to a (potential) closure, not a regular - // function. See whether it is a closure and if so, call it as such. - // Else, supply a dummy nil pointer as the last parameter. - if targetFunc.IsExported() { - // don't pass a context parameter - } else if mkClosure, ok := instr.Value.(*ssa.MakeClosure); ok { - // closure is {context, function pointer} - closure, err := c.parseExpr(frame, mkClosure) + switch value := instr.Value.(type) { + case *ssa.Function: + // Regular function call. No context is necessary. + context = llvm.Undef(c.i8ptrType) + case *ssa.MakeClosure: + // A call on a func value, but the callee is trivial to find. For + // example: immediately applied functions. + funcValue, err := c.parseExpr(frame, value) if err != nil { return llvm.Value{}, err } - context = c.builder.CreateExtractValue(closure, 0, "") - } else { - context = llvm.Undef(c.i8ptrType) + context = c.extractFuncContext(funcValue) + default: + panic("StaticCallee returned an unexpected value") } return c.parseFunctionCall(frame, instr.Args, targetFunc.LLVMFn, context, targetFunc.IsExported()) } @@ -1386,13 +1336,11 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon) (llvm.Value, e if err != nil { return llvm.Value{}, err } - // 'value' is a closure, not a raw function pointer. - // Extract the function pointer and the context pointer. - // closure: {context, function pointer} - context := c.builder.CreateExtractValue(value, 0, "") - value = c.builder.CreateExtractValue(value, 1, "") - c.emitNilCheck(frame, value, "fpcall") - return c.parseFunctionCall(frame, instr.Args, value, context, false) + // This is a func value, which cannot be called directly. We have to + // extract the function pointer and context first from the func value. + funcPtr, context := c.decodeFuncValue(value) + c.emitNilCheck(frame, funcPtr, "fpcall") + return c.parseFunctionCall(frame, instr.Args, funcPtr, context, false) } } @@ -1560,12 +1508,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if fn.IsExported() { return llvm.Value{}, c.makeError(expr.Pos(), "cannot use an exported function as value") } - // Create closure for function pointer. - // Closure is: {context, function pointer} - return c.ctx.ConstStruct([]llvm.Value{ - llvm.Undef(c.i8ptrType), - fn.LLVMFn, - }, false), nil + return c.createFuncValue(fn.LLVMFn) case *ssa.Global: value := c.ir.GetGlobal(expr).LLVMGlobal if value.IsNil() { @@ -1681,8 +1624,6 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { case *ssa.MakeChan: return c.emitMakeChan(expr) case *ssa.MakeClosure: - // A closure returns a function pointer with context: - // {context, fp} return c.parseMakeClosure(frame, expr) case *ssa.MakeInterface: val, err := c.parseExpr(frame, expr.X) @@ -2194,12 +2135,13 @@ func (c *Compiler) parseBinOp(op token.Token, typ types.Type, x, y llvm.Value, p return llvm.Value{}, c.makeError(pos, "todo: unknown basic type in binop: "+typ.String()) } case *types.Signature: - // Extract function pointers from the function values (closures). + // Get raw scalars from the function value and compare those. + // Function values may be implemented in multiple ways, but they all + // have some way of getting a scalar value identifying the function. // This is safe: function pointers are generally not comparable - // against each other, only against nil. So one or both has to be - // nil, so we can ignore the closure context. - x = c.builder.CreateExtractValue(x, 1, "") - y = c.builder.CreateExtractValue(y, 1, "") + // against each other, only against nil. So one of these has to be nil. + x = c.extractFuncScalar(x) + y = c.extractFuncScalar(y) switch op { case token.EQL: // == return c.builder.CreateICmp(llvm.IntEQ, x, y, ""), nil @@ -2597,81 +2539,6 @@ func (c *Compiler) parseConvert(typeFrom, typeTo types.Type, value llvm.Value, p } } -func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.Value, error) { - if len(expr.Bindings) == 0 { - panic("unexpected: MakeClosure without bound variables") - } - f := c.ir.GetFunction(expr.Fn.(*ssa.Function)) - - // Collect all bound variables. - boundVars := make([]llvm.Value, 0, len(expr.Bindings)) - boundVarTypes := make([]llvm.Type, 0, len(expr.Bindings)) - for _, binding := range expr.Bindings { - // The context stores the bound variables. - llvmBoundVar, err := c.parseExpr(frame, binding) - if err != nil { - return llvm.Value{}, err - } - boundVars = append(boundVars, llvmBoundVar) - boundVarTypes = append(boundVarTypes, llvmBoundVar.Type()) - } - contextType := c.ctx.StructType(boundVarTypes, false) - - // Allocate memory for the context. - contextAlloc := llvm.Value{} - contextHeapAlloc := llvm.Value{} - if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) { - // Context fits in a pointer - e.g. when it is a pointer. Store it - // directly in the stack after a convert. - // Because contextType is a struct and we have to cast it to a *i8, - // store it in an alloca first for bitcasting (store+bitcast+load). - contextAlloc = c.builder.CreateAlloca(contextType, "") - } else { - // Context is bigger than a pointer, so allocate it on the heap. - size := c.targetData.TypeAllocSize(contextType) - sizeValue := llvm.ConstInt(c.uintptrType, size, false) - contextHeapAlloc = c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "") - contextAlloc = c.builder.CreateBitCast(contextHeapAlloc, llvm.PointerType(contextType, 0), "") - } - - // Store all bound variables in the alloca or heap pointer. - for i, boundVar := range boundVars { - indices := []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false), - } - gep := c.builder.CreateInBoundsGEP(contextAlloc, indices, "") - c.builder.CreateStore(boundVar, gep) - } - - context := llvm.Value{} - if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) { - // Load value (as *i8) from the alloca. - contextAlloc = c.builder.CreateBitCast(contextAlloc, llvm.PointerType(c.i8ptrType, 0), "") - context = c.builder.CreateLoad(contextAlloc, "") - } else { - // Get the original heap allocation pointer, which already is an - // *i8. - context = contextHeapAlloc - } - - // Get the function signature type, which is a closure type. - // A closure is a tuple of {context, function pointer}. - typ, err := c.getLLVMType(f.Signature) - if err != nil { - return llvm.Value{}, err - } - - // Create the closure, which is a struct: {context, function pointer}. - closure, err := c.getZeroValue(typ) - if err != nil { - return llvm.Value{}, err - } - closure = c.builder.CreateInsertValue(closure, f.LLVMFn, 1, "") - closure = c.builder.CreateInsertValue(closure, context, 0, "") - return closure, nil -} - func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) { x, err := c.parseExpr(frame, unop.X) if err != nil { diff --git a/compiler/func.go b/compiler/func.go new file mode 100644 index 00000000..68db1fc8 --- /dev/null +++ b/compiler/func.go @@ -0,0 +1,189 @@ +package compiler + +// This file implements function values and closures. A func value is +// implemented as a pair of pointers: {context, function pointer}, where the +// context may be a pointer to a heap-allocated struct containing the free +// variables, or it may be undef if the function being pointed to doesn't need a +// context. + +import ( + "go/types" + + "golang.org/x/tools/go/ssa" + "tinygo.org/x/go-llvm" +) + +// createFuncValue creates a function value from a raw function pointer with no +// context. +func (c *Compiler) createFuncValue(funcPtr llvm.Value) (llvm.Value, error) { + // Closure is: {context, function pointer} + return c.ctx.ConstStruct([]llvm.Value{ + llvm.Undef(c.i8ptrType), + funcPtr, + }, false), nil +} + +// extractFuncScalar returns some scalar that can be used in comparisons. It is +// a cheap operation. +func (c *Compiler) extractFuncScalar(funcValue llvm.Value) llvm.Value { + return c.builder.CreateExtractValue(funcValue, 1, "") +} + +// extractFuncContext extracts the context pointer from this function value. It +// is a cheap operation. +func (c *Compiler) extractFuncContext(funcValue llvm.Value) llvm.Value { + return c.builder.CreateExtractValue(funcValue, 0, "") +} + +// decodeFuncValue extracts the context and the function pointer from this func +// value. This may be an expensive operation. +func (c *Compiler) decodeFuncValue(funcValue llvm.Value) (funcPtr, context llvm.Value) { + context = c.builder.CreateExtractValue(funcValue, 0, "") + funcPtr = c.builder.CreateExtractValue(funcValue, 1, "") + return +} + +// getFuncType returns the type of a func value given a signature. +func (c *Compiler) getFuncType(typ *types.Signature) (llvm.Type, error) { + rawPtr, err := c.getRawFuncType(typ) + if err != nil { + return llvm.Type{}, err + } + return c.ctx.StructType([]llvm.Type{c.i8ptrType, rawPtr}, false), nil +} + +// getRawFuncType returns a LLVM function pointer type for a given signature. +func (c *Compiler) getRawFuncType(typ *types.Signature) (llvm.Type, error) { + // Get the return type. + var err error + var returnType llvm.Type + switch typ.Results().Len() { + case 0: + // No return values. + returnType = c.ctx.VoidType() + case 1: + // Just one return value. + returnType, err = c.getLLVMType(typ.Results().At(0).Type()) + if err != nil { + return llvm.Type{}, err + } + default: + // Multiple return values. Put them together in a struct. + // This appears to be the common way to handle multiple return values in + // LLVM. + members := make([]llvm.Type, typ.Results().Len()) + for i := 0; i < typ.Results().Len(); i++ { + returnType, err := c.getLLVMType(typ.Results().At(i).Type()) + if err != nil { + return llvm.Type{}, err + } + members[i] = returnType + } + returnType = c.ctx.StructType(members, false) + } + + // Get the parameter types. + var paramTypes []llvm.Type + if typ.Recv() != nil { + recv, err := c.getLLVMType(typ.Recv().Type()) + if err != nil { + return llvm.Type{}, err + } + if recv.StructName() == "runtime._interface" { + // This is a call on an interface, not a concrete type. + // The receiver is not an interface, but a i8* type. + recv = c.i8ptrType + } + paramTypes = append(paramTypes, c.expandFormalParamType(recv)...) + } + for i := 0; i < typ.Params().Len(); i++ { + subType, err := c.getLLVMType(typ.Params().At(i).Type()) + if err != nil { + return llvm.Type{}, err + } + paramTypes = append(paramTypes, c.expandFormalParamType(subType)...) + } + // All functions take these parameters at the end. + paramTypes = append(paramTypes, c.i8ptrType) // context + paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine + + // Make a func type out of the signature. + return llvm.PointerType(llvm.FunctionType(returnType, paramTypes, false), c.funcPtrAddrSpace), nil +} + +// parseMakeClosure makes a function value (with context) from the given +// closure expression. +func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.Value, error) { + if len(expr.Bindings) == 0 { + panic("unexpected: MakeClosure without bound variables") + } + f := c.ir.GetFunction(expr.Fn.(*ssa.Function)) + + // Collect all bound variables. + boundVars := make([]llvm.Value, 0, len(expr.Bindings)) + boundVarTypes := make([]llvm.Type, 0, len(expr.Bindings)) + for _, binding := range expr.Bindings { + // The context stores the bound variables. + llvmBoundVar, err := c.parseExpr(frame, binding) + if err != nil { + return llvm.Value{}, err + } + boundVars = append(boundVars, llvmBoundVar) + boundVarTypes = append(boundVarTypes, llvmBoundVar.Type()) + } + contextType := c.ctx.StructType(boundVarTypes, false) + + // Allocate memory for the context. + contextAlloc := llvm.Value{} + contextHeapAlloc := llvm.Value{} + if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) { + // Context fits in a pointer - e.g. when it is a pointer. Store it + // directly in the stack after a convert. + // Because contextType is a struct and we have to cast it to a *i8, + // store it in an alloca first for bitcasting (store+bitcast+load). + contextAlloc = c.builder.CreateAlloca(contextType, "") + } else { + // Context is bigger than a pointer, so allocate it on the heap. + size := c.targetData.TypeAllocSize(contextType) + sizeValue := llvm.ConstInt(c.uintptrType, size, false) + contextHeapAlloc = c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "") + contextAlloc = c.builder.CreateBitCast(contextHeapAlloc, llvm.PointerType(contextType, 0), "") + } + + // Store all bound variables in the alloca or heap pointer. + for i, boundVar := range boundVars { + indices := []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false), + } + gep := c.builder.CreateInBoundsGEP(contextAlloc, indices, "") + c.builder.CreateStore(boundVar, gep) + } + + context := llvm.Value{} + if c.targetData.TypeAllocSize(contextType) <= c.targetData.TypeAllocSize(c.i8ptrType) { + // Load value (as *i8) from the alloca. + contextAlloc = c.builder.CreateBitCast(contextAlloc, llvm.PointerType(c.i8ptrType, 0), "") + context = c.builder.CreateLoad(contextAlloc, "") + } else { + // Get the original heap allocation pointer, which already is an + // *i8. + context = contextHeapAlloc + } + + // Get the function signature type, which is a closure type. + // A closure is a tuple of {context, function pointer}. + typ, err := c.getFuncType(f.Signature) + if err != nil { + return llvm.Value{}, err + } + + // Create the closure, which is a struct: {context, function pointer}. + closure, err := c.getZeroValue(typ) + if err != nil { + return llvm.Value{}, err + } + closure = c.builder.CreateInsertValue(closure, f.LLVMFn, 1, "") + closure = c.builder.CreateInsertValue(closure, context, 0, "") + return closure, nil +}