compiler: Implement interface calls

This is a big combined change. Other changes in this commit:

  * Analyze makeinterface and make sure type switches don't include
    unnecessary cases.
  * Do not include CGo wrapper functions in the analyzer callgraph.
    This also avoids some unnecessary type IDs.
  * Give all Go named structs a name in LLVM.
  * Use such a named struct for compiler-generated task data.
  * Use the type and function names defined by the ssa and types
    package instead of generating our own.
  * Some improvements to function pointers.
  * A few other minor improvements.

The one thing lacking here is interface-to-interface assertions.
Этот коммит содержится в:
Ayke van Laethem 2018-06-10 00:36:39 +02:00
родитель 62325eab40
коммит a97ca91c1f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: E97FF5335DFDFDED
5 изменённых файлов: 478 добавлений и 119 удалений

Просмотреть файл

@ -72,3 +72,6 @@ Implemented analysis passes:
sleep, chan send, etc. It's parents are also blocking. sleep, chan send, etc. It's parents are also blocking.
* Check whether the scheduler is needed. It is only needed when there are `go` * Check whether the scheduler is needed. It is only needed when there are `go`
statements for blocking functions. statements for blocking functions.
* Check whether a given type switch or type assert is possible with
[type-based alias analysis](https://en.wikipedia.org/wiki/Alias_analysis#Type-based_alias_analysis).
I would like to use flow-based alias analysis in the future.

Просмотреть файл

@ -9,9 +9,12 @@ import (
// Analysis results over a whole program. // Analysis results over a whole program.
type Analysis struct { type Analysis struct {
functions map[*ssa.Function]*FuncMeta functions map[*ssa.Function]*FuncMeta
needsScheduler bool needsScheduler bool
goCalls []*ssa.Go goCalls []*ssa.Go
typesWithMethods map[string]*TypeMeta
typesWithoutMethods map[string]int
methodSignatureNames map[string]int
} }
// Some analysis results of a single function. // Some analysis results of a single function.
@ -22,10 +25,19 @@ type FuncMeta struct {
children []*ssa.Function children []*ssa.Function
} }
type TypeMeta struct {
t types.Type
Num int
Methods map[string]*types.Selection
}
// Return a new Analysis object. // Return a new Analysis object.
func NewAnalysis() *Analysis { func NewAnalysis() *Analysis {
return &Analysis{ return &Analysis{
functions: make(map[*ssa.Function]*FuncMeta), functions: make(map[*ssa.Function]*FuncMeta),
typesWithMethods: make(map[string]*TypeMeta),
typesWithoutMethods: make(map[string]int),
methodSignatureNames: make(map[string]int),
} }
} }
@ -34,12 +46,19 @@ func (a *Analysis) AddPackage(pkg *ssa.Package) {
for _, member := range pkg.Members { for _, member := range pkg.Members {
switch member := member.(type) { switch member := member.(type) {
case *ssa.Function: case *ssa.Function:
if isCGoInternal(member.Name()) || getCName(member.Name()) != "" {
continue
}
a.addFunction(member) a.addFunction(member)
case *ssa.Type: case *ssa.Type:
ms := pkg.Prog.MethodSets.MethodSet(member.Type()) methods := getAllMethods(pkg.Prog, member.Type())
if !types.IsInterface(member.Type()) { if types.IsInterface(member.Type()) {
for i := 0; i < ms.Len(); i++ { for _, method := range methods {
a.addFunction(pkg.Prog.MethodValue(ms.At(i))) a.MethodName(method.Obj().(*types.Func))
}
} else { // named type
for _, method := range methods {
a.addFunction(pkg.Prog.MethodValue(method))
} }
} }
} }
@ -54,13 +73,39 @@ func (a *Analysis) addFunction(f *ssa.Function) {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
switch instr := instr.(type) { switch instr := instr.(type) {
case *ssa.Call: case *ssa.Call:
switch call := instr.Call.Value.(type) { if instr.Common().IsInvoke() {
case *ssa.Function: name := a.MethodName(instr.Common().Method)
name := getFunctionName(call, false) a.methodSignatureNames[name] = len(a.methodSignatureNames)
if name == "runtime.Sleep" { } else {
fm.blocking = true switch call := instr.Call.Value.(type) {
case *ssa.Builtin:
// ignore
case *ssa.Function:
if isCGoInternal(call.Name()) || getCName(call.Name()) != "" {
continue
}
name := getFunctionName(call, false)
if name == "runtime.Sleep" {
fm.blocking = true
}
fm.children = append(fm.children, call)
} }
fm.children = append(fm.children, call) }
case *ssa.MakeInterface:
methods := getAllMethods(f.Prog, instr.X.Type())
if _, ok := a.typesWithMethods[instr.X.Type().String()]; !ok && len(methods) > 0 {
meta := &TypeMeta{
t: instr.X.Type(),
Num: len(a.typesWithMethods),
Methods: make(map[string]*types.Selection),
}
for _, sel := range methods {
name := a.MethodName(sel.Obj().(*types.Func))
meta.Methods[name] = sel
}
a.typesWithMethods[instr.X.Type().String()] = meta
} else if _, ok := a.typesWithoutMethods[instr.X.Type().String()]; !ok && len(methods) == 0 {
a.typesWithoutMethods[instr.X.Type().String()] = len(a.typesWithoutMethods)
} }
case *ssa.Go: case *ssa.Go:
a.goCalls = append(a.goCalls, instr) a.goCalls = append(a.goCalls, instr)
@ -74,6 +119,44 @@ func (a *Analysis) addFunction(f *ssa.Function) {
} }
} }
// Make a readable version of the method signature (including the function name,
// excluding the receiver name). This string is used internally to match
// interfaces and to call the correct method on an interface. Examples:
//
// String() string
// Read([]byte) (int, error)
func (a *Analysis) MethodName(method *types.Func) string {
sig := method.Type().(*types.Signature)
name := method.Name()
if sig.Params().Len() == 0 {
name += "()"
} else {
name += "("
for i := 0; i < sig.Params().Len(); i++ {
if i > 0 {
name += ", "
}
name += sig.Params().At(i).Type().String()
}
name += ")"
}
if sig.Results().Len() == 0 {
// keep as-is
} else if sig.Results().Len() == 1 {
name += " " + sig.Results().At(0).Type().String()
} else {
name += " ("
for i := 0; i < sig.Results().Len(); i++ {
if i > 0 {
name += ", "
}
name += sig.Results().At(i).Type().String()
}
name += ")"
}
return name
}
// Fill in parents of all functions. // Fill in parents of all functions.
// //
// All packages need to be added before this pass can run, or it will produce // All packages need to be added before this pass can run, or it will produce
@ -83,7 +166,7 @@ func (a *Analysis) AnalyseCallgraph() {
for _, child := range fm.children { for _, child := range fm.children {
childRes, ok := a.functions[child] childRes, ok := a.functions[child]
if !ok { if !ok {
print("child not found: " + child.Pkg.Pkg.Path() + "." + child.Name() + ", function: " + f.Name()) println("child not found: " + child.Pkg.Pkg.Path() + "." + child.Name() + ", function: " + f.Name())
continue continue
} }
childRes.parents = append(childRes.parents, f) childRes.parents = append(childRes.parents, f)
@ -163,3 +246,45 @@ func (a *Analysis) isBlocking(f ssa.Value) bool {
panic("Analysis.IsBlocking on unknown type") panic("Analysis.IsBlocking on unknown type")
} }
} }
// Return the type number and whether this type is actually used. Used in
// interface conversions (type is always used) and type asserts (type may not be
// used, meaning assert is always false in this program).
//
// May only be used after all packages have been added to the analyser.
func (a *Analysis) TypeNum(typ types.Type) (int, bool) {
if n, ok := a.typesWithoutMethods[typ.String()]; ok {
return n, true
} else if meta, ok := a.typesWithMethods[typ.String()]; ok {
return len(a.typesWithoutMethods) + meta.Num, true
} else {
return -1, false // type is never put in an interface
}
}
// MethodNum returns the numeric ID of this method, to be used in method lookups
// on interfaces for example.
func (a *Analysis) MethodNum(method *types.Func) int {
if n, ok := a.methodSignatureNames[a.MethodName(method)]; ok {
return n
}
return -1 // signal error
}
// The start index of the first dynamic type that has methods.
// Types without methods always have a lower ID and types with methods have this
// or a higher ID.
//
// May only be used after all packages have been added to the analyser.
func (a *Analysis) FirstDynamicType() int {
return len(a.typesWithoutMethods)
}
// Return all types with methods, sorted by type ID.
func (a *Analysis) AllDynamicTypes() []*TypeMeta {
l := make([]*TypeMeta, len(a.typesWithMethods))
for _, m := range a.typesWithMethods {
l[m.Num] = m
}
return l
}

Просмотреть файл

@ -9,6 +9,10 @@ func (t Thing) String() string {
return t.name return t.name
} }
type Stringer interface {
String() string
}
const SIX = 6 const SIX = 6
func main() { func main() {
@ -20,21 +24,26 @@ func main() {
println("sumrange(100) =", sumrange(100)) println("sumrange(100) =", sumrange(100))
println("strlen foo:", strlen("foo")) println("strlen foo:", strlen("foo"))
thing := Thing{"foo"} thing := &Thing{"foo"}
println("thing:", thing.String()) println("thing:", thing.String())
printItf(5) printItf(5)
printItf(byte('x')) printItf(byte('x'))
printItf("foo") printItf("foo")
printItf(*thing)
printItf(thing)
printItf(Stringer(thing))
s := Stringer(thing)
println("Stringer.String():", s.String())
runFunc(hello) // must be indirect to avoid obvious inlining runFunc(hello, 5) // must be indirect to avoid obvious inlining
} }
func runFunc(f func()) { func runFunc(f func(int), arg int) {
f() f(arg)
} }
func hello() { func hello(n int) {
println("hello from function pointer!") println("hello from function pointer:", n)
} }
func strlen(s string) int { func strlen(s string) int {
@ -49,6 +58,10 @@ func printItf(val interface{}) {
println("is byte:", val) println("is byte:", val)
case string: case string:
println("is string:", val) println("is string:", val)
case Thing:
println("is Thing:", val.String())
case *Thing:
println("is *Thing:", val.String())
default: default:
println("is ?") println("is ?")
} }

Просмотреть файл

@ -1,27 +1,69 @@
source_filename = "runtime/runtime.ll" source_filename = "runtime/runtime.ll"
%interface = type { i32, i8* }
declare void @runtime.initAll() declare void @runtime.initAll()
declare void @main.main() declare void @main.main()
declare i8* @main.main$async(i8*) declare i8* @main.main$async(i8*)
declare void @runtime.scheduler(i8*) declare void @runtime.scheduler(i8*)
; Will be changed to true if there are 'go' statements in the compiled program. ; Will be changed to true if there are 'go' statements in the compiled program.
@.has_scheduler = private unnamed_addr constant i1 false @has_scheduler = private unnamed_addr constant i1 false
; Will be changed by the compiler to the first type number with methods.
@first_interface_num = private unnamed_addr constant i32 0
; Will be filled by the compiler with runtime type information.
%interface_tuple = type { i32, i32 } ; { index, len }
@interface_tuples = external global [0 x %interface_tuple]
@interface_signatures = external global [0 x i32] ; array of method IDs
@interface_functions = external global [0 x i8*] ; array of function pointers
define i32 @main() { define i32 @main() {
call void @runtime.initAll() call void @runtime.initAll()
%has_scheduler = load i1, i1* @.has_scheduler %has_scheduler = load i1, i1* @has_scheduler
; This branch will be optimized away. Only one of the targets will remain. ; This branch will be optimized away. Only one of the targets will remain.
br i1 %has_scheduler, label %with_scheduler, label %without_scheduler br i1 %has_scheduler, label %with_scheduler, label %without_scheduler
with_scheduler: with_scheduler:
; Initialize main and run the scheduler. ; Initialize main and run the scheduler.
%main = call i8* @main.main$async(i8* null) %main = call i8* @main.main$async(i8* null)
call void @runtime.scheduler(i8* %main) call void @runtime.scheduler(i8* %main)
ret i32 0 ret i32 0
without_scheduler: without_scheduler:
; No scheduler is necessary. Call main directly. ; No scheduler is necessary. Call main directly.
call void @main.main() call void @main.main()
ret i32 0 ret i32 0
}
; Get the function pointer for the method on the interface.
; This function only reads constant global data and it's own arguments so it can
; be 'readnone' (a pure function).
define i8* @itfmethod(%interface %itf, i32 %method) noinline readnone {
entry:
; Calculate the index in @interface_tuples
%concrete_type_num = extractvalue %interface %itf, 0
%first_interface_num = load i32, i32* @first_interface_num
%index = sub i32 %concrete_type_num, %first_interface_num
; Calculate the index for @interface_signatures and @interface_functions
%itf_index_ptr = getelementptr inbounds [0 x %interface_tuple], [0 x %interface_tuple]* @interface_tuples, i32 0, i32 %index, i32 0
%itf_index = load i32, i32* %itf_index_ptr
br label %find_method
; This is a while loop until the method has been found.
; It must be in here, so avoid checking the length.
find_method:
%itf_index.phi = phi i32 [ %itf_index, %entry], [ %itf_index.phi.next, %find_method]
%m_ptr = getelementptr inbounds [0 x i32], [0 x i32]* @interface_signatures, i32 0, i32 %itf_index.phi
%m = load i32, i32* %m_ptr
%found = icmp eq i32 %m, %method
%itf_index.phi.next = add i32 %itf_index.phi, 1
br i1 %found, label %found_method, label %find_method
found_method:
%fp_ptr = getelementptr inbounds [0 x i8*], [0 x i8*]* @interface_functions, i32 0, i32 %itf_index.phi
%fp = load i8*, i8** %fp_ptr
ret i8* %fp
} }

348
tgo.go
Просмотреть файл

@ -39,7 +39,6 @@ type Compiler struct {
i8ptrType llvm.Type // for convenience i8ptrType llvm.Type // for convenience
uintptrType llvm.Type uintptrType llvm.Type
stringLenType llvm.Type stringLenType llvm.Type
taskDataType llvm.Type
allocFunc llvm.Value allocFunc llvm.Value
freeFunc llvm.Value freeFunc llvm.Value
coroIdFunc llvm.Value coroIdFunc llvm.Value
@ -48,8 +47,8 @@ type Compiler struct {
coroSuspendFunc llvm.Value coroSuspendFunc llvm.Value
coroEndFunc llvm.Value coroEndFunc llvm.Value
coroFreeFunc llvm.Value coroFreeFunc llvm.Value
itfTypeNumbers map[types.Type]uint64 program *ssa.Program
itfTypes []types.Type mainPkg *ssa.Package
initFuncs []llvm.Value initFuncs []llvm.Value
analysis *Analysis analysis *Analysis
} }
@ -62,19 +61,11 @@ type Frame struct {
blocks map[*ssa.BasicBlock]llvm.BasicBlock blocks map[*ssa.BasicBlock]llvm.BasicBlock
phis []Phi phis []Phi
blocking bool blocking bool
taskState llvm.Value
taskHandle llvm.Value taskHandle llvm.Value
cleanupBlock llvm.BasicBlock cleanupBlock llvm.BasicBlock
suspendBlock llvm.BasicBlock suspendBlock llvm.BasicBlock
} }
func pkgPrefix(pkg *ssa.Package) string {
if pkg.Pkg.Name() == "main" {
return "main"
}
return pkg.Pkg.Path()
}
type Phi struct { type Phi struct {
ssa *ssa.Phi ssa *ssa.Phi
llvm llvm.Value llvm llvm.Value
@ -82,10 +73,9 @@ type Phi struct {
func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) {
c := &Compiler{ c := &Compiler{
dumpSSA: dumpSSA, dumpSSA: dumpSSA,
triple: triple, triple: triple,
itfTypeNumbers: make(map[types.Type]uint64), analysis: NewAnalysis(),
analysis: NewAnalysis(),
} }
target, err := llvm.GetTargetFromTriple(triple) target, err := llvm.GetTargetFromTriple(triple)
@ -109,13 +99,6 @@ func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) {
t := c.ctx.StructCreateNamed("string") t := c.ctx.StructCreateNamed("string")
t.StructSetBody([]llvm.Type{c.stringLenType, c.i8ptrType}, false) t.StructSetBody([]llvm.Type{c.stringLenType, c.i8ptrType}, false)
// Go interface: tuple of (type, ptr)
t = c.ctx.StructCreateNamed("interface")
t.StructSetBody([]llvm.Type{llvm.Int32Type(), c.i8ptrType}, false)
// Goroutine / task data: {i8 state, i32 data, i8* next}
c.taskDataType = llvm.StructType([]llvm.Type{llvm.Int8Type(), llvm.Int32Type(), c.i8ptrType}, false)
allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false) allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false)
c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType) c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType)
@ -178,8 +161,10 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
} }
} }
program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits) c.program = ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits)
program.Build() c.program.Build()
c.mainPkg = c.program.ImportedPackage(mainPath)
// Make a list of packages in import order. // Make a list of packages in import order.
packageList := []*ssa.Package{} packageList := []*ssa.Package{}
@ -187,7 +172,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
worklist := []string{"runtime", mainPath} worklist := []string{"runtime", mainPath}
for len(worklist) != 0 { for len(worklist) != 0 {
pkgPath := worklist[0] pkgPath := worklist[0]
pkg := program.ImportedPackage(pkgPath) pkg := c.program.ImportedPackage(pkgPath)
if pkg == nil { if pkg == nil {
// Non-SSA package (e.g. cgo). // Non-SSA package (e.g. cgo).
packageSet[pkgPath] = struct{}{} packageSet[pkgPath] = struct{}{}
@ -231,7 +216,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Transform each package into LLVM IR. // Transform each package into LLVM IR.
for _, pkg := range packageList { for _, pkg := range packageList {
err := c.parsePackage(program, pkg) err := c.parsePackage(pkg)
if err != nil { if err != nil {
return err return err
} }
@ -252,23 +237,85 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
} }
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
// Set functions referenced in runtime.ll to internal linkage, to improve // Adjust main function.
// optimization (hopefully).
main := c.mod.NamedFunction("main.main") main := c.mod.NamedFunction("main.main")
if !main.IsDeclaration() { realMain := c.mod.NamedFunction(c.mainPkg.Pkg.Path() + ".main")
main.SetLinkage(llvm.PrivateLinkage) if !realMain.IsNil() {
main.ReplaceAllUsesWith(realMain)
} }
mainAsync := c.mod.NamedFunction("main.main$async") mainAsync := c.mod.NamedFunction("main.main$async")
if !mainAsync.IsDeclaration() { realMainAsync := c.mod.NamedFunction(c.mainPkg.Pkg.Path() + ".main$async")
mainAsync.SetLinkage(llvm.PrivateLinkage) if !realMainAsync.IsNil() {
mainAsync.ReplaceAllUsesWith(realMainAsync)
} }
// Set functions referenced in runtime.ll to internal linkage, to improve
// optimization (hopefully).
c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.PrivateLinkage) c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.PrivateLinkage)
// Only use a scheduler when necessary.
if c.analysis.NeedsScheduler() { if c.analysis.NeedsScheduler() {
// Enable the scheduler. // Enable the scheduler.
c.mod.NamedGlobal(".has_scheduler").SetInitializer(llvm.ConstInt(llvm.Int1Type(), 1, false)) c.mod.NamedGlobal("has_scheduler").SetInitializer(llvm.ConstInt(llvm.Int1Type(), 1, false))
} }
// Initialize runtime type information, for interfaces.
dynamicTypes := c.analysis.AllDynamicTypes()
numDynamicTypes := 0
for _, meta := range dynamicTypes {
numDynamicTypes += len(meta.Methods)
}
tuples := make([]llvm.Value, 0, len(dynamicTypes))
funcPointers := make([]llvm.Value, 0, numDynamicTypes)
signatures := make([]llvm.Value, 0, numDynamicTypes)
startIndex := 0
tupleType := c.mod.GetTypeByName("interface_tuple")
for _, meta := range dynamicTypes {
tupleValues := []llvm.Value{
llvm.ConstInt(llvm.Int32Type(), uint64(startIndex), false),
llvm.ConstInt(llvm.Int32Type(), uint64(len(meta.Methods)), false),
}
tuple := llvm.ConstNamedStruct(tupleType, tupleValues)
tuples = append(tuples, tuple)
for _, method := range meta.Methods {
fnName := getFunctionName(c.program.MethodValue(method), false)
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
return errors.New("cannot find function: " + fnName)
}
fn := llvm.ConstBitCast(llvmFn, c.i8ptrType)
funcPointers = append(funcPointers, fn)
signatureNum := c.analysis.MethodNum(method.Obj().(*types.Func))
signature := llvm.ConstInt(llvm.Int32Type(), uint64(signatureNum), false)
signatures = append(signatures, signature)
}
startIndex += len(meta.Methods)
}
// Replace the pre-created arrays with the generated arrays.
tupleArray := llvm.ConstArray(tupleType, tuples)
tupleArrayNewGlobal := llvm.AddGlobal(c.mod, tupleArray.Type(), "interface_tuples.tmp")
tupleArrayNewGlobal.SetInitializer(tupleArray)
tupleArrayOldGlobal := c.mod.NamedGlobal("interface_tuples")
tupleArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(tupleArrayNewGlobal, tupleArrayOldGlobal.Type()))
tupleArrayOldGlobal.EraseFromParentAsGlobal()
tupleArrayNewGlobal.SetName("interface_tuples")
funcArray := llvm.ConstArray(c.i8ptrType, funcPointers)
funcArrayNewGlobal := llvm.AddGlobal(c.mod, funcArray.Type(), "interface_functions.tmp")
funcArrayNewGlobal.SetInitializer(funcArray)
funcArrayOldGlobal := c.mod.NamedGlobal("interface_functions")
funcArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(funcArrayNewGlobal, funcArrayOldGlobal.Type()))
funcArrayOldGlobal.EraseFromParentAsGlobal()
funcArrayNewGlobal.SetName("interface_functions")
signatureArray := llvm.ConstArray(llvm.Int32Type(), signatures)
signatureArrayNewGlobal := llvm.AddGlobal(c.mod, signatureArray.Type(), "interface_signatures.tmp")
signatureArrayNewGlobal.SetInitializer(signatureArray)
signatureArrayOldGlobal := c.mod.NamedGlobal("interface_signatures")
signatureArrayOldGlobal.ReplaceAllUsesWith(llvm.ConstBitCast(signatureArrayNewGlobal, signatureArrayOldGlobal.Type()))
signatureArrayOldGlobal.EraseFromParentAsGlobal()
signatureArrayNewGlobal.SetName("interface_signatures")
c.mod.NamedGlobal("first_interface_num").SetInitializer(llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.FirstDynamicType()), false))
return nil return nil
} }
@ -306,6 +353,13 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
case *types.Interface: case *types.Interface:
return c.mod.GetTypeByName("interface"), nil return c.mod.GetTypeByName("interface"), nil
case *types.Named: case *types.Named:
if _, ok := typ.Underlying().(*types.Struct); ok {
llvmType := c.mod.GetTypeByName(typ.Obj().Pkg().Path() + "." + typ.Obj().Name())
if llvmType.IsNil() {
return llvm.Type{}, errors.New("type not found: " + typ.Obj().Pkg().Path() + "." + typ.Obj().Name())
}
return llvmType, nil
}
return c.getLLVMType(typ.Underlying()) return c.getLLVMType(typ.Underlying())
case *types.Pointer: case *types.Pointer:
ptrTo, err := c.getLLVMType(typ.Elem()) ptrTo, err := c.getLLVMType(typ.Elem())
@ -329,6 +383,16 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
} }
// param values // param values
var paramTypes []llvm.Type var paramTypes []llvm.Type
if typ.Recv() != nil {
recv, err := c.getLLVMType(typ.Recv().Type())
if err != nil {
return llvm.Type{}, err
}
if recv.StructName() == "interface" {
recv = c.i8ptrType
}
paramTypes = append(paramTypes, recv)
}
params := typ.Params() params := typ.Params()
for i := 0; i < params.Len(); i++ { for i := 0; i < params.Len(); i++ {
subType, err := c.getLLVMType(params.At(i).Type()) subType, err := c.getLLVMType(params.At(i).Type())
@ -354,13 +418,17 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
} }
} }
func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) { // Return a zero LLVM value for any LLVM type. Setting this value as an
// initializer has the same effect as setting 'zeroinitializer' on a value.
// Sadly, I haven't found a way to do it directly with the Go API but this works
// just fine.
func getZeroValue(typ llvm.Type) (llvm.Value, error) {
switch typ.TypeKind() { switch typ.TypeKind() {
case llvm.ArrayTypeKind: case llvm.ArrayTypeKind:
subTyp := typ.ElementType() subTyp := typ.ElementType()
vals := make([]llvm.Value, typ.ArrayLength()) vals := make([]llvm.Value, typ.ArrayLength())
for i := range vals { for i := range vals {
val, err := c.getZeroValue(subTyp) val, err := getZeroValue(subTyp)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -375,7 +443,7 @@ func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) {
types := typ.StructElementTypes() types := typ.StructElementTypes()
vals := make([]llvm.Value, len(types)) vals := make([]llvm.Value, len(types))
for i, subTyp := range types { for i, subTyp := range types {
val, err := c.getZeroValue(subTyp) val, err := getZeroValue(subTyp)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -391,15 +459,6 @@ func (c *Compiler) getZeroValue(typ llvm.Type) (llvm.Value, error) {
} }
} }
func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value {
if _, ok := c.itfTypeNumbers[typ]; !ok {
num := uint64(len(c.itfTypes))
c.itfTypes = append(c.itfTypes, typ)
c.itfTypeNumbers[typ] = num
}
return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false)
}
// Is this a pointer type of some sort? Can be unsafe.Pointer or any *T pointer. // Is this a pointer type of some sort? Can be unsafe.Pointer or any *T pointer.
func isPointer(typ types.Type) bool { func isPointer(typ types.Type) bool {
if _, ok := typ.(*types.Pointer); ok { if _, ok := typ.(*types.Pointer); ok {
@ -411,22 +470,40 @@ func isPointer(typ types.Type) bool {
} }
} }
// Get all methods of a type: both value receivers and pointer receivers.
func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection {
var methods []*types.Selection
// value receivers
ms := prog.MethodSets.MethodSet(typ)
for i := 0; i < ms.Len(); i++ {
methods = append(methods, ms.At(i))
}
// pointer receivers
ms = prog.MethodSets.MethodSet(types.NewPointer(typ))
for i := 0; i < ms.Len(); i++ {
methods = append(methods, ms.At(i))
}
return methods
}
func getFunctionName(fn *ssa.Function, blocking bool) string { func getFunctionName(fn *ssa.Function, blocking bool) string {
suffix := "" suffix := ""
if blocking { if blocking {
suffix = "$async" suffix = "$async"
} }
if fn.Signature.Recv() != nil { if fn.Signature.Recv() != nil {
// Method on a defined type. // Method on a defined type (which may be a pointer).
typeName := fn.Params[0].Type().(*types.Named).Obj().Name() return fn.RelString(nil) + suffix
return pkgPrefix(fn.Pkg) + "." + typeName + "." + fn.Name() + suffix
} else { } else {
// Bare function. // Bare function.
if strings.HasPrefix(fn.Name(), "_Cfunc_") { if name := getCName(fn.Name()); name != "" {
// Name CGo functions directly. // Name CGo functions directly.
return fn.Name()[len("_Cfunc_"):] return name
} else { } else {
name := pkgPrefix(fn.Pkg) + "." + fn.Name() + suffix name := fn.RelString(nil) + suffix
if fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(fn.Name(), "_llvm_") { if fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(fn.Name(), "_llvm_") {
// Special case for LLVM intrinsics in the runtime. // Special case for LLVM intrinsics in the runtime.
name = "llvm." + strings.Replace(fn.Name()[len("_llvm_"):], "_", ".", -1) name = "llvm." + strings.Replace(fn.Name()[len("_llvm_"):], "_", ".", -1)
@ -440,21 +517,38 @@ func getGlobalName(global *ssa.Global) string {
if strings.HasPrefix(global.Name(), "_extern_") { if strings.HasPrefix(global.Name(), "_extern_") {
return global.Name()[len("_extern_"):] return global.Name()[len("_extern_"):]
} else { } else {
return pkgPrefix(global.Pkg) + "." + global.Name() return global.RelString(nil)
} }
} }
func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { // Return true if this is a CGo-internal function that can be ignored.
func isCGoInternal(name string) bool {
if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") {
// _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall
return true // CGo-internal functions
}
if strings.HasPrefix(name, "__cgofn__cgo_") {
return true // CGo function pointer in global scope
}
return false
}
// Return the name of the C function if this is a CGo call. Otherwise, return a
// zero-length string.
func getCName(name string) string {
if strings.HasPrefix(name, "_Cfunc_") {
return name[len("_Cfunc_"):]
}
return ""
}
func (c *Compiler) parsePackage(pkg *ssa.Package) error {
// Make sure we're walking through all members in a constant order every // Make sure we're walking through all members in a constant order every
// run. // run, and skip cgo wrapper functions/globals which we don't need.
memberNames := make([]string, 0) memberNames := make([]string, 0)
for name := range pkg.Members { for name := range pkg.Members {
if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { if isCGoInternal(name) {
// _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall continue
continue // CGo-internal functions
}
if strings.HasPrefix(name, "__cgofn__cgo_") {
continue // CGo function pointer in global scope
} }
memberNames = append(memberNames, name) memberNames = append(memberNames, name)
} }
@ -462,7 +556,26 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error {
frames := make(map[*ssa.Function]*Frame) frames := make(map[*ssa.Function]*Frame)
// First, build all function declarations. // First, declare all named (struct) types.
for _, name := range memberNames {
member := pkg.Members[name]
switch member := member.(type) {
case *ssa.Type:
if named, ok := member.Type().(*types.Named); ok {
if st, ok := named.Underlying().(*types.Struct); ok {
llvmType, err := c.getLLVMType(st)
if err != nil {
return err
}
llvmNamedType := c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name())
llvmNamedType.StructSetBody(llvmType.StructElementTypes(), false)
}
}
}
}
// With the types defined, build all function declarations.
for _, name := range memberNames { for _, name := range memberNames {
member := pkg.Members[name] member := pkg.Members[name]
@ -514,7 +627,7 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error {
global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false)) global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false))
global.SetGlobalConstant(true) global.SetGlobalConstant(true)
} else { } else {
initializer, err := c.getZeroValue(llvmType) initializer, err := getZeroValue(llvmType)
if err != nil { if err != nil {
return err return err
} }
@ -523,9 +636,8 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error {
} }
case *ssa.Type: case *ssa.Type:
if !types.IsInterface(member.Type()) { if !types.IsInterface(member.Type()) {
ms := program.MethodSets.MethodSet(member.Type()) for _, sel := range getAllMethods(c.program, member.Type()) {
for i := 0; i < ms.Len(); i++ { fn := c.program.MethodValue(sel)
fn := program.MethodValue(ms.At(i))
frame, err := c.parseFuncDecl(fn) frame, err := c.parseFuncDecl(fn)
if err != nil { if err != nil {
return err return err
@ -543,7 +655,7 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error {
member := pkg.Members[name] member := pkg.Members[name]
switch member := member.(type) { switch member := member.(type) {
case *ssa.Function: case *ssa.Function:
if strings.HasPrefix(name, "_Cfunc_") { if getCName(name) != "" {
// CGo function. Don't implement it's body. // CGo function. Don't implement it's body.
continue continue
} }
@ -561,9 +673,8 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error {
} }
case *ssa.Type: case *ssa.Type:
if !types.IsInterface(member.Type()) { if !types.IsInterface(member.Type()) {
ms := program.MethodSets.MethodSet(member.Type()) for _, sel := range getAllMethods(c.program, member.Type()) {
for i := 0; i < ms.Len(); i++ { fn := c.program.MethodValue(sel)
fn := program.MethodValue(ms.At(i))
err := c.parseFunc(frames[fn], fn) err := c.parseFunc(frames[fn], fn)
if err != nil { if err != nil {
return err return err
@ -671,7 +782,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error {
llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmAddr := c.mod.NamedGlobal(getGlobalName(global))
llvmValue := llvmAddr.Initializer() llvmValue := llvmAddr.Initializer()
if llvmValue.IsNil() { if llvmValue.IsNil() {
llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) llvmValue, err = getZeroValue(llvmAddr.Type().ElementType())
if err != nil { if err != nil {
return err return err
} }
@ -693,7 +804,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error {
llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) llvmAddr := c.mod.NamedGlobal(getGlobalName(global))
llvmValue := llvmAddr.Initializer() llvmValue := llvmAddr.Initializer()
if llvmValue.IsNil() { if llvmValue.IsNil() {
llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) llvmValue, err = getZeroValue(llvmAddr.Type().ElementType())
if err != nil { if err != nil {
return err return err
} }
@ -741,8 +852,8 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
if frame.blocking { if frame.blocking {
// Coroutine initialization. // Coroutine initialization.
c.builder.SetInsertPointAtEnd(frame.blocks[f.Blocks[0]]) c.builder.SetInsertPointAtEnd(frame.blocks[f.Blocks[0]])
frame.taskState = c.builder.CreateAlloca(c.taskDataType, "task.state") taskState := c.builder.CreateAlloca(c.mod.GetTypeByName("runtime.taskState"), "task.state")
stateI8 := c.builder.CreateBitCast(frame.taskState, c.i8ptrType, "task.state.i8") stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8")
id := c.builder.CreateCall(c.coroIdFunc, []llvm.Value{ id := c.builder.CreateCall(c.coroIdFunc, []llvm.Value{
llvm.ConstInt(llvm.Int32Type(), 0, false), llvm.ConstInt(llvm.Int32Type(), 0, false),
stateI8, stateI8,
@ -978,12 +1089,15 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string)
default: default:
return llvm.Value{}, errors.New("todo: len: unknown type") return llvm.Value{}, errors.New("todo: len: unknown type")
} }
case "ssa:wrapnilchk":
// TODO: do an actual nil check?
return c.parseExpr(frame, args[0])
default: default:
return llvm.Value{}, errors.New("todo: builtin: " + callName) return llvm.Value{}, errors.New("todo: builtin: " + callName)
} }
} }
func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) { func (c *Compiler) parseFunctionCall(frame *Frame, args []ssa.Value, llvmFn llvm.Value, blocking bool, parentHandle llvm.Value) (llvm.Value, error) {
var params []llvm.Value var params []llvm.Value
if blocking { if blocking {
if parentHandle.IsNil() { if parentHandle.IsNil() {
@ -994,7 +1108,7 @@ func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn
params = append(params, parentHandle) params = append(params, parentHandle)
} }
} }
for _, param := range call.Args { for _, param := range args {
val, err := c.parseExpr(frame, param) val, err := c.parseExpr(frame, param)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
@ -1048,6 +1162,36 @@ func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, llvmFn
} }
func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle llvm.Value) (llvm.Value, error) { func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle llvm.Value) (llvm.Value, error) {
if instr.IsInvoke() {
// Call an interface method with dynamic dispatch.
itf, err := c.parseExpr(frame, instr.Value) // interface
if err != nil {
return llvm.Value{}, err
}
llvmFnType, err := c.getLLVMType(instr.Method.Type())
if err != nil {
return llvm.Value{}, err
}
values := []llvm.Value{
itf,
llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.MethodNum(instr.Method)), false),
}
fn := c.builder.CreateCall(c.mod.NamedFunction("itfmethod"), values, "invoke.func")
fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast")
receiverValue := c.builder.CreateExtractValue(itf, 1, "invoke.func.receiver")
args := []llvm.Value{receiverValue}
for _, arg := range instr.Args {
val, err := c.parseExpr(frame, arg)
if err != nil {
return llvm.Value{}, err
}
args = append(args, val)
}
// TODO: blocking methods (needs analysis)
return c.builder.CreateCall(fnCast, args, ""), nil
}
// Regular function, builtin, or function pointer.
switch call := instr.Value.(type) { switch call := instr.Value.(type) {
case *ssa.Builtin: case *ssa.Builtin:
return c.parseBuiltin(frame, instr.Args, call.Name()) return c.parseBuiltin(frame, instr.Args, call.Name())
@ -1072,14 +1216,14 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
return llvm.Value{}, errors.New("undefined function: " + name) return llvm.Value{}, errors.New("undefined function: " + name)
} }
} }
return c.parseFunctionCall(frame, instr, llvmFn, targetBlocks, parentHandle) return c.parseFunctionCall(frame, instr.Args, llvmFn, targetBlocks, parentHandle)
default: // function pointer default: // function pointer
value, err := c.parseExpr(frame, instr.Value) value, err := c.parseExpr(frame, instr.Value)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
// TODO: blocking function pointers (needs analysis) // TODO: blocking function pointers (needs analysis)
return c.parseFunctionCall(frame, instr, value, false, parentHandle) return c.parseFunctionCall(frame, instr.Args, value, false, parentHandle)
} }
} }
@ -1108,7 +1252,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
buf = c.builder.CreateBitCast(buf, llvm.PointerType(typ, 0), "") buf = c.builder.CreateBitCast(buf, llvm.PointerType(typ, 0), "")
} else { } else {
buf = c.builder.CreateAlloca(typ, expr.Comment) buf = c.builder.CreateAlloca(typ, expr.Comment)
zero, err := c.getZeroValue(typ) zero, err := getZeroValue(typ)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -1253,11 +1397,25 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
c.builder.CreateStore(val, itfValueCast) c.builder.CreateStore(val, itfValueCast)
} else { } else {
// Directly place the value in the interface. // Directly place the value in the interface.
// TODO: non-integers switch val.Type().TypeKind() {
itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") case llvm.IntegerTypeKind:
itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "")
case llvm.PointerTypeKind:
itfValue = c.builder.CreateBitCast(val, c.i8ptrType, "")
case llvm.StructTypeKind:
// A bitcast would be useful here, but bitcast doesn't allow
// aggregate types. So we'll bitcast it using an alloca.
// Hopefully this will get optimized away.
mem := c.builder.CreateAlloca(c.i8ptrType, "")
memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(val.Type(), 0), "")
c.builder.CreateStore(val, memStructPtr)
itfValue = c.builder.CreateLoad(mem, "")
default:
return llvm.Value{}, errors.New("todo: makeinterface: cast small type to i8*")
}
} }
itfTypeNum := c.getInterfaceType(expr.X.Type()) itfTypeNum, _ := c.analysis.TypeNum(expr.X.Type())
itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{itfTypeNum, llvm.Undef(c.i8ptrType)}) itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{llvm.ConstInt(llvm.Int32Type(), uint64(itfTypeNum), false), llvm.Undef(c.i8ptrType)})
itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") itf = c.builder.CreateInsertValue(itf, itfValue, 1, "")
return itf, nil return itf, nil
case *ssa.Phi: case *ssa.Phi:
@ -1280,7 +1438,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
assertedTypeNum := c.getInterfaceType(expr.AssertedType) assertedTypeNum, typeExists := c.analysis.TypeNum(expr.AssertedType)
if !typeExists {
// Static analysis has determined this type assert will never apply.
return llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil
}
actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type") actualTypeNum := c.builder.CreateExtractValue(itf, 0, "interface.type")
valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value") valuePtr := c.builder.CreateExtractValue(itf, 1, "interface.value")
var value llvm.Value var value llvm.Value
@ -1290,12 +1452,26 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
value = c.builder.CreateLoad(valuePtrCast, "") value = c.builder.CreateLoad(valuePtrCast, "")
} else { } else {
// Value was stored directly in the interface. // Value was stored directly in the interface.
// TODO: non-integer values. switch assertedType.TypeKind() {
value = c.builder.CreatePtrToInt(valuePtr, assertedType, "") case llvm.IntegerTypeKind:
value = c.builder.CreatePtrToInt(valuePtr, assertedType, "")
case llvm.PointerTypeKind:
value = c.builder.CreateBitCast(valuePtr, assertedType, "")
case llvm.StructTypeKind:
// A bitcast would be useful here, but bitcast doesn't allow
// aggregate types. So we'll bitcast it using an alloca.
// Hopefully this will get optimized away.
mem := c.builder.CreateAlloca(c.i8ptrType, "")
c.builder.CreateStore(valuePtr, mem)
memStructPtr := c.builder.CreateBitCast(mem, llvm.PointerType(assertedType, 0), "")
value = c.builder.CreateLoad(memStructPtr, "")
default:
return llvm.Value{}, errors.New("todo: typeassert: bitcast small types")
}
} }
// TODO: for interfaces, check whether the type implements the // TODO: for interfaces, check whether the type implements the
// interface. // interface.
commaOk := c.builder.CreateICmp(llvm.IntEQ, assertedTypeNum, actualTypeNum, "") commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int32Type(), uint64(assertedTypeNum), false), actualTypeNum, "")
tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple
tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value
tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean