diff --git a/analysis.go b/analysis.go index 6ad79dbb..dfda38bd 100644 --- a/analysis.go +++ b/analysis.go @@ -7,125 +7,13 @@ import ( "golang.org/x/tools/go/ssa" ) -// Analysis results over a whole program. -type Analysis struct { - functions map[*ssa.Function]*FuncMeta - needsScheduler bool - goCalls []*ssa.Go - typesWithMethods map[string]*TypeMeta - typesWithoutMethods map[string]int - methodSignatureNames map[string]int -} - -// Some analysis results of a single function. -type FuncMeta struct { - f *ssa.Function - blocking bool - parents []*ssa.Function // calculated by AnalyseCallgraph - children []*ssa.Function -} - -type TypeMeta struct { - t types.Type - Num int - Methods map[string]*types.Selection -} - -// Return a new Analysis object. -func NewAnalysis() *Analysis { - return &Analysis{ - functions: make(map[*ssa.Function]*FuncMeta), - typesWithMethods: make(map[string]*TypeMeta), - typesWithoutMethods: make(map[string]int), - methodSignatureNames: make(map[string]int), - } -} - -// Add a given package to the analyzer, to be analyzed later. -func (a *Analysis) AddPackage(pkg *ssa.Package) { - for _, member := range pkg.Members { - switch member := member.(type) { - case *ssa.Function: - if isCGoInternal(member.Name()) || getCName(member.Name()) != "" { - continue - } - a.addFunction(member) - case *ssa.Type: - methods := getAllMethods(pkg.Prog, member.Type()) - if types.IsInterface(member.Type()) { - for _, method := range methods { - a.MethodName(method.Obj().(*types.Func)) - } - } else { // named type - for _, method := range methods { - a.addFunction(pkg.Prog.MethodValue(method)) - } - } - } - } -} - -// Analyze the given function quickly without any recursion, and add it to the -// list of functions in the analyzer. -func (a *Analysis) addFunction(f *ssa.Function) { - fm := &FuncMeta{} - for _, block := range f.Blocks { - for _, instr := range block.Instrs { - switch instr := instr.(type) { - case *ssa.Call: - if instr.Common().IsInvoke() { - name := a.MethodName(instr.Common().Method) - a.methodSignatureNames[name] = len(a.methodSignatureNames) - } else { - switch call := instr.Call.Value.(type) { - case *ssa.Builtin: - // ignore - case *ssa.Function: - if isCGoInternal(call.Name()) || getCName(call.Name()) != "" { - continue - } - name := getFunctionName(call, false) - if name == "runtime.Sleep" { - fm.blocking = true - } - fm.children = append(fm.children, call) - } - } - case *ssa.MakeInterface: - methods := getAllMethods(f.Prog, instr.X.Type()) - if _, ok := a.typesWithMethods[instr.X.Type().String()]; !ok && len(methods) > 0 { - meta := &TypeMeta{ - t: instr.X.Type(), - Num: len(a.typesWithMethods), - Methods: make(map[string]*types.Selection), - } - for _, sel := range methods { - name := a.MethodName(sel.Obj().(*types.Func)) - meta.Methods[name] = sel - } - a.typesWithMethods[instr.X.Type().String()] = meta - } else if _, ok := a.typesWithoutMethods[instr.X.Type().String()]; !ok && len(methods) == 0 { - a.typesWithoutMethods[instr.X.Type().String()] = len(a.typesWithoutMethods) - } - case *ssa.Go: - a.goCalls = append(a.goCalls, instr) - } - } - } - a.functions[f] = fm - - for _, child := range f.AnonFuncs { - a.addFunction(child) - } -} - // Make a readable version of the method signature (including the function name, // excluding the receiver name). This string is used internally to match // interfaces and to call the correct method on an interface. Examples: // // String() string // Read([]byte) (int, error) -func (a *Analysis) MethodName(method *types.Func) string { +func MethodName(method *types.Func) string { sig := method.Type().(*types.Signature) name := method.Name() if sig.Params().Len() == 0 { @@ -161,15 +49,75 @@ func (a *Analysis) MethodName(method *types.Func) string { // // All packages need to be added before this pass can run, or it will produce // incorrect results. -func (a *Analysis) AnalyseCallgraph() { - for f, fm := range a.functions { - for _, child := range fm.children { - childRes, ok := a.functions[child] - if !ok { - println("child not found: " + child.Pkg.Pkg.Path() + "." + child.Name() + ", function: " + f.Name()) - continue +func (p *Program) AnalyseCallgraph() { + for _, f := range p.Functions { + // Clear, if AnalyseCallgraph has been called before. + f.children = nil + f.parents = nil + + for _, block := range f.fn.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.Call: + if instr.Common().IsInvoke() { + continue + } + switch call := instr.Call.Value.(type) { + case *ssa.Builtin: + // ignore + case *ssa.Function: + if isCGoInternal(call.Name()) { + continue + } + child := p.GetFunction(call) + if child.CName() != "" { + continue // assume non-blocking + } + if child.Name(false) == "runtime.Sleep" { + f.blocking = true + } + f.children = append(f.children, child) + } + } + } + } + } + for _, f := range p.Functions { + for _, child := range f.children { + child.parents = append(child.parents, f) + } + } +} + +// Find all types that are put in an interface. +func (p *Program) AnalyseInterfaceConversions() { + // Clear, if AnalyseTypes has been called before. + p.typesWithMethods = make(map[string]*InterfaceType) + p.typesWithoutMethods = make(map[string]int) + + for _, f := range p.Functions { + for _, block := range f.fn.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.MakeInterface: + methods := getAllMethods(f.fn.Prog, instr.X.Type()) + name := instr.X.Type().String() + if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 { + t := &InterfaceType{ + t: instr.X.Type(), + Num: len(p.typesWithMethods), + Methods: make(map[string]*types.Selection), + } + for _, sel := range methods { + name := MethodName(sel.Obj().(*types.Func)) + t.Methods[name] = sel + } + p.typesWithMethods[instr.X.Type().String()] = t + } else if _, ok := p.typesWithoutMethods[name]; !ok && len(methods) == 0 { + p.typesWithoutMethods[name] = len(p.typesWithoutMethods) + } + } } - childRes.parents = append(childRes.parents, f) } } } @@ -177,13 +125,13 @@ func (a *Analysis) AnalyseCallgraph() { // Analyse which functions are recursively blocking. // // Depends on AnalyseCallgraph. -func (a *Analysis) AnalyseBlockingRecursive() { - worklist := make([]*FuncMeta, 0) +func (p *Program) AnalyseBlockingRecursive() { + worklist := make([]*Function, 0) // Fill worklist with directly blocking functions. - for _, fm := range a.functions { - if fm.blocking { - worklist = append(worklist, fm) + for _, f := range p.Functions { + if f.blocking { + worklist = append(worklist, f) } } @@ -193,13 +141,12 @@ func (a *Analysis) AnalyseBlockingRecursive() { // The work items are then grey objects. for len(worklist) != 0 { // Pick the topmost. - fm := worklist[len(worklist)-1] + f := worklist[len(worklist)-1] worklist = worklist[:len(worklist)-1] - for _, parent := range fm.parents { - parentfm := a.functions[parent] - if !parentfm.blocking { - parentfm.blocking = true - worklist = append(worklist, parentfm) + for _, parent := range f.parents { + if !parent.blocking { + parent.blocking = true + worklist = append(worklist, parent) } } } @@ -210,10 +157,27 @@ func (a *Analysis) AnalyseBlockingRecursive() { // function can be turned into a regular function call). // // Depends on AnalyseBlockingRecursive. -func (a *Analysis) AnalyseGoCalls() { - for _, instr := range a.goCalls { - if a.isBlocking(instr.Call.Value) { - a.needsScheduler = true +func (p *Program) AnalyseGoCalls() { + p.goCalls = nil + for _, f := range p.Functions { + for _, block := range f.fn.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.Go: + p.goCalls = append(p.goCalls, instr) + } + } + } + } + for _, instr := range p.goCalls { + switch instr := instr.Call.Value.(type) { + case *ssa.Builtin: + case *ssa.Function: + if p.functionMap[instr].blocking { + p.needsScheduler = true + } + default: + panic("unknown go call function type") } } } @@ -221,30 +185,19 @@ func (a *Analysis) AnalyseGoCalls() { // Whether this function needs a scheduler. // // Depends on AnalyseGoCalls. -func (a *Analysis) NeedsScheduler() bool { - return a.needsScheduler +func (p *Program) NeedsScheduler() bool { + return p.needsScheduler } // Whether this function blocks. Builtins are also accepted for convenience. // They will always be non-blocking. // // Depends on AnalyseBlockingRecursive. -func (a *Analysis) IsBlocking(f ssa.Value) bool { - if !a.needsScheduler { +func (p *Program) IsBlocking(f *Function) bool { + if !p.needsScheduler { return false } - return a.isBlocking(f) -} - -func (a *Analysis) isBlocking(f ssa.Value) bool { - switch f := f.(type) { - case *ssa.Builtin: - return false - case *ssa.Function: - return a.functions[f].blocking - default: - panic("Analysis.IsBlocking on unknown type") - } + return f.blocking } // Return the type number and whether this type is actually used. Used in @@ -252,11 +205,11 @@ func (a *Analysis) isBlocking(f ssa.Value) bool { // used, meaning assert is always false in this program). // // May only be used after all packages have been added to the analyser. -func (a *Analysis) TypeNum(typ types.Type) (int, bool) { - if n, ok := a.typesWithoutMethods[typ.String()]; ok { +func (p *Program) TypeNum(typ types.Type) (int, bool) { + if n, ok := p.typesWithoutMethods[typ.String()]; ok { return n, true - } else if meta, ok := a.typesWithMethods[typ.String()]; ok { - return len(a.typesWithoutMethods) + meta.Num, true + } else if meta, ok := p.typesWithMethods[typ.String()]; ok { + return len(p.typesWithoutMethods) + meta.Num, true } else { return -1, false // type is never put in an interface } @@ -264,11 +217,12 @@ func (a *Analysis) TypeNum(typ types.Type) (int, bool) { // MethodNum returns the numeric ID of this method, to be used in method lookups // on interfaces for example. -func (a *Analysis) MethodNum(method *types.Func) int { - if n, ok := a.methodSignatureNames[a.MethodName(method)]; ok { - return n +func (p *Program) MethodNum(method *types.Func) int { + name := MethodName(method) + if _, ok := p.methodSignatureNames[name]; !ok { + p.methodSignatureNames[name] = len(p.methodSignatureNames) } - return -1 // signal error + return p.methodSignatureNames[MethodName(method)] } // The start index of the first dynamic type that has methods. @@ -276,14 +230,14 @@ func (a *Analysis) MethodNum(method *types.Func) int { // or a higher ID. // // May only be used after all packages have been added to the analyser. -func (a *Analysis) FirstDynamicType() int { - return len(a.typesWithoutMethods) +func (p *Program) FirstDynamicType() int { + return len(p.typesWithoutMethods) } // Return all types with methods, sorted by type ID. -func (a *Analysis) AllDynamicTypes() []*TypeMeta { - l := make([]*TypeMeta, len(a.typesWithMethods)) - for _, m := range a.typesWithMethods { +func (p *Program) AllDynamicTypes() []*InterfaceType { + l := make([]*InterfaceType, len(p.typesWithMethods)) + for _, m := range p.typesWithMethods { l[m.Num] = m } return l diff --git a/ir.go b/ir.go new file mode 100644 index 00000000..76dde3f3 --- /dev/null +++ b/ir.go @@ -0,0 +1,156 @@ +package main + +import ( + "go/types" + "sort" + "strings" + + "github.com/aykevl/llvm/bindings/go/llvm" + "golang.org/x/tools/go/ssa" +) + +// View on all functions, types, and globals in a program, with analysis +// results. +type Program struct { + Functions []*Function + functionMap map[*ssa.Function]*Function + Globals []*Global + globalMap map[*ssa.Global]*Global + NamedTypes []*NamedType + needsScheduler bool + goCalls []*ssa.Go + typesWithMethods map[string]*InterfaceType + typesWithoutMethods map[string]int + methodSignatureNames map[string]int +} + +// Function or method. +type Function struct { + fn *ssa.Function + llvmFn llvm.Value + blocking bool + parents []*Function // calculated by AnalyseCallgraph + children []*Function +} + +// Global variable, possibly constant. +type Global struct { + g *ssa.Global + llvmGlobal llvm.Value +} + +// Type with a name and possibly methods. +type NamedType struct { + t *ssa.Type +} + +// Type that is at some point put in an interface. +type InterfaceType struct { + t types.Type + Num int + Methods map[string]*types.Selection +} + +func NewProgram() *Program { + return &Program{ + functionMap: make(map[*ssa.Function]*Function), + globalMap: make(map[*ssa.Global]*Global), + methodSignatureNames: make(map[string]int), + } +} + +// Add a package to this Program. All packages need to be added first before any +// analysis is done for correct results. +func (p *Program) AddPackage(pkg *ssa.Package) { + memberNames := make([]string, 0) + for name := range pkg.Members { + if isCGoInternal(name) { + continue + } + memberNames = append(memberNames, name) + } + sort.Strings(memberNames) + + for _, name := range memberNames { + member := pkg.Members[name] + switch member := member.(type) { + case *ssa.Function: + if isCGoInternal(member.Name()) { + continue + } + p.addFunction(member) + case *ssa.Type: + t := &NamedType{t: member} + p.NamedTypes = append(p.NamedTypes, t) + methods := getAllMethods(pkg.Prog, member.Type()) + if !types.IsInterface(member.Type()) { + // named type + for _, method := range methods { + p.addFunction(pkg.Prog.MethodValue(method)) + } + } + case *ssa.Global: + g := &Global{g: member} + p.Globals = append(p.Globals, g) + p.globalMap[member] = g + } + } +} + +func (p *Program) addFunction(ssaFn *ssa.Function) { + f := &Function{fn: ssaFn} + p.Functions = append(p.Functions, f) + p.functionMap[ssaFn] = f +} + +func (p *Program) GetFunction(ssaFn *ssa.Function) *Function { + return p.functionMap[ssaFn] +} + +func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global { + return p.globalMap[ssaGlobal] +} + +// Return the link name for this function. +func (f *Function) Name(blocking bool) string { + suffix := "" + if blocking { + suffix = "$async" + } + if f.fn.Signature.Recv() != nil { + // Method on a defined type (which may be a pointer). + return f.fn.RelString(nil) + suffix + } else { + // Bare function. + if name := f.CName(); name != "" { + // Name CGo functions directly. + return name + } else { + name := f.fn.RelString(nil) + suffix + if f.fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(f.fn.Name(), "_llvm_") { + // Special case for LLVM intrinsics in the runtime. + name = "llvm." + strings.Replace(f.fn.Name()[len("_llvm_"):], "_", ".", -1) + } + return name + } + } +} + +// Return the name of the C function if this is a CGo wrapper. Otherwise, return +// a zero-length string. +func (f *Function) CName() string { + name := f.fn.Name() + if strings.HasPrefix(name, "_Cfunc_") { + return name[len("_Cfunc_"):] + } + return "" +} + +// Return the link name for this global. +func (g *Global) Name() string { + if strings.HasPrefix(g.g.Name(), "_extern_") { + return g.g.Name()[len("_extern_"):] + } else { + return g.g.RelString(nil) + } +} diff --git a/tgo.go b/tgo.go index 1eb50836..8a445b83 100644 --- a/tgo.go +++ b/tgo.go @@ -10,7 +10,6 @@ import ( "go/token" "go/types" "os" - "sort" "strings" "github.com/aykevl/llvm/bindings/go/llvm" @@ -50,12 +49,11 @@ type Compiler struct { program *ssa.Program mainPkg *ssa.Package initFuncs []llvm.Value - analysis *Analysis + ir *Program } type Frame struct { - fn *ssa.Function - llvmFn llvm.Value + fn *Function params map[*ssa.Parameter]int // arguments to the function locals map[ssa.Value]llvm.Value // local variables blocks map[*ssa.BasicBlock]llvm.BasicBlock @@ -73,9 +71,9 @@ type Phi struct { func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { c := &Compiler{ - dumpSSA: dumpSSA, - triple: triple, - analysis: NewAnalysis(), + dumpSSA: dumpSSA, + triple: triple, + ir: NewProgram(), } target, err := llvm.GetTargetFromTriple(triple) @@ -208,15 +206,91 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { } for _, pkg := range packageList { - c.analysis.AddPackage(pkg) + c.ir.AddPackage(pkg) } - c.analysis.AnalyseCallgraph() // set up callgraph - c.analysis.AnalyseBlockingRecursive() // make all parents of blocking calls blocking (transitively) - c.analysis.AnalyseGoCalls() // check whether we need a scheduler + c.ir.AnalyseCallgraph() // set up callgraph + c.ir.AnalyseInterfaceConversions() // determine which types are converted to an interface + c.ir.AnalyseBlockingRecursive() // make all parents of blocking calls blocking (transitively) + c.ir.AnalyseGoCalls() // check whether we need a scheduler - // Transform each package into LLVM IR. - for _, pkg := range packageList { - err := c.parsePackage(pkg) + var frames []*Frame + + // Declare all named (struct) types. + for _, t := range c.ir.NamedTypes { + if named, ok := t.t.Type().(*types.Named); ok { + if st, ok := named.Underlying().(*types.Struct); ok { + llvmType, err := c.getLLVMType(st) + if err != nil { + return err + } + llvmNamedType := c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name()) + llvmNamedType.StructSetBody(llvmType.StructElementTypes(), false) + } + } + } + + // Declare all globals. These will get an initializer when parsing "package + // initializer" packages. + for _, g := range c.ir.Globals { + typ := g.g.Type() + if typPtr, ok := typ.(*types.Pointer); ok { + typ = typPtr.Elem() + } else { + return errors.New("global is not a pointer") + } + llvmType, err := c.getLLVMType(typ) + if err != nil { + return err + } + global := llvm.AddGlobal(c.mod, llvmType, g.Name()) + g.llvmGlobal = global + if !strings.HasPrefix(g.Name(), "_extern_") { + global.SetLinkage(llvm.PrivateLinkage) + if g.Name() == "runtime.TargetBits" { + bitness := c.targetData.PointerSize() * 8 + if bitness < 32 { + // Only 8 and 32+ architectures supported at the moment. + // On 8 bit architectures, pointers are normally bigger + // than 8 bits to do anything meaningful. + // TODO: clean up this hack to support 16-bit + // architectures. + bitness = 8 + } + global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false)) + global.SetGlobalConstant(true) + } else { + initializer, err := getZeroValue(llvmType) + if err != nil { + return err + } + global.SetInitializer(initializer) + } + } + } + + // Declare all functions. + for _, f := range c.ir.Functions { + frame, err := c.parseFuncDecl(f) + if err != nil { + return err + } + frames = append(frames, frame) + } + + // Add definitions to declarations. + for _, frame := range frames { + if frame.fn.CName() != "" { + continue + } + if frame.fn.fn.Blocks == nil { + continue // external function + } + var err error + if frame.fn.fn.Synthetic == "package initializer" { + err = c.parseInitFunc(frame) + } else { + err = c.parseFunc(frame) + } if err != nil { return err } @@ -254,13 +328,13 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.PrivateLinkage) // Only use a scheduler when necessary. - if c.analysis.NeedsScheduler() { + if c.ir.NeedsScheduler() { // Enable the scheduler. c.mod.NamedGlobal("has_scheduler").SetInitializer(llvm.ConstInt(llvm.Int1Type(), 1, false)) } // Initialize runtime type information, for interfaces. - dynamicTypes := c.analysis.AllDynamicTypes() + dynamicTypes := c.ir.AllDynamicTypes() numDynamicTypes := 0 for _, meta := range dynamicTypes { numDynamicTypes += len(meta.Methods) @@ -278,14 +352,13 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { tuple := llvm.ConstNamedStruct(tupleType, tupleValues) tuples = append(tuples, tuple) for _, method := range meta.Methods { - fnName := getFunctionName(c.program.MethodValue(method), false) - llvmFn := c.mod.NamedFunction(fnName) - if llvmFn.IsNil() { - return errors.New("cannot find function: " + fnName) + f := c.ir.GetFunction(c.program.MethodValue(method)) + if f.llvmFn.IsNil() { + return errors.New("cannot find function: " + f.Name(false)) } - fn := llvm.ConstBitCast(llvmFn, c.i8ptrType) + fn := llvm.ConstBitCast(f.llvmFn, c.i8ptrType) funcPointers = append(funcPointers, fn) - signatureNum := c.analysis.MethodNum(method.Obj().(*types.Func)) + signatureNum := c.ir.MethodNum(method.Obj().(*types.Func)) signature := llvm.ConstInt(llvm.Int32Type(), uint64(signatureNum), false) signatures = append(signatures, signature) } @@ -314,7 +387,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error { signatureArrayOldGlobal.EraseFromParentAsGlobal() signatureArrayNewGlobal.SetName("interface_signatures") - c.mod.NamedGlobal("first_interface_num").SetInitializer(llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.FirstDynamicType()), false)) + c.mod.NamedGlobal("first_interface_num").SetInitializer(llvm.ConstInt(llvm.Int32Type(), uint64(c.ir.FirstDynamicType()), false)) return nil } @@ -489,38 +562,6 @@ func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection { return methods } -func getFunctionName(fn *ssa.Function, blocking bool) string { - suffix := "" - if blocking { - suffix = "$async" - } - if fn.Signature.Recv() != nil { - // Method on a defined type (which may be a pointer). - return fn.RelString(nil) + suffix - } else { - // Bare function. - if name := getCName(fn.Name()); name != "" { - // Name CGo functions directly. - return name - } else { - name := fn.RelString(nil) + suffix - if fn.Pkg.Pkg.Path() == "runtime" && strings.HasPrefix(fn.Name(), "_llvm_") { - // Special case for LLVM intrinsics in the runtime. - name = "llvm." + strings.Replace(fn.Name()[len("_llvm_"):], "_", ".", -1) - } - return name - } - } -} - -func getGlobalName(global *ssa.Global) string { - if strings.HasPrefix(global.Name(), "_extern_") { - return global.Name()[len("_extern_"):] - } else { - return global.RelString(nil) - } -} - // Return true if this is a CGo-internal function that can be ignored. func isCGoInternal(name string) bool { if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { @@ -533,180 +574,26 @@ func isCGoInternal(name string) bool { return false } -// Return the name of the C function if this is a CGo call. Otherwise, return a -// zero-length string. -func getCName(name string) string { - if strings.HasPrefix(name, "_Cfunc_") { - return name[len("_Cfunc_"):] - } - return "" -} - -func (c *Compiler) parsePackage(pkg *ssa.Package) error { - // Make sure we're walking through all members in a constant order every - // run, and skip cgo wrapper functions/globals which we don't need. - memberNames := make([]string, 0) - for name := range pkg.Members { - if isCGoInternal(name) { - continue - } - memberNames = append(memberNames, name) - } - sort.Strings(memberNames) - - frames := make(map[*ssa.Function]*Frame) - - // First, declare all named (struct) types. - for _, name := range memberNames { - member := pkg.Members[name] - - switch member := member.(type) { - case *ssa.Type: - if named, ok := member.Type().(*types.Named); ok { - if st, ok := named.Underlying().(*types.Struct); ok { - llvmType, err := c.getLLVMType(st) - if err != nil { - return err - } - llvmNamedType := c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name()) - llvmNamedType.StructSetBody(llvmType.StructElementTypes(), false) - } - } - } - } - - // With the types defined, build all function declarations. - for _, name := range memberNames { - member := pkg.Members[name] - - switch member := member.(type) { - case *ssa.Function: - frame, err := c.parseFuncDecl(member) - if err != nil { - return err - } - frames[member] = frame - if member.Synthetic == "package initializer" { - c.initFuncs = append(c.initFuncs, frame.llvmFn) - } - // TODO: recursively anonymous functions - for _, child := range member.AnonFuncs { - frame, err := c.parseFuncDecl(child) - if err != nil { - return err - } - frames[child] = frame - } - case *ssa.NamedConst: - // Ignore package-level untyped constants. The SSA form doesn't need - // them. - case *ssa.Global: - typ := member.Type() - if typPtr, ok := typ.(*types.Pointer); ok { - typ = typPtr.Elem() - } else { - return errors.New("global is not a pointer") - } - llvmType, err := c.getLLVMType(typ) - if err != nil { - return err - } - global := llvm.AddGlobal(c.mod, llvmType, getGlobalName(member)) - if !strings.HasPrefix(member.Name(), "_extern_") { - global.SetLinkage(llvm.PrivateLinkage) - if getGlobalName(member) == "runtime.TargetBits" { - bitness := c.targetData.PointerSize() * 8 - if bitness < 32 { - // Only 8 and 32+ architectures supported at the moment. - // On 8 bit architectures, pointers are normally bigger - // than 8 bits to do anything meaningful. - // TODO: clean up this hack to support 16-bit - // architectures. - bitness = 8 - } - global.SetInitializer(llvm.ConstInt(llvm.Int8Type(), uint64(bitness), false)) - global.SetGlobalConstant(true) - } else { - initializer, err := getZeroValue(llvmType) - if err != nil { - return err - } - global.SetInitializer(initializer) - } - } - case *ssa.Type: - if !types.IsInterface(member.Type()) { - for _, sel := range getAllMethods(c.program, member.Type()) { - fn := c.program.MethodValue(sel) - frame, err := c.parseFuncDecl(fn) - if err != nil { - return err - } - frames[fn] = frame - } - } - default: - return errors.New("todo: member: " + fmt.Sprintf("%#v", member)) - } - } - - // Now, add definitions to those declarations. - for _, name := range memberNames { - member := pkg.Members[name] - switch member := member.(type) { - case *ssa.Function: - if getCName(name) != "" { - // CGo function. Don't implement it's body. - continue - } - if member.Blocks == nil { - continue // external function - } - var err error - if member.Synthetic == "package initializer" { - err = c.parseInitFunc(frames[member], member) - } else { - err = c.parseFunc(frames[member], member) - } - if err != nil { - return err - } - case *ssa.Type: - if !types.IsInterface(member.Type()) { - for _, sel := range getAllMethods(c.program, member.Type()) { - fn := c.program.MethodValue(sel) - err := c.parseFunc(frames[fn], fn) - if err != nil { - return err - } - } - } - } - } - - return nil -} - -func (c *Compiler) parseFuncDecl(f *ssa.Function) (*Frame, error) { +func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) { frame := &Frame{ fn: f, params: make(map[*ssa.Parameter]int), locals: make(map[ssa.Value]llvm.Value), blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), - blocking: c.analysis.IsBlocking(f), + blocking: c.ir.IsBlocking(f), } var retType llvm.Type if frame.blocking { - if f.Signature.Results() != nil { + if f.fn.Signature.Results() != nil { return nil, errors.New("todo: return values in blocking function") } retType = c.i8ptrType - } else if f.Signature.Results() == nil { + } else if f.fn.Signature.Results() == nil { retType = llvm.VoidType() - } else if f.Signature.Results().Len() == 1 { + } else if f.fn.Signature.Results().Len() == 1 { var err error - retType, err = c.getLLVMType(f.Signature.Results().At(0).Type()) + retType, err = c.getLLVMType(f.fn.Signature.Results().At(0).Type()) if err != nil { return nil, err } @@ -718,7 +605,7 @@ func (c *Compiler) parseFuncDecl(f *ssa.Function) (*Frame, error) { if frame.blocking { paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine } - for i, param := range f.Params { + for i, param := range f.fn.Params { paramType, err := c.getLLVMType(param.Type()) if err != nil { return nil, err @@ -729,22 +616,22 @@ func (c *Compiler) parseFuncDecl(f *ssa.Function) (*Frame, error) { fnType := llvm.FunctionType(retType, paramTypes, false) - name := getFunctionName(f, frame.blocking) - frame.llvmFn = c.mod.NamedFunction(name) - if frame.llvmFn.IsNil() { - frame.llvmFn = llvm.AddFunction(c.mod, name, fnType) + name := f.Name(frame.blocking) + frame.fn.llvmFn = c.mod.NamedFunction(name) + if frame.fn.llvmFn.IsNil() { + frame.fn.llvmFn = llvm.AddFunction(c.mod, name, fnType) } return frame, nil } // Special function parser for generated package initializers (which also // initializes global variables). -func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { - frame.llvmFn.SetLinkage(llvm.PrivateLinkage) - llvmBlock := c.ctx.AddBasicBlock(frame.llvmFn, "entry") +func (c *Compiler) parseInitFunc(frame *Frame) error { + frame.fn.llvmFn.SetLinkage(llvm.PrivateLinkage) + llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "entry") c.builder.SetInsertPointAtEnd(llvmBlock) - for _, block := range f.DomPreorder() { + for _, block := range frame.fn.fn.DomPreorder() { for _, instr := range block.Instrs { var err error switch instr := instr.(type) { @@ -766,7 +653,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { if err != nil { return err } - llvmAddr := c.mod.NamedGlobal(getGlobalName(addr)) + llvmAddr := c.ir.GetGlobal(addr).llvmGlobal llvmAddr.SetInitializer(val) case *ssa.FieldAddr: // Initialize field of a global struct. @@ -779,7 +666,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { return err } global := addr.X.(*ssa.Global) - llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) + llvmAddr := c.ir.GetGlobal(global).llvmGlobal llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) @@ -801,7 +688,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { } fieldAddr := addr.X.(*ssa.FieldAddr) global := fieldAddr.X.(*ssa.Global) - llvmAddr := c.mod.NamedGlobal(getGlobalName(global)) + llvmAddr := c.ir.GetGlobal(global).llvmGlobal llvmValue := llvmAddr.Initializer() if llvmValue.IsNil() { llvmValue, err = getZeroValue(llvmAddr.Type().ElementType()) @@ -827,31 +714,31 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { return nil } -func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { +func (c *Compiler) parseFunc(frame *Frame) error { if c.dumpSSA { - fmt.Printf("\nfunc %s:\n", f) + fmt.Printf("\nfunc %s:\n", frame.fn.fn) } - frame.llvmFn.SetLinkage(llvm.PrivateLinkage) + frame.fn.llvmFn.SetLinkage(llvm.PrivateLinkage) // Pre-create all basic blocks in the function. - for _, block := range f.DomPreorder() { - llvmBlock := c.ctx.AddBasicBlock(frame.llvmFn, block.Comment) + for _, block := range frame.fn.fn.DomPreorder() { + llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, block.Comment) frame.blocks[block] = llvmBlock } if frame.blocking { - frame.cleanupBlock = c.ctx.AddBasicBlock(frame.llvmFn, "task.cleanup") - frame.suspendBlock = c.ctx.AddBasicBlock(frame.llvmFn, "task.suspend") + frame.cleanupBlock = c.ctx.AddBasicBlock(frame.fn.llvmFn, "task.cleanup") + frame.suspendBlock = c.ctx.AddBasicBlock(frame.fn.llvmFn, "task.suspend") } // Load function parameters - for _, param := range f.Params { - llvmParam := frame.llvmFn.Param(frame.params[param]) + for _, param := range frame.fn.fn.Params { + llvmParam := frame.fn.llvmFn.Param(frame.params[param]) frame.locals[param] = llvmParam } if frame.blocking { // Coroutine initialization. - c.builder.SetInsertPointAtEnd(frame.blocks[f.Blocks[0]]) + c.builder.SetInsertPointAtEnd(frame.blocks[frame.fn.fn.Blocks[0]]) taskState := c.builder.CreateAlloca(c.mod.GetTypeByName("runtime.taskState"), "task.state") stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8") id := c.builder.CreateCall(c.coroIdFunc, []llvm.Value{ @@ -874,7 +761,7 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { mem := c.builder.CreateCall(c.coroFreeFunc, []llvm.Value{id, frame.taskHandle}, "task.data.free") c.builder.CreateCall(c.freeFunc, []llvm.Value{mem}, "") // re-insert parent coroutine - c.builder.CreateCall(c.mod.NamedFunction("runtime.scheduleTask"), []llvm.Value{frame.llvmFn.FirstParam()}, "") + c.builder.CreateCall(c.mod.NamedFunction("runtime.scheduleTask"), []llvm.Value{frame.fn.llvmFn.FirstParam()}, "") c.builder.CreateBr(frame.suspendBlock) // Coroutine suspend. A call to llvm.coro.suspend() will branch here. @@ -884,7 +771,7 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { } // Fill blocks with instructions. - for _, block := range f.DomPreorder() { + for _, block := range frame.fn.fn.DomPreorder() { if c.dumpSSA { fmt.Printf("%s:\n", block.Comment) } @@ -933,7 +820,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { // Execute non-blocking calls (including builtins) directly. // parentHandle param is ignored. - if !c.analysis.IsBlocking(instr.Common().Value) { + if !c.ir.IsBlocking(c.ir.GetFunction(instr.Common().Value.(*ssa.Function))) { _, err := c.parseCall(frame, instr.Common(), llvm.Value{}) return err // probably nil } @@ -1174,7 +1061,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } values := []llvm.Value{ itf, - llvm.ConstInt(llvm.Int32Type(), uint64(c.analysis.MethodNum(instr.Method)), false), + llvm.ConstInt(llvm.Int32Type(), uint64(c.ir.MethodNum(instr.Method)), false), } fn := c.builder.CreateCall(c.mod.NamedFunction("itfmethod"), values, "invoke.func") fnCast := c.builder.CreateBitCast(fn, llvmFnType, "invoke.func.cast") @@ -1206,11 +1093,11 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l } } targetBlocks := false - name := getFunctionName(call, targetBlocks) + name := c.ir.GetFunction(call).Name(targetBlocks) llvmFn := c.mod.NamedFunction(name) if llvmFn.IsNil() { targetBlocks = true - nameAsync := getFunctionName(call, targetBlocks) + nameAsync := c.ir.GetFunction(call).Name(targetBlocks) llvmFn = c.mod.NamedFunction(nameAsync) if llvmFn.IsNil() { return llvm.Value{}, errors.New("undefined function: " + name) @@ -1297,12 +1184,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { } return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Function: - return c.mod.NamedFunction(getFunctionName(expr, false)), nil + return c.mod.NamedFunction(c.ir.GetFunction(expr).Name(false)), nil case *ssa.Global: - fullName := getGlobalName(expr) - value := c.mod.NamedGlobal(fullName) + value := c.ir.GetGlobal(expr).llvmGlobal if value.IsNil() { - return llvm.Value{}, errors.New("global not found: " + fullName) + return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).Name()) } return value, nil case *ssa.IndexAddr: @@ -1365,7 +1251,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // Bounds check. // LLVM optimizes this away in most cases. - if frame.llvmFn.Name() != "runtime.boundsCheck" { + if frame.fn.llvmFn.Name() != "runtime.boundsCheck" { constZero := llvm.ConstInt(c.intType, 0, false) isNegative := c.builder.CreateICmp(llvm.IntSLT, index, constZero, "") // index < 0 strlen, err := c.parseBuiltin(frame, []ssa.Value{expr.X}, "len") @@ -1414,7 +1300,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, errors.New("todo: makeinterface: cast small type to i8*") } } - itfTypeNum, _ := c.analysis.TypeNum(expr.X.Type()) + itfTypeNum, _ := c.ir.TypeNum(expr.X.Type()) itf := llvm.ConstNamedStruct(c.mod.GetTypeByName("interface"), []llvm.Value{llvm.ConstInt(llvm.Int32Type(), uint64(itfTypeNum), false), llvm.Undef(c.i8ptrType)}) itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil @@ -1438,7 +1324,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } - assertedTypeNum, typeExists := c.analysis.TypeNum(expr.AssertedType) + assertedTypeNum, typeExists := c.ir.TypeNum(expr.AssertedType) if !typeExists { // Static analysis has determined this type assert will never apply. return llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.ConstInt(llvm.Int1Type(), 0, false)}, false), nil @@ -1473,8 +1359,8 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { // interface. commaOk := c.builder.CreateICmp(llvm.IntEQ, llvm.ConstInt(llvm.Int32Type(), uint64(assertedTypeNum), false), actualTypeNum, "") tuple := llvm.ConstStruct([]llvm.Value{llvm.Undef(assertedType), llvm.Undef(llvm.Int1Type())}, false) // create empty tuple - tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value - tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean + tuple = c.builder.CreateInsertValue(tuple, value, 0, "") // insert value + tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "") // insert 'comma ok' boolean return tuple, nil case *ssa.UnOp: return c.parseUnOp(frame, expr) @@ -1682,7 +1568,7 @@ func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) { // Magic type name: treat the value as a register pointer. register := unop.X.(*ssa.FieldAddr) global := register.X.(*ssa.Global) - llvmGlobal := c.mod.NamedGlobal(getGlobalName(global)) + llvmGlobal := c.ir.GetGlobal(global).llvmGlobal llvmAddr := c.builder.CreateExtractValue(llvmGlobal.Initializer(), register.Field, "") ptr := llvm.ConstIntToPtr(llvmAddr, x.Type()) load := c.builder.CreateLoad(ptr, "")