compiler: refactor IR parts into separate package

Этот коммит содержится в:
Ayke van Laethem 2018-09-22 20:25:50 +02:00
родитель 473e71b573
коммит b75a02e66d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: E97FF5335DFDFDED
5 изменённых файлов: 182 добавлений и 169 удалений

Просмотреть файл

@ -54,7 +54,7 @@ clean:
@rm -rf build @rm -rf build
fmt: fmt:
@go fmt . ./src/examples/* ./src/machine ./src/runtime ./src/sync @go fmt . ./ir ./src/examples/* ./src/machine ./src/runtime ./src/sync
gen-device: gen-device-nrf gen-device-avr gen-device: gen-device-nrf gen-device-avr

Просмотреть файл

@ -14,6 +14,7 @@ import (
"strings" "strings"
"github.com/aykevl/llvm/bindings/go/llvm" "github.com/aykevl/llvm/bindings/go/llvm"
"github.com/aykevl/tinygo/ir"
"go/parser" "go/parser"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
"golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa"
@ -53,12 +54,12 @@ type Compiler struct {
coroEndFunc llvm.Value coroEndFunc llvm.Value
coroFreeFunc llvm.Value coroFreeFunc llvm.Value
initFuncs []llvm.Value initFuncs []llvm.Value
deferFuncs []*Function deferFuncs []*ir.Function
ir *Program ir *ir.Program
} }
type Frame struct { type Frame struct {
fn *Function fn *ir.Function
params map[*ssa.Parameter]int // arguments to the function params map[*ssa.Parameter]int // arguments to the function
locals map[ssa.Value]llvm.Value // local variables locals map[ssa.Value]llvm.Value // local variables
blocks map[*ssa.BasicBlock]llvm.BasicBlock blocks map[*ssa.BasicBlock]llvm.BasicBlock
@ -77,8 +78,6 @@ type Phi struct {
llvm llvm.Value llvm llvm.Value
} }
var cgoWrapperError = errors.New("tinygo internal: cgo wrapper")
func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) { func NewCompiler(pkgName, triple string, dumpSSA bool) (*Compiler, error) {
c := &Compiler{ c := &Compiler{
dumpSSA: dumpSSA, dumpSSA: dumpSSA,
@ -178,7 +177,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
} }
} }
c.ir = NewProgram(lprogram, mainPath) c.ir = ir.NewProgram(lprogram, mainPath)
// Make a list of packages in import order. // Make a list of packages in import order.
packageList := []*ssa.Package{} packageList := []*ssa.Package{}
@ -186,7 +185,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
worklist := []string{"runtime", mainPath} worklist := []string{"runtime", mainPath}
for len(worklist) != 0 { for len(worklist) != 0 {
pkgPath := worklist[0] pkgPath := worklist[0]
pkg := c.ir.program.ImportedPackage(pkgPath) pkg := c.ir.Program.ImportedPackage(pkgPath)
if pkg == nil { if pkg == nil {
// Non-SSA package (e.g. cgo). // Non-SSA package (e.g. cgo).
packageSet[pkgPath] = struct{}{} packageSet[pkgPath] = struct{}{}
@ -244,22 +243,22 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Declare all named struct types. // Declare all named struct types.
for _, t := range c.ir.NamedTypes { for _, t := range c.ir.NamedTypes {
if named, ok := t.t.Type().(*types.Named); ok { if named, ok := t.Type.Type().(*types.Named); ok {
if _, ok := named.Underlying().(*types.Struct); ok { if _, ok := named.Underlying().(*types.Struct); ok {
t.llvmType = c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name()) t.LLVMType = c.ctx.StructCreateNamed(named.Obj().Pkg().Path() + "." + named.Obj().Name())
} }
} }
} }
// Define all named struct types. // Define all named struct types.
for _, t := range c.ir.NamedTypes { for _, t := range c.ir.NamedTypes {
if named, ok := t.t.Type().(*types.Named); ok { if named, ok := t.Type.Type().(*types.Named); ok {
if st, ok := named.Underlying().(*types.Struct); ok { if st, ok := named.Underlying().(*types.Struct); ok {
llvmType, err := c.getLLVMType(st) llvmType, err := c.getLLVMType(st)
if err != nil { if err != nil {
return err return err
} }
t.llvmType.StructSetBody(llvmType.StructElementTypes(), false) t.LLVMType.StructSetBody(llvmType.StructElementTypes(), false)
} }
} }
} }
@ -267,7 +266,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Declare all globals. These will get an initializer when parsing "package // Declare all globals. These will get an initializer when parsing "package
// initializer" functions. // initializer" functions.
for _, g := range c.ir.Globals { for _, g := range c.ir.Globals {
typ := g.g.Type().(*types.Pointer).Elem() typ := g.Type().(*types.Pointer).Elem()
llvmType, err := c.getLLVMType(typ) llvmType, err := c.getLLVMType(typ)
if err != nil { if err != nil {
return err return err
@ -276,7 +275,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
if global.IsNil() { if global.IsNil() {
global = llvm.AddGlobal(c.mod, llvmType, g.LinkName()) global = llvm.AddGlobal(c.mod, llvmType, g.LinkName())
} }
g.llvmGlobal = global g.LLVMGlobal = global
if !g.IsExtern() { if !g.IsExtern() {
global.SetLinkage(llvm.InternalLinkage) global.SetLinkage(llvm.InternalLinkage)
initializer, err := getZeroValue(llvmType) initializer, err := getZeroValue(llvmType)
@ -298,9 +297,9 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Find and interpret package initializers. // Find and interpret package initializers.
for _, frame := range frames { for _, frame := range frames {
if frame.fn.fn.Synthetic == "package initializer" { if frame.fn.Synthetic == "package initializer" {
c.initFuncs = append(c.initFuncs, frame.fn.llvmFn) c.initFuncs = append(c.initFuncs, frame.fn.LLVMFn)
if len(frame.fn.fn.Blocks) != 1 { if len(frame.fn.Blocks) != 1 {
panic("unexpected number of basic blocks in package initializer") panic("unexpected number of basic blocks in package initializer")
} }
// Try to interpret as much as possible of the init() function. // Try to interpret as much as possible of the init() function.
@ -309,7 +308,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// continues at runtime). // continues at runtime).
// This should only happen when it hits a function call or the end // This should only happen when it hits a function call or the end
// of the block, ideally. // of the block, ideally.
err := c.ir.Interpret(frame.fn.fn.Blocks[0], c.dumpSSA) err := c.ir.Interpret(frame.fn.Blocks[0], c.dumpSSA)
if err != nil { if err != nil {
return err return err
} }
@ -322,7 +321,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Set values for globals (after package initializer has been interpreted). // Set values for globals (after package initializer has been interpreted).
for _, g := range c.ir.Globals { for _, g := range c.ir.Globals {
if g.initializer == nil { if g.Initializer() == nil {
continue continue
} }
err := c.parseGlobalInitializer(g) err := c.parseGlobalInitializer(g)
@ -336,11 +335,11 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
if frame.fn.CName() != "" { if frame.fn.CName() != "" {
continue continue
} }
if frame.fn.fn.Blocks == nil { if frame.fn.Blocks == nil {
continue // external function continue // external function
} }
var err error var err error
if frame.fn.fn.Synthetic == "package initializer" { if frame.fn.Synthetic == "package initializer" {
continue // already done continue // already done
} else { } else {
err = c.parseFunc(frame) err = c.parseFunc(frame)
@ -364,7 +363,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Get the real param type and cast to it. // Get the real param type and cast to it.
valueTypes := []llvm.Type{llvmFn.Type(), llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)} valueTypes := []llvm.Type{llvmFn.Type(), llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)}
for _, param := range fn.fn.Params { for _, param := range fn.Params {
llvmType, err := c.getLLVMType(param.Type()) llvmType, err := c.getLLVMType(param.Type())
if err != nil { if err != nil {
return err return err
@ -377,14 +376,14 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
// Extract the params from the struct. // Extract the params from the struct.
forwardParams := []llvm.Value{} forwardParams := []llvm.Value{}
zero := llvm.ConstInt(llvm.Int32Type(), 0, false) zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
for i := range fn.fn.Params { for i := range fn.Params {
gep := c.builder.CreateGEP(contextPtr, []llvm.Value{zero, llvm.ConstInt(llvm.Int32Type(), uint64(i+2), false)}, "gep") gep := c.builder.CreateGEP(contextPtr, []llvm.Value{zero, llvm.ConstInt(llvm.Int32Type(), uint64(i+2), false)}, "gep")
forwardParam := c.builder.CreateLoad(gep, "param") forwardParam := c.builder.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam) forwardParams = append(forwardParams, forwardParam)
} }
// Call real function (of which this is a wrapper). // Call real function (of which this is a wrapper).
c.builder.CreateCall(fn.llvmFn, forwardParams, "") c.builder.CreateCall(fn.LLVMFn, forwardParams, "")
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
} }
@ -400,7 +399,7 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
// Adjust main function. // Adjust main function.
realMain := c.mod.NamedFunction(c.ir.mainPkg.Pkg.Path() + ".main") realMain := c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path() + ".main")
if c.ir.NeedsScheduler() { if c.ir.NeedsScheduler() {
c.mod.NamedFunction("runtime.main_mainAsync").ReplaceAllUsesWith(realMain) c.mod.NamedFunction("runtime.main_mainAsync").ReplaceAllUsesWith(realMain)
} else { } else {
@ -441,11 +440,11 @@ func (c *Compiler) Parse(mainPath string, buildTags []string) error {
} }
c.ir.SortMethods(methods) c.ir.SortMethods(methods)
for _, method := range methods { for _, method := range methods {
f := c.ir.GetFunction(c.ir.program.MethodValue(method)) f := c.ir.GetFunction(c.ir.Program.MethodValue(method))
if f.llvmFn.IsNil() { if f.LLVMFn.IsNil() {
return errors.New("cannot find function: " + f.LinkName()) return errors.New("cannot find function: " + f.LinkName())
} }
fn := llvm.ConstBitCast(f.llvmFn, c.i8ptrType) fn := llvm.ConstBitCast(f.LLVMFn, c.i8ptrType)
funcPointers = append(funcPointers, fn) funcPointers = append(funcPointers, fn)
signatureNum := c.ir.MethodNum(method.Obj().(*types.Func)) signatureNum := c.ir.MethodNum(method.Obj().(*types.Func))
signature := llvm.ConstInt(llvm.Int16Type(), uint64(signatureNum), false) signature := llvm.ConstInt(llvm.Int16Type(), uint64(signatureNum), false)
@ -780,29 +779,7 @@ func (c *Compiler) getDIType(typ types.Type) (llvm.Metadata, error) {
} }
} }
// Get all methods of a type. func (c *Compiler) parseFuncDecl(f *ir.Function) (*Frame, error) {
func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection {
ms := prog.MethodSets.MethodSet(typ)
methods := make([]*types.Selection, ms.Len())
for i := 0; i < ms.Len(); i++ {
methods[i] = ms.At(i)
}
return methods
}
// Return true if this is a CGo-internal function that can be ignored.
func isCGoInternal(name string) bool {
if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") {
// _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall
return true // CGo-internal functions
}
if strings.HasPrefix(name, "__cgofn__cgo_") {
return true // CGo function pointer in global scope
}
return false
}
func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
frame := &Frame{ frame := &Frame{
fn: f, fn: f,
params: make(map[*ssa.Parameter]int), params: make(map[*ssa.Parameter]int),
@ -813,22 +790,22 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
var retType llvm.Type var retType llvm.Type
if frame.blocking { if frame.blocking {
if f.fn.Signature.Results() != nil { if f.Signature.Results() != nil {
return nil, errors.New("todo: return values in blocking function") return nil, errors.New("todo: return values in blocking function")
} }
retType = c.i8ptrType retType = c.i8ptrType
} else if f.fn.Signature.Results() == nil { } else if f.Signature.Results() == nil {
retType = llvm.VoidType() retType = llvm.VoidType()
} else if f.fn.Signature.Results().Len() == 1 { } else if f.Signature.Results().Len() == 1 {
var err error var err error
retType, err = c.getLLVMType(f.fn.Signature.Results().At(0).Type()) retType, err = c.getLLVMType(f.Signature.Results().At(0).Type())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
results := make([]llvm.Type, 0, f.fn.Signature.Results().Len()) results := make([]llvm.Type, 0, f.Signature.Results().Len())
for i := 0; i < f.fn.Signature.Results().Len(); i++ { for i := 0; i < f.Signature.Results().Len(); i++ {
typ, err := c.getLLVMType(f.fn.Signature.Results().At(i).Type()) typ, err := c.getLLVMType(f.Signature.Results().At(i).Type())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -841,7 +818,7 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
if frame.blocking { if frame.blocking {
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
} }
for i, param := range f.fn.Params { for i, param := range f.Params {
paramType, err := c.getLLVMType(param.Type()) paramType, err := c.getLLVMType(param.Type())
if err != nil { if err != nil {
return nil, err return nil, err
@ -859,22 +836,22 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
fnType := llvm.FunctionType(retType, paramTypes, false) fnType := llvm.FunctionType(retType, paramTypes, false)
name := f.LinkName() name := f.LinkName()
frame.fn.llvmFn = c.mod.NamedFunction(name) frame.fn.LLVMFn = c.mod.NamedFunction(name)
if frame.fn.llvmFn.IsNil() { if frame.fn.LLVMFn.IsNil() {
frame.fn.llvmFn = llvm.AddFunction(c.mod, name, fnType) frame.fn.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
} }
if c.debug && f.fn.Syntax() != nil && len(f.fn.Blocks) != 0 { if c.debug && f.Syntax() != nil && len(f.Blocks) != 0 {
// Create debug info file if needed. // Create debug info file if needed.
pos := c.ir.program.Fset.Position(f.fn.Syntax().Pos()) pos := c.ir.Program.Fset.Position(f.Syntax().Pos())
if _, ok := c.difiles[pos.Filename]; !ok { if _, ok := c.difiles[pos.Filename]; !ok {
dir, file := filepath.Split(pos.Filename) dir, file := filepath.Split(pos.Filename)
c.difiles[pos.Filename] = c.dibuilder.CreateFile(file, dir[:len(dir)-1]) c.difiles[pos.Filename] = c.dibuilder.CreateFile(file, dir[:len(dir)-1])
} }
// Debug info for this function. // Debug info for this function.
diparams := make([]llvm.Metadata, 0, len(f.fn.Params)) diparams := make([]llvm.Metadata, 0, len(f.Params))
for _, param := range f.fn.Params { for _, param := range f.Params {
ditype, err := c.getDIType(param.Type()) ditype, err := c.getDIType(param.Type())
if err != nil { if err != nil {
return nil, err return nil, err
@ -887,7 +864,7 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
Flags: 0, // ? Flags: 0, // ?
}) })
frame.difunc = c.dibuilder.CreateFunction(c.difiles[pos.Filename], llvm.DIFunction{ frame.difunc = c.dibuilder.CreateFunction(c.difiles[pos.Filename], llvm.DIFunction{
Name: f.fn.RelString(nil), Name: f.RelString(nil),
LinkageName: f.LinkName(), LinkageName: f.LinkName(),
File: c.difiles[pos.Filename], File: c.difiles[pos.Filename],
Line: pos.Line, Line: pos.Line,
@ -898,7 +875,7 @@ func (c *Compiler) parseFuncDecl(f *Function) (*Frame, error) {
Flags: llvm.FlagPrototyped, Flags: llvm.FlagPrototyped,
Optimized: true, Optimized: true,
}) })
frame.fn.llvmFn.SetSubprogram(frame.difunc) frame.fn.LLVMFn.SetSubprogram(frame.difunc)
} }
return frame, nil return frame, nil
@ -932,24 +909,24 @@ func (c *Compiler) initMapNewBucket(prefix string, mapType *types.Map) (llvm.Val
return bucket, keySize, valueSize, nil return bucket, keySize, valueSize, nil
} }
func (c *Compiler) parseGlobalInitializer(g *Global) error { func (c *Compiler) parseGlobalInitializer(g *ir.Global) error {
if g.IsExtern() { if g.IsExtern() {
return nil return nil
} }
llvmValue, err := c.getInterpretedValue(g.LinkName(), g.initializer) llvmValue, err := c.getInterpretedValue(g.LinkName(), g.Initializer())
if err != nil { if err != nil {
return err return err
} }
g.llvmGlobal.SetInitializer(llvmValue) g.LLVMGlobal.SetInitializer(llvmValue)
return nil return nil
} }
// Turn a computed Value type (ConstValue, ArrayValue, etc.) into a LLVM value. // Turn a computed Value type (ConstValue, ArrayValue, etc.) into a LLVM value.
// This is used to set the initializer of globals after they have been // This is used to set the initializer of globals after they have been
// calculated by the package initializer interpreter. // calculated by the package initializer interpreter.
func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value, error) { func (c *Compiler) getInterpretedValue(prefix string, value ir.Value) (llvm.Value, error) {
switch value := value.(type) { switch value := value.(type) {
case *ArrayValue: case *ir.ArrayValue:
vals := make([]llvm.Value, len(value.Elems)) vals := make([]llvm.Value, len(value.Elems))
for i, elem := range value.Elems { for i, elem := range value.Elems {
val, err := c.getInterpretedValue(prefix+"$arrayval", elem) val, err := c.getInterpretedValue(prefix+"$arrayval", elem)
@ -964,10 +941,10 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
} }
return llvm.ConstArray(subTyp, vals), nil return llvm.ConstArray(subTyp, vals), nil
case *ConstValue: case *ir.ConstValue:
return c.parseConst(prefix, value.Expr) return c.parseConst(prefix, value.Expr)
case *FunctionValue: case *ir.FunctionValue:
if value.Elem == nil { if value.Elem == nil {
llvmType, err := c.getLLVMType(value.Type) llvmType, err := c.getLLVMType(value.Type)
if err != nil { if err != nil {
@ -976,19 +953,19 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
return getZeroValue(llvmType) return getZeroValue(llvmType)
} }
fn := c.ir.GetFunction(value.Elem) fn := c.ir.GetFunction(value.Elem)
ptr := fn.llvmFn ptr := fn.LLVMFn
if c.ir.SignatureNeedsContext(fn.fn.Signature) { if c.ir.SignatureNeedsContext(fn.Signature) {
// Create closure value: {context, function pointer} // Create closure value: {context, function pointer}
ptr = llvm.ConstStruct([]llvm.Value{llvm.ConstPointerNull(c.i8ptrType), ptr}, false) ptr = llvm.ConstStruct([]llvm.Value{llvm.ConstPointerNull(c.i8ptrType), ptr}, false)
} }
return ptr, nil return ptr, nil
case *GlobalValue: case *ir.GlobalValue:
zero := llvm.ConstInt(llvm.Int32Type(), 0, false) zero := llvm.ConstInt(llvm.Int32Type(), 0, false)
ptr := llvm.ConstInBoundsGEP(value.Global.llvmGlobal, []llvm.Value{zero}) ptr := llvm.ConstInBoundsGEP(value.Global.LLVMGlobal, []llvm.Value{zero})
return ptr, nil return ptr, nil
case *InterfaceValue: case *ir.InterfaceValue:
underlying := llvm.ConstPointerNull(c.i8ptrType) // could be any 0 value underlying := llvm.ConstPointerNull(c.i8ptrType) // could be any 0 value
if value.Elem != nil { if value.Elem != nil {
elem, err := c.getInterpretedValue(prefix, value.Elem) elem, err := c.getInterpretedValue(prefix, value.Elem)
@ -999,7 +976,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
} }
return c.parseMakeInterface(underlying, value.Type, prefix) return c.parseMakeInterface(underlying, value.Type, prefix)
case *MapValue: case *ir.MapValue:
// Create initial bucket. // Create initial bucket.
firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(prefix, value.Type) firstBucketGlobal, keySize, valueSize, err := c.initMapNewBucket(prefix, value.Type)
if err != nil { if err != nil {
@ -1018,7 +995,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
return llvm.Value{}, nil return llvm.Value{}, nil
} }
constVal := key.(*ConstValue).Expr constVal := key.(*ir.ConstValue).Expr
var keyBuf []byte var keyBuf []byte
switch constVal.Type().Underlying().(*types.Basic).Kind() { switch constVal.Type().Underlying().(*types.Basic).Kind() {
case types.String: case types.String:
@ -1079,7 +1056,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
hashmapPtr.SetLinkage(llvm.InternalLinkage) hashmapPtr.SetLinkage(llvm.InternalLinkage)
return llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}), nil return llvm.ConstInBoundsGEP(hashmapPtr, []llvm.Value{zero}), nil
case *PointerBitCastValue: case *ir.PointerBitCastValue:
elem, err := c.getInterpretedValue(prefix, value.Elem) elem, err := c.getInterpretedValue(prefix, value.Elem)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
@ -1090,14 +1067,14 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
} }
return llvm.ConstBitCast(elem, llvmType), nil return llvm.ConstBitCast(elem, llvmType), nil
case *PointerToUintptrValue: case *ir.PointerToUintptrValue:
elem, err := c.getInterpretedValue(prefix, value.Elem) elem, err := c.getInterpretedValue(prefix, value.Elem)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
return llvm.ConstPtrToInt(elem, c.uintptrType), nil return llvm.ConstPtrToInt(elem, c.uintptrType), nil
case *PointerValue: case *ir.PointerValue:
if value.Elem == nil { if value.Elem == nil {
typ, err := c.getLLVMType(value.Type) typ, err := c.getLLVMType(value.Type)
if err != nil { if err != nil {
@ -1119,7 +1096,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
ptr := llvm.ConstInBoundsGEP(elem, []llvm.Value{zero}) ptr := llvm.ConstInBoundsGEP(elem, []llvm.Value{zero})
return ptr, nil return ptr, nil
case *SliceValue: case *ir.SliceValue:
var globalPtr llvm.Value var globalPtr llvm.Value
var arrayLength uint64 var arrayLength uint64
if value.Array == nil { if value.Array == nil {
@ -1159,7 +1136,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
}) })
return slice, nil return slice, nil
case *StructValue: case *ir.StructValue:
fields := make([]llvm.Value, len(value.Fields)) fields := make([]llvm.Value, len(value.Fields))
for i, elem := range value.Fields { for i, elem := range value.Fields {
field, err := c.getInterpretedValue(prefix, elem) field, err := c.getInterpretedValue(prefix, elem)
@ -1181,7 +1158,7 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
return llvm.Value{}, errors.New("init: unknown struct type: " + value.Type.String()) return llvm.Value{}, errors.New("init: unknown struct type: " + value.Type.String())
} }
case *ZeroBasicValue: case *ir.ZeroBasicValue:
llvmType, err := c.getLLVMType(value.Type) llvmType, err := c.getLLVMType(value.Type)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
@ -1195,36 +1172,36 @@ func (c *Compiler) getInterpretedValue(prefix string, value Value) (llvm.Value,
func (c *Compiler) parseFunc(frame *Frame) error { func (c *Compiler) parseFunc(frame *Frame) error {
if c.dumpSSA { if c.dumpSSA {
fmt.Printf("\nfunc %s:\n", frame.fn.fn) fmt.Printf("\nfunc %s:\n", frame.fn.Function)
} }
if !frame.fn.IsExported() { if !frame.fn.IsExported() {
frame.fn.llvmFn.SetLinkage(llvm.InternalLinkage) frame.fn.LLVMFn.SetLinkage(llvm.InternalLinkage)
} }
if c.debug { if c.debug {
pos := c.ir.program.Fset.Position(frame.fn.fn.Pos()) pos := c.ir.Program.Fset.Position(frame.fn.Pos())
c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), frame.difunc, llvm.Metadata{}) c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), frame.difunc, llvm.Metadata{})
} }
// Pre-create all basic blocks in the function. // Pre-create all basic blocks in the function.
for _, block := range frame.fn.fn.DomPreorder() { for _, block := range frame.fn.DomPreorder() {
llvmBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, block.Comment) llvmBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, block.Comment)
frame.blocks[block] = llvmBlock frame.blocks[block] = llvmBlock
} }
if frame.blocking { if frame.blocking {
frame.cleanupBlock = c.ctx.AddBasicBlock(frame.fn.llvmFn, "task.cleanup") frame.cleanupBlock = c.ctx.AddBasicBlock(frame.fn.LLVMFn, "task.cleanup")
frame.suspendBlock = c.ctx.AddBasicBlock(frame.fn.llvmFn, "task.suspend") frame.suspendBlock = c.ctx.AddBasicBlock(frame.fn.LLVMFn, "task.suspend")
} }
entryBlock := frame.blocks[frame.fn.fn.Blocks[0]] entryBlock := frame.blocks[frame.fn.Blocks[0]]
// Load function parameters // Load function parameters
for i, param := range frame.fn.fn.Params { for i, param := range frame.fn.Params {
llvmParam := frame.fn.llvmFn.Param(frame.params[param]) llvmParam := frame.fn.LLVMFn.Param(frame.params[param])
frame.locals[param] = llvmParam frame.locals[param] = llvmParam
// Add debug information to this parameter (if available) // Add debug information to this parameter (if available)
if c.debug && frame.fn.fn.Syntax() != nil { if c.debug && frame.fn.Syntax() != nil {
pos := c.ir.program.Fset.Position(frame.fn.fn.Syntax().Pos()) pos := c.ir.Program.Fset.Position(frame.fn.Syntax().Pos())
dityp, err := c.getDIType(param.Type()) dityp, err := c.getDIType(param.Type())
if err != nil { if err != nil {
return err return err
@ -1243,16 +1220,16 @@ func (c *Compiler) parseFunc(frame *Frame) error {
// Load free variables from the context. This is a closure (or bound // Load free variables from the context. This is a closure (or bound
// method). // method).
if len(frame.fn.fn.FreeVars) != 0 { if len(frame.fn.FreeVars) != 0 {
if !c.ir.FunctionNeedsContext(frame.fn) { if !c.ir.FunctionNeedsContext(frame.fn) {
panic("free variables on function without context") panic("free variables on function without context")
} }
c.builder.SetInsertPointAtEnd(entryBlock) c.builder.SetInsertPointAtEnd(entryBlock)
context := frame.fn.llvmFn.Param(len(frame.fn.fn.Params)) context := frame.fn.LLVMFn.Param(len(frame.fn.Params))
// Determine the context type. It's a struct containing all variables. // Determine the context type. It's a struct containing all variables.
freeVarTypes := make([]llvm.Type, 0, len(frame.fn.fn.FreeVars)) freeVarTypes := make([]llvm.Type, 0, len(frame.fn.FreeVars))
for _, freeVar := range frame.fn.fn.FreeVars { for _, freeVar := range frame.fn.FreeVars {
typ, err := c.getLLVMType(freeVar.Type()) typ, err := c.getLLVMType(freeVar.Type())
if err != nil { if err != nil {
return err return err
@ -1279,7 +1256,7 @@ func (c *Compiler) parseFunc(frame *Frame) error {
// A free variable is always a pointer when this is a closure, but it // A free variable is always a pointer when this is a closure, but it
// can be another type when it is a wrapper for a bound method (these // can be another type when it is a wrapper for a bound method (these
// wrappers are generated by the ssa package). // wrappers are generated by the ssa package).
for i, freeVar := range frame.fn.fn.FreeVars { for i, freeVar := range frame.fn.FreeVars {
indices := []llvm.Value{ indices := []llvm.Value{
llvm.ConstInt(llvm.Int32Type(), 0, false), llvm.ConstInt(llvm.Int32Type(), 0, false),
llvm.ConstInt(llvm.Int32Type(), uint64(i), false), llvm.ConstInt(llvm.Int32Type(), uint64(i), false),
@ -1291,7 +1268,7 @@ func (c *Compiler) parseFunc(frame *Frame) error {
c.builder.SetInsertPointAtEnd(entryBlock) c.builder.SetInsertPointAtEnd(entryBlock)
if frame.fn.fn.Recover != nil { if frame.fn.Recover != nil {
// Create defer list pointer. // Create defer list pointer.
deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0) deferType := llvm.PointerType(c.mod.GetTypeByName("runtime._defer"), 0)
frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr") frame.deferPtr = c.builder.CreateAlloca(deferType, "deferPtr")
@ -1322,7 +1299,7 @@ func (c *Compiler) parseFunc(frame *Frame) error {
mem := c.builder.CreateCall(c.coroFreeFunc, []llvm.Value{id, frame.taskHandle}, "task.data.free") mem := c.builder.CreateCall(c.coroFreeFunc, []llvm.Value{id, frame.taskHandle}, "task.data.free")
c.builder.CreateCall(c.freeFunc, []llvm.Value{mem}, "") c.builder.CreateCall(c.freeFunc, []llvm.Value{mem}, "")
// re-insert parent coroutine // re-insert parent coroutine
c.builder.CreateCall(c.mod.NamedFunction("runtime.yieldToScheduler"), []llvm.Value{frame.fn.llvmFn.FirstParam()}, "") c.builder.CreateCall(c.mod.NamedFunction("runtime.yieldToScheduler"), []llvm.Value{frame.fn.LLVMFn.FirstParam()}, "")
c.builder.CreateBr(frame.suspendBlock) c.builder.CreateBr(frame.suspendBlock)
// Coroutine suspend. A call to llvm.coro.suspend() will branch here. // Coroutine suspend. A call to llvm.coro.suspend() will branch here.
@ -1332,7 +1309,7 @@ func (c *Compiler) parseFunc(frame *Frame) error {
} }
// Fill blocks with instructions. // Fill blocks with instructions.
for _, block := range frame.fn.fn.DomPreorder() { for _, block := range frame.fn.DomPreorder() {
if c.dumpSSA { if c.dumpSSA {
fmt.Printf("%s:\n", block.Comment) fmt.Printf("%s:\n", block.Comment)
} }
@ -1354,7 +1331,7 @@ func (c *Compiler) parseFunc(frame *Frame) error {
return err return err
} }
} }
if frame.fn.fn.Name() == "init" && len(block.Instrs) == 0 { if frame.fn.Name() == "init" && len(block.Instrs) == 0 {
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
} }
} }
@ -1377,14 +1354,14 @@ func (c *Compiler) parseFunc(frame *Frame) error {
func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
if c.debug { if c.debug {
pos := c.ir.program.Fset.Position(instr.Pos()) pos := c.ir.Program.Fset.Position(instr.Pos())
c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), frame.difunc, llvm.Metadata{}) c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), frame.difunc, llvm.Metadata{})
} }
switch instr := instr.(type) { switch instr := instr.(type) {
case ssa.Value: case ssa.Value:
value, err := c.parseExpr(frame, instr) value, err := c.parseExpr(frame, instr)
if err == cgoWrapperError { if err == ir.ErrCGoWrapper {
// Ignore CGo global variables which we don't use. // Ignore CGo global variables which we don't use.
return nil return nil
} }
@ -1551,7 +1528,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
return nil return nil
} else { } else {
// Multiple return values. Put them all in a struct. // Multiple return values. Put them all in a struct.
retVal, err := getZeroValue(frame.fn.llvmFn.Type().ElementType().ReturnType()) retVal, err := getZeroValue(frame.fn.LLVMFn.Type().ElementType().ReturnType())
if err != nil { if err != nil {
return err return err
} }
@ -1573,7 +1550,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
return nil return nil
case *ssa.Store: case *ssa.Store:
llvmAddr, err := c.parseExpr(frame, instr.Addr) llvmAddr, err := c.parseExpr(frame, instr.Addr)
if err == cgoWrapperError { if err == ir.ErrCGoWrapper {
// Ignore CGo global variables which we don't use. // Ignore CGo global variables which we don't use.
return nil return nil
} }
@ -1854,7 +1831,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
} }
} }
targetFunc := c.ir.GetFunction(fn) targetFunc := c.ir.GetFunction(fn)
if targetFunc.llvmFn.IsNil() { if targetFunc.LLVMFn.IsNil() {
return llvm.Value{}, errors.New("undefined function: " + targetFunc.LinkName()) return llvm.Value{}, errors.New("undefined function: " + targetFunc.LinkName())
} }
var context llvm.Value var context llvm.Value
@ -1877,7 +1854,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
} }
} }
} }
return c.parseFunctionCall(frame, instr.Args, targetFunc.llvmFn, context, c.ir.IsBlocking(targetFunc), parentHandle) return c.parseFunctionCall(frame, instr.Args, targetFunc.LLVMFn, context, c.ir.IsBlocking(targetFunc), parentHandle)
} }
// Builtin or function pointer. // Builtin or function pointer.
@ -1903,7 +1880,7 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.CallCommon, parentHandle l
} }
func (c *Compiler) emitBoundsCheck(frame *Frame, arrayLen, index llvm.Value) { func (c *Compiler) emitBoundsCheck(frame *Frame, arrayLen, index llvm.Value) {
if frame.fn.nobounds { if frame.fn.IsNoBounds() {
// The //go:nobounds pragma was added to the function to avoid bounds // The //go:nobounds pragma was added to the function to avoid bounds
// checking. // checking.
return return
@ -2016,7 +1993,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
return c.builder.CreateGEP(val, indices, ""), nil return c.builder.CreateGEP(val, indices, ""), nil
case *ssa.Function: case *ssa.Function:
fn := c.ir.GetFunction(expr) fn := c.ir.GetFunction(expr)
ptr := fn.llvmFn ptr := fn.LLVMFn
if c.ir.FunctionNeedsContext(fn) { if c.ir.FunctionNeedsContext(fn) {
// Create closure for function pointer. // Create closure for function pointer.
// Closure is: {context, function pointer} // Closure is: {context, function pointer}
@ -2029,9 +2006,9 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
case *ssa.Global: case *ssa.Global:
if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") { if strings.HasPrefix(expr.Name(), "__cgofn__cgo_") || strings.HasPrefix(expr.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use. // Ignore CGo global variables which we don't use.
return llvm.Value{}, cgoWrapperError return llvm.Value{}, ir.ErrCGoWrapper
} }
value := c.ir.GetGlobal(expr).llvmGlobal value := c.ir.GetGlobal(expr).LLVMGlobal
if value.IsNil() { if value.IsNil() {
return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).LinkName()) return llvm.Value{}, errors.New("global not found: " + c.ir.GetGlobal(expr).LinkName())
} }
@ -2209,7 +2186,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
elemSize := c.targetData.TypeAllocSize(llvmElemType) elemSize := c.targetData.TypeAllocSize(llvmElemType)
// Bounds checking. // Bounds checking.
if !frame.fn.nobounds { if !frame.fn.IsNoBounds() {
sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheckMake") sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheckMake")
c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{sliceLen, sliceCap}, "") c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{sliceLen, sliceCap}, "")
} }
@ -2343,7 +2320,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
sliceCap := c.builder.CreateSub(llvmLenInt, low, "slice.cap") sliceCap := c.builder.CreateSub(llvmLenInt, low, "slice.cap")
// This check is optimized away in most cases. // This check is optimized away in most cases.
if !frame.fn.nobounds { if !frame.fn.IsNoBounds() {
sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck")
c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{llvmLen, low, high}, "") c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{llvmLen, low, high}, "")
} }
@ -2372,7 +2349,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
high = oldLen high = oldLen
} }
if !frame.fn.nobounds { if !frame.fn.IsNoBounds() {
sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck")
c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{oldLen, low, high}, "") c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{oldLen, low, high}, "")
} }
@ -2408,7 +2385,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
high = oldLen high = oldLen
} }
if !frame.fn.nobounds { if !frame.fn.IsNoBounds() {
sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck") sliceBoundsCheck := c.mod.NamedFunction("runtime.sliceBoundsCheck")
c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{oldLen, low, high}, "") c.builder.CreateCall(sliceBoundsCheck, []llvm.Value{oldLen, low, high}, "")
} }
@ -2491,8 +2468,8 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
// value. // value.
prevBlock := c.builder.GetInsertBlock() prevBlock := c.builder.GetInsertBlock()
okBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.ok") okBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "typeassert.ok")
nextBlock := c.ctx.AddBasicBlock(frame.fn.llvmFn, "typeassert.next") nextBlock := c.ctx.AddBasicBlock(frame.fn.LLVMFn, "typeassert.next")
frame.blocks[frame.currentBlock] = nextBlock // adjust outgoing block for phi nodes frame.blocks[frame.currentBlock] = nextBlock // adjust outgoing block for phi nodes
c.builder.CreateCondBr(commaOk, okBlock, nextBlock) c.builder.CreateCondBr(commaOk, okBlock, nextBlock)
@ -3034,7 +3011,7 @@ func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.V
// Get the function signature type, which is a closure type. // Get the function signature type, which is a closure type.
// A closure is a tuple of {context, function pointer}. // A closure is a tuple of {context, function pointer}.
typ, err := c.getLLVMType(f.fn.Signature) typ, err := c.getLLVMType(f.Signature)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -3044,7 +3021,7 @@ func (c *Compiler) parseMakeClosure(frame *Frame, expr *ssa.MakeClosure) (llvm.V
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
closure = c.builder.CreateInsertValue(closure, f.llvmFn, 1, "") closure = c.builder.CreateInsertValue(closure, f.LLVMFn, 1, "")
closure = c.builder.CreateInsertValue(closure, context, 0, "") closure = c.builder.CreateInsertValue(closure, context, 0, "")
return closure, nil return closure, nil
} }

Просмотреть файл

@ -1,4 +1,4 @@
package main package ir
// This file provides functionality to interpret very basic Go SSA, for // This file provides functionality to interpret very basic Go SSA, for
// compile-time initialization of globals. // compile-time initialization of globals.
@ -14,6 +14,8 @@ import (
"golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa"
) )
var ErrCGoWrapper = errors.New("tinygo internal: cgo wrapper") // a signal, not an error
// Ignore these calls (replace with a zero return value) when encountered during // Ignore these calls (replace with a zero return value) when encountered during
// interpretation. // interpretation.
var ignoreInitCalls = map[string]struct{}{ var ignoreInitCalls = map[string]struct{}{
@ -33,7 +35,7 @@ func (p *Program) Interpret(block *ssa.BasicBlock, dumpSSA bool) error {
} }
for { for {
i, err := p.interpret(block.Instrs, nil, nil, nil, dumpSSA) i, err := p.interpret(block.Instrs, nil, nil, nil, dumpSSA)
if err == cgoWrapperError { if err == ErrCGoWrapper {
// skip this instruction // skip this instruction
block.Instrs = block.Instrs[i+1:] block.Instrs = block.Instrs[i+1:]
continue continue
@ -410,7 +412,7 @@ func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value,
case *ssa.Global: case *ssa.Global:
if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") { if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") {
// Ignore CGo global variables which we don't use. // Ignore CGo global variables which we don't use.
return nil, cgoWrapperError return nil, ErrCGoWrapper
} }
g := p.GetGlobal(value) g := p.GetGlobal(value)
if g.initializer == nil { if g.initializer == nil {

Просмотреть файл

@ -1,4 +1,4 @@
package main package ir
import ( import (
"go/ast" "go/ast"
@ -19,7 +19,7 @@ import (
// View on all functions, types, and globals in a program, with analysis // View on all functions, types, and globals in a program, with analysis
// results. // results.
type Program struct { type Program struct {
program *ssa.Program Program *ssa.Program
mainPkg *ssa.Package mainPkg *ssa.Package
Functions []*Function Functions []*Function
functionMap map[*ssa.Function]*Function functionMap map[*ssa.Function]*Function
@ -38,8 +38,8 @@ type Program struct {
// Function or method. // Function or method.
type Function struct { type Function struct {
fn *ssa.Function *ssa.Function
llvmFn llvm.Value LLVMFn llvm.Value
linkName string // go:linkname or go:export pragma linkName string // go:linkname or go:export pragma
exported bool // go:export exported bool // go:export
nobounds bool // go:nobounds pragma nobounds bool // go:nobounds pragma
@ -52,9 +52,9 @@ type Function struct {
// Global variable, possibly constant. // Global variable, possibly constant.
type Global struct { type Global struct {
*ssa.Global
program *Program program *Program
g *ssa.Global LLVMGlobal llvm.Value
llvmGlobal llvm.Value
linkName string // go:extern linkName string // go:extern
extern bool // go:extern extern bool // go:extern
initializer Value initializer Value
@ -62,8 +62,8 @@ type Global struct {
// Type with a name and possibly methods. // Type with a name and possibly methods.
type NamedType struct { type NamedType struct {
t *ssa.Type *ssa.Type
llvmType llvm.Type LLVMType llvm.Type
} }
// Type that is at some point put in an interface. // Type that is at some point put in an interface.
@ -110,7 +110,7 @@ func NewProgram(lprogram *loader.Program, mainPath string) *Program {
program.Build() program.Build()
return &Program{ return &Program{
program: program, Program: program,
mainPkg: program.ImportedPackage(mainPath), mainPkg: program.ImportedPackage(mainPath),
functionMap: make(map[*ssa.Function]*Function), functionMap: make(map[*ssa.Function]*Function),
globalMap: make(map[*ssa.Global]*Global), globalMap: make(map[*ssa.Global]*Global),
@ -141,7 +141,7 @@ func (p *Program) AddPackage(pkg *ssa.Package) {
} }
p.addFunction(member) p.addFunction(member)
case *ssa.Type: case *ssa.Type:
t := &NamedType{t: member} t := &NamedType{Type: member}
p.NamedTypes = append(p.NamedTypes, t) p.NamedTypes = append(p.NamedTypes, t)
methods := getAllMethods(pkg.Prog, member.Type()) methods := getAllMethods(pkg.Prog, member.Type())
if !types.IsInterface(member.Type()) { if !types.IsInterface(member.Type()) {
@ -151,8 +151,8 @@ func (p *Program) AddPackage(pkg *ssa.Package) {
} }
} }
case *ssa.Global: case *ssa.Global:
g := &Global{program: p, g: member} g := &Global{program: p, Global: member}
doc := p.comments[g.g.RelString(nil)] doc := p.comments[g.RelString(nil)]
if doc != nil { if doc != nil {
g.parsePragmas(doc) g.parsePragmas(doc)
} }
@ -167,7 +167,7 @@ func (p *Program) AddPackage(pkg *ssa.Package) {
} }
func (p *Program) addFunction(ssaFn *ssa.Function) { func (p *Program) addFunction(ssaFn *ssa.Function) {
f := &Function{fn: ssaFn} f := &Function{Function: ssaFn}
f.parsePragmas() f.parsePragmas()
p.Functions = append(p.Functions, f) p.Functions = append(p.Functions, f)
p.functionMap[ssaFn] = f p.functionMap[ssaFn] = f
@ -207,12 +207,16 @@ func (p *Program) SortFuncs(funcs []*types.Func) {
sort.Sort(m) sort.Sort(m)
} }
func (p *Program) MainPkg() *ssa.Package {
return p.mainPkg
}
// Parse compiler directives in the preceding comments. // Parse compiler directives in the preceding comments.
func (f *Function) parsePragmas() { func (f *Function) parsePragmas() {
if f.fn.Syntax() == nil { if f.Syntax() == nil {
return return
} }
if decl, ok := f.fn.Syntax().(*ast.FuncDecl); ok && decl.Doc != nil { if decl, ok := f.Syntax().(*ast.FuncDecl); ok && decl.Doc != nil {
for _, comment := range decl.Doc.List { for _, comment := range decl.Doc.List {
if !strings.HasPrefix(comment.Text, "//go:") { if !strings.HasPrefix(comment.Text, "//go:") {
continue continue
@ -220,14 +224,14 @@ func (f *Function) parsePragmas() {
parts := strings.Fields(comment.Text) parts := strings.Fields(comment.Text)
switch parts[0] { switch parts[0] {
case "//go:linkname": case "//go:linkname":
if len(parts) != 3 || parts[1] != f.fn.Name() { if len(parts) != 3 || parts[1] != f.Name() {
continue continue
} }
// Only enable go:linkname when the package imports "unsafe". // Only enable go:linkname when the package imports "unsafe".
// This is a slightly looser requirement than what gc uses: gc // This is a slightly looser requirement than what gc uses: gc
// requires the file to import "unsafe", not the package as a // requires the file to import "unsafe", not the package as a
// whole. // whole.
if hasUnsafeImport(f.fn.Pkg.Pkg) { if hasUnsafeImport(f.Pkg.Pkg) {
f.linkName = parts[2] f.linkName = parts[2]
} }
case "//go:nobounds": case "//go:nobounds":
@ -235,7 +239,7 @@ func (f *Function) parsePragmas() {
// runtime functions. // runtime functions.
// This is somewhat dangerous and thus only imported in packages // This is somewhat dangerous and thus only imported in packages
// that import unsafe. // that import unsafe.
if hasUnsafeImport(f.fn.Pkg.Pkg) { if hasUnsafeImport(f.Pkg.Pkg) {
f.nobounds = true f.nobounds = true
} }
case "//go:export": case "//go:export":
@ -249,6 +253,10 @@ func (f *Function) parsePragmas() {
} }
} }
func (f *Function) IsNoBounds() bool {
return f.nobounds
}
// Return true iff this function is externally visible. // Return true iff this function is externally visible.
func (f *Function) IsExported() bool { func (f *Function) IsExported() bool {
return f.exported return f.exported
@ -259,16 +267,16 @@ func (f *Function) LinkName() string {
if f.linkName != "" { if f.linkName != "" {
return f.linkName return f.linkName
} }
if f.fn.Signature.Recv() != nil { if f.Signature.Recv() != nil {
// Method on a defined type (which may be a pointer). // Method on a defined type (which may be a pointer).
return f.fn.RelString(nil) return f.RelString(nil)
} else { } else {
// Bare function. // Bare function.
if name := f.CName(); name != "" { if name := f.CName(); name != "" {
// Name CGo functions directly. // Name CGo functions directly.
return name return name
} else { } else {
return f.fn.RelString(nil) return f.RelString(nil)
} }
} }
} }
@ -276,7 +284,7 @@ func (f *Function) LinkName() string {
// Return the name of the C function if this is a CGo wrapper. Otherwise, return // Return the name of the C function if this is a CGo wrapper. Otherwise, return
// a zero-length string. // a zero-length string.
func (f *Function) CName() string { func (f *Function) CName() string {
name := f.fn.Name() name := f.Name()
if strings.HasPrefix(name, "_Cfunc_") { if strings.HasPrefix(name, "_Cfunc_") {
return name[len("_Cfunc_"):] return name[len("_Cfunc_"):]
} }
@ -305,13 +313,17 @@ func (g *Global) LinkName() string {
if g.linkName != "" { if g.linkName != "" {
return g.linkName return g.linkName
} }
return g.g.RelString(nil) return g.RelString(nil)
} }
func (g *Global) IsExtern() bool { func (g *Global) IsExtern() bool {
return g.extern return g.extern
} }
func (g *Global) Initializer() Value {
return g.initializer
}
// Wrapper type to implement sort.Interface for []*types.Selection. // Wrapper type to implement sort.Interface for []*types.Selection.
type methodList struct { type methodList struct {
methods []*types.Selection methods []*types.Selection
@ -351,3 +363,25 @@ func (fl *funcList) Less(i, j int) bool {
func (fl *funcList) Swap(i, j int) { func (fl *funcList) Swap(i, j int) {
fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i] fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i]
} }
// Return true if this is a CGo-internal function that can be ignored.
func isCGoInternal(name string) bool {
if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") {
// _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall
return true // CGo-internal functions
}
if strings.HasPrefix(name, "__cgofn__cgo_") {
return true // CGo function pointer in global scope
}
return false
}
// Get all methods of a type.
func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection {
ms := prog.MethodSets.MethodSet(typ)
methods := make([]*types.Selection, ms.Len())
for i := 0; i < ms.Len(); i++ {
methods[i] = ms.At(i)
}
return methods
}

Просмотреть файл

@ -1,4 +1,4 @@
package main package ir
import ( import (
"go/types" "go/types"
@ -81,7 +81,7 @@ func (p *Program) AnalyseCallgraph() {
f.children = nil f.children = nil
f.parents = nil f.parents = nil
for _, block := range f.fn.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
switch instr := instr.(type) { switch instr := instr.(type) {
case *ssa.Call: case *ssa.Call:
@ -99,7 +99,7 @@ func (p *Program) AnalyseCallgraph() {
if child.CName() != "" { if child.CName() != "" {
continue // assume non-blocking continue // assume non-blocking
} }
if child.fn.RelString(nil) == "time.Sleep" { if child.RelString(nil) == "time.Sleep" {
f.blocking = true f.blocking = true
} }
f.children = append(f.children, child) f.children = append(f.children, child)
@ -122,11 +122,11 @@ func (p *Program) AnalyseInterfaceConversions() {
p.typesWithMethods = map[string]*TypeWithMethods{} p.typesWithMethods = map[string]*TypeWithMethods{}
for _, f := range p.Functions { for _, f := range p.Functions {
for _, block := range f.fn.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
switch instr := instr.(type) { switch instr := instr.(type) {
case *ssa.MakeInterface: case *ssa.MakeInterface:
methods := getAllMethods(f.fn.Prog, instr.X.Type()) methods := getAllMethods(f.Prog, instr.X.Type())
name := instr.X.Type().String() name := instr.X.Type().String()
if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 { if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 {
t := &TypeWithMethods{ t := &TypeWithMethods{
@ -155,7 +155,7 @@ func (p *Program) AnalyseFunctionPointers() {
p.fpWithContext = map[string]struct{}{} p.fpWithContext = map[string]struct{}{}
for _, f := range p.Functions { for _, f := range p.Functions {
for _, block := range f.fn.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
switch instr := instr.(type) { switch instr := instr.(type) {
case ssa.CallInstruction: case ssa.CallInstruction:
@ -229,7 +229,7 @@ func (p *Program) AnalyseBlockingRecursive() {
func (p *Program) AnalyseGoCalls() { func (p *Program) AnalyseGoCalls() {
p.goCalls = nil p.goCalls = nil
for _, f := range p.Functions { for _, f := range p.Functions {
for _, block := range f.fn.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
switch instr := instr.(type) { switch instr := instr.(type) {
case *ssa.Go: case *ssa.Go:
@ -262,16 +262,16 @@ func (p *Program) SimpleDCE() {
// Initial set of live functions. Include main.main, *.init and runtime.* // Initial set of live functions. Include main.main, *.init and runtime.*
// functions. // functions.
main := p.mainPkg.Members["main"].(*ssa.Function) main := p.mainPkg.Members["main"].(*ssa.Function)
runtimePkg := p.program.ImportedPackage("runtime") runtimePkg := p.Program.ImportedPackage("runtime")
p.GetFunction(main).flag = true p.GetFunction(main).flag = true
worklist := []*ssa.Function{main} worklist := []*ssa.Function{main}
for _, f := range p.Functions { for _, f := range p.Functions {
if f.fn.Synthetic == "package initializer" || f.fn.Pkg == runtimePkg { if f.Synthetic == "package initializer" || f.Pkg == runtimePkg {
if f.flag || isCGoInternal(f.fn.Name()) { if f.flag || isCGoInternal(f.Name()) {
continue continue
} }
f.flag = true f.flag = true
worklist = append(worklist, f.fn) worklist = append(worklist, f.Function)
} }
} }
@ -282,8 +282,8 @@ func (p *Program) SimpleDCE() {
for _, block := range f.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
if instr, ok := instr.(*ssa.MakeInterface); ok { if instr, ok := instr.(*ssa.MakeInterface); ok {
for _, sel := range getAllMethods(p.program, instr.X.Type()) { for _, sel := range getAllMethods(p.Program, instr.X.Type()) {
fn := p.program.MethodValue(sel) fn := p.Program.MethodValue(sel)
callee := p.GetFunction(fn) callee := p.GetFunction(fn)
if callee == nil { if callee == nil {
// TODO: why is this necessary? // TODO: why is this necessary?
@ -292,7 +292,7 @@ func (p *Program) SimpleDCE() {
} }
if !callee.flag { if !callee.flag {
callee.flag = true callee.flag = true
worklist = append(worklist, callee.fn) worklist = append(worklist, callee.Function)
} }
} }
} }
@ -325,7 +325,7 @@ func (p *Program) SimpleDCE() {
if f.flag { if f.flag {
livefunctions = append(livefunctions, f) livefunctions = append(livefunctions, f)
} else { } else {
delete(p.functionMap, f.fn) delete(p.functionMap, f.Function)
} }
} }
p.Functions = livefunctions p.Functions = livefunctions
@ -418,7 +418,7 @@ func (p *Program) FunctionNeedsContext(f *Function) bool {
if !f.addressTaken { if !f.addressTaken {
return false return false
} }
return p.SignatureNeedsContext(f.fn.Signature) return p.SignatureNeedsContext(f.Signature)
} }
func (p *Program) SignatureNeedsContext(sig *types.Signature) bool { func (p *Program) SignatureNeedsContext(sig *types.Signature) bool {