From b4c90f36776fcd2f8a2ab30b50928442c973afc5 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Fri, 9 Nov 2018 17:16:36 +0100 Subject: [PATCH] compiler: lower interfaces in a separate pass This commit changes many things: * Most interface-related operations are moved into an optimization pass for more modularity. IR construction creates pseudo-calls which are lowered in this pass. * Type codes are assigned in this interface lowering pass, after DCE. * Type codes are sorted by usage: types more often used in type asserts are assigned lower numbers to ease jump table construction during machine code generation. * Interface assertions are optimized: they are replaced by constant false, comparison against a constant, or a typeswitch with only concrete types in the general case. * Interface calls are replaced with unreachable, direct calls, or a concrete type switch with direct calls depending on the number of implementing types. This hopefully makes some interface patterns zero-cost. These changes lead to a ~0.5K reduction in code size on Cortex-M for testdata/interface.go. It appears that a major cause for this is the replacement of function pointers with direct calls, which are far more susceptible to optimization. Also, not having a fixed global array of function pointers greatly helps dead code elimination. This change also makes future optimizations easier, like optimizations on interface value comparisons. --- compiler/compiler.go | 74 ++-- compiler/interface-lowering.go | 715 +++++++++++++++++++++++++++++++++ compiler/interface.go | 368 +++++++++-------- compiler/optimizer.go | 17 +- interp/frame.go | 2 + interp/scan.go | 5 + ir/interpreter.go | 10 - ir/ir.go | 91 +---- ir/passes.go | 102 +---- main.go | 32 +- src/runtime/interface.go | 142 ++----- src/runtime/print.go | 9 +- testdata/interface.go | 7 + 13 files changed, 1039 insertions(+), 535 deletions(-) create mode 100644 compiler/interface-lowering.go diff --git a/compiler/compiler.go b/compiler/compiler.go index 816f91d8..1dfcde07 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -43,29 +43,30 @@ type Config struct { type Compiler struct { Config - mod llvm.Module - ctx llvm.Context - builder llvm.Builder - dibuilder *llvm.DIBuilder - cu llvm.Metadata - difiles map[string]llvm.Metadata - ditypes map[string]llvm.Metadata - machine llvm.TargetMachine - targetData llvm.TargetData - intType llvm.Type - i8ptrType llvm.Type // for convenience - uintptrType llvm.Type - coroIdFunc llvm.Value - coroSizeFunc llvm.Value - coroBeginFunc llvm.Value - coroSuspendFunc llvm.Value - coroEndFunc llvm.Value - coroFreeFunc llvm.Value - initFuncs []llvm.Value - deferFuncs []*ir.Function - deferInvokeFuncs []InvokeDeferFunction - ctxDeferFuncs []ContextDeferFunction - ir *ir.Program + mod llvm.Module + ctx llvm.Context + builder llvm.Builder + dibuilder *llvm.DIBuilder + cu llvm.Metadata + difiles map[string]llvm.Metadata + ditypes map[string]llvm.Metadata + machine llvm.TargetMachine + targetData llvm.TargetData + intType llvm.Type + i8ptrType llvm.Type // for convenience + uintptrType llvm.Type + coroIdFunc llvm.Value + coroSizeFunc llvm.Value + coroBeginFunc llvm.Value + coroSuspendFunc llvm.Value + coroEndFunc llvm.Value + coroFreeFunc llvm.Value + initFuncs []llvm.Value + deferFuncs []*ir.Function + deferInvokeFuncs []InvokeDeferFunction + ctxDeferFuncs []ContextDeferFunction + interfaceInvokeWrappers []interfaceInvokeWrapper + ir *ir.Program } type Frame struct { @@ -489,6 +490,15 @@ func (c *Compiler) Compile(mainPath string) error { c.builder.CreateRetVoid() } + // Define the already declared functions that wrap methods for use in + // interfaces. + for _, state := range c.interfaceInvokeWrappers { + err = c.createInterfaceInvokeWrapper(state) + if err != nil { + return err + } + } + // After all packages are imported, add a synthetic initializer function // that calls the initializer of each package. initFn := c.ir.GetFunction(c.ir.Program.ImportedPackage("runtime").Members["initAll"].(*ssa.Function)) @@ -534,13 +544,6 @@ func (c *Compiler) Compile(mainPath string) error { } c.builder.CreateRetVoid() - // Add runtime type information for interfaces: interface calls and type - // asserts. - err = c.createInterfaceRTTI() - if err != nil { - return err - } - // see: https://reviews.llvm.org/D18355 c.mod.AddNamedMetadataOperand("llvm.module.flags", c.ctx.MDNode([]llvm.Metadata{ @@ -995,17 +998,6 @@ func (c *Compiler) getInterpretedValue(prefix string, value ir.Value) (llvm.Valu ptr := llvm.ConstInBoundsGEP(value.Global.LLVMGlobal, []llvm.Value{zero}) return ptr, nil - case *ir.InterfaceValue: - underlying := llvm.ConstPointerNull(c.i8ptrType) // could be any 0 value - if value.Elem != nil { - elem, err := c.getInterpretedValue(prefix, value.Elem) - if err != nil { - return llvm.Value{}, err - } - underlying = elem - } - return c.parseMakeInterface(underlying, value.Type, prefix) - case *ir.MapValue: // Create initial bucket. firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(prefix, value.Type) diff --git a/compiler/interface-lowering.go b/compiler/interface-lowering.go new file mode 100644 index 00000000..8108f435 --- /dev/null +++ b/compiler/interface-lowering.go @@ -0,0 +1,715 @@ +package compiler + +// This file provides function to lower interface intrinsics to their final LLVM +// form, optimizing them in the process. +// +// During SSA construction, the following pseudo-calls are created: +// runtime.makeInterface(typecode, methodSet) +// runtime.typeAssert(typecode, assertedType) +// runtime.interfaceImplements(typecode, interfaceMethodSet) +// runtime.interfaceMethod(typecode, interfaceMethodSet, signature) +// See src/runtime/interface.go for details. +// These calls are to declared but not defined functions, so the optimizer will +// leave them alone. +// +// This pass lowers the above functions to their final form: +// +// makeInterface: +// Replaced with a constant typecode. +// +// typeAssert: +// Replaced with an icmp instruction so it can be directly used in a type +// switch. This is very easy to optimize for LLVM: it will often translate a +// type switch into a regular switch statement. +// When this type assert is not possible (the type is never used in an +// interface with makeInterface), this call is replaced with a constant +// false to optimize the type assert away completely. +// +// interfaceImplements: +// This call is translated into a call that checks whether the underlying +// type is one of the types implementing this interface. +// When there is only one type implementing this interface, the check is +// replaced with a simple icmp instruction, just like a type assert. +// When there is no type at all that implements this interface, it is +// replaced with a constant false to optimize it completely. +// +// interfaceMethod: +// This call is replaced with a call to a function that calls the +// appropriate method depending on the underlying type. +// When there is only one type implementing this interface, this call is +// translated into a direct call of that method. +// When there is no type implementing this interface, this code is marked +// unreachable as there is no way such an interface could be constructed. +// +// Note that this way of implementing interfaces is very different from how the +// main Go compiler implements them. For more details on how the main Go +// compiler does it: https://research.swtch.com/interfaces + +import ( + "sort" + "strings" + + "github.com/aykevl/go-llvm" +) + +// signatureInfo is a Go signature of an interface method. It does not represent +// any method in particular. +type signatureInfo struct { + name string + methods []*methodInfo + interfaces []*interfaceInfo +} + +// methodName takes a method name like "func String()" and returns only the +// name, which is "String" in this case. +func (s *signatureInfo) methodName() string { + if !strings.HasPrefix(s.name, "func ") { + panic("signature must start with \"func \"") + } + methodName := s.name[len("func "):] + if openingParen := strings.IndexByte(methodName, '('); openingParen < 0 { + panic("no opening paren in signature name") + } else { + return methodName[:openingParen] + } +} + +// methodInfo describes a single method on a concrete type. +type methodInfo struct { + *signatureInfo + function llvm.Value +} + +// typeInfo describes a single concrete Go type, which can be a basic or a named +// type. If it is a named type, it may have methods. +type typeInfo struct { + name string + typecode llvm.Value + methodSet llvm.Value + num uint64 // the type number after lowering + countMakeInterfaces int // how often this type is used in an interface + countTypeAsserts int // how often a type assert happens on this method + methods []*methodInfo +} + +// getMethod looks up the method on this type with the given signature and +// returns it. The method must exist on this type, otherwise getMethod will +// panic. +func (t *typeInfo) getMethod(signature *signatureInfo) *methodInfo { + for _, method := range t.methods { + if method.signatureInfo == signature { + return method + } + } + panic("could not find method") +} + +// id returns the fully-qualified type name including import path, removing the +// $type suffix. +func (t *typeInfo) id() string { + if !strings.HasSuffix(t.name, "$type") { + panic("concrete type does not have $type suffix: " + t.name) + } + return t.name[:len(t.name)-len("$type")] +} + +// typeInfoSlice implements sort.Slice, sorting the most commonly used types +// first. +type typeInfoSlice []*typeInfo + +func (t typeInfoSlice) Len() int { return len(t) } +func (t typeInfoSlice) Less(i, j int) bool { + // Try to sort the most commonly used types first. + if t[i].countTypeAsserts != t[j].countTypeAsserts { + return t[i].countTypeAsserts < t[j].countTypeAsserts + } + if t[i].countMakeInterfaces != t[j].countMakeInterfaces { + return t[i].countMakeInterfaces < t[j].countMakeInterfaces + } + return t[i].name < t[j].name +} +func (t typeInfoSlice) Swap(i, j int) { t[i], t[j] = t[j], t[i] } + +// interfaceInfo keeps information about a Go interface type, including all +// methods it has. +type interfaceInfo struct { + name string // name with $interface suffix + signatures []*signatureInfo // method set + types typeInfoSlice // types this interface implements + assertFunc llvm.Value // runtime.interfaceImplements replacement + methodFuncs map[*signatureInfo]llvm.Value // runtime.interfaceMethod replacements for each signature +} + +// id removes the $interface suffix from the name and returns the clean +// interface name including import path. +func (itf *interfaceInfo) id() string { + if !strings.HasSuffix(itf.name, "$interface") { + panic("interface type does not have $interface suffix: " + itf.name) + } + return itf.name[:len(itf.name)-len("$interface")] +} + +// lowerInterfacesPass keeps state related to the interface lowering pass. The +// pass has been implemented as an object type because of its complexity, but +// should be seen as a regular function call (see LowerInterfaces). +type lowerInterfacesPass struct { + *Compiler + types map[string]*typeInfo + signatures map[string]*signatureInfo + interfaces map[string]*interfaceInfo +} + +// Lower all interface functions. They are emitted by the compiler as +// higher-level intrinsics that need some lowering before LLVM can work on them. +// This is done so that a few cleanup passes can run before assigning the final +// type codes. +func (c *Compiler) LowerInterfaces() { + p := &lowerInterfacesPass{ + Compiler: c, + types: make(map[string]*typeInfo), + signatures: make(map[string]*signatureInfo), + interfaces: make(map[string]*interfaceInfo), + } + p.run() +} + +// run runs the pass itself. +func (p *lowerInterfacesPass) run() { + // Count per type how often it is put in an interface. Also, collect all + // methods this type has (if it is named). + makeInterface := p.mod.NamedFunction("runtime.makeInterface") + makeInterfaceUses := getUses(makeInterface) + for _, use := range makeInterfaceUses { + typecode := use.Operand(0) + name := typecode.Name() + if t, ok := p.types[name]; !ok { + // This is the first time this type has been seen, add it to the + // list of types. + t = p.addType(typecode) + p.addTypeMethods(t, use.Operand(1)) + } else { + p.addTypeMethods(t, use.Operand(1)) + } + + // Count the number of MakeInterface instructions, for sorting the + // typecodes later. + p.types[name].countMakeInterfaces++ + } + + // Count per type how often it is type asserted on (e.g. in a switch + // statement). + typeAssert := p.mod.NamedFunction("runtime.typeAssert") + typeAssertUses := getUses(typeAssert) + for _, use := range typeAssertUses { + typecode := use.Operand(1) + name := typecode.Name() + if _, ok := p.types[name]; !ok { + p.addType(typecode) + } + p.types[name].countTypeAsserts++ + } + + // Find all interface method calls. + interfaceMethod := p.mod.NamedFunction("runtime.interfaceMethod") + interfaceMethodUses := getUses(interfaceMethod) + for _, use := range interfaceMethodUses { + methodSet := use.Operand(1).Operand(0) + name := methodSet.Name() + if _, ok := p.interfaces[name]; !ok { + p.addInterface(methodSet) + } + } + + // Find all interface type asserts. + interfaceImplements := p.mod.NamedFunction("runtime.interfaceImplements") + interfaceImplementsUses := getUses(interfaceImplements) + for _, use := range interfaceImplementsUses { + methodSet := use.Operand(1).Operand(0) + name := methodSet.Name() + if _, ok := p.interfaces[name]; !ok { + p.addInterface(methodSet) + } + } + + // Find all the interfaces that are implemented per type. + for _, t := range p.types { + // This type has no methods, so don't spend time calculating them. + if len(t.methods) == 0 { + continue + } + + // Pre-calculate a set of signatures that this type has, for easy + // lookup/check. + typeSignatureSet := make(map[*signatureInfo]struct{}) + for _, method := range t.methods { + typeSignatureSet[method.signatureInfo] = struct{}{} + } + + // A set of interfaces, mapped from the name to the info. + // When the name maps to a nil pointer, one of the methods of this type + // exists in the given interface but not all of them so this type + // doesn't implement the interface. + satisfiesInterfaces := make(map[string]*interfaceInfo) + + for _, method := range t.methods { + for _, itf := range method.interfaces { + if _, ok := satisfiesInterfaces[itf.name]; ok { + // interface already checked with a different method + continue + } + // check whether this interface satisfies this type + satisfies := true + for _, itfSignature := range itf.signatures { + if _, ok := typeSignatureSet[itfSignature]; !ok { + satisfiesInterfaces[itf.name] = nil // does not satisfy + satisfies = false + break + } + } + if !satisfies { + continue + } + satisfiesInterfaces[itf.name] = itf + } + } + + // Add this type to all interfaces that satisfy this type. + for _, itf := range satisfiesInterfaces { + if itf == nil { + // Interface does not implement this type, but one of the + // methods on this type also exists on the interface. + continue + } + itf.types = append(itf.types, t) + } + } + + // Sort all types added to the interfaces, to check for more common types + // first. + for _, itf := range p.interfaces { + sort.Sort(itf.types) + } + + // Replace all interface methods with their uses, if possible. + for _, use := range interfaceMethodUses { + typecode := use.Operand(0) + signature := p.signatures[use.Operand(2).Name()] + + // If the interface was created in the same function, we can insert a + // direct call. This may not happen often but it is an easy + // optimization so let's do it anyway. + if !typecode.IsACallInst().IsNil() && typecode.CalledValue() == makeInterface { + name := typecode.Operand(0).Name() + typ := p.types[name] + p.replaceInvokeWithCall(use, typ, signature) + continue + } + + methodSet := use.Operand(1).Operand(0) // global variable + itf := p.interfaces[methodSet.Name()] + if len(itf.types) == 0 { + // This method call is impossible: no type implements this + // interface. In fact, the previous type assert that got this + // interface value should already have returned false. + // Replace the function pointer with undef (which will then be + // called), indicating to the optimizer this code is unreachable. + use.ReplaceAllUsesWith(llvm.Undef(p.i8ptrType)) + use.EraseFromParentAsInstruction() + } else if len(itf.types) == 1 { + // There is only one implementation of the given type. + // Call that function directly. + p.replaceInvokeWithCall(use, itf.types[0], signature) + } else { + // There are multiple types implementing this interface, thus there + // are multiple possible functions to call. Delegate calling the + // right function to a special wrapper function. + bitcasts := getUses(use) + if len(bitcasts) != 1 || bitcasts[0].IsABitCastInst().IsNil() { + panic("expected exactly one bitcast use of runtime.interfaceMethod") + } + bitcast := bitcasts[0] + calls := getUses(bitcast) + if len(calls) != 1 || calls[0].IsACallInst().IsNil() { + panic("expected exactly one call use of runtime.interfaceMethod") + } + call := calls[0] + + // Set up parameters for the call. First copy the regular params... + params := make([]llvm.Value, call.OperandsCount()) + paramTypes := make([]llvm.Type, len(params)) + for i := 0; i < len(params)-1; i++ { + params[i] = call.Operand(i) + paramTypes[i] = params[i].Type() + } + // then add the typecode to the end of the list. + params[len(params)-1] = typecode + paramTypes[len(params)-1] = p.uintptrType + + // Create a function that redirects the call to the destination + // call, after selecting the right concrete type. + redirector := p.getInterfaceMethodFunc(itf, signature, call.Type(), paramTypes) + + // Replace the old lookup/bitcast/call with the new call. + p.builder.SetInsertPointBefore(call) + retval := p.builder.CreateCall(redirector, params, "") + if retval.Type().TypeKind() != llvm.VoidTypeKind { + call.ReplaceAllUsesWith(retval) + } + call.EraseFromParentAsInstruction() + bitcast.EraseFromParentAsInstruction() + use.EraseFromParentAsInstruction() + } + } + + // Replace all typeasserts on interface types with matches on their concrete + // types, if possible. + for _, use := range interfaceImplementsUses { + actualType := use.Operand(0) + if !actualType.IsACallInst().IsNil() && actualType.CalledValue() == makeInterface { + // Type assert is in the same function that creates the interface + // value. This means the underlying type is already known so match + // on that. + // This may not happen often but it is an easy optimization. + name := actualType.Operand(0).Name() + typ := p.types[name] + p.builder.SetInsertPointBefore(use) + assertedType := p.builder.CreatePtrToInt(typ.typecode, p.uintptrType, "typeassert.typecode") + commaOk := p.builder.CreateICmp(llvm.IntEQ, assertedType, actualType, "typeassert.ok") + use.ReplaceAllUsesWith(commaOk) + use.EraseFromParentAsInstruction() + continue + } + + methodSet := use.Operand(1).Operand(0) // global variable + itf := p.interfaces[methodSet.Name()] + if len(itf.types) == 0 { + // There are no types implementing this interface, so this assert + // can never succeed. + // Signal this to the optimizer by branching on constant false. It + // should remove the "then" block. + use.ReplaceAllUsesWith(llvm.ConstInt(p.ctx.Int1Type(), 0, false)) + use.EraseFromParentAsInstruction() + } else if len(itf.types) == 1 { + // There is only one type implementing this interface. + // Transform this interface assert into comparison against a + // constant. + p.builder.SetInsertPointBefore(use) + assertedType := p.builder.CreatePtrToInt(itf.types[0].typecode, p.uintptrType, "typeassert.typecode") + commaOk := p.builder.CreateICmp(llvm.IntEQ, assertedType, actualType, "typeassert.ok") + use.ReplaceAllUsesWith(commaOk) + use.EraseFromParentAsInstruction() + } else { + // There are multiple possible types implementing this interface. + // Create a function that does a type switch on all available types + // that implement this interface. + fn := p.getInterfaceImplementsFunc(itf) + p.builder.SetInsertPointBefore(use) + commaOk := p.builder.CreateCall(fn, []llvm.Value{actualType}, "typeassert.ok") + use.ReplaceAllUsesWith(commaOk) + use.EraseFromParentAsInstruction() + } + } + + // Make a slice of types sorted by frequency of use. + typeSlice := make(typeInfoSlice, 0, len(p.types)) + for _, t := range p.types { + typeSlice = append(typeSlice, t) + } + sort.Sort(typeSlice) + + // A type code must fit in 16 bits. + if len(typeSlice) >= 1<<16 { + panic("typecode does not fit in a uint16: too many types in this program") + } + + // Assign a type code for each type. + for i, t := range typeSlice { + t.num = uint64(i + 1) + } + + // Replace each call to runtime.makeInterface with the constant type code. + for _, use := range makeInterfaceUses { + global := use.Operand(0) + t := p.types[global.Name()] + use.ReplaceAllUsesWith(llvm.ConstPtrToInt(t.typecode, p.uintptrType)) + use.EraseFromParentAsInstruction() + } + + // Replace each type assert with an actual type comparison or (if the type + // assert is impossible) the constant false. + for _, use := range typeAssertUses { + actualType := use.Operand(0) + assertedTypeGlobal := use.Operand(1) + t := p.types[assertedTypeGlobal.Name()] + var commaOk llvm.Value + if t.countMakeInterfaces == 0 { + // impossible type assert: optimize accordingly + commaOk = llvm.ConstInt(llvm.Int1Type(), 0, false) + } else { + // regular type assert + p.builder.SetInsertPointBefore(use) + commaOk = p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(assertedTypeGlobal, p.uintptrType), actualType, "typeassert.ok") + } + use.ReplaceAllUsesWith(commaOk) + use.EraseFromParentAsInstruction() + } + + // Fill in each helper function for type asserts on interfaces + // (interface-to-interface matches). + for _, itf := range p.interfaces { + if !itf.assertFunc.IsNil() { + p.createInterfaceImplementsFunc(itf) + } + for signature := range itf.methodFuncs { + p.createInterfaceMethodFunc(itf, signature) + } + } + + // Replace all ptrtoint typecode placeholders with their final type code + // numbers. + for _, typ := range p.types { + for _, use := range getUses(typ.typecode) { + if use.IsConstant() && use.Opcode() == llvm.PtrToInt { + use.ReplaceAllUsesWith(llvm.ConstInt(p.uintptrType, typ.num, false)) + } + } + } + + // Remove method sets of types. Unnecessary, but cleans up the IR for + // inspection. + for _, typ := range p.types { + if !typ.methodSet.IsNil() { + typ.methodSet.EraseFromParentAsGlobal() + typ.methodSet = llvm.Value{} + } + } +} + +// addType retrieves Go type information based on a i16 global variable. +// Only the name of the i16 is relevant, the object itself is const-propagated +// and discared afterwards. +func (p *lowerInterfacesPass) addType(typecode llvm.Value) *typeInfo { + name := typecode.Name() + t := &typeInfo{ + name: name, + typecode: typecode, + } + p.types[name] = t + return t +} + +// addTypeMethods reads the method set of the given type info struct. It +// retrieves the signatures and the references to the method functions +// themselves for later type<->interface matching. +func (p *lowerInterfacesPass) addTypeMethods(t *typeInfo, methodSet llvm.Value) { + if !t.methodSet.IsNil() || methodSet.IsNull() { + // no methods or methods already read + return + } + methodSet = methodSet.Operand(0) // get global from GEP + + // This type has methods, collect all methods of this type. + t.methodSet = methodSet + set := methodSet.Initializer() // get value from global + for i := 0; i < set.Type().ArrayLength(); i++ { + methodData := llvm.ConstExtractValue(set, []uint32{uint32(i)}) + signatureName := llvm.ConstExtractValue(methodData, []uint32{0}).Name() + function := llvm.ConstExtractValue(methodData, []uint32{1}).Operand(0) + signature := p.getSignature(signatureName) + method := &methodInfo{ + function: function, + signatureInfo: signature, + } + signature.methods = append(signature.methods, method) + t.methods = append(t.methods, method) + } +} + +// addInterface reads information about an interface, which is the +// fully-qualified name and the signatures of all methods it has. +func (p *lowerInterfacesPass) addInterface(methodSet llvm.Value) { + name := methodSet.Name() + t := &interfaceInfo{ + name: name, + } + p.interfaces[name] = t + methodSet = methodSet.Initializer() // get global value from getelementptr + for i := 0; i < methodSet.Type().ArrayLength(); i++ { + signatureName := llvm.ConstExtractValue(methodSet, []uint32{uint32(i)}).Name() + signature := p.getSignature(signatureName) + signature.interfaces = append(signature.interfaces, t) + t.signatures = append(t.signatures, signature) + } +} + +// getSignature returns a new *signatureInfo, creating it if it doesn't already +// exist. +func (p *lowerInterfacesPass) getSignature(name string) *signatureInfo { + if _, ok := p.signatures[name]; !ok { + p.signatures[name] = &signatureInfo{ + name: name, + } + } + return p.signatures[name] +} + +// replaceInvokeWithCall replaces a runtime.interfaceMethod + bitcast with a +// concrete method. This can be done when only one type implements the +// interface. +func (p *lowerInterfacesPass) replaceInvokeWithCall(use llvm.Value, typ *typeInfo, signature *signatureInfo) { + bitcasts := getUses(use) + if len(bitcasts) != 1 || bitcasts[0].IsABitCastInst().IsNil() { + panic("expected exactly one bitcast use of runtime.interfaceMethod") + } + bitcast := bitcasts[0] + function := typ.getMethod(signature).function + if bitcast.Type() != function.Type() { + p.builder.SetInsertPointBefore(use) + function = p.builder.CreateBitCast(function, bitcast.Type(), "") + } + bitcast.ReplaceAllUsesWith(function) + bitcast.EraseFromParentAsInstruction() + use.EraseFromParentAsInstruction() +} + +// getInterfaceImplementsFunc returns a function that checks whether a given +// interface type implements a given interface, by checking all possible types +// that implement this interface. +func (p *lowerInterfacesPass) getInterfaceImplementsFunc(itf *interfaceInfo) llvm.Value { + if !itf.assertFunc.IsNil() { + return itf.assertFunc + } + + // Create the function and function signature. + // TODO: debug info + fnName := itf.id() + "$typeassert" + fnType := llvm.FunctionType(p.ctx.Int1Type(), []llvm.Type{p.uintptrType}, false) + itf.assertFunc = llvm.AddFunction(p.mod, fnName, fnType) + itf.assertFunc.Param(0).SetName("actualType") + + // Type asserts will be made for each type, so increment the counter for + // those. + for _, typ := range itf.types { + typ.countTypeAsserts++ + } + + return itf.assertFunc +} + +// createInterfaceImplementsFunc finishes the work of +// getInterfaceImplementsFunc, because it needs to run after types have a type +// code assigned. +// +// The type match is implemented using a big type switch over all possible +// types. +func (p *lowerInterfacesPass) createInterfaceImplementsFunc(itf *interfaceInfo) { + fn := itf.assertFunc + fn.SetLinkage(llvm.InternalLinkage) + fn.SetUnnamedAddr(true) + + // TODO: debug info + + // Create all used basic blocks. + entry := llvm.AddBasicBlock(fn, "entry") + thenBlock := llvm.AddBasicBlock(fn, "then") + elseBlock := llvm.AddBasicBlock(fn, "else") + + // Add all possible types as cases. + p.builder.SetInsertPointAtEnd(entry) + actualType := fn.Param(0) + sw := p.builder.CreateSwitch(actualType, elseBlock, len(itf.types)) + for _, typ := range itf.types { + sw.AddCase(llvm.ConstInt(p.uintptrType, typ.num, false), thenBlock) + } + + // Fill 'then' block (type assert was successful). + p.builder.SetInsertPointAtEnd(thenBlock) + p.builder.CreateRet(llvm.ConstInt(p.ctx.Int1Type(), 1, false)) + + // Fill 'else' block (type asserted failed). + p.builder.SetInsertPointAtEnd(elseBlock) + p.builder.CreateRet(llvm.ConstInt(p.ctx.Int1Type(), 0, false)) +} + +// getInterfaceMethodFunc return a function that returns a function pointer for +// calling a method on an interface. It only declares the function, +// createInterfaceMethodFunc actually defines the function. +func (p *lowerInterfacesPass) getInterfaceMethodFunc(itf *interfaceInfo, signature *signatureInfo, returnType llvm.Type, params []llvm.Type) llvm.Value { + if fn, ok := itf.methodFuncs[signature]; ok { + // This function has already been created. + return fn + } + if itf.methodFuncs == nil { + // initialize the above map + itf.methodFuncs = make(map[*signatureInfo]llvm.Value) + } + + // Construct the function name, which is of the form: + // (main.Stringer).String + fnName := "(" + itf.id() + ")." + signature.methodName() + fnType := llvm.FunctionType(returnType, params, false) + fn := llvm.AddFunction(p.mod, fnName, fnType) + fn.LastParam().SetName("actualType") + itf.methodFuncs[signature] = fn + return fn +} + +// createInterfaceMethodFunc finishes the work of getInterfaceMethodFunc, +// because it needs to run after type codes have been assigned to concrete +// types. +// +// Matching the actual type is implemented using a big type switch over all +// possible types. +func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, signature *signatureInfo) { + fn := itf.methodFuncs[signature] + fn.SetLinkage(llvm.InternalLinkage) + fn.SetUnnamedAddr(true) + + // TODO: debug info + + // Create entry block. + entry := llvm.AddBasicBlock(fn, "entry") + + // Create default block and make it unreachable (which it is, because all + // possible types are checked). + defaultBlock := llvm.AddBasicBlock(fn, "default") + p.builder.SetInsertPointAtEnd(defaultBlock) + p.builder.CreateUnreachable() + + // Create type switch in entry block. + p.builder.SetInsertPointAtEnd(entry) + actualType := fn.LastParam() + sw := p.builder.CreateSwitch(actualType, defaultBlock, len(itf.types)) + + // Collect the params that will be passed to the functions to call. + // These params exclude the receiver (which may actually consist of multiple + // parts). + params := make([]llvm.Value, fn.ParamsCount()-2) + for i := range params { + params[i] = fn.Param(i + 1) + } + + // Define all possible functions that can be called. + for _, typ := range itf.types { + bb := llvm.AddBasicBlock(fn, typ.id()) + sw.AddCase(llvm.ConstInt(p.uintptrType, typ.num, false), bb) + + // The function we will redirect to when the interface has this type. + function := typ.getMethod(signature).function + + p.builder.SetInsertPointAtEnd(bb) + receiver := fn.FirstParam() + if receiver.Type() != function.FirstParam().Type() { + // When the receiver is a pointer, it is not wrapped. This means the + // i8* has to be cast to the correct pointer type of the target + // function. + receiver = p.builder.CreateBitCast(receiver, function.FirstParam().Type(), "") + } + retval := p.builder.CreateCall(function, append([]llvm.Value{receiver}, params...), "") + if retval.Type().TypeKind() == llvm.VoidTypeKind { + p.builder.CreateRetVoid() + } else { + p.builder.CreateRet(retval) + } + } +} diff --git a/compiler/interface.go b/compiler/interface.go index 129ceb1b..6517f2ed 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -1,5 +1,10 @@ package compiler +// This file transforms interface-related instructions (*ssa.MakeInterface, +// *ssa.TypeAssert, calls on interface types) to an intermediate IR form, to be +// lowered to the final form by the interface lowering pass. See +// interface-lowering.go for more details. + import ( "errors" "go/types" @@ -32,10 +37,10 @@ func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, global str // Allocate on the heap and put a pointer in the interface. // TODO: escape analysis. sizeValue := llvm.ConstInt(c.uintptrType, size, false) - alloc := c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "") - itfValueCast := c.builder.CreateBitCast(alloc, llvm.PointerType(val.Type(), 0), "") + alloc := c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "makeinterface.alloc") + itfValueCast := c.builder.CreateBitCast(alloc, llvm.PointerType(val.Type(), 0), "makeinterface.cast.value") c.builder.CreateStore(val, itfValueCast) - itfValue = c.builder.CreateBitCast(itfValueCast, c.i8ptrType, "") + itfValue = c.builder.CreateBitCast(itfValueCast, c.i8ptrType, "makeinterface.cast.i8ptr") } } else if size == 0 { itfValue = llvm.ConstPointerNull(c.i8ptrType) @@ -43,38 +48,136 @@ func (c *Compiler) parseMakeInterface(val llvm.Value, typ types.Type, global str // Directly place the value in the interface. switch val.Type().TypeKind() { case llvm.IntegerTypeKind: - itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") + itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "makeinterface.cast.int") case llvm.PointerTypeKind: - itfValue = c.builder.CreateBitCast(val, c.i8ptrType, "") + itfValue = c.builder.CreateBitCast(val, c.i8ptrType, "makeinterface.cast.ptr") case llvm.StructTypeKind: // A bitcast would be useful here, but bitcast doesn't allow // aggregate types. So we'll bitcast it using an alloca. // Hopefully this will get optimized away. - mem := c.builder.CreateAlloca(c.i8ptrType, "") - memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(val.Type(), 0), "") + mem := c.builder.CreateAlloca(c.i8ptrType, "makeinterface.cast.struct") + memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(val.Type(), 0), "makeinterface.cast.struct.cast") c.builder.CreateStore(val, memStructPtr) - itfValue = c.builder.CreateLoad(mem, "") + itfValue = c.builder.CreateLoad(mem, "makeinterface.cast.load") default: return llvm.Value{}, errors.New("todo: makeinterface: cast small type to i8*") } } - itfTypeNum, _ := c.ir.TypeNum(typ) - if itfTypeNum >= 1<<16 { - return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer") + itfTypeCodeGlobal := c.getTypeCode(typ) + itfMethodSetGlobal, err := c.getTypeMethodSet(typ) + if err != nil { + return llvm.Value{}, nil } - itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("runtime._interface"), []llvm.Value{llvm.ConstInt(c.ctx.Int16Type(), uint64(itfTypeNum), false), llvm.Undef(c.i8ptrType)}) + itfTypeCode := c.createRuntimeCall("makeInterface", []llvm.Value{itfTypeCodeGlobal, itfMethodSetGlobal}, "makeinterface.typecode") + itf := llvm.Undef(c.mod.GetTypeByName("runtime._interface")) + itf = c.builder.CreateInsertValue(itf, itfTypeCode, 0, "") itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil } +// getTypeCode returns a reference to a type code. +// It returns a pointer to an external global which should be replaced with the +// real type in the interface lowering pass. +func (c *Compiler) getTypeCode(typ types.Type) llvm.Value { + global := c.mod.NamedGlobal(typ.String() + "$type") + if global.IsNil() { + global = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), typ.String()+"$type") + global.SetGlobalConstant(true) + } + return global +} + +// getTypeMethodSet returns a reference (GEP) to a global method set. This +// method set should be unreferenced after the interface lowering pass. +func (c *Compiler) getTypeMethodSet(typ types.Type) (llvm.Value, error) { + global := c.mod.NamedGlobal(typ.String() + "$methodset") + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + if !global.IsNil() { + // the method set already exists + return llvm.ConstGEP(global, []llvm.Value{zero, zero}), nil + } + + ms := c.ir.Program.MethodSets.MethodSet(typ) + if ms.Len() == 0 { + // no methods, so can leave that one out + return llvm.ConstPointerNull(llvm.PointerType(c.mod.GetTypeByName("runtime.interfaceMethodInfo"), 0)), nil + } + + methods := make([]llvm.Value, ms.Len()) + interfaceMethodInfoType := c.mod.GetTypeByName("runtime.interfaceMethodInfo") + for i := 0; i < ms.Len(); i++ { + method := ms.At(i) + signatureGlobal := c.getMethodSignature(method.Obj().(*types.Func)) + f := c.ir.GetFunction(c.ir.Program.MethodValue(method)) + if f.LLVMFn.IsNil() { + // compiler error, so panic + panic("cannot find function: " + f.LinkName()) + } + fn, err := c.getInterfaceInvokeWrapper(f) + if err != nil { + return llvm.Value{}, err + } + methodInfo := llvm.ConstNamedStruct(interfaceMethodInfoType, []llvm.Value{ + signatureGlobal, + llvm.ConstBitCast(fn, c.i8ptrType), + }) + methods[i] = methodInfo + } + arrayType := llvm.ArrayType(interfaceMethodInfoType, len(methods)) + value := llvm.ConstArray(interfaceMethodInfoType, methods) + global = llvm.AddGlobal(c.mod, arrayType, typ.String()+"$methodset") + global.SetInitializer(value) + global.SetGlobalConstant(true) + global.SetLinkage(llvm.PrivateLinkage) + return llvm.ConstGEP(global, []llvm.Value{zero, zero}), nil +} + +// getInterfaceMethodSet returns a global variable with the method set of the +// given named interface type. This method set is used by the interface lowering +// pass. +func (c *Compiler) getInterfaceMethodSet(typ *types.Named) llvm.Value { + global := c.mod.NamedGlobal(typ.String() + "$interface") + zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) + if !global.IsNil() { + // method set already exist, return it + return llvm.ConstGEP(global, []llvm.Value{zero, zero}) + } + + // Every method is a *i16 reference indicating the signature of this method. + methods := make([]llvm.Value, typ.Underlying().(*types.Interface).NumMethods()) + for i := range methods { + method := typ.Underlying().(*types.Interface).Method(i) + methods[i] = c.getMethodSignature(method) + } + + value := llvm.ConstArray(methods[0].Type(), methods) + global = llvm.AddGlobal(c.mod, value.Type(), typ.String()+"$interface") + global.SetInitializer(value) + global.SetGlobalConstant(true) + global.SetLinkage(llvm.PrivateLinkage) + return llvm.ConstGEP(global, []llvm.Value{zero, zero}) +} + +// getMethodSignature returns a global variable which is a reference to an +// external *i16 indicating the indicating the signature of this method. It is +// used during the interface lowering pass. +func (c *Compiler) getMethodSignature(method *types.Func) llvm.Value { + signature := ir.MethodSignature(method) + signatureGlobal := c.mod.NamedGlobal("func " + signature) + if signatureGlobal.IsNil() { + signatureGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), "func "+signature) + signatureGlobal.SetGlobalConstant(true) + } + return signatureGlobal +} + // parseTypeAssert will emit the code for a typeassert, used in if statements -// and in switch statements (Go SSA does not have type switches, only if/else +// and in type switches (Go SSA does not have type switches, only if/else // chains). Note that even though the Go SSA does not contain type switches, // LLVM will recognize the pattern and make it a real switch in many cases. // // Type asserts on concrete types are trivial: just compare type numbers. Type -// asserts on interfaces are more difficult to implement and so are delegated to -// a runtime library function. +// asserts on interfaces are more difficult, see the comments in the function. func (c *Compiler) parseTypeAssert(frame *Frame, expr *ssa.TypeAssert) (llvm.Value, error) { itf, err := c.parseExpr(frame, expr.X) if err != nil { @@ -91,38 +194,24 @@ func (c *Compiler) parseTypeAssert(frame *Frame, expr *ssa.TypeAssert) (llvm.Val actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") commaOk := llvm.Value{} - if itf, ok := expr.AssertedType.Underlying().(*types.Interface); ok { + if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok { // Type assert on interface type. - // This is slightly non-trivial: at runtime the list of methods - // needs to be checked to see whether it implements the interface. - // At the same time, the interface value itself is unchanged. - itfTypeNum := c.ir.InterfaceNum(itf) - itfTypeNumValue := llvm.ConstInt(c.ctx.Int16Type(), uint64(itfTypeNum), false) - commaOk = c.createRuntimeCall("interfaceImplements", []llvm.Value{actualTypeNum, itfTypeNumValue}, "") + // This pseudo call will be lowered in the interface lowering pass to a + // real call which checks whether the provided typecode is any of the + // concrete types that implements this interface. + // This is very different from how interface asserts are implemented in + // the main Go compiler, where the runtime checks whether the type + // implements each method of the interface. See: + // https://research.swtch.com/interfaces + methodSet := c.getInterfaceMethodSet(expr.AssertedType.(*types.Named)) + commaOk = c.createRuntimeCall("interfaceImplements", []llvm.Value{actualTypeNum, methodSet}, "") } else { // Type assert on concrete type. - // This is easy: just compare the type number. - assertedTypeNum, typeExists := c.ir.TypeNum(expr.AssertedType) - if !typeExists { - // Static analysis has determined this type assert will never apply. - // Using undef here so that LLVM knows we'll never get here and - // can optimize accordingly. - undef := llvm.Undef(assertedType) - commaOk := llvm.ConstInt(c.ctx.Int1Type(), 0, false) - if expr.CommaOk { - return c.ctx.ConstStruct([]llvm.Value{undef, commaOk}, false), nil - } else { - c.createRuntimeCall("interfaceTypeAssert", []llvm.Value{commaOk}, "") - return undef, nil - } - } - if assertedTypeNum >= 1<<16 { - return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer") - } - - assertedTypeNumValue := llvm.ConstInt(c.ctx.Int16Type(), uint64(assertedTypeNum), false) - commaOk = c.builder.CreateICmp(llvm.IntEQ, assertedTypeNumValue, actualTypeNum, "") + // Call runtime.typeAssert, which will be lowered to a simple icmp or + // const false in the interface lowering pass. + assertedTypeCodeGlobal := c.getTypeCode(expr.AssertedType) + commaOk = c.createRuntimeCall("typeAssert", []llvm.Value{actualTypeNum, assertedTypeCodeGlobal}, "typecode") } // Add 2 new basic blocks (that should get optimized away): one for the @@ -171,16 +260,14 @@ func (c *Compiler) parseTypeAssert(frame *Frame, expr *ssa.TypeAssert) (llvm.Val valueOk = c.builder.CreatePtrToInt(valuePtr, assertedType, "typeassert.value.ok") case llvm.PointerTypeKind: valueOk = c.builder.CreateBitCast(valuePtr, assertedType, "typeassert.value.ok") - case llvm.StructTypeKind: + default: // struct, float, etc. // A bitcast would be useful here, but bitcast doesn't allow // aggregate types. So we'll bitcast it using an alloca. // Hopefully this will get optimized away. mem := c.builder.CreateAlloca(c.i8ptrType, "") c.builder.CreateStore(valuePtr, mem) - memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") - valueOk = c.builder.CreateLoad(memStructPtr, "typeassert.value.ok") - default: - return llvm.Value{}, errors.New("todo: typeassert: bitcast small types") + memCast := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") + valueOk = c.builder.CreateLoad(memCast, "typeassert.value.ok") } } } @@ -231,7 +318,8 @@ func (c *Compiler) getInvokeCall(frame *Frame, instr *ssa.CallCommon) (llvm.Valu typecode := c.builder.CreateExtractValue(itf, 0, "invoke.typecode") values := []llvm.Value{ typecode, - llvm.ConstInt(c.ctx.Int16Type(), uint64(c.ir.MethodNum(instr.Method)), false), + c.getInterfaceMethodSet(instr.Value.Type().(*types.Named)), + c.getMethodSignature(instr.Method), } fn := c.createRuntimeCall("interfaceMethod", values, "invoke.func") fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast") @@ -255,142 +343,40 @@ func (c *Compiler) getInvokeCall(frame *Frame, instr *ssa.CallCommon) (llvm.Valu return fnCast, args, nil } -// Initialize runtime type information, for interfaces. -// See src/runtime/interface.go for more details. -func (c *Compiler) createInterfaceRTTI() error { - dynamicTypes := c.ir.AllDynamicTypes() - numDynamicTypes := 0 - for _, meta := range dynamicTypes { - numDynamicTypes += len(meta.Methods) - } - ranges := make([]llvm.Value, 0, len(dynamicTypes)) - funcPointers := make([]llvm.Value, 0, numDynamicTypes) - signatures := make([]llvm.Value, 0, numDynamicTypes) - startIndex := 0 - rangeType := c.mod.GetTypeByName("runtime.methodSetRange") - for _, meta := range dynamicTypes { - rangeValues := []llvm.Value{ - llvm.ConstInt(c.ctx.Int16Type(), uint64(startIndex), false), - llvm.ConstInt(c.ctx.Int16Type(), uint64(len(meta.Methods)), false), - } - rangeValue := llvm.ConstNamedStruct(rangeType, rangeValues) - ranges = append(ranges, rangeValue) - methods := make([]*types.Selection, 0, len(meta.Methods)) - for _, method := range meta.Methods { - methods = append(methods, method) - } - c.ir.SortMethods(methods) - for _, method := range methods { - f := c.ir.GetFunction(c.ir.Program.MethodValue(method)) - if f.LLVMFn.IsNil() { - return errors.New("cannot find function: " + f.LinkName()) - } - fn, err := c.wrapInterfaceInvoke(f) - if err != nil { - return err - } - fnPtr := llvm.ConstBitCast(fn, c.i8ptrType) - funcPointers = append(funcPointers, fnPtr) - signatureNum := c.ir.MethodNum(method.Obj().(*types.Func)) - signature := llvm.ConstInt(c.ctx.Int16Type(), uint64(signatureNum), false) - signatures = append(signatures, signature) - } - startIndex += len(meta.Methods) - } - - interfaceTypes := c.ir.AllInterfaces() - interfaceIndex := make([]llvm.Value, len(interfaceTypes)) - interfaceLengths := make([]llvm.Value, len(interfaceTypes)) - interfaceMethods := make([]llvm.Value, 0) - for i, itfType := range interfaceTypes { - if itfType.Type.NumMethods() > 0xff { - return errors.New("too many methods for interface " + itfType.Type.String()) - } - interfaceIndex[i] = llvm.ConstInt(c.ctx.Int16Type(), uint64(i), false) - interfaceLengths[i] = llvm.ConstInt(c.ctx.Int8Type(), uint64(itfType.Type.NumMethods()), false) - funcs := make([]*types.Func, itfType.Type.NumMethods()) - for i := range funcs { - funcs[i] = itfType.Type.Method(i) - } - c.ir.SortFuncs(funcs) - for _, f := range funcs { - id := llvm.ConstInt(c.ctx.Int16Type(), uint64(c.ir.MethodNum(f)), false) - interfaceMethods = append(interfaceMethods, id) - } - } - - if len(ranges) >= 1<<16 { - return errors.New("method call numbers do not fit in a 16-bit integer") - } - - // Replace the pre-created arrays with the generated arrays. - rangeArray := llvm.ConstArray(rangeType, ranges) - rangeArrayNewGlobal := llvm.AddGlobal(c.mod, rangeArray.Type(), "runtime.methodSetRanges.tmp") - rangeArrayNewGlobal.SetInitializer(rangeArray) - rangeArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - rangeArrayOldGlobal := c.mod.NamedGlobal("runtime.methodSetRanges") - rangeArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(rangeArrayNewGlobal, rangeArrayOldGlobal.Type())) - rangeArrayOldGlobal.EraseFromParentAsGlobal() - rangeArrayNewGlobal.SetName("runtime.methodSetRanges") - funcArray := llvm.ConstArray(c.i8ptrType, funcPointers) - funcArrayNewGlobal := llvm.AddGlobal(c.mod, funcArray.Type(), "runtime.methodSetFunctions.tmp") - funcArrayNewGlobal.SetInitializer(funcArray) - funcArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - funcArrayOldGlobal := c.mod.NamedGlobal("runtime.methodSetFunctions") - funcArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(funcArrayNewGlobal, funcArrayOldGlobal.Type())) - funcArrayOldGlobal.EraseFromParentAsGlobal() - funcArrayNewGlobal.SetName("runtime.methodSetFunctions") - signatureArray := llvm.ConstArray(c.ctx.Int16Type(), signatures) - signatureArrayNewGlobal := llvm.AddGlobal(c.mod, signatureArray.Type(), "runtime.methodSetSignatures.tmp") - signatureArrayNewGlobal.SetInitializer(signatureArray) - signatureArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - signatureArrayOldGlobal := c.mod.NamedGlobal("runtime.methodSetSignatures") - signatureArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(signatureArrayNewGlobal, signatureArrayOldGlobal.Type())) - signatureArrayOldGlobal.EraseFromParentAsGlobal() - signatureArrayNewGlobal.SetName("runtime.methodSetSignatures") - interfaceIndexArray := llvm.ConstArray(c.ctx.Int16Type(), interfaceIndex) - interfaceIndexArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceIndexArray.Type(), "runtime.interfaceIndex.tmp") - interfaceIndexArrayNewGlobal.SetInitializer(interfaceIndexArray) - interfaceIndexArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - interfaceIndexArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceIndex") - interfaceIndexArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceIndexArrayNewGlobal, interfaceIndexArrayOldGlobal.Type())) - interfaceIndexArrayOldGlobal.EraseFromParentAsGlobal() - interfaceIndexArrayNewGlobal.SetName("runtime.interfaceIndex") - interfaceLengthsArray := llvm.ConstArray(c.ctx.Int8Type(), interfaceLengths) - interfaceLengthsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceLengthsArray.Type(), "runtime.interfaceLengths.tmp") - interfaceLengthsArrayNewGlobal.SetInitializer(interfaceLengthsArray) - interfaceLengthsArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - interfaceLengthsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceLengths") - interfaceLengthsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceLengthsArrayNewGlobal, interfaceLengthsArrayOldGlobal.Type())) - interfaceLengthsArrayOldGlobal.EraseFromParentAsGlobal() - interfaceLengthsArrayNewGlobal.SetName("runtime.interfaceLengths") - interfaceMethodsArray := llvm.ConstArray(c.ctx.Int16Type(), interfaceMethods) - interfaceMethodsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceMethodsArray.Type(), "runtime.interfaceMethods.tmp") - interfaceMethodsArrayNewGlobal.SetInitializer(interfaceMethodsArray) - interfaceMethodsArrayNewGlobal.SetLinkage(llvm.InternalLinkage) - interfaceMethodsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceMethods") - interfaceMethodsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceMethodsArrayNewGlobal, interfaceMethodsArrayOldGlobal.Type())) - interfaceMethodsArrayOldGlobal.EraseFromParentAsGlobal() - interfaceMethodsArrayNewGlobal.SetName("runtime.interfaceMethods") - - c.mod.NamedGlobal("runtime.firstTypeWithMethods").SetInitializer(llvm.ConstInt(c.ctx.Int16Type(), uint64(c.ir.FirstDynamicType()), false)) - - return nil +// interfaceInvokeWrapper keeps some state between getInterfaceInvokeWrapper and +// createInterfaceInvokeWrapper. The former is called during IR construction +// itself and the latter is called when finishing up the IR. +type interfaceInvokeWrapper struct { + fn *ir.Function + wrapper llvm.Value + receiverType llvm.Type } // Wrap an interface method function pointer. The wrapper takes in a pointer to // the underlying value, dereferences it, and calls the real method. This // wrapper is only needed when the interface value actually doesn't fit in a // pointer and a pointer to the value must be created. -func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { +func (c *Compiler) getInterfaceInvokeWrapper(f *ir.Function) (llvm.Value, error) { + wrapperName := f.LinkName() + "$invoke" + wrapper := c.mod.NamedFunction(wrapperName) + if !wrapper.IsNil() { + // Wrapper already created. Return it directly. + return wrapper, nil + } + + // Get the expanded receiver type. receiverType, err := c.getLLVMType(f.Params[0].Type()) if err != nil { return llvm.Value{}, err } expandedReceiverType := c.expandFormalParamType(receiverType) - if c.targetData.TypeAllocSize(receiverType) <= c.targetData.TypeAllocSize(c.i8ptrType) && len(expandedReceiverType) == 1 { - // nothing to wrap + // Does this method even need any wrapping? + if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind { + // Nothing to wrap. + // Casting a function signature to a different signature and calling it + // with a receiver pointer bitcasted to *i8 (as done in calls on an + // interface) is hopefully a safe (defined) operation. return f.LLVMFn, nil } @@ -398,16 +384,30 @@ func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { fnType := f.LLVMFn.Type().ElementType() paramTypes := append([]llvm.Type{c.i8ptrType}, fnType.ParamTypes()[len(expandedReceiverType):]...) wrapFnType := llvm.FunctionType(fnType.ReturnType(), paramTypes, false) - wrapper := llvm.AddFunction(c.mod, f.LinkName()+"$invoke", wrapFnType) + wrapper = llvm.AddFunction(c.mod, wrapperName, wrapFnType) + c.interfaceInvokeWrappers = append(c.interfaceInvokeWrappers, interfaceInvokeWrapper{ + fn: f, + wrapper: wrapper, + receiverType: receiverType, + }) + return wrapper, nil +} + +// createInterfaceInvokeWrapper finishes the work of getInterfaceInvokeWrapper, +// see that function for details. +func (c *Compiler) createInterfaceInvokeWrapper(state interfaceInvokeWrapper) error { + wrapper := state.wrapper + fn := state.fn + receiverType := state.receiverType wrapper.SetLinkage(llvm.InternalLinkage) wrapper.SetUnnamedAddr(true) - // add debug info + // add debug info if needed if c.Debug { - pos := c.ir.Program.Fset.Position(f.Pos()) - difunc, err := c.attachDebugInfoRaw(f, wrapper, "$invoke", pos.Filename, pos.Line) + pos := c.ir.Program.Fset.Position(fn.Pos()) + difunc, err := c.attachDebugInfoRaw(fn, wrapper, "$invoke", pos.Filename, pos.Line) if err != nil { - return llvm.Value{}, err + return err } c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), difunc, llvm.Metadata{}) } @@ -424,7 +424,7 @@ func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { // Load the underlying value. receiverPtrType := llvm.PointerType(receiverType, 0) receiverPtr = c.builder.CreateBitCast(wrapper.Param(0), receiverPtrType, "receiver.ptr") - } else if len(expandedReceiverType) != 1 { + } else { // The value is stored in the interface, but it is of type struct which // is expanded to multiple parameters (e.g. {i8, i8}). So we have to // receive the struct as parameter, expand it, and pass it on to the @@ -435,19 +435,17 @@ func (c *Compiler) wrapInterfaceInvoke(f *ir.Function) (llvm.Value, error) { alloca := c.builder.CreateAlloca(c.i8ptrType, "receiver.alloca") c.builder.CreateStore(wrapper.Param(0), alloca) receiverPtr = c.builder.CreateBitCast(alloca, llvm.PointerType(receiverType, 0), "receiver.ptr") - } else { - panic("unreachable") } receiverValue := c.builder.CreateLoad(receiverPtr, "receiver") params := append(c.expandFormalParam(receiverValue), wrapper.Params()[1:]...) - if fnType.ReturnType().TypeKind() == llvm.VoidTypeKind { - c.builder.CreateCall(f.LLVMFn, params, "") + if fn.LLVMFn.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { + c.builder.CreateCall(fn.LLVMFn, params, "") c.builder.CreateRetVoid() } else { - ret := c.builder.CreateCall(f.LLVMFn, params, "ret") + ret := c.builder.CreateCall(fn.LLVMFn, params, "ret") c.builder.CreateRet(ret) } - return wrapper, nil + return nil } diff --git a/compiler/optimizer.go b/compiler/optimizer.go index 06c0aae6..08d759b1 100644 --- a/compiler/optimizer.go +++ b/compiler/optimizer.go @@ -1,12 +1,14 @@ package compiler import ( + "errors" + "github.com/aykevl/go-llvm" ) // Run the LLVM optimizer over the module. // The inliner can be disabled (if necessary) by passing 0 to the inlinerThreshold. -func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) { +func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) error { builder := llvm.NewPassManagerBuilder() defer builder.Dispose() builder.SetOptLevel(optLevel) @@ -40,7 +42,13 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) { c.OptimizeMaps() c.OptimizeStringToBytes() c.OptimizeAllocs() - c.Verify() + c.LowerInterfaces() + } else { + // Must be run at any optimization level. + c.LowerInterfaces() + } + if err := c.Verify(); err != nil { + return errors.New("optimizations caused a verification failure") } // Run module passes. @@ -48,6 +56,8 @@ func (c *Compiler) Optimize(optLevel, sizeLevel int, inlinerThreshold uint) { defer modPasses.Dispose() builder.Populate(modPasses) modPasses.Run(c.mod) + + return nil } // Eliminate created but not used maps. @@ -299,6 +309,9 @@ func (c *Compiler) hasFlag(call, param llvm.Value, kind string) bool { // Return a list of values (actually, instructions) where this value is used as // an operand. func getUses(value llvm.Value) []llvm.Value { + if value.IsNil() { + return nil + } var uses []llvm.Value use := value.FirstUse() for !use.IsNil() { diff --git a/interp/frame.go b/interp/frame.go index b3ba843b..f0218cd7 100644 --- a/interp/frame.go +++ b/interp/frame.go @@ -300,6 +300,8 @@ func (fr *frame) evalBasicBlock(bb, incoming llvm.BasicBlock, indent string) (re ret = llvm.ConstInsertValue(ret, retLen, []uint32{1}) // len ret = llvm.ConstInsertValue(ret, retLen, []uint32{2}) // cap fr.locals[inst] = &LocalValue{fr.Eval, ret} + case callee.Name() == "runtime.makeInterface": + fr.locals[inst] = &LocalValue{fr.Eval, llvm.ConstPtrToInt(inst.Operand(0), fr.TargetData.IntPtrType())} case strings.HasPrefix(callee.Name(), "runtime.print") || callee.Name() == "runtime._panic": // all print instructions, which necessarily have side // effects but no results diff --git a/interp/scan.go b/interp/scan.go index b0d2d3aa..6265f57a 100644 --- a/interp/scan.go +++ b/interp/scan.go @@ -73,7 +73,12 @@ func (e *Eval) hasSideEffects(fn llvm.Value) *sideEffectResult { result.updateSeverity(sideEffectAll) continue } + name := child.Name() if child.IsDeclaration() { + if name == "runtime.makeInterface" { + // Can be interpreted so does not have side effects. + continue + } // External function call. Assume only limited side effects // (no affected globals, etc.). if result.hasLocalSideEffects(dirtyLocals, inst) { diff --git a/ir/interpreter.go b/ir/interpreter.go index ac95a92c..ed81478a 100644 --- a/ir/interpreter.go +++ b/ir/interpreter.go @@ -278,8 +278,6 @@ func (p *Program) interpret(instrs []ssa.Instruction, paramKeys []*ssa.Parameter } else { return i, errors.New("todo: init IndexAddr index: " + instr.Index.String()) } - case *ssa.MakeInterface: - locals[instr] = &InterfaceValue{instr.X.Type(), locals[instr.X]} case *ssa.MakeMap: locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil} case *ssa.MapUpdate: @@ -388,7 +386,6 @@ func canInterpret(callee *ssa.Function) bool { case *ssa.Extract: case *ssa.FieldAddr: case *ssa.IndexAddr: - case *ssa.MakeInterface: case *ssa.MakeMap: case *ssa.MapUpdate: case *ssa.Return: @@ -447,8 +444,6 @@ func (p *Program) getZeroValue(t types.Type) (Value, error) { return &ZeroBasicValue{typ}, nil case *types.Signature: return &FunctionValue{typ, nil}, nil - case *types.Interface: - return &InterfaceValue{typ, nil}, nil case *types.Map: return &MapValue{typ, nil, nil}, nil case *types.Pointer: @@ -492,11 +487,6 @@ type FunctionValue struct { Elem *ssa.Function } -type InterfaceValue struct { - Type types.Type - Elem Value -} - type PointerBitCastValue struct { Type types.Type Elem Value diff --git a/ir/ir.go b/ir/ir.go index 9f838a99..5da2e480 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -19,21 +19,18 @@ import ( // View on all functions, types, and globals in a program, with analysis // results. type Program struct { - Program *ssa.Program - mainPkg *ssa.Package - Functions []*Function - functionMap map[*ssa.Function]*Function - Globals []*Global - globalMap map[*ssa.Global]*Global - comments map[string]*ast.CommentGroup - NamedTypes []*NamedType - needsScheduler bool - goCalls []*ssa.Go - typesWithMethods map[string]*TypeWithMethods // see AnalyseInterfaceConversions - typesWithoutMethods map[string]int // see AnalyseInterfaceConversions - methodSignatureNames map[string]int // see MethodNum - interfaces map[string]*Interface // see AnalyseInterfaceConversions - fpWithContext map[string]struct{} // see AnalyseFunctionPointers + Program *ssa.Program + mainPkg *ssa.Package + Functions []*Function + functionMap map[*ssa.Function]*Function + Globals []*Global + globalMap map[*ssa.Global]*Global + comments map[string]*ast.CommentGroup + NamedTypes []*NamedType + needsScheduler bool + goCalls []*ssa.Go + typesInInterfaces map[string]struct{} // see AnalyseInterfaceConversions + fpWithContext map[string]struct{} // see AnalyseFunctionPointers } // Function or method. @@ -179,13 +176,11 @@ func NewProgram(lprogram *loader.Program, mainPath string) *Program { } p := &Program{ - Program: program, - mainPkg: mainPkg, - functionMap: make(map[*ssa.Function]*Function), - globalMap: make(map[*ssa.Global]*Global), - methodSignatureNames: make(map[string]int), - interfaces: make(map[string]*Interface), - comments: comments, + Program: program, + mainPkg: mainPkg, + functionMap: make(map[*ssa.Function]*Function), + globalMap: make(map[*ssa.Global]*Global), + comments: comments, } for _, pkg := range packageList { @@ -270,18 +265,6 @@ func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global { return p.globalMap[ssaGlobal] } -// SortMethods sorts the list of methods by method ID. -func (p *Program) SortMethods(methods []*types.Selection) { - m := &methodList{methods: methods, program: p} - sort.Sort(m) -} - -// SortFuncs sorts the list of functions by method ID. -func (p *Program) SortFuncs(funcs []*types.Func) { - m := &funcList{funcs: funcs, program: p} - sort.Sort(m) -} - func (p *Program) MainPkg() *ssa.Package { return p.mainPkg } @@ -442,46 +425,6 @@ func (p *Program) IsVolatile(t types.Type) bool { } } -// Wrapper type to implement sort.Interface for []*types.Selection. -type methodList struct { - methods []*types.Selection - program *Program -} - -func (m *methodList) Len() int { - return len(m.methods) -} - -func (m *methodList) Less(i, j int) bool { - iid := m.program.MethodNum(m.methods[i].Obj().(*types.Func)) - jid := m.program.MethodNum(m.methods[j].Obj().(*types.Func)) - return iid < jid -} - -func (m *methodList) Swap(i, j int) { - m.methods[i], m.methods[j] = m.methods[j], m.methods[i] -} - -// Wrapper type to implement sort.Interface for []*types.Func. -type funcList struct { - funcs []*types.Func - program *Program -} - -func (fl *funcList) Len() int { - return len(fl.funcs) -} - -func (fl *funcList) Less(i, j int) bool { - iid := fl.program.MethodNum(fl.funcs[i]) - jid := fl.program.MethodNum(fl.funcs[j]) - return iid < jid -} - -func (fl *funcList) Swap(i, j int) { - fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i] -} - // Return true if this is a CGo-internal function that can be ignored. func isCGoInternal(name string) bool { if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { diff --git a/ir/passes.go b/ir/passes.go index 46133a10..e2dbbb88 100644 --- a/ir/passes.go +++ b/ir/passes.go @@ -2,8 +2,6 @@ package ir import ( "go/types" - "sort" - "strings" "golang.org/x/tools/go/ssa" ) @@ -59,18 +57,6 @@ func Signature(sig *types.Signature) string { return s } -// Convert an interface type to a string of all method strings, separated by -// "; ". For example: "Read([]byte) (int, error); Close() error" -func InterfaceKey(itf *types.Interface) string { - methodStrings := []string{} - for i := 0; i < itf.NumMethods(); i++ { - method := itf.Method(i) - methodStrings = append(methodStrings, MethodSignature(method)) - } - sort.Strings(methodStrings) - return strings.Join(methodStrings, ";") -} - // Fill in parents of all functions. // // All packages need to be added before this pass can run, or it will produce @@ -117,30 +103,17 @@ func (p *Program) AnalyseCallgraph() { // Find all types that are put in an interface. func (p *Program) AnalyseInterfaceConversions() { - // Clear, if AnalyseTypes has been called before. - p.typesWithoutMethods = map[string]int{"nil": 0} - p.typesWithMethods = map[string]*TypeWithMethods{} + // Clear, if AnalyseInterfaceConversions has been called before. + p.typesInInterfaces = map[string]struct{}{} for _, f := range p.Functions { for _, block := range f.Blocks { for _, instr := range block.Instrs { switch instr := instr.(type) { case *ssa.MakeInterface: - methods := getAllMethods(f.Prog, instr.X.Type()) name := instr.X.Type().String() - if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 { - t := &TypeWithMethods{ - t: instr.X.Type(), - Num: len(p.typesWithMethods), - Methods: make(map[string]*types.Selection), - } - for _, sel := range methods { - name := MethodSignature(sel.Obj().(*types.Func)) - t.Methods[name] = sel - } - p.typesWithMethods[name] = t - } else if _, ok := p.typesWithoutMethods[name]; !ok && len(methods) == 0 { - p.typesWithoutMethods[name] = len(p.typesWithoutMethods) + if _, ok := p.typesInInterfaces[name]; !ok { + p.typesInInterfaces[name] = struct{}{} } } } @@ -349,75 +322,10 @@ func (p *Program) IsBlocking(f *Function) bool { return f.blocking } -// Return the type number and whether this type is actually used. Used in -// interface conversions (type is always used) and type asserts (type may not be -// used, meaning assert is always false in this program). -// -// May only be used after all packages have been added to the analyser. -func (p *Program) TypeNum(typ types.Type) (int, bool) { - if n, ok := p.typesWithoutMethods[typ.String()]; ok { - return n, true - } else if meta, ok := p.typesWithMethods[typ.String()]; ok { - return len(p.typesWithoutMethods) + meta.Num, true - } else { - return -1, false // type is never put in an interface - } -} - -// InterfaceNum returns the numeric interface ID of this type, for use in type -// asserts. -func (p *Program) InterfaceNum(itfType *types.Interface) int { - key := InterfaceKey(itfType) - if itf, ok := p.interfaces[key]; !ok { - num := len(p.interfaces) - p.interfaces[key] = &Interface{Num: num, Type: itfType} - return num - } else { - return itf.Num - } -} - -// MethodNum returns the numeric ID of this method, to be used in method lookups -// on interfaces for example. -func (p *Program) MethodNum(method *types.Func) int { - name := MethodSignature(method) - if _, ok := p.methodSignatureNames[name]; !ok { - p.methodSignatureNames[name] = len(p.methodSignatureNames) - } - return p.methodSignatureNames[MethodSignature(method)] -} - -// The start index of the first dynamic type that has methods. -// Types without methods always have a lower ID and types with methods have this -// or a higher ID. -// -// May only be used after all packages have been added to the analyser. -func (p *Program) FirstDynamicType() int { - return len(p.typesWithoutMethods) -} - -// Return all types with methods, sorted by type ID. -func (p *Program) AllDynamicTypes() []*TypeWithMethods { - l := make([]*TypeWithMethods, len(p.typesWithMethods)) - for _, m := range p.typesWithMethods { - l[m.Num] = m - } - return l -} - -// Return all interface types, sorted by interface ID. -func (p *Program) AllInterfaces() []*Interface { - l := make([]*Interface, len(p.interfaces)) - for _, itf := range p.interfaces { - l[itf.Num] = itf - } - return l -} - func (p *Program) FunctionNeedsContext(f *Function) bool { if !f.addressTaken { if f.Signature.Recv() != nil { - _, hasInterfaceConversion := p.TypeNum(f.Signature.Recv().Type()) + _, hasInterfaceConversion := p.typesInInterfaces[f.Signature.Recv().Type().String()] if hasInterfaceConversion && p.SignatureNeedsContext(f.Signature) { return true } diff --git a/main.go b/main.go index 7466b8ad..d85554ac 100644 --- a/main.go +++ b/main.go @@ -64,7 +64,7 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act fmt.Println(c.IR()) } if err := c.Verify(); err != nil { - return err + return errors.New("verification error after IR construction") } if config.initInterp { @@ -73,13 +73,13 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act return err } if err := c.Verify(); err != nil { - return err + return errors.New("verification error after interpreting runtime.initAll") } } c.ApplyFunctionSections() // -ffunction-sections if err := c.Verify(); err != nil { - return err + return errors.New("verification error after applying function sections") } // Browsers cannot handle external functions that have type i64 because it @@ -92,7 +92,7 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act return err } if err := c.Verify(); err != nil { - return err + return errors.New("verification error after running the wasm i64 hack") } } @@ -100,20 +100,23 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act // exactly. switch config.opt { case "none:", "0": - c.Optimize(0, 0, 0) // -O0 + err = c.Optimize(0, 0, 0) // -O0 case "1": - c.Optimize(1, 0, 0) // -O1 + err = c.Optimize(1, 0, 0) // -O1 case "2": - c.Optimize(2, 0, 225) // -O2 + err = c.Optimize(2, 0, 225) // -O2 case "s": - c.Optimize(2, 1, 225) // -Os + err = c.Optimize(2, 1, 225) // -Os case "z": - c.Optimize(2, 2, 5) // -Oz, default + err = c.Optimize(2, 2, 5) // -Oz, default default: - return errors.New("unknown optimization level: -opt=" + config.opt) + err = errors.New("unknown optimization level: -opt=" + config.opt) + } + if err != nil { + return err } if err := c.Verify(); err != nil { - return err + return errors.New("verification failure after LLVM optimization passes") } // On the AVR, pointers can point either to flash or to RAM, but we don't @@ -124,7 +127,7 @@ func Compile(pkgName, outpath string, spec *TargetSpec, config *BuildConfig, act if strings.HasPrefix(spec.Triple, "avr") { c.NonConstGlobals() if err := c.Verify(); err != nil { - return err + return errors.New("verification error after making all globals non-constant on AVR") } } @@ -382,7 +385,10 @@ func Run(pkgName string) error { // -Oz, which is the fastest optimization level (faster than -O0, -O1, -O2 // and -Os). Turn off the inliner, as the inliner increases optimization // time. - c.Optimize(2, 2, 0) + err = c.Optimize(2, 2, 0) + if err != nil { + return err + } engine, err := llvm.NewExecutionEngine(c.Module()) if err != nil { diff --git a/src/runtime/interface.go b/src/runtime/interface.go index b29b0766..7eba2a79 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -4,64 +4,12 @@ package runtime // // Interfaces are represented as a pair of {typecode, value}, where value can be // anything (including non-pointers). -// -// Signatures itself are not matched on strings, but on uniqued numbers that -// contain the name and the signature of the function (to save space), think of -// signatures as interned strings at compile time. -// -// The typecode is a small number unique for the Go type. All typecodes < -// firstTypeWithMethods do not have any methods and typecodes >= -// firstTypeWithMethods all have at least one method. This means that -// methodSetRanges does not need to contain types without methods and is thus -// indexed starting at a typecode with number firstTypeWithMethods. -// -// To further conserve some space, the methodSetRange (as the name indicates) -// doesn't contain a list of methods and function pointers directly, but instead -// just indexes into methodSetSignatures and methodSetFunctions which contains -// the mapping from uniqued signature to function pointer. type _interface struct { - typecode uint16 + typecode uintptr value *uint8 } -// This struct indicates the range of methods in the methodSetSignatures and -// methodSetFunctions arrays that belong to this named type. -type methodSetRange struct { - index uint16 // start index into interfaceSignatures and interfaceFunctions - length uint16 // number of methods -} - -// Global constants that will be set by the compiler. The arrays are of size 0, -// which is a dummy value, but will be bigger after the compiler has filled them -// in. -var ( - firstTypeWithMethods uint16 // the lowest typecode that has at least one method - methodSetRanges [0]methodSetRange // indices into methodSetSignatures and methodSetFunctions - methodSetSignatures [0]uint16 // uniqued method ID - methodSetFunctions [0]*uint8 // function pointer of method - interfaceIndex [0]uint16 // mapping from interface ID to an index in interfaceMethods - interfaceLengths [0]uint8 // mapping from interface ID to the number of methods it has - interfaceMethods [0]uint16 // the method an interface implements (list of method IDs) -) - -// Get the function pointer for the method on the interface. -// This is a compiler intrinsic. -//go:nobounds -func interfaceMethod(typecode uint16, method uint16) *uint8 { - // This function doesn't do bounds checking as the supplied method must be - // in the list of signatures. The compiler will only emit - // runtime.interfaceMethod calls when the method actually exists on this - // interface (proven by the typechecker). - i := methodSetRanges[typecode-firstTypeWithMethods].index - for { - if methodSetSignatures[i] == method { - return methodSetFunctions[i] - } - i++ - } -} - // Return true iff both interfaces are equal. func interfaceEqual(x, y _interface) bool { if x.typecode != y.typecode { @@ -76,67 +24,37 @@ func interfaceEqual(x, y _interface) bool { panic("unimplemented: interface equality") } -// Return true iff the type implements all methods needed by the interface. This -// means the type satisfies the interface. -// This is a compiler intrinsic. -//go:nobounds -func interfaceImplements(typecode, interfaceNum uint16) bool { - // method set indices of the interface - itfIndex := interfaceIndex[interfaceNum] - itfIndexEnd := itfIndex + uint16(interfaceLengths[interfaceNum]) - - if itfIndex == itfIndexEnd { - // This interface has no methods, so it satisfies all types. - // TODO: this should be figured out at compile time (as it is known at - // compile time), so that this check is unnecessary at runtime. - return true - } - - if typecode < firstTypeWithMethods { - // Type has no methods while the interface has (checked above), so this - // type does not satisfy this interface. - return false - } - - // method set indices of the concrete type - methodSet := methodSetRanges[typecode-firstTypeWithMethods] - methodIndex := methodSet.index - methodIndexEnd := methodSet.index + methodSet.length - - // Iterate over all methods of the interface: - for itfIndex < itfIndexEnd { - methodId := interfaceMethods[itfIndex] - if methodIndex >= methodIndexEnd { - // Reached the end of the list of methods, so interface doesn't - // implement this type. - return false - } - if methodId == methodSetSignatures[methodIndex] { - // Found a matching method, continue to the next method. - itfIndex++ - methodIndex++ - continue - } else if methodId > methodSetSignatures[methodIndex] { - // The method didn't match, but method ID of the concrete type was - // lower than that of the interface, so probably it has a method the - // interface doesn't implement. - // Move on to the next method of the concrete type. - methodIndex++ - continue - } else { - // The concrete type is missing a method. This means the type assert - // fails. - return false - } - } - - // Found a method for each expected method in the interface. This type - // assert is successful. - return true -} - +// interfaceTypeAssert is called when a type assert without comma-ok still +// returns false. func interfaceTypeAssert(ok bool) { if !ok { runtimePanic("type assert failed") } } + +// The following declarations are only used during IR construction. They are +// lowered to inline IR in the interface lowering pass. +// See compiler/interface-lowering.go for details. + +type interfaceMethodInfo struct { + signature *uint8 // external *i8 with a name identifying the Go function signature + funcptr *uint8 // bitcast from the actual function pointer +} + +// Pseudo function call used while putting a concrete value in an interface, +// that must be lowered to a constant uintptr. +func makeInterface(typecode *uint8, methodSet *interfaceMethodInfo) uintptr + +// Pseudo function call used during a type assert. It is used during interface +// lowering, to assign the lowest type numbers to the types with the most type +// asserts. Also, it is replaced with const false if this type assert can never +// happen. +func typeAssert(actualType uintptr, assertedType *uint8) bool + +// Pseudo function call that returns whether a given type implements all methods +// of the given interface. +func interfaceImplements(typecode uintptr, interfaceMethodSet **uint8) bool + +// Pseudo function that returns a function pointer to the method to call. +// See the interface lowering pass for how this is lowered to a real call. +func interfaceMethod(typecode uintptr, interfaceMethodSet **uint8, signature *uint8) *uint8 diff --git a/src/runtime/print.go b/src/runtime/print.go index abde9a20..e0d6fc43 100644 --- a/src/runtime/print.go +++ b/src/runtime/print.go @@ -205,7 +205,14 @@ func printitf(msg interface{}) { // cast to underlying type itf := *(*_interface)(unsafe.Pointer(&msg)) putchar('(') - print(itf.typecode) + switch unsafe.Sizeof(itf.typecode) { + case 2: + printuint16(uint16(itf.typecode)) + case 4: + printuint32(uint32(itf.typecode)) + case 8: + printuint64(uint64(itf.typecode)) + } putchar(':') print(itf.value) putchar(')') diff --git a/testdata/interface.go b/testdata/interface.go index e931934a..3c6871d8 100644 --- a/testdata/interface.go +++ b/testdata/interface.go @@ -26,6 +26,8 @@ func main() { func printItf(val interface{}) { switch val := val.(type) { + case Unmatched: + panic("matched the unmatchable") case Doubler: println("is Doubler:", val.Double()) case Tuple: @@ -127,3 +129,8 @@ func (p SmallPair) Nth(n int) uint32 { func (p SmallPair) Print() { println("SmallPair.Print:", p.a, p.b) } + +// There is no type that matches this method. +type Unmatched interface { + NeverImplementedMethod() +}