From a97ca91c1fa5fda6d70376538ca3d6a4de9cf07b Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Sun, 10 Jun 2018 00:36:39 +0200 Subject: [PATCH] compiler: Implement interface calls This is a big combined change. Other changes in this commit: * Analyze makeinterface and make sure type switches don't include unnecessary cases. * Do not include CGo wrapper functions in the analyzer callgraph. This also avoids some unnecessary type IDs. * Give all Go named structs a name in LLVM. * Use such a named struct for compiler-generated task data. * Use the type and function names defined by the ssa and types package instead of generating our own. * Some improvements to function pointers. * A few other minor improvements. The one thing lacking here is interface-to-interface assertions. --- README.markdown | 3 + analysis.go | 155 ++++++++++++++-- src/examples/hello/hello.go | 25 ++- src/runtime/runtime.ll | 66 +++++-- tgo.go | 348 +++++++++++++++++++++++++++--------- 5 files changed, 478 insertions(+), 119 deletions(-) diff --git a/README.markdown b/README.markdown index e67195b0..22a4a069 100644 --- a/README.markdown +++ b/README.markdown @@ -72,3 +72,6 @@ Implemented analysis passes: sleep, chan send, etc. It's parents are also blocking. * Check whether the scheduler is needed. It is only needed when there are `go` statements for blocking functions. + * Check whether a given type switch or type assert is possible with + [type-based alias analysis](https://en.wikipedia.org/wiki/Alias_analysis#Type-based_alias_analysis). + I would like to use flow-based alias analysis in the future. diff --git a/analysis.go b/analysis.go index 6a5dd8a8..6ad79dbb 100644 --- a/analysis.go +++ b/analysis.go @@ -9,9 +9,12 @@ import ( // Analysis results over a whole program. type Analysis struct { - functions map[*ssa.Function]*FuncMeta - needsScheduler bool - goCalls []*ssa.Go + functions map[*ssa.Function]*FuncMeta + needsScheduler bool + goCalls []*ssa.Go + typesWithMethods map[string]*TypeMeta + typesWithoutMethods map[string]int + methodSignatureNames map[string]int } // Some analysis results of a single function. @@ -22,10 +25,19 @@ type FuncMeta struct { children []*ssa.Function } +type TypeMeta struct { + t types.Type + Num int + Methods map[string]*types.Selection +} + // Return a new Analysis object. func NewAnalysis() *Analysis { return &Analysis{ - functions: make(map[*ssa.Function]*FuncMeta), + functions: make(map[*ssa.Function]*FuncMeta), + typesWithMethods: make(map[string]*TypeMeta), + typesWithoutMethods: make(map[string]int), + methodSignatureNames: make(map[string]int), } } @@ -34,12 +46,19 @@ func (a *Analysis) AddPackage(pkg *ssa.Package) { for _, member := range pkg.Members { switch member := member.(type) { case *ssa.Function: + if isCGoInternal(member.Name()) || getCName(member.Name()) != "" { + continue + } a.addFunction(member) case *ssa.Type: - ms := pkg.Prog.MethodSets.MethodSet(member.Type()) - if !types.IsInterface(member.Type()) { - for i := 0; i < ms.Len(); i++ { - a.addFunction(pkg.Prog.MethodValue(ms.At(i))) + methods := getAllMethods(pkg.Prog, member.Type()) + if types.IsInterface(member.Type()) { + for _, method := range methods { + a.MethodName(method.Obj().(*types.Func)) + } + } else { // named type + for _, method := range methods { + a.addFunction(pkg.Prog.MethodValue(method)) } } } @@ -54,13 +73,39 @@ func (a *Analysis) addFunction(f *ssa.Function) { for _, instr := range block.Instrs { switch instr := instr.(type) { case *ssa.Call: - switch call := instr.Call.Value.(type) { - case *ssa.Function: - name := getFunctionName(call, false) - if name == "runtime.Sleep" { - fm.blocking = true + if instr.Common().IsInvoke() { + name := a.MethodName(instr.Common().Method) + a.methodSignatureNames[name] = len(a.methodSignatureNames) + } else { + switch call := instr.Call.Value.(type) { + case *ssa.Builtin: + // ignore + case *ssa.Function: + if isCGoInternal(call.Name()) || getCName(call.Name()) != "" { + continue + } + name := getFunctionName(call, false) + if name == "runtime.Sleep" { + fm.blocking = true + } + fm.children = append(fm.children, call) } - fm.children = append(fm.children, call) + } + case *ssa.MakeInterface: + methods := getAllMethods(f.Prog, instr.X.Type()) + if _, ok := a.typesWithMethods[instr.X.Type().String()]; !ok && len(methods) > 0 { + meta := &TypeMeta{ + t: instr.X.Type(), + Num: len(a.typesWithMethods), + Methods: make(map[string]*types.Selection), + } + for _, sel := range methods { + name := a.MethodName(sel.Obj().(*types.Func)) + meta.Methods[name] = sel + } + a.typesWithMethods[instr.X.Type().String()] = meta + } else if _, ok := a.typesWithoutMethods[instr.X.Type().String()]; !ok && len(methods) == 0 { + a.typesWithoutMethods[instr.X.Type().String()] = len(a.typesWithoutMethods) } case *ssa.Go: a.goCalls = append(a.goCalls, instr) @@ -74,6 +119,44 @@ func (a *Analysis) addFunction(f *ssa.Function) { } } +// Make a readable version of the method signature (including the function name, +// excluding the receiver name). This string is used internally to match +// interfaces and to call the correct method on an interface. Examples: +// +// String() string +// Read([]byte) (int, error) +func (a *Analysis) MethodName(method *types.Func) string { + sig := method.Type().(*types.Signature) + name := method.Name() + if sig.Params().Len() == 0 { + name += "()" + } else { + name += "(" + for i := 0; i < sig.Params().Len(); i++ { + if i > 0 { + name += ", " + } + name += sig.Params().At(i).Type().String() + } + name += ")" + } + if sig.Results().Len() == 0 { + // keep as-is + } else if sig.Results().Len() == 1 { + name += " " + sig.Results().At(0).Type().String() + } else { + name += " (" + for i := 0; i < sig.Results().Len(); i++ { + if i > 0 { + name += ", " + } + name += sig.Results().At(i).Type().String() + } + name += ")" + } + return name +} + // Fill in parents of all functions. // // All packages need to be added before this pass can run, or it will produce @@ -83,7 +166,7 @@ func (a *Analysis) AnalyseCallgraph() { for _, child := range fm.children { childRes, ok := a.functions[child] if !ok { - print("child not found: " + child.Pkg.Pkg.Path() + "." + child.Name() + ", function: " + f.Name()) + println("child not found: " + child.Pkg.Pkg.Path() + "." + child.Name() + ", function: " + f.Name()) continue } childRes.parents = append(childRes.parents, f) @@ -163,3 +246,45 @@ func (a *Analysis) isBlocking(f ssa.Value) bool { panic("Analysis.IsBlocking on unknown type") } } + +// Return the type number and whether this type is actually used. Used in +// interface conversions (type is always used) and type asserts (type may not be +// used, meaning assert is always false in this program). +// +// May only be used after all packages have been added to the analyser. +func (a *Analysis) TypeNum(typ types.Type) (int, bool) { + if n, ok := a.typesWithoutMethods[typ.String()]; ok { + return n, true + } else if meta, ok := a.typesWithMethods[typ.String()]; ok { + return len(a.typesWithoutMethods) + meta.Num, true + } else { + return -1, false // type is never put in an interface + } +} + +// MethodNum returns the numeric ID of this method, to be used in method lookups +// on interfaces for example. +func (a *Analysis) MethodNum(method *types.Func) int { + if n, ok := a.methodSignatureNames[a.MethodName(method)]; ok { + return n + } + return -1 // signal error +} + +// The start index of the first dynamic type that has methods. +// Types without methods always have a lower ID and types with methods have this +// or a higher ID. +// +// May only be used after all packages have been added to the analyser. +func (a *Analysis) FirstDynamicType() int { + return len(a.typesWithoutMethods) +} + +// Return all types with methods, sorted by type ID. +func (a *Analysis) AllDynamicTypes() []*TypeMeta { + l := make([]*TypeMeta, len(a.typesWithMethods)) + for _, m := range a.typesWithMethods { + l[m.Num] = m + } + return l +} diff --git a/src/examples/hello/hello.go b/src/examples/hello/hello.go index a9762d65..a860b4e6 100644 --- a/src/examples/hello/hello.go +++ b/src/examples/hello/hello.go @@ -9,6 +9,10 @@ func (t Thing) String() string { return t.name } +type Stringer interface { + String() string +} + const SIX = 6 func main() { @@ -20,21 +24,26 @@ func main() { println("sumrange(100) =", sumrange(100)) println("strlen foo:", strlen("foo")) - thing := Thing{"foo"} + thing := &Thing{"foo"} println("thing:", thing.String()) printItf(5) printItf(byte('x')) printItf("foo") + printItf(*thing) + printItf(thing) + printItf(Stringer(thing)) + s := Stringer(thing) + println("Stringer.String():", s.String()) - runFunc(hello) // must be indirect to avoid obvious inlining + runFunc(hello, 5) // must be indirect to avoid obvious inlining } -func runFunc(f func()) { - f() +func runFunc(f func(int), arg int) { + f(arg) } -func hello() { - println("hello from function pointer!") +func hello(n int) { + println("hello from function pointer:", n) } func strlen(s string) int { @@ -49,6 +58,10 @@ func printItf(val interface{}) { println("is byte:", val) case string: println("is string:", val) + case Thing: + println("is Thing:", val.String()) + case *Thing: + println("is *Thing:", val.String()) default: println("is ?") } diff --git a/src/runtime/runtime.ll b/src/runtime/runtime.ll index 677e5152..5b797aea 100644 --- a/src/runtime/runtime.ll +++ b/src/runtime/runtime.ll @@ -1,27 +1,69 @@ source_filename = "runtime/runtime.ll" +%interface = type { i32, i8* } + declare void @runtime.initAll() declare void @main.main() declare i8* @main.main$async(i8*) declare void @runtime.scheduler(i8*) ; Will be changed to true if there are 'go' statements in the compiled program. -@.has_scheduler = private unnamed_addr constant i1 false +@has_scheduler = private unnamed_addr constant i1 false + +; Will be changed by the compiler to the first type number with methods. +@first_interface_num = private unnamed_addr constant i32 0 + +; Will be filled by the compiler with runtime type information. +%interface_tuple = type { i32, i32 } ; { index, len } +@interface_tuples = external global [0 x %interface_tuple] +@interface_signatures = external global [0 x i32] ; array of method IDs +@interface_functions = external global [0 x i8*] ; array of function pointers define i32 @main() { - call void @runtime.initAll() - %has_scheduler = load i1, i1* @.has_scheduler - ; This branch will be optimized away. Only one of the targets will remain. - br i1 %has_scheduler, label %with_scheduler, label %without_scheduler + call void @runtime.initAll() + %has_scheduler = load i1, i1* @has_scheduler + ; This branch will be optimized away. Only one of the targets will remain. + br i1 %has_scheduler, label %with_scheduler, label %without_scheduler with_scheduler: - ; Initialize main and run the scheduler. - %main = call i8* @main.main$async(i8* null) - call void @runtime.scheduler(i8* %main) - ret i32 0 + ; Initialize main and run the scheduler. + %main = call i8* @main.main$async(i8* null) + call void @runtime.scheduler(i8* %main) + ret i32 0 without_scheduler: - ; No scheduler is necessary. Call main directly. - call void @main.main() - ret i32 0 + ; No scheduler is necessary. Call main directly. + call void @main.main() + ret i32 0 +} + +; Get the function pointer for the method on the interface. +; This function only reads constant global data and it's own arguments so it can +; be 'readnone' (a pure function). +define i8* @itfmethod(%interface %itf, i32 %method) noinline readnone { +entry: + ; Calculate the index in @interface_tuples + %concrete_type_num = extractvalue %interface %itf, 0 + %first_interface_num = load i32, i32* @first_interface_num + %index = sub i32 %concrete_type_num, %first_interface_num + + ; Calculate the index for @interface_signatures and @interface_functions + %itf_index_ptr = getelementptr inbounds [0 x %interface_tuple], [0 x %interface_tuple]* @interface_tuples, i32 0, i32 %index, i32 0 + %itf_index = load i32, i32* %itf_index_ptr + br label %find_method + + ; This is a while loop until the method has been found. + ; It must be in here, so avoid checking the length. +find_method: + %itf_index.phi = phi i32 [ %itf_index, %entry], [ %itf_index.phi.next, %find_method] + %m_ptr = getelementptr inbounds [0 x i32], [0 x i32]* @interface_signatures, i32 0, i32 %itf_index.phi + %m = load i32, i32* %m_ptr + %found = icmp eq i32 %m, %method + %itf_index.phi.next = add i32 %itf_index.phi, 1 + br i1 %found, label %found_method, label %find_method + +found_method: + %fp_ptr = getelementptr inbounds [0 x i8*], [0 x i8*]* @interface_functions, i32 0, i32 %itf_index.phi + %fp = load i8*, i8** %fp_ptr + ret i8* %fp } diff --git a/tgo.go b/tgo.go index 2c9a1f82..1eb50836 100644 --- a/tgo.go +++ b/tgo.go @@ -39,7 +39,6 @@ type Compiler struct { i8ptrType llvm.Type // for convenience uintptrType llvm.Type stringLenType llvm.Type - taskDataType llvm.Type allocFunc llvm.Value freeFunc llvm.Value coroIdFunc llvm.Value @@ -48,8 +47,8 @@ type Compiler struct { coroSuspendFunc llvm.Value coroEndFunc llvm.Value coroFreeFunc llvm.Value - itfTypeNumbers map[types.Type]uint64 - itfTypes []types.Type + program *ssa.Program + mainPkg *ssa.Package initFuncs []llvm.Value analysis *Analysis } @@ -62,19 +61,11 @@ type Frame struct { blocks map[*ssa.BasicBlock]llvm.BasicBlock phis []Phi blocking bool - taskState llvm.Value taskHandle llvm.Value cleanupBlock llvm.BasicBlock suspendBlock llvm.BasicBlock } -func pkgPrefix(pkg *ssa.Package) string { - if pkg.Pkg.Name() == "main" { - return "main" - } - return pkg.Pkg.Path() -} - type Phi struct { ssa *ssa.Phi llvm llvm.Value @@ -82,10 +73,9 @@ type Phi struct { func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { c := &Compiler{ - dumpSSA: dumpSSA, - triple: triple, - itfTypeNumbers: make(map[types.Type]uint64), - analysis: NewAnalysis(), + dumpSSA: dumpSSA, + triple: triple, + analysis: NewAnalysis(), } target, err := llvm.GetTargetFromTriple(triple) @@ -109,13 +99,6 @@ func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { t := c.ctx.StructCreateNamed("string") t.StructSetBody([]llvm.Type{c.stringLenType, c.i8ptrType}, false) - // Go interface: tuple of (type, ptr) - t = c.ctx.StructCreateNamed("interface") - t.StructSetBody([]llvm.Type{llvm.Int32Type(), c.i8ptrType}, false) - - // Goroutine / task data: {i8 state, i32 data, i8* next} - c.taskDataType = llvm.StructType([]llvm.Type{llvm.Int8Type(), llvm.Int32Type(), c.i8ptrType}, false) - allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false) c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType) @@ -178,8 +161,10 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } } - program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits) - program.Build() + c.program = ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits) + c.program.Build() + + c.mainPkg = c.program.ImportedPackage(mainPath) // Make a list of packages in import order. packageList := []*ssa.Package{} @@ -187,7 +172,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { worklist := []string{"runtime", mainPath} for len(worklist) != 0 { pkgPath := worklist[0] - pkg := program.ImportedPackage(pkgPath) + pkg := c.program.ImportedPackage(pkgPath) if pkg == nil { // Non-SSA package (e.g. cgo). packageSet[pkgPath] = struct{}{} @@ -231,7 +216,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { // Transform each package into LLVM IR. for _, pkg := range packageList { - err := c.parsePackage(program, pkg) + err := c.parsePackage(pkg) if err != nil { return err } @@ -252,23 +237,85 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } c.builder.CreateRetVoid() - // Set functions referenced in runtime.ll to internal linkage, to improve - // optimization (hopefully). + // Adjust main function. main := c.mod.NamedFunction("main.main") - if !main.IsDeclaration() { - main.SetLinkage(llvm.PrivateLinkage) + realMain := c.mod.NamedFunction(c.mainPkg.Pkg.Path() + ".main") + if !realMain.IsNil() { + main.ReplaceAllUsesWith(realMain) } mainAsync := c.mod.NamedFunction("main.main$async") - if !mainAsync.IsDeclaration() { - mainAsync.SetLinkage(llvm.PrivateLinkage) + realMainAsync := c.mod.NamedFunction(c.mainPkg.Pkg.Path() + ".main$async") + if !realMainAsync.IsNil() { + mainAsync.ReplaceAllUsesWith(realMainAsync) } + + // Set functions referenced in runtime.ll to internal linkage, to improve + // optimization (hopefully). c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.PrivateLinkage) + // Only use a scheduler when necessary. if c.analysis.NeedsScheduler() { // Enable the scheduler. - c.mod.NamedGlobal(".has_scheduler").SetInitializer(llvm.ConstInt(llvm.Int1Type(), 1, false)) + c.mod.NamedGlobal("has_scheduler").SetInitializer(llvm.ConstInt(llvm.Int1Type(), 1, false)) } + // Initialize runtime type information, for interfaces. + dynamicTypes := c.analysis.AllDynamicTypes() + numDynamicTypes := 0 + for _, meta := range dynamicTypes { + numDynamicTypes += len(meta.Methods) + } + tuples := make([]llvm.Value, 0, len(dynamicTypes)) + funcPointers := make([]llvm.Value, 0, numDynamicTypes) + signatures := make([]llvm.Value, 0, numDynamicTypes) + startIndex := 0 + tupleType := c.mod.GetTypeByName("interface_tuple") + for _, meta := range dynamicTypes { + tupleValues := []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), uint64(startIndex), false), + llvm.ConstInt(llvm.Int32Type(), uint64(len(meta.Methods)), false), + } + tuple := llvm.ConstNamedStruct(tupleType, tupleValues) + tuples = append(tuples, tuple) + for _, method := range meta.Methods { + fnName := getFunctionName(c.program.MethodValue(method), false) + llvmFn := c.mod.NamedFunction(fnName) + if llvmFn.IsNil() { + return errors.New("cannot find function: " + fnName) + } + fn := llvm.ConstBitCast(llvmFn, c.i8ptrType) + funcPointers = append(funcPointers, fn) + signatureNum := c.analysis.MethodNum(method.Obj().(*types.Func)) + signature := llvm.ConstInt(llvm.Int32Type(), uint64(signatureNum), false) + signatures = append(signatures, signature) + } + startIndex += len(meta.Methods) + } + // Replace the pre-created arrays with the generated arrays. + tupleArray := llvm.ConstArray(tupleType, tuples) + tupleArrayNewGlobal := llvm.AddGlobal(c.mod, tupleArray.Type(), "interface_tuples.tmp") + tupleArrayNewGlobal.SetInitializer(tupleArray) + tupleArrayOldGlobal := c.mod.NamedGlobal("interface_tuples") + tupleArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(tupleArrayNewGlobal, tupleArrayOldGlobal.Type())) + tupleArrayOldGlobal.EraseFromParentAsGlobal() + tupleArrayNewGlobal.SetName("interface_tuples") + funcArray := llvm.ConstArray(c.i8ptrType, funcPointers) + funcArrayNewGlobal := llvm.AddGlobal(c.mod, funcArray.Type(), "interface_functions.tmp") + funcArrayNewGlobal.SetInitializer(funcArray) + funcArrayOldGlobal := c.mod.NamedGlobal("interface_functions") + funcArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(funcArrayNewGlobal, funcArrayOldGlobal.Type())) + funcArrayOldGlobal.EraseFromParentAsGlobal() + funcArrayNewGlobal.SetName("interface_functions") + signatureArray := llvm.ConstArray(llvm.Int32Type(), signatures) + signatureArrayNewGlobal := llvm.AddGlobal(c.mod, signatureArray.Type(), "interface_signatures.tmp") + signatureArrayNewGlobal.SetInitializer(signatureArray) + signatureArrayOldGlobal := c.mod.NamedGlobal("interface_signatures") + signatureArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(signatureArrayNewGlobal, signatureArrayOldGlobal.Type())) + signatureArrayOldGlobal.EraseFromParentAsGlobal() + signatureArrayNewGlobal.SetName("interface_signatures") + + c.mod.NamedGlobal("first_interface_num").SetInitializer(llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.FirstDynamicType()), false)) + return nil } @@ -306,6 +353,13 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { case *types.Interface: return c.mod.GetTypeByName("interface"), nil case *types.Named: + if _, ok := typ.Underlying().(*types.Struct); ok { + llvmType := c.mod.GetTypeByName(typ.Obj().Pkg().Path() + "." + typ.Obj().Name()) + if llvmType.IsNil() { + return llvm.Type{}, errors.New("type not found: " + typ.Obj().Pkg().Path() + "." + typ.Obj().Name()) + } + return llvmType, nil + } return c.getLLVMType(typ.Underlying()) case *types.Pointer: ptrTo, err := c.getLLVMType(typ.Elem()) @@ -329,6 +383,16 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } // param values var paramTypes []llvm.Type + if typ.Recv() != nil { + recv, err := c.getLLVMType(typ.Recv().Type()) + if err != nil { + return llvm.Type{}, err + } + if recv.StructName() == "interface" { + recv = c.i8ptrType + } + paramTypes = append(paramTypes, recv) + } params := typ.Params() for i := 0; i < params.Len(); i++ { subType, err := c.getLLVMType(params.At(i).Type()) @@ -354,13 +418,17 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } } -func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) { +// Return a zero LLVM value for any LLVM type. Setting this value as an +// initializer has the same effect as setting 'zeroinitializer' on a value. +// Sadly, I haven't found a way to do it directly with the Go API but this works +// just fine. +func getZeroValue(typ llvm.Type) (llvm.Value, error) { switch typ.TypeKind() { case llvm.ArrayTypeKind: subTyp := typ.ElementType() vals := make([]llvm.Value, typ.ArrayLength()) for i := range vals { - val, err := c.getZeroValue(subTyp) + val, err := getZeroValue(subTyp) if err != nil { return llvm.Value{}, err } @@ -375,7 +443,7 @@ func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) { types := typ.StructElementTypes() vals := make([]llvm.Value, len(types)) for i, subTyp := range types { - val, err := c.getZeroValue(subTyp) + val, err := getZeroValue(subTyp) if err != nil { return llvm.Value{}, err } @@ -391,15 +459,6 @@ func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) { } } -func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value { - if _, ok := c.itfTypeNumbers[typ]; !ok { - num := uint64(len(c.itfTypes)) - c.itfTypes = append(c.itfTypes, typ) - c.itfTypeNumbers[typ] = num - } - return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false) -} - // Is this a pointer type of some sort? Can be unsafe.Pointer or any *T pointer. func isPointer(typ types.Type) bool { if _, ok := typ.(*types.Pointer); ok { @@ -411,22 +470,40 @@ func isPointer(typ types.Type) bool { } } +// Get all methods of a type: both value receivers and pointer receivers. +func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection { + var methods []*types.Selection + + // value receivers + ms := prog.MethodSets.MethodSet(typ) + for i := 0; i < ms.Len(); i++ { + methods = append(methods, ms.At(i)) + } + + // pointer receivers + ms = prog.MethodSets.MethodSet(types.NewPointer(typ)) + for i := 0; i < ms.Len(); i++ { + methods = append(methods, ms.At(i)) + } + + return methods +} + func getFunctionName(fn *ssa.Function, blocking bool) string { suffix := "" if blocking { suffix = "$async" } if fn.Signature.Recv() != nil { - // Method on a defined type. - typeName := fn.Params[0].Type().(*types.Named).Obj().Name() - return pkgPrefix(fn.Pkg) + "." + typeName + "." + fn.Name() + suffix + // Method on a defined type (which may be a pointer). + return fn.RelString(nil) + suffix } else { // Bare function. - if strings.HasPrefix(fn.Name(), "_Cfunc_") { + if name := getCName(fn.Name()); name != "" { // Name CGo functions directly. - return fn.Name()[len("_Cfunc_"):] + return name } else { - name := pkgPrefix(fn.Pkg) + "." + fn.Name() + suffix + name := fn.RelString(nil) + suffix if fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(fn.Name(), "_llvm_") { // Special case for LLVM intrinsics in the runtime. name = "llvm." + strings.Replace(fn.Name()[len("_llvm_"):], "_", ".", -1) @@ -440,21 +517,38 @@ func getGlobalName(global *ssa.Global) string { if strings.HasPrefix(global.Name(), "_extern_") { return global.Name()[len("_extern_"):] } else { - return pkgPrefix(global.Pkg) + "." + global.Name() + return global.RelString(nil) } } -func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { +// Return true if this is a CGo-internal function that can be ignored. +func isCGoInternal(name string) bool { + if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { + // _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall + return true // CGo-internal functions + } + if strings.HasPrefix(name, "__cgofn__cgo_") { + return true // CGo function pointer in global scope + } + return false +} + +// Return the name of the C function if this is a CGo call. Otherwise, return a +// zero-length string. +func getCName(name string) string { + if strings.HasPrefix(name, "_Cfunc_") { + return name[len("_Cfunc_"):] + } + return "" +} + +func (c *Compiler) parsePackage(pkg *ssa.Package) error { // Make sure we're walking through all members in a constant order every - // run. + // run, and skip cgo wrapper functions/globals which we don't need. memberNames := make([]string, 0) for name := range pkg.Members { - if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { - // _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall - continue // CGo-internal functions - } - if strings.HasPrefix(name, "__cgofn__cgo_") { - continue // CGo function pointer in global scope + if isCGoInternal(name) { + continue } memberNames = append(memberNames, name) } @@ -462,7 +556,26 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { frames := make(map[*ssa.Function]*Frame) - // First, build all function declarations. + // First, declare all named (struct) types. + for _, name := range memberNames { + member := pkg.Members[name] + + switch member := member.(type) { + case *ssa.Type: + if named, ok := member.Type().(*types.Named); ok { + if st, ok := named.Underlying().(*types.Struct); ok { + llvmType, err := c.getLLVMType(st) + if err != nil { + return err + } + llvmNamedType := c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name()) + llvmNamedType.StructSetBody(llvmType.StructElementTypes(), false) + } + } + } + } + + // With the types defined, build all function declarations. for _, name := range memberNames { member := pkg.Members[name] @@ -514,7 +627,7 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false)) global.SetGlobalConstant(true) } else { - initializer, err := c.getZeroValue(llvmType) + initializer, err := getZeroValue(llvmType) if err != nil { return err } @@ -523,9 +636,8 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { } case *ssa.Type: if !types.IsInterface(member.Type()) { - ms := program.MethodSets.MethodSet(member.Type()) - for i := 0; i < ms.Len(); i++ { - fn := program.MethodValue(ms.At(i)) + for _, sel := range getAllMethods(c.program, member.Type()) { + fn := c.program.MethodValue(sel) frame, err := c.parseFuncDecl(fn) if err != nil { return err @@ -543,7 +655,7 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { member := pkg.Members[name] switch member := member.(type) { case *ssa.Function: - if strings.HasPrefix(name, "_Cfunc_") { + if getCName(name) != "" { // CGo function. Don't implement it's body. continue } @@ -561,9 +673,8 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { } case *ssa.Type: if !types.IsInterface(member.Type()) { - ms := program.MethodSets.MethodSet(member.Type()) - for i := 0; i < ms.Len(); i++ { - fn := program.MethodValue(ms.At(i)) + for _, sel := range getAllMethods(c.program, member.Type()) { + fn := c.program.MethodValue(sel) err := c.parseFunc(frames[fn], fn) if err != nil { return err @@ -671,7 +782,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { - llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) + llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) if err != nil { return err } @@ -693,7 +804,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { - llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) + llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) if err != nil { return err } @@ -741,8 +852,8 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { if frame.blocking { // Coroutine initialization. c.builder.SetInsertPointAtEnd(frame.blocks[f.Blocks[0]]) - frame.taskState = c.builder.CreateAlloca(c.taskDataType, "task.state") - stateI8 := c.builder.CreateBitCast(frame.taskState, c.i8ptrType, "task.state.i8") + taskState := c.builder.CreateAlloca(c.mod.GetTypeByName("runtime.taskState"), "task.state") + stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8") id := c.builder.CreateCall(c.coroIdFunc, []llvm.Value{ llvm.ConstInt(llvm.Int32Type(), 0, false), stateI8, @@ -978,12 +1089,15 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string) default: return llvm.Value{}, errors.New("todo: len: unknown type") } + case "ssa:wrapnilchk": + // TODO: do an actual nil check? + return c.parseExpr(frame, args[0]) default: return llvm.Value{}, errors.New("todo: builtin: " + callName) } } -func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) { +func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) { var params []llvm.Value if blocking { if parentHandle.IsNil() { @@ -994,7 +1108,7 @@ func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn params = append(params, parentHandle) } } - for _, param := range call.Args { + for _, param := range args { val, err := c.parseExpr(frame, param) if err != nil { return llvm.Value{}, err @@ -1048,6 +1162,36 @@ func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn } func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle llvm.Value) (llvm.Value, error) { + if instr.IsInvoke() { + // Call an interface method with dynamic dispatch. + itf, err := c.parseExpr(frame, instr.Value) // interface + if err != nil { + return llvm.Value{}, err + } + llvmFnType, err := c.getLLVMType(instr.Method.Type()) + if err != nil { + return llvm.Value{}, err + } + values := []llvm.Value{ + itf, + llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.MethodNum(instr.Method)), false), + } + fn := c.builder.CreateCall(c.mod.NamedFunction("itfmethod"), values, "invoke.func") + fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast") + receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver") + args := []llvm.Value{receiverValue} + for _, arg := range instr.Args { + val, err := c.parseExpr(frame, arg) + if err != nil { + return llvm.Value{}, err + } + args = append(args, val) + } + // TODO: blocking methods (needs analysis) + return c.builder.CreateCall(fnCast, args, ""), nil + } + + // Regular function, builtin, or function pointer. switch call := instr.Value.(type) { case *ssa.Builtin: return c.parseBuiltin(frame, instr.Args, call.Name()) @@ -1072,14 +1216,14 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l return llvm.Value{}, errors.New("undefined function: " + name) } } - return c.parseFunctionCall(frame, instr, llvmFn, targetBlocks, parentHandle) + return c.parseFunctionCall(frame, instr.Args, llvmFn, targetBlocks, parentHandle) default: // function pointer value, err := c.parseExpr(frame, instr.Value) if err != nil { return llvm.Value{}, err } // TODO: blocking function pointers (needs analysis) - return c.parseFunctionCall(frame, instr, value, false, parentHandle) + return c.parseFunctionCall(frame, instr.Args, value, false, parentHandle) } } @@ -1108,7 +1252,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { buf = c.builder.CreateBitCast(buf, llvm.PointerType(typ, 0), "") } else { buf = c.builder.CreateAlloca(typ, expr.Comment) - zero, err := c.getZeroValue(typ) + zero, err := getZeroValue(typ) if err != nil { return llvm.Value{}, err } @@ -1253,11 +1397,25 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { c.builder.CreateStore(val, itfValueCast) } else { // Directly place the value in the interface. - // TODO: non-integers - itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") + switch val.Type().TypeKind() { + case llvm.IntegerTypeKind: + itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") + case llvm.PointerTypeKind: + itfValue = c.builder.CreateBitCast(val, c.i8ptrType, "") + case llvm.StructTypeKind: + // A bitcast would be useful here, but bitcast doesn't allow + // aggregate types. So we'll bitcast it using an alloca. + // Hopefully this will get optimized away. + mem := c.builder.CreateAlloca(c.i8ptrType, "") + memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(val.Type(), 0), "") + c.builder.CreateStore(val, memStructPtr) + itfValue = c.builder.CreateLoad(mem, "") + default: + return llvm.Value{}, errors.New("todo: makeinterface: cast small type to i8*") + } } - itfTypeNum := c.getInterfaceType(expr.X.Type()) - itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{itfTypeNum, llvm.Undef(c.i8ptrType)}) + itfTypeNum, _ := c.analysis.TypeNum(expr.X.Type()) + itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{llvm.ConstInt(llvm.Int32Type(), uint64(itfTypeNum), false), llvm.Undef(c.i8ptrType)}) itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil case *ssa.Phi: @@ -1280,7 +1438,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } - assertedTypeNum := c.getInterfaceType(expr.AssertedType) + assertedTypeNum, typeExists := c.analysis.TypeNum(expr.AssertedType) + if !typeExists { + // Static analysis has determined this type assert will never apply. + return llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil + } actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value") var value llvm.Value @@ -1290,12 +1452,26 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { value = c.builder.CreateLoad(valuePtrCast, "") } else { // Value was stored directly in the interface. - // TODO: non-integer values. - value = c.builder.CreatePtrToInt(valuePtr, assertedType, "") + switch assertedType.TypeKind() { + case llvm.IntegerTypeKind: + value = c.builder.CreatePtrToInt(valuePtr, assertedType, "") + case llvm.PointerTypeKind: + value = c.builder.CreateBitCast(valuePtr, assertedType, "") + case llvm.StructTypeKind: + // A bitcast would be useful here, but bitcast doesn't allow + // aggregate types. So we'll bitcast it using an alloca. + // Hopefully this will get optimized away. + mem := c.builder.CreateAlloca(c.i8ptrType, "") + c.builder.CreateStore(valuePtr, mem) + memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "") + value = c.builder.CreateLoad(memStructPtr, "") + default: + return llvm.Value{}, errors.New("todo: typeassert: bitcast small types") + } } // TODO: for interfaces, check whether the type implements the // interface. - commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "") + commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int32Type(), uint64(assertedTypeNum), false), actualTypeNum, "") tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean