compiler: implement interface assertions

This is a lot harder than 'regular' type assertions as the actual
methods need to be checked.
Этот коммит содержится в:
Ayke van Laethem 2018-09-06 20:18:18 +02:00
родитель 30ac6ec281
коммит 43b8c24226
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: E97FF5335DFDFDED
5 изменённых файлов: 290 добавлений и 59 удалений

Просмотреть файл

@ -426,7 +426,12 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
}
rangeValue := llvm.ConstNamedStruct(rangeType, rangeValues)
ranges = append(ranges, rangeValue)
methods := make([]*types.Selection, 0, len(meta.Methods))
for _, method := range meta.Methods {
methods = append(methods, method)
}
c.ir.SortMethods(methods)
for _, method := range methods {
f := c.ir.GetFunction(program.MethodValue(method))
if f.llvmFn.IsNil() {
return errors.New("cannot find function: " + f.LinkName())
@ -440,6 +445,25 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
startIndex += len(meta.Methods)
}
interfaceTypes := c.ir.AllInterfaces()
interfaceLengths := make([]llvm.Value, len(interfaceTypes))
interfaceMethods := make([]llvm.Value, 0)
for i, itfType := range interfaceTypes {
if itfType.Type.NumMethods() > 0xff {
return errors.New("too many methods for interface " + itfType.Type.String())
}
interfaceLengths[i] = llvm.ConstInt(llvm.Int8Type(), uint64(itfType.Type.NumMethods()), false)
funcs := make([]*types.Func, itfType.Type.NumMethods())
for i := range funcs {
funcs[i] = itfType.Type.Method(i)
}
c.ir.SortFuncs(funcs)
for _, f := range funcs {
id := llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.MethodNum(f)), false)
interfaceMethods = append(interfaceMethods, id)
}
}
if len(ranges) >= 1<<16 {
return errors.New("method call numbers do not fit in a 16-bit integer")
}
@ -469,8 +493,24 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
signatureArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(signatureArrayNewGlobal, signatureArrayOldGlobal.Type()))
signatureArrayOldGlobal.EraseFromParentAsGlobal()
signatureArrayNewGlobal.SetName("runtime.methodSetSignatures")
interfaceLengthsArray := llvm.ConstArray(llvm.Int8Type(), interfaceLengths)
interfaceLengthsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceLengthsArray.Type(), "runtime.interfaceLengths.tmp")
interfaceLengthsArrayNewGlobal.SetInitializer(interfaceLengthsArray)
interfaceLengthsArrayNewGlobal.SetLinkage(llvm.InternalLinkage)
interfaceLengthsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceLengths")
interfaceLengthsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceLengthsArrayNewGlobal, interfaceLengthsArrayOldGlobal.Type()))
interfaceLengthsArrayOldGlobal.EraseFromParentAsGlobal()
interfaceLengthsArrayNewGlobal.SetName("runtime.interfaceLengths")
interfaceMethodsArray := llvm.ConstArray(llvm.Int16Type(), interfaceMethods)
interfaceMethodsArrayNewGlobal := llvm.AddGlobal(c.mod, interfaceMethodsArray.Type(), "runtime.interfaceMethods.tmp")
interfaceMethodsArrayNewGlobal.SetInitializer(interfaceMethodsArray)
interfaceMethodsArrayNewGlobal.SetLinkage(llvm.InternalLinkage)
interfaceMethodsArrayOldGlobal := c.mod.NamedGlobal("runtime.interfaceMethods")
interfaceMethodsArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(interfaceMethodsArrayNewGlobal, interfaceMethodsArrayOldGlobal.Type()))
interfaceMethodsArrayOldGlobal.EraseFromParentAsGlobal()
interfaceMethodsArrayNewGlobal.SetName("runtime.interfaceMethods")
c.mod.NamedGlobal("runtime.firstInterfaceNum").SetInitializer(llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.FirstDynamicType()), false))
c.mod.NamedGlobal("runtime.firstTypeWithMethods").SetInitializer(llvm.ConstInt(llvm.Int16Type(), uint64(c.ir.FirstDynamicType()), false))
// see: https://reviews.llvm.org/D18355
c.mod.AddNamedMetadataOperand("llvm.module.flags",
@ -2238,25 +2278,43 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
if err != nil {
return llvm.Value{}, err
}
if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok {
// TODO: check whether the type implements the interface.
return llvm.Value{}, errors.New("todo: assert on interface")
}
assertedType, err := c.getLLVMType(expr.AssertedType)
if err != nil {
return llvm.Value{}, err
}
valueNil, err := getZeroValue(assertedType)
if err != nil {
return llvm.Value{}, err
}
actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type")
commaOk := llvm.Value{}
if itf, ok := expr.AssertedType.Underlying().(*types.Interface); ok {
// Type assert on interface type.
// This is slightly non-trivial: at runtime the list of methods
// needs to be checked to see whether it implements the interface.
// At the same time, the interface value itself is unchanged.
itfTypeNum := c.ir.InterfaceNum(itf)
itfTypeNumValue := llvm.ConstInt(llvm.Int16Type(), uint64(itfTypeNum), false)
fn := c.mod.NamedFunction("runtime.interfaceImplements")
commaOk = c.builder.CreateCall(fn, []llvm.Value{actualTypeNum, itfTypeNumValue}, "")
} else {
// Type assert on concrete type.
// This is easy: just compare the type number.
assertedTypeNum, typeExists := c.ir.TypeNum(expr.AssertedType)
if !typeExists {
// Static analysis has determined this type assert will never apply.
return llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil
return llvm.ConstStruct([]llvm.Value{valueNil, llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil
}
if assertedTypeNum >= 1<<16 {
return llvm.Value{}, errors.New("interface typecodes do not fit in a 16-bit integer")
}
actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type")
commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false), actualTypeNum, "")
assertedTypeNumValue := llvm.ConstInt(llvm.Int16Type(), uint64(assertedTypeNum), false)
commaOk = c.builder.CreateICmp(llvm.IntEQ, assertedTypeNumValue, actualTypeNum, "")
}
// Add 2 new basic blocks (that should get optimized away): one for the
// 'ok' case and one for all instructions following this type assert.
@ -2269,11 +2327,6 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
// typeassert should return a zero value, not an incorrectly casted
// value.
valueNil, err := getZeroValue(assertedType)
if err != nil {
return llvm.Value{}, err
}
prevBlock := c.builder.GetInsertBlock()
okBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.ok")
nextBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.next")
@ -2282,8 +2335,15 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
// Retrieve the value from the interface if the type assert was
// successful.
c.builder.SetInsertPointAtEnd(okBlock)
valuePtr := c.builder.CreateExtractValue(itf, 1, "typeassert.value.ptr")
var valueOk llvm.Value
if _, ok := expr.AssertedType.Underlying().(*types.Interface); ok {
// Type assert on interface type. Easy: just return the same
// interface value.
valueOk = itf
} else {
// Type assert on concrete type. Extract the underlying type from
// the interface (but only after checking it matches).
valuePtr := c.builder.CreateExtractValue(itf, 1, "typeassert.value.ptr")
if c.targetData.TypeAllocSize(assertedType) > c.targetData.TypeAllocSize(c.i8ptrType) {
// Value was stored in an allocated buffer, load it from there.
valuePtrCast := c.builder.CreateBitCast(valuePtr, llvm.PointerType(assertedType, 0), "")
@ -2307,6 +2367,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
return llvm.Value{}, errors.New("todo: typeassert: bitcast small types")
}
}
}
c.builder.CreateBr(nextBlock)
// Continue after the if statement.

67
ir.go
Просмотреть файл

@ -25,9 +25,10 @@ type Program struct {
NamedTypes []*NamedType
needsScheduler bool
goCalls []*ssa.Go
typesWithMethods map[string]*InterfaceType // see AnalyseInterfaceConversions
typesWithMethods map[string]*TypeWithMethods // see AnalyseInterfaceConversions
typesWithoutMethods map[string]int // see AnalyseInterfaceConversions
methodSignatureNames map[string]int
methodSignatureNames map[string]int // see MethodNum
interfaces map[string]*Interface // see AnalyseInterfaceConversions
fpWithContext map[string]struct{} // see AnalyseFunctionPointers
}
@ -60,12 +61,19 @@ type NamedType struct {
}
// Type that is at some point put in an interface.
type InterfaceType struct {
type TypeWithMethods struct {
t types.Type
Num int
Methods map[string]*types.Selection
}
// Interface type that is at some point used in a type assert (to check whether
// it implements another interface).
type Interface struct {
Num int
Type *types.Interface
}
// Create and intialize a new *Program from a *ssa.Program.
func NewProgram(program *ssa.Program, mainPath string) *Program {
return &Program{
@ -74,6 +82,7 @@ func NewProgram(program *ssa.Program, mainPath string) *Program {
functionMap: make(map[*ssa.Function]*Function),
globalMap: make(map[*ssa.Global]*Global),
methodSignatureNames: make(map[string]int),
interfaces: make(map[string]*Interface),
}
}
@ -148,6 +157,18 @@ func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global {
return p.globalMap[ssaGlobal]
}
// SortMethods sorts the list of methods by method ID.
func (p *Program) SortMethods(methods []*types.Selection) {
m := &methodList{methods: methods, program: p}
sort.Sort(m)
}
// SortFuncs sorts the list of functions by method ID.
func (p *Program) SortFuncs(funcs []*types.Func) {
m := &funcList{funcs: funcs, program: p}
sort.Sort(m)
}
// Parse compiler directives in the preceding comments.
func (f *Function) parsePragmas() {
if f.fn.Syntax() == nil {
@ -236,3 +257,43 @@ func (g *Global) LinkName() string {
func (g *Global) IsExtern() bool {
return strings.HasPrefix(g.g.Name(), "_extern_")
}
// Wrapper type to implement sort.Interface for []*types.Selection.
type methodList struct {
methods []*types.Selection
program *Program
}
func (m *methodList) Len() int {
return len(m.methods)
}
func (m *methodList) Less(i, j int) bool {
iid := m.program.MethodNum(m.methods[i].Obj().(*types.Func))
jid := m.program.MethodNum(m.methods[j].Obj().(*types.Func))
return iid < jid
}
func (m *methodList) Swap(i, j int) {
m.methods[i], m.methods[j] = m.methods[j], m.methods[i]
}
// Wrapper type to implement sort.Interface for []*types.Func.
type funcList struct {
funcs []*types.Func
program *Program
}
func (fl *funcList) Len() int {
return len(fl.funcs)
}
func (fl *funcList) Less(i, j int) bool {
iid := fl.program.MethodNum(fl.funcs[i])
jid := fl.program.MethodNum(fl.funcs[j])
return iid < jid
}
func (fl *funcList) Swap(i, j int) {
fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i]
}

Просмотреть файл

@ -2,6 +2,9 @@ package main
import (
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/ssa"
)
@ -56,6 +59,18 @@ func Signature(sig *types.Signature) string {
return s
}
// Convert an interface type to a string of all method strings, separated by
// "; ". For example: "Read([]byte) (int, error); Close() error"
func InterfaceKey(itf *types.Interface) string {
methodStrings := []string{}
for i := 0; i < itf.NumMethods(); i++ {
method := itf.Method(i)
methodStrings = append(methodStrings, MethodSignature(method))
}
sort.Strings(methodStrings)
return strings.Join(methodStrings, ";")
}
// Fill in parents of all functions.
//
// All packages need to be added before this pass can run, or it will produce
@ -104,7 +119,7 @@ func (p *Program) AnalyseCallgraph() {
func (p *Program) AnalyseInterfaceConversions() {
// Clear, if AnalyseTypes has been called before.
p.typesWithoutMethods = map[string]int{"nil": 0}
p.typesWithMethods = map[string]*InterfaceType{}
p.typesWithMethods = map[string]*TypeWithMethods{}
for _, f := range p.Functions {
for _, block := range f.fn.Blocks {
@ -114,7 +129,7 @@ func (p *Program) AnalyseInterfaceConversions() {
methods := getAllMethods(f.fn.Prog, instr.X.Type())
name := instr.X.Type().String()
if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 {
t := &InterfaceType{
t := &TypeWithMethods{
t: instr.X.Type(),
Num: len(p.typesWithMethods),
Methods: make(map[string]*types.Selection),
@ -271,7 +286,13 @@ func (p *Program) SimpleDCE() {
for _, instr := range block.Instrs {
if instr, ok := instr.(*ssa.MakeInterface); ok {
for _, sel := range getAllMethods(p.program, instr.X.Type()) {
callee := p.GetFunction(p.program.MethodValue(sel))
fn := p.program.MethodValue(sel)
callee := p.GetFunction(fn)
if callee == nil {
// TODO: why is this necessary?
p.addFunction(fn)
callee = p.GetFunction(fn)
}
if !callee.flag {
callee.flag = true
worklist = append(worklist, callee.fn)
@ -361,6 +382,19 @@ func (p *Program) TypeNum(typ types.Type) (int, bool) {
}
}
// InterfaceNum returns the numeric interface ID of this type, for use in type
// asserts.
func (p *Program) InterfaceNum(itfType *types.Interface) int {
key := InterfaceKey(itfType)
if itf, ok := p.interfaces[key]; !ok {
num := len(p.interfaces)
p.interfaces[key] = &Interface{Num: num, Type: itfType}
return num
} else {
return itf.Num
}
}
// MethodNum returns the numeric ID of this method, to be used in method lookups
// on interfaces for example.
func (p *Program) MethodNum(method *types.Func) int {
@ -381,14 +415,23 @@ func (p *Program) FirstDynamicType() int {
}
// Return all types with methods, sorted by type ID.
func (p *Program) AllDynamicTypes() []*InterfaceType {
l := make([]*InterfaceType, len(p.typesWithMethods))
func (p *Program) AllDynamicTypes() []*TypeWithMethods {
l := make([]*TypeWithMethods, len(p.typesWithMethods))
for _, m := range p.typesWithMethods {
l[m.Num] = m
}
return l
}
// Return all interface types, sorted by interface ID.
func (p *Program) AllInterfaces() []*Interface {
l := make([]*Interface, len(p.interfaces))
for _, itf := range p.interfaces {
l[itf.Num] = itf
}
return l
}
func (p *Program) FunctionNeedsContext(f *Function) bool {
if !f.addressTaken {
return false

Просмотреть файл

@ -16,6 +16,18 @@ type Stringer interface {
String() string
}
type Foo int
type Number int
func (n Number) Double() int {
return int(n) * 2
}
type Doubler interface {
Double() int
}
const SIX = 6
var testmap = map[string]int{"data": 3}
@ -54,6 +66,7 @@ func main() {
printItf(*thing)
printItf(thing)
printItf(Stringer(thing))
printItf(Number(3))
s := Stringer(thing)
println("Stringer.String():", s.String())
@ -107,6 +120,8 @@ func strlen(s string) int {
func printItf(val interface{}) {
switch val := val.(type) {
case Doubler:
println("is Doubler:", val.Double())
case int:
println("is int:", val)
case byte:
@ -117,6 +132,8 @@ func printItf(val interface{}) {
println("is Thing:", val.String())
case *Thing:
println("is *Thing:", val.String())
case Foo:
println("is Foo:", val)
default:
println("is ?")
}

Просмотреть файл

@ -10,10 +10,10 @@ package runtime
// signatures as interned strings at compile time.
//
// The typecode is a small number unique for the Go type. All typecodes <
// firstInterfaceNum do not have any methods and typecodes >= firstInterfaceNum
// all have at least one method. This means that methodSetRanges does not need
// to contain types without methods and is thus indexed starting at a typecode
// with number firstInterfaceNum.
// firstTypeWithMethods do not have any methods and typecodes >=
// firstTypeWithMethods all have at least one method. This means that
// methodSetRanges does not need to contain types without methods and is thus
// indexed starting at a typecode with number firstTypeWithMethods.
//
// To further conserve some space, the methodSetRange (as the name indicates)
// doesn't contain a list of methods and function pointers directly, but instead
@ -36,10 +36,13 @@ type methodSetRange struct {
// which is a dummy value, but will be bigger after the compiler has filled them
// in.
var (
firstInterfaceNum uint16 // the lowest typecode that has at least one method
firstTypeWithMethods uint16 // the lowest typecode that has at least one method
methodSetRanges [0]methodSetRange // indexes into methodSetSignatures and methodSetFunctions
methodSetSignatures [0]uint16 // uniqued method ID
methodSetFunctions [0]*uint8 // function pointer of method
interfaceIndex [0]uint16 // mapping from interface ID to an index in interfaceMethods
interfaceLengths [0]uint8 // mapping from interface ID to the number of methods it has
interfaceMethods [0]uint16 // the method an interface implements (list of method IDs)
)
// Get the function pointer for the method on the interface.
@ -50,7 +53,7 @@ func interfaceMethod(itf _interface, method uint16) *uint8 {
// in the list of signatures. The compiler will only emit
// runtime.interfaceMethod calls when the method actually exists on this
// interface (proven by the typechecker).
i := methodSetRanges[itf.typecode-firstInterfaceNum].index
i := methodSetRanges[itf.typecode-firstTypeWithMethods].index
for {
if methodSetSignatures[i] == method {
return methodSetFunctions[i]
@ -72,3 +75,49 @@ func interfaceEqual(x, y _interface) bool {
// TODO: depends on reflection.
panic("unimplemented: interface equality")
}
// Return true iff the type implements all methods needed by the interface. This
// means the type satisfies the interface.
// This is a compiler intrinsic.
//go:nobounds
func interfaceImplements(typecode, interfaceNum uint16) bool {
// method set indexes of the concrete type
methodSet := methodSetRanges[typecode-firstTypeWithMethods]
methodIndex := methodSet.index
methodIndexEnd := methodSet.index + methodSet.length
// method set indexes of the interface
itfIndex := interfaceIndex[interfaceNum]
itfIndexEnd := itfIndex + uint16(interfaceLengths[interfaceNum])
// Iterate over all methods of the interface:
for itfIndex < itfIndexEnd {
methodId := interfaceMethods[itfIndex]
if methodIndex >= methodIndexEnd {
// Reached the end of the list of methods, so interface doesn't
// implement this type.
return false
}
if methodId == methodSetSignatures[methodIndex] {
// Found a matching method, continue to the next method.
itfIndex++
methodIndex++
continue
} else if methodId > methodSetSignatures[methodIndex] {
// The method didn't match, but method ID of the concrete type was
// lower than that of the interface, so probably it has a method the
// interface doesn't implement.
// Move on to the next method of the concrete type.
methodIndex++
continue
} else {
// The concrete type is missing a method. This means the type assert
// fails.
return false
}
}
// Found a method for each expected method in the interface. This type
// assert is successful.
return true
}