diff --git a/compiler.go b/compiler.go index 661b9ff6..e215e75c 100644 --- a/compiler.go +++ b/compiler.go @@ -426,7 +426,12 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } rangeValue := llvm.ConstNamedStruct(rangeType, rangeValues) ranges = append(ranges, rangeValue) + methods := make([]*types.Selection, 0, len(meta.Methods)) for _, method := range meta.Methods { + methods = append(methods, method) + } + c.ir.SortMethods(methods) + for _, method := range methods { f := c.ir.GetFunction(program.MethodValue(method)) if f.llvmFn.IsNil() { return errors.New("cannot find function: " + f.LinkName()) @@ -440,6 +445,25 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { startIndex += len(meta.Methods) } + interfaceTypes := c.ir.AllInterfaces() + interfaceLengths := make([]llvm.Value, len(interfaceTypes)) + interfaceMethods := make([]llvm.Value, 0) + for i, itfType := range interfaceTypes { + if itfType.Type.NumMethods() > 0xff { + return errors.New("too many methods for interface " + itfType.Type.String()) + } + interfaceLengths[i] = llvm.ConstInt(llvm.Int8Type(), uint64(itfType.Type.NumMethods()), false) + funcs := make([]*types.Func, itfType.Type.NumMethods()) + for i := range funcs { + funcs[i] = itfType.Type.Method(i) + } + c.ir.SortFuncs(funcs) + for _, f := range funcs { + id := llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.MethodNum(f)), false) + interfaceMethods = append(interfaceMethods, id) + } + } + if len(ranges) >= 1<<16 { return errors.New("method call numbers do not fit in a 16-bit integer") } @@ -469,8 +493,24 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { signatureArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(signatureArrayNewGlobal, signatureArrayOldGlobal.Type())) signatureArrayOldGlobal.EraseFromParentAsGlobal() signatureArrayNewGlobal.SetName("runtime.methodSetSignatures") + interfaceLengthsArray := llvm.ConstArray(llvm.Int8Type(), interfaceLengths) + interfaceLengthsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceLengthsArray.Type(), "runtime.interfaceLengths.tmp") + interfaceLengthsArrayNewGlobal.SetInitializer(interfaceLengthsArray) + interfaceLengthsArrayNewGlobal.SetLinkage(llvm.InternalLinkage) + interfaceLengthsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceLengths") + interfaceLengthsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceLengthsArrayNewGlobal, interfaceLengthsArrayOldGlobal.Type())) + interfaceLengthsArrayOldGlobal.EraseFromParentAsGlobal() + interfaceLengthsArrayNewGlobal.SetName("runtime.interfaceLengths") + interfaceMethodsArray := llvm.ConstArray(llvm.Int16Type(), interfaceMethods) + interfaceMethodsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceMethodsArray.Type(), "runtime.interfaceMethods.tmp") + interfaceMethodsArrayNewGlobal.SetInitializer(interfaceMethodsArray) + interfaceMethodsArrayNewGlobal.SetLinkage(llvm.InternalLinkage) + interfaceMethodsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceMethods") + interfaceMethodsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceMethodsArrayNewGlobal, interfaceMethodsArrayOldGlobal.Type())) + interfaceMethodsArrayOldGlobal.EraseFromParentAsGlobal() + interfaceMethodsArrayNewGlobal.SetName("runtime.interfaceMethods") - c.mod.NamedGlobal("runtime.firstInterfaceNum").SetInitializer(llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.FirstDynamicType()), false)) + c.mod.NamedGlobal("runtime.firstTypeWithMethods").SetInitializer(llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.FirstDynamicType()), false)) // see: https://reviews.llvm.org/D18355 c.mod.AddNamedMetadataOperand("llvm.module.flags", @@ -2238,25 +2278,43 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } - if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok { - // TODO: check whether the type implements the interface. - return llvm.Value{}, errors.New("todo: assert on interface") - } + assertedType, err := c.getLLVMType(expr.AssertedType) if err != nil { return llvm.Value{}, err } - assertedTypeNum, typeExists := c.ir.TypeNum(expr.AssertedType) - if !typeExists { - // Static analysis has determined this type assert will never apply. - return llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil + valueNil, err := getZeroValue(assertedType) + if err != nil { + return llvm.Value{}, err } - if assertedTypeNum >= 1<<16 { - return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer") - } - actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") - commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false), actualTypeNum, "") + actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") + commaOk := llvm.Value{} + if itf, ok := expr.AssertedType.Underlying().(*types.Interface); ok { + // Type assert on interface type. + // This is slightly non-trivial: at runtime the list of methods + // needs to be checked to see whether it implements the interface. + // At the same time, the interface value itself is unchanged. + itfTypeNum := c.ir.InterfaceNum(itf) + itfTypeNumValue := llvm.ConstInt(llvm.Int16Type(), uint64(itfTypeNum), false) + fn := c.mod.NamedFunction("runtime.interfaceImplements") + commaOk = c.builder.CreateCall(fn, []llvm.Value{actualTypeNum, itfTypeNumValue}, "") + + } else { + // Type assert on concrete type. + // This is easy: just compare the type number. + assertedTypeNum, typeExists := c.ir.TypeNum(expr.AssertedType) + if !typeExists { + // Static analysis has determined this type assert will never apply. + return llvm.ConstStruct([]llvm.Value{valueNil, llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil + } + if assertedTypeNum >= 1<<16 { + return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer") + } + + assertedTypeNumValue := llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false) + commaOk = c.builder.CreateICmp(llvm.IntEQ, assertedTypeNumValue, actualTypeNum, "") + } // Add 2 new basic blocks (that should get optimized away): one for the // 'ok' case and one for all instructions following this type assert. @@ -2269,11 +2327,6 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // typeassert should return a zero value, not an incorrectly casted // value. - valueNil, err := getZeroValue(assertedType) - if err != nil { - return llvm.Value{}, err - } - prevBlock := c.builder.GetInsertBlock() okBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.ok") nextBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.next") @@ -2282,29 +2335,37 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Retrieve the value from the interface if the type assert was // successful. c.builder.SetInsertPointAtEnd(okBlock) - valuePtr := c.builder.CreateExtractValue(itf, 1, "typeassert.value.ptr") var valueOk llvm.Value - if c.targetData.TypeAllocSize(assertedType) > c.targetData.TypeAllocSize(c.i8ptrType) { - // Value was stored in an allocated buffer, load it from there. - valuePtrCast := c.builder.CreateBitCast(valuePtr, llvm.PointerType(assertedType, 0), "") - valueOk = c.builder.CreateLoad(valuePtrCast, "typeassert.value.ok") + if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok { + // Type assert on interface type. Easy: just return the same + // interface value. + valueOk = itf } else { - // Value was stored directly in the interface. - switch assertedType.TypeKind() { - case llvm.IntegerTypeKind: - valueOk = c.builder.CreatePtrToInt(valuePtr, assertedType, "typeassert.value.ok") - case llvm.PointerTypeKind: - valueOk = c.builder.CreateBitCast(valuePtr, assertedType, "typeassert.value.ok") - case llvm.StructTypeKind: - // A bitcast would be useful here, but bitcast doesn't allow - // aggregate types. So we'll bitcast it using an alloca. - // Hopefully this will get optimized away. - mem := c.builder.CreateAlloca(c.i8ptrType, "") - c.builder.CreateStore(valuePtr, mem) - memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") - valueOk = c.builder.CreateLoad(memStructPtr, "typeassert.value.ok") - default: - return llvm.Value{}, errors.New("todo: typeassert: bitcast small types") + // Type assert on concrete type. Extract the underlying type from + // the interface (but only after checking it matches). + valuePtr := c.builder.CreateExtractValue(itf, 1, "typeassert.value.ptr") + if c.targetData.TypeAllocSize(assertedType) > c.targetData.TypeAllocSize(c.i8ptrType) { + // Value was stored in an allocated buffer, load it from there. + valuePtrCast := c.builder.CreateBitCast(valuePtr, llvm.PointerType(assertedType, 0), "") + valueOk = c.builder.CreateLoad(valuePtrCast, "typeassert.value.ok") + } else { + // Value was stored directly in the interface. + switch assertedType.TypeKind() { + case llvm.IntegerTypeKind: + valueOk = c.builder.CreatePtrToInt(valuePtr, assertedType, "typeassert.value.ok") + case llvm.PointerTypeKind: + valueOk = c.builder.CreateBitCast(valuePtr, assertedType, "typeassert.value.ok") + case llvm.StructTypeKind: + // A bitcast would be useful here, but bitcast doesn't allow + // aggregate types. So we'll bitcast it using an alloca. + // Hopefully this will get optimized away. + mem := c.builder.CreateAlloca(c.i8ptrType, "") + c.builder.CreateStore(valuePtr, mem) + memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") + valueOk = c.builder.CreateLoad(memStructPtr, "typeassert.value.ok") + default: + return llvm.Value{}, errors.New("todo: typeassert: bitcast small types") + } } } c.builder.CreateBr(nextBlock) diff --git a/ir.go b/ir.go index cf6c0958..52ec5326 100644 --- a/ir.go +++ b/ir.go @@ -25,10 +25,11 @@ type Program struct { NamedTypes []*NamedType needsScheduler bool goCalls []*ssa.Go - typesWithMethods map[string]*InterfaceType // see AnalyseInterfaceConversions - typesWithoutMethods map[string]int // see AnalyseInterfaceConversions - methodSignatureNames map[string]int - fpWithContext map[string]struct{} // see AnalyseFunctionPointers + typesWithMethods map[string]*TypeWithMethods // see AnalyseInterfaceConversions + typesWithoutMethods map[string]int // see AnalyseInterfaceConversions + methodSignatureNames map[string]int // see MethodNum + interfaces map[string]*Interface // see AnalyseInterfaceConversions + fpWithContext map[string]struct{} // see AnalyseFunctionPointers } // Function or method. @@ -60,12 +61,19 @@ type NamedType struct { } // Type that is at some point put in an interface. -type InterfaceType struct { +type TypeWithMethods struct { t types.Type Num int Methods map[string]*types.Selection } +// Interface type that is at some point used in a type assert (to check whether +// it implements another interface). +type Interface struct { + Num int + Type *types.Interface +} + // Create and intialize a new *Program from a *ssa.Program. func NewProgram(program *ssa.Program, mainPath string) *Program { return &Program{ @@ -74,6 +82,7 @@ func NewProgram(program *ssa.Program, mainPath string) *Program { functionMap: make(map[*ssa.Function]*Function), globalMap: make(map[*ssa.Global]*Global), methodSignatureNames: make(map[string]int), + interfaces: make(map[string]*Interface), } } @@ -148,6 +157,18 @@ func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global { return p.globalMap[ssaGlobal] } +// SortMethods sorts the list of methods by method ID. +func (p *Program) SortMethods(methods []*types.Selection) { + m := &methodList{methods: methods, program: p} + sort.Sort(m) +} + +// SortFuncs sorts the list of functions by method ID. +func (p *Program) SortFuncs(funcs []*types.Func) { + m := &funcList{funcs: funcs, program: p} + sort.Sort(m) +} + // Parse compiler directives in the preceding comments. func (f *Function) parsePragmas() { if f.fn.Syntax() == nil { @@ -236,3 +257,43 @@ func (g *Global) LinkName() string { func (g *Global) IsExtern() bool { return strings.HasPrefix(g.g.Name(), "_extern_") } + +// Wrapper type to implement sort.Interface for []*types.Selection. +type methodList struct { + methods []*types.Selection + program *Program +} + +func (m *methodList) Len() int { + return len(m.methods) +} + +func (m *methodList) Less(i, j int) bool { + iid := m.program.MethodNum(m.methods[i].Obj().(*types.Func)) + jid := m.program.MethodNum(m.methods[j].Obj().(*types.Func)) + return iid < jid +} + +func (m *methodList) Swap(i, j int) { + m.methods[i], m.methods[j] = m.methods[j], m.methods[i] +} + +// Wrapper type to implement sort.Interface for []*types.Func. +type funcList struct { + funcs []*types.Func + program *Program +} + +func (fl *funcList) Len() int { + return len(fl.funcs) +} + +func (fl *funcList) Less(i, j int) bool { + iid := fl.program.MethodNum(fl.funcs[i]) + jid := fl.program.MethodNum(fl.funcs[j]) + return iid < jid +} + +func (fl *funcList) Swap(i, j int) { + fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i] +} diff --git a/passes.go b/passes.go index eb784f18..82962bf3 100644 --- a/passes.go +++ b/passes.go @@ -2,6 +2,9 @@ package main import ( "go/types" + "sort" + "strings" + "golang.org/x/tools/go/ssa" ) @@ -56,6 +59,18 @@ func Signature(sig *types.Signature) string { return s } +// Convert an interface type to a string of all method strings, separated by +// "; ". For example: "Read([]byte) (int, error); Close() error" +func InterfaceKey(itf *types.Interface) string { + methodStrings := []string{} + for i := 0; i < itf.NumMethods(); i++ { + method := itf.Method(i) + methodStrings = append(methodStrings, MethodSignature(method)) + } + sort.Strings(methodStrings) + return strings.Join(methodStrings, ";") +} + // Fill in parents of all functions. // // All packages need to be added before this pass can run, or it will produce @@ -104,7 +119,7 @@ func (p *Program) AnalyseCallgraph() { func (p *Program) AnalyseInterfaceConversions() { // Clear, if AnalyseTypes has been called before. p.typesWithoutMethods = map[string]int{"nil": 0} - p.typesWithMethods = map[string]*InterfaceType{} + p.typesWithMethods = map[string]*TypeWithMethods{} for _, f := range p.Functions { for _, block := range f.fn.Blocks { @@ -114,7 +129,7 @@ func (p *Program) AnalyseInterfaceConversions() { methods := getAllMethods(f.fn.Prog, instr.X.Type()) name := instr.X.Type().String() if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 { - t := &InterfaceType{ + t := &TypeWithMethods{ t: instr.X.Type(), Num: len(p.typesWithMethods), Methods: make(map[string]*types.Selection), @@ -271,7 +286,13 @@ func (p *Program) SimpleDCE() { for _, instr := range block.Instrs { if instr, ok := instr.(*ssa.MakeInterface); ok { for _, sel := range getAllMethods(p.program, instr.X.Type()) { - callee := p.GetFunction(p.program.MethodValue(sel)) + fn := p.program.MethodValue(sel) + callee := p.GetFunction(fn) + if callee == nil { + // TODO: why is this necessary? + p.addFunction(fn) + callee = p.GetFunction(fn) + } if !callee.flag { callee.flag = true worklist = append(worklist, callee.fn) @@ -361,6 +382,19 @@ func (p *Program) TypeNum(typ types.Type) (int, bool) { } } +// InterfaceNum returns the numeric interface ID of this type, for use in type +// asserts. +func (p *Program) InterfaceNum(itfType *types.Interface) int { + key := InterfaceKey(itfType) + if itf, ok := p.interfaces[key]; !ok { + num := len(p.interfaces) + p.interfaces[key] = &Interface{Num: num, Type: itfType} + return num + } else { + return itf.Num + } +} + // MethodNum returns the numeric ID of this method, to be used in method lookups // on interfaces for example. func (p *Program) MethodNum(method *types.Func) int { @@ -381,14 +415,23 @@ func (p *Program) FirstDynamicType() int { } // Return all types with methods, sorted by type ID. -func (p *Program) AllDynamicTypes() []*InterfaceType { - l := make([]*InterfaceType, len(p.typesWithMethods)) +func (p *Program) AllDynamicTypes() []*TypeWithMethods { + l := make([]*TypeWithMethods, len(p.typesWithMethods)) for _, m := range p.typesWithMethods { l[m.Num] = m } return l } +// Return all interface types, sorted by interface ID. +func (p *Program) AllInterfaces() []*Interface { + l := make([]*Interface, len(p.interfaces)) + for _, itf := range p.interfaces { + l[itf.Num] = itf + } + return l +} + func (p *Program) FunctionNeedsContext(f *Function) bool { if !f.addressTaken { return false diff --git a/src/examples/test/test.go b/src/examples/test/test.go index 537bc77d..4d4b83f4 100644 --- a/src/examples/test/test.go +++ b/src/examples/test/test.go @@ -16,6 +16,18 @@ type Stringer interface { String() string } +type Foo int + +type Number int + +func (n Number) Double() int { + return int(n) * 2 +} + +type Doubler interface { + Double() int +} + const SIX = 6 var testmap = map[string]int{"data": 3} @@ -54,6 +66,7 @@ func main() { printItf(*thing) printItf(thing) printItf(Stringer(thing)) + printItf(Number(3)) s := Stringer(thing) println("Stringer.String():", s.String()) @@ -107,6 +120,8 @@ func strlen(s string) int { func printItf(val interface{}) { switch val := val.(type) { + case Doubler: + println("is Doubler:", val.Double()) case int: println("is int:", val) case byte: @@ -117,6 +132,8 @@ func printItf(val interface{}) { println("is Thing:", val.String()) case *Thing: println("is *Thing:", val.String()) + case Foo: + println("is Foo:", val) default: println("is ?") } diff --git a/src/runtime/interface.go b/src/runtime/interface.go index aaaa575d..455788d6 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -10,10 +10,10 @@ package runtime // signatures as interned strings at compile time. // // The typecode is a small number unique for the Go type. All typecodes < -// firstInterfaceNum do not have any methods and typecodes >= firstInterfaceNum -// all have at least one method. This means that methodSetRanges does not need -// to contain types without methods and is thus indexed starting at a typecode -// with number firstInterfaceNum. +// firstTypeWithMethods do not have any methods and typecodes >= +// firstTypeWithMethods all have at least one method. This means that +// methodSetRanges does not need to contain types without methods and is thus +// indexed starting at a typecode with number firstTypeWithMethods. // // To further conserve some space, the methodSetRange (as the name indicates) // doesn't contain a list of methods and function pointers directly, but instead @@ -36,10 +36,13 @@ type methodSetRange struct { // which is a dummy value, but will be bigger after the compiler has filled them // in. var ( - firstInterfaceNum uint16 // the lowest typecode that has at least one method - methodSetRanges [0]methodSetRange // indexes into methodSetSignatures and methodSetFunctions - methodSetSignatures [0]uint16 // uniqued method ID - methodSetFunctions [0]*uint8 // function pointer of method + firstTypeWithMethods uint16 // the lowest typecode that has at least one method + methodSetRanges [0]methodSetRange // indexes into methodSetSignatures and methodSetFunctions + methodSetSignatures [0]uint16 // uniqued method ID + methodSetFunctions [0]*uint8 // function pointer of method + interfaceIndex [0]uint16 // mapping from interface ID to an index in interfaceMethods + interfaceLengths [0]uint8 // mapping from interface ID to the number of methods it has + interfaceMethods [0]uint16 // the method an interface implements (list of method IDs) ) // Get the function pointer for the method on the interface. @@ -50,7 +53,7 @@ func interfaceMethod(itf _interface, method uint16) *uint8 { // in the list of signatures. The compiler will only emit // runtime.interfaceMethod calls when the method actually exists on this // interface (proven by the typechecker). - i := methodSetRanges[itf.typecode-firstInterfaceNum].index + i := methodSetRanges[itf.typecode-firstTypeWithMethods].index for { if methodSetSignatures[i] == method { return methodSetFunctions[i] @@ -72,3 +75,49 @@ func interfaceEqual(x, y _interface) bool { // TODO: depends on reflection. panic("unimplemented: interface equality") } + +// Return true iff the type implements all methods needed by the interface. This +// means the type satisfies the interface. +// This is a compiler intrinsic. +//go:nobounds +func interfaceImplements(typecode, interfaceNum uint16) bool { + // method set indexes of the concrete type + methodSet := methodSetRanges[typecode-firstTypeWithMethods] + methodIndex := methodSet.index + methodIndexEnd := methodSet.index + methodSet.length + + // method set indexes of the interface + itfIndex := interfaceIndex[interfaceNum] + itfIndexEnd := itfIndex + uint16(interfaceLengths[interfaceNum]) + + // Iterate over all methods of the interface: + for itfIndex < itfIndexEnd { + methodId := interfaceMethods[itfIndex] + if methodIndex >= methodIndexEnd { + // Reached the end of the list of methods, so interface doesn't + // implement this type. + return false + } + if methodId == methodSetSignatures[methodIndex] { + // Found a matching method, continue to the next method. + itfIndex++ + methodIndex++ + continue + } else if methodId > methodSetSignatures[methodIndex] { + // The method didn't match, but method ID of the concrete type was + // lower than that of the interface, so probably it has a method the + // interface doesn't implement. + // Move on to the next method of the concrete type. + methodIndex++ + continue + } else { + // The concrete type is missing a method. This means the type assert + // fails. + return false + } + } + + // Found a method for each expected method in the interface. This type + // assert is successful. + return true +}