* Don't skip init function
* Add global variables and constants
* Add unary operations
* Use import path instead of package name (except for main)
* ...more
Этот коммит содержится в:
Ayke van Laethem 2018-04-15 04:49:00 +02:00
родитель 5dfcb5f085
коммит de0ff3b3af

273
tgo.go
Просмотреть файл

@ -5,6 +5,8 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"go/ast"
"go/build"
"go/constant" "go/constant"
"go/token" "go/token"
"go/types" "go/types"
@ -27,6 +29,7 @@ func init() {
} }
type Compiler struct { type Compiler struct {
triple string
mod llvm.Module mod llvm.Module
ctx llvm.Context ctx llvm.Context
builder llvm.Builder builder llvm.Builder
@ -41,12 +44,12 @@ type Compiler struct {
} }
type Frame struct { type Frame struct {
pkgName string pkgPrefix string
name string // full name, including package name string // full name, including package
params map[*ssa.Parameter]int // arguments to the function params map[*ssa.Parameter]int // arguments to the function
locals map[ssa.Value]llvm.Value // local variables locals map[ssa.Value]llvm.Value // local variables
blocks map[*ssa.BasicBlock]llvm.BasicBlock blocks map[*ssa.BasicBlock]llvm.BasicBlock
phis []Phi phis []Phi
} }
type Phi struct { type Phi struct {
@ -54,8 +57,10 @@ type Phi struct {
llvm llvm.Value llvm llvm.Value
} }
func NewCompiler(path, triple string) (*Compiler, error) { func NewCompiler(pkgName, triple string) (*Compiler, error) {
c := &Compiler{} c := &Compiler{
triple: triple,
}
target, err := llvm.GetTargetFromTriple(triple) target, err := llvm.GetTargetFromTriple(triple)
if err != nil { if err != nil {
@ -63,7 +68,7 @@ func NewCompiler(path, triple string) (*Compiler, error) {
} }
c.machine = target.CreateTargetMachine(triple, "", "", llvm.CodeGenLevelDefault, llvm.RelocDefault, llvm.CodeModelDefault) c.machine = target.CreateTargetMachine(triple, "", "", llvm.CodeGenLevelDefault, llvm.RelocDefault, llvm.CodeModelDefault)
c.mod = llvm.NewModule(path) c.mod = llvm.NewModule(pkgName)
c.ctx = c.mod.Context() c.ctx = c.mod.Context()
c.builder = c.ctx.NewBuilder() c.builder = c.ctx.NewBuilder()
@ -86,21 +91,41 @@ func NewCompiler(path, triple string) (*Compiler, error) {
return c, nil return c, nil
} }
func (c *Compiler) Parse(path string) error { func (c *Compiler) Parse(pkgName string) error {
tripleSplit := strings.Split(c.triple, "-")
config := loader.Config { config := loader.Config {
// TODO: TypeChecker.Sizes // TODO: TypeChecker.Sizes
// TODO: Build (build.Context) - GOOS, GOARCH, GOPATH, etc Build: &build.Context {
GOARCH: tripleSplit[0],
GOOS: tripleSplit[2],
GOROOT: ".",
CgoEnabled: true,
UseAllFiles: false,
Compiler: "gc", // must be one of the recognized compilers
BuildTags: []string{"tgo"},
},
AllowErrors: true,
} }
config.CreateFromFilenames("main", path) config.Import(pkgName)
lprogram, err := config.Load() lprogram, err := config.Load()
if err != nil { if err != nil {
return err return err
} }
program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions) // TODO: pick the error of the first package, not a random package
for _, pkgInfo := range lprogram.AllPackages {
fmt.Println("package:", pkgInfo.Pkg.Name())
if len(pkgInfo.Errors) != 0 {
return pkgInfo.Errors[0]
}
}
program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits)
program.Build() program.Build()
// TODO: order of packages is random
for _, pkg := range program.AllPackages() { for _, pkg := range program.AllPackages() {
fmt.Println("package:", pkg.Pkg.Name()) fmt.Println("package:", pkg.Pkg.Path())
// Make sure we're walking through all members in a constant order every // Make sure we're walking through all members in a constant order every
// run. // run.
@ -115,12 +140,41 @@ func (c *Compiler) Parse(path string) error {
// First, build all function declarations. // First, build all function declarations.
for _, name := range memberNames { for _, name := range memberNames {
member := pkg.Members[name] member := pkg.Members[name]
if member, ok := member.(*ssa.Function); ok {
frame, err := c.parseFuncDecl(pkg.Pkg.Name(), member) pkgPrefix := pkg.Pkg.Path()
if pkg.Pkg.Name() == "main" {
pkgPrefix = "main"
}
switch member := member.(type) {
case *ssa.Function:
frame, err := c.parseFuncDecl(pkgPrefix, member)
if err != nil { if err != nil {
return err return err
} }
frames[member] = frame frames[member] = frame
case *ssa.NamedConst:
val, err := c.parseConst(member.Value)
if err != nil {
return err
}
global := llvm.AddGlobal(c.mod, val.Type(), pkgPrefix + "." + member.Name())
global.SetInitializer(val)
global.SetGlobalConstant(true)
if ast.IsExported(member.Name()) {
global.SetLinkage(llvm.PrivateLinkage)
}
case *ssa.Global:
typ, err := c.getLLVMType(member.Type())
if err != nil {
return err
}
global := llvm.AddGlobal(c.mod, typ, pkgPrefix + "." + member.Name())
if ast.IsExported(member.Name()) {
global.SetLinkage(llvm.PrivateLinkage)
}
default:
return errors.New("todo: member: " + fmt.Sprintf("%#v", member))
} }
} }
@ -128,18 +182,15 @@ func (c *Compiler) Parse(path string) error {
for _, name := range memberNames { for _, name := range memberNames {
member := pkg.Members[name] member := pkg.Members[name]
fmt.Println("member:", member.Token(), member) fmt.Println("member:", member.Token(), member)
if name == "init" {
continue if member, ok := member.(*ssa.Function); ok {
} if member.Blocks == nil {
switch member := member.(type) { continue // external function
case *ssa.Function: }
err := c.parseFunc(frames[member], member) err := c.parseFunc(frames[member], member)
if err != nil { if err != nil {
return err return err
} }
default:
fmt.Println(" TODO")
} }
} }
} }
@ -151,26 +202,43 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
switch typ := goType.(type) { switch typ := goType.(type) {
case *types.Basic: case *types.Basic:
switch typ.Kind() { switch typ.Kind() {
case types.Bool:
return llvm.Int1Type(), nil
case types.Int: case types.Int:
return c.intType, nil return c.intType, nil
case types.Int32: case types.Int32:
return llvm.Int32Type(), nil return llvm.Int32Type(), nil
case types.UnsafePointer:
return llvm.PointerType(llvm.Int8Type(), 0), nil
default: default:
return llvm.Type{}, errors.New("todo: unknown basic type") return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ))
} }
case *types.Pointer:
ptrTo, err := c.getLLVMType(typ.Elem())
if err != nil {
return llvm.Type{}, err
}
return llvm.PointerType(ptrTo, 0), nil
default: default:
return llvm.Type{}, errors.New("todo: unknown type") return llvm.Type{}, errors.New("todo: unknown type: " + fmt.Sprintf("%#v", goType))
} }
} }
func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) { func (c *Compiler) getPackageRelativeName(frame *Frame, name string) string {
name := pkgName + "." + f.Name() if strings.IndexByte(name, '.') == -1 {
name = frame.pkgPrefix + "." + name
}
return name
}
func (c *Compiler) parseFuncDecl(pkgPrefix string, f *ssa.Function) (*Frame, error) {
name := pkgPrefix + "." + f.Name()
frame := &Frame{ frame := &Frame{
pkgName: pkgName, pkgPrefix: pkgPrefix,
name: name, name: name,
params: make(map[*ssa.Parameter]int), params: make(map[*ssa.Parameter]int),
locals: make(map[ssa.Value]llvm.Value), locals: make(map[ssa.Value]llvm.Value),
blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock),
} }
var retType llvm.Type var retType llvm.Type
@ -188,22 +256,12 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error
var paramTypes []llvm.Type var paramTypes []llvm.Type
for i, param := range f.Params { for i, param := range f.Params {
switch typ := param.Type().(type) { paramType, err := c.getLLVMType(param.Type())
case *types.Basic: if err != nil {
var paramType llvm.Type return nil, err
switch typ.Kind() {
case types.Int:
paramType = c.intType
case types.Int32:
paramType = llvm.Int32Type()
default:
return nil, errors.New("todo: unknown basic param type")
}
paramTypes = append(paramTypes, paramType)
frame.params[param] = i
default:
return nil, errors.New("todo: unknown param type")
} }
paramTypes = append(paramTypes, paramType)
frame.params[param] = i
} }
fnType := llvm.FunctionType(retType, paramTypes, false) fnType := llvm.FunctionType(retType, paramTypes, false)
@ -212,8 +270,6 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error
} }
func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
// TODO: external functions
// Pre-create all basic blocks in the function. // Pre-create all basic blocks in the function.
llvmFn := c.mod.NamedFunction(frame.name) llvmFn := c.mod.NamedFunction(frame.name)
for _, block := range f.DomPreorder() { for _, block := range f.DomPreorder() {
@ -289,7 +345,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
} }
func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) {
fmt.Printf(" builtin: %#v\n", call) fmt.Printf(" builtin: %v\n", call)
name := call.Name() name := call.Name()
switch name { switch name {
@ -324,11 +380,7 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B
func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) {
fmt.Printf(" function: %s\n", fn) fmt.Printf(" function: %s\n", fn)
name := fn.Name() name := c.getPackageRelativeName(frame, fn.Name())
if strings.IndexByte(name, '.') == -1 {
// TODO: import path instead of pkgName
name = frame.pkgName + "." + name
}
target := c.mod.NamedFunction(name) target := c.mod.NamedFunction(name)
if target.IsNil() { if target.IsNil() {
return llvm.Value{}, errors.New("undefined function: " + name) return llvm.Value{}, errors.New("undefined function: " + name)
@ -359,6 +411,44 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error)
} }
} }
func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
fmt.Printf(" expr: %v\n", expr)
if frame != nil {
if value, ok := frame.locals[expr]; ok {
// Value is a local variable that has already been computed.
fmt.Println(" from local var")
return value, nil
}
}
switch expr := expr.(type) {
case *ssa.Const:
return c.parseConst(expr)
case *ssa.BinOp:
return c.parseBinOp(frame, expr)
case *ssa.Call:
return c.parseCall(frame, expr)
case *ssa.Global:
return c.mod.NamedGlobal(c.getPackageRelativeName(frame, expr.Name())), nil
case *ssa.Parameter:
llvmFn := c.mod.NamedFunction(frame.name)
return llvmFn.Param(frame.params[expr]), nil
case *ssa.Phi:
t, err := c.getLLVMType(expr.Type())
if err != nil {
return llvm.Value{}, err
}
phi := c.builder.CreatePHI(t, "")
frame.phis = append(frame.phis, Phi{expr, phi})
return phi, nil
case *ssa.UnOp:
return c.parseUnOp(frame, expr)
default:
return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr))
}
}
func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) { func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) {
x, err := c.parseExpr(frame, binop.X) x, err := c.parseExpr(frame, binop.X)
if err != nil { if err != nil {
@ -410,47 +500,32 @@ func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error
} }
} }
func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) {
fmt.Printf(" expr: %v\n", expr) switch expr.Value.Kind() {
case constant.String:
if value, ok := frame.locals[expr]; ok { str := constant.StringVal(expr.Value)
// Value is a local variable that has already been computed. strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false)
fmt.Println(" from local var") strPtr := c.builder.CreateGlobalStringPtr(str, ".str")
return value, nil strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false)
} return strObj, nil
case constant.Int:
switch expr := expr.(type) { n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value?
case *ssa.Const: return llvm.ConstInt(c.intType, uint64(n), true), nil
switch expr.Value.Kind() {
case constant.String:
str := constant.StringVal(expr.Value)
strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false)
strPtr := c.builder.CreateGlobalStringPtr(str, ".str")
strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false)
return strObj, nil
case constant.Int:
n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value?
return llvm.ConstInt(c.intType, uint64(n), true), nil
default:
return llvm.Value{}, errors.New("todo: unknown constant")
}
case *ssa.BinOp:
return c.parseBinOp(frame, expr)
case *ssa.Call:
return c.parseCall(frame, expr)
case *ssa.Parameter:
llvmFn := c.mod.NamedFunction(frame.name)
return llvmFn.Param(frame.params[expr]), nil
case *ssa.Phi:
t, err := c.getLLVMType(expr.Type())
if err != nil {
return llvm.Value{}, err
}
phi := c.builder.CreatePHI(t, "")
frame.phis = append(frame.phis, Phi{expr, phi})
return phi, nil
default: default:
return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) return llvm.Value{}, errors.New("todo: unknown constant")
}
}
func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) {
x, err := c.parseExpr(frame, unop.X)
if err != nil {
return llvm.Value{}, err
}
switch unop.Op {
case token.NOT:
return c.builder.CreateNot(x, ""), nil
default:
return llvm.Value{}, errors.New("todo: unknown unop")
} }
} }
@ -495,13 +570,13 @@ func (c *Compiler) EmitObject(path string) error {
} }
// Helper function for Compiler object. // Helper function for Compiler object.
func Compile(inpath, outpath, target string, printIR bool) error { func Compile(pkgName, outpath, target string, printIR bool) error {
c, err := NewCompiler(inpath, target) c, err := NewCompiler(pkgName, target)
if err != nil { if err != nil {
return err return err
} }
parseErr := c.Parse(inpath) parseErr := c.Parse(pkgName)
if printIR { if printIR {
fmt.Println(c.IR()) fmt.Println(c.IR())
} }