diff --git a/tgo.go b/tgo.go index 0a17ad4a..84b5d975 100644 --- a/tgo.go +++ b/tgo.go @@ -5,6 +5,8 @@ import ( "errors" "flag" "fmt" + "go/ast" + "go/build" "go/constant" "go/token" "go/types" @@ -27,6 +29,7 @@ func init() { } type Compiler struct { + triple string mod llvm.Module ctx llvm.Context builder llvm.Builder @@ -41,12 +44,12 @@ type Compiler struct { } type Frame struct { - pkgName string - name string // full name, including package - params map[*ssa.Parameter]int // arguments to the function - locals map[ssa.Value]llvm.Value // local variables - blocks map[*ssa.BasicBlock]llvm.BasicBlock - phis []Phi + pkgPrefix string + name string // full name, including package + params map[*ssa.Parameter]int // arguments to the function + locals map[ssa.Value]llvm.Value // local variables + blocks map[*ssa.BasicBlock]llvm.BasicBlock + phis []Phi } type Phi struct { @@ -54,8 +57,10 @@ type Phi struct { llvm llvm.Value } -func NewCompiler(path, triple string) (*Compiler, error) { - c := &Compiler{} +func NewCompiler(pkgName, triple string) (*Compiler, error) { + c := &Compiler{ + triple: triple, + } target, err := llvm.GetTargetFromTriple(triple) if err != nil { @@ -63,7 +68,7 @@ func NewCompiler(path, triple string) (*Compiler, error) { } c.machine = target.CreateTargetMachine(triple, "", "", llvm.CodeGenLevelDefault, llvm.RelocDefault, llvm.CodeModelDefault) - c.mod = llvm.NewModule(path) + c.mod = llvm.NewModule(pkgName) c.ctx = c.mod.Context() c.builder = c.ctx.NewBuilder() @@ -86,21 +91,41 @@ func NewCompiler(path, triple string) (*Compiler, error) { return c, nil } -func (c *Compiler) Parse(path string) error { +func (c *Compiler) Parse(pkgName string) error { + tripleSplit := strings.Split(c.triple, "-") + config := loader.Config { // TODO: TypeChecker.Sizes - // TODO: Build (build.Context) - GOOS, GOARCH, GOPATH, etc + Build: &build.Context { + GOARCH: tripleSplit[0], + GOOS: tripleSplit[2], + GOROOT: ".", + CgoEnabled: true, + UseAllFiles: false, + Compiler: "gc", // must be one of the recognized compilers + BuildTags: []string{"tgo"}, + }, + AllowErrors: true, } - config.CreateFromFilenames("main", path) + config.Import(pkgName) lprogram, err := config.Load() if err != nil { return err } - program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions) + // TODO: pick the error of the first package, not a random package + for _, pkgInfo := range lprogram.AllPackages { + fmt.Println("package:", pkgInfo.Pkg.Name()) + if len(pkgInfo.Errors) != 0 { + return pkgInfo.Errors[0] + } + } + + program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions | ssa.BareInits) program.Build() + // TODO: order of packages is random for _, pkg := range program.AllPackages() { - fmt.Println("package:", pkg.Pkg.Name()) + fmt.Println("package:", pkg.Pkg.Path()) // Make sure we're walking through all members in a constant order every // run. @@ -115,12 +140,41 @@ func (c *Compiler) Parse(path string) error { // First, build all function declarations. for _, name := range memberNames { member := pkg.Members[name] - if member, ok := member.(*ssa.Function); ok { - frame, err := c.parseFuncDecl(pkg.Pkg.Name(), member) + + pkgPrefix := pkg.Pkg.Path() + if pkg.Pkg.Name() == "main" { + pkgPrefix = "main" + } + + switch member := member.(type) { + case *ssa.Function: + frame, err := c.parseFuncDecl(pkgPrefix, member) if err != nil { return err } frames[member] = frame + case *ssa.NamedConst: + val, err := c.parseConst(member.Value) + if err != nil { + return err + } + global := llvm.AddGlobal(c.mod, val.Type(), pkgPrefix + "." + member.Name()) + global.SetInitializer(val) + global.SetGlobalConstant(true) + if ast.IsExported(member.Name()) { + global.SetLinkage(llvm.PrivateLinkage) + } + case *ssa.Global: + typ, err := c.getLLVMType(member.Type()) + if err != nil { + return err + } + global := llvm.AddGlobal(c.mod, typ, pkgPrefix + "." + member.Name()) + if ast.IsExported(member.Name()) { + global.SetLinkage(llvm.PrivateLinkage) + } + default: + return errors.New("todo: member: " + fmt.Sprintf("%#v", member)) } } @@ -128,18 +182,15 @@ func (c *Compiler) Parse(path string) error { for _, name := range memberNames { member := pkg.Members[name] fmt.Println("member:", member.Token(), member) - if name == "init" { - continue - } - switch member := member.(type) { - case *ssa.Function: + + if member, ok := member.(*ssa.Function); ok { + if member.Blocks == nil { + continue // external function + } err := c.parseFunc(frames[member], member) if err != nil { return err } - - default: - fmt.Println(" TODO") } } } @@ -151,26 +202,43 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { switch typ := goType.(type) { case *types.Basic: switch typ.Kind() { + case types.Bool: + return llvm.Int1Type(), nil case types.Int: return c.intType, nil case types.Int32: return llvm.Int32Type(), nil + case types.UnsafePointer: + return llvm.PointerType(llvm.Int8Type(), 0), nil default: - return llvm.Type{}, errors.New("todo: unknown basic type") + return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ)) } + case *types.Pointer: + ptrTo, err := c.getLLVMType(typ.Elem()) + if err != nil { + return llvm.Type{}, err + } + return llvm.PointerType(ptrTo, 0), nil default: - return llvm.Type{}, errors.New("todo: unknown type") + return llvm.Type{}, errors.New("todo: unknown type: " + fmt.Sprintf("%#v", goType)) } } -func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) { - name := pkgName + "." + f.Name() +func (c *Compiler) getPackageRelativeName(frame *Frame, name string) string { + if strings.IndexByte(name, '.') == -1 { + name = frame.pkgPrefix + "." + name + } + return name +} + +func (c *Compiler) parseFuncDecl(pkgPrefix string, f *ssa.Function) (*Frame, error) { + name := pkgPrefix + "." + f.Name() frame := &Frame{ - pkgName: pkgName, - name: name, - params: make(map[*ssa.Parameter]int), - locals: make(map[ssa.Value]llvm.Value), - blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), + pkgPrefix: pkgPrefix, + name: name, + params: make(map[*ssa.Parameter]int), + locals: make(map[ssa.Value]llvm.Value), + blocks: make(map[*ssa.BasicBlock]llvm.BasicBlock), } var retType llvm.Type @@ -188,22 +256,12 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error var paramTypes []llvm.Type for i, param := range f.Params { - switch typ := param.Type().(type) { - case *types.Basic: - var paramType llvm.Type - switch typ.Kind() { - case types.Int: - paramType = c.intType - case types.Int32: - paramType = llvm.Int32Type() - default: - return nil, errors.New("todo: unknown basic param type") - } - paramTypes = append(paramTypes, paramType) - frame.params[param] = i - default: - return nil, errors.New("todo: unknown param type") + paramType, err := c.getLLVMType(param.Type()) + if err != nil { + return nil, err } + paramTypes = append(paramTypes, paramType) + frame.params[param] = i } fnType := llvm.FunctionType(retType, paramTypes, false) @@ -212,8 +270,6 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error } func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { - // TODO: external functions - // Pre-create all basic blocks in the function. llvmFn := c.mod.NamedFunction(frame.name) for _, block := range f.DomPreorder() { @@ -289,7 +345,7 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { } func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { - fmt.Printf(" builtin: %#v\n", call) + fmt.Printf(" builtin: %v\n", call) name := call.Name() switch name { @@ -324,11 +380,7 @@ func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.B func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { fmt.Printf(" function: %s\n", fn) - name := fn.Name() - if strings.IndexByte(name, '.') == -1 { - // TODO: import path instead of pkgName - name = frame.pkgName + "." + name - } + name := c.getPackageRelativeName(frame, fn.Name()) target := c.mod.NamedFunction(name) if target.IsNil() { return llvm.Value{}, errors.New("undefined function: " + name) @@ -359,6 +411,44 @@ func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error) } } +func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { + fmt.Printf(" expr: %v\n", expr) + + if frame != nil { + if value, ok := frame.locals[expr]; ok { + // Value is a local variable that has already been computed. + fmt.Println(" from local var") + return value, nil + } + } + + switch expr := expr.(type) { + case *ssa.Const: + return c.parseConst(expr) + case *ssa.BinOp: + return c.parseBinOp(frame, expr) + case *ssa.Call: + return c.parseCall(frame, expr) + case *ssa.Global: + return c.mod.NamedGlobal(c.getPackageRelativeName(frame, expr.Name())), nil + case *ssa.Parameter: + llvmFn := c.mod.NamedFunction(frame.name) + return llvmFn.Param(frame.params[expr]), nil + case *ssa.Phi: + t, err := c.getLLVMType(expr.Type()) + if err != nil { + return llvm.Value{}, err + } + phi := c.builder.CreatePHI(t, "") + frame.phis = append(frame.phis, Phi{expr, phi}) + return phi, nil + case *ssa.UnOp: + return c.parseUnOp(frame, expr) + default: + return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) + } +} + func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) { x, err := c.parseExpr(frame, binop.X) if err != nil { @@ -410,47 +500,32 @@ func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error } } -func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { - fmt.Printf(" expr: %v\n", expr) - - if value, ok := frame.locals[expr]; ok { - // Value is a local variable that has already been computed. - fmt.Println(" from local var") - return value, nil - } - - switch expr := expr.(type) { - case *ssa.Const: - switch expr.Value.Kind() { - case constant.String: - str := constant.StringVal(expr.Value) - strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) - strPtr := c.builder.CreateGlobalStringPtr(str, ".str") - strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) - return strObj, nil - case constant.Int: - n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value? - return llvm.ConstInt(c.intType, uint64(n), true), nil - default: - return llvm.Value{}, errors.New("todo: unknown constant") - } - case *ssa.BinOp: - return c.parseBinOp(frame, expr) - case *ssa.Call: - return c.parseCall(frame, expr) - case *ssa.Parameter: - llvmFn := c.mod.NamedFunction(frame.name) - return llvmFn.Param(frame.params[expr]), nil - case *ssa.Phi: - t, err := c.getLLVMType(expr.Type()) - if err != nil { - return llvm.Value{}, err - } - phi := c.builder.CreatePHI(t, "") - frame.phis = append(frame.phis, Phi{expr, phi}) - return phi, nil +func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) { + switch expr.Value.Kind() { + case constant.String: + str := constant.StringVal(expr.Value) + strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) + strPtr := c.builder.CreateGlobalStringPtr(str, ".str") + strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) + return strObj, nil + case constant.Int: + n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value? + return llvm.ConstInt(c.intType, uint64(n), true), nil default: - return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) + return llvm.Value{}, errors.New("todo: unknown constant") + } +} + +func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) { + x, err := c.parseExpr(frame, unop.X) + if err != nil { + return llvm.Value{}, err + } + switch unop.Op { + case token.NOT: + return c.builder.CreateNot(x, ""), nil + default: + return llvm.Value{}, errors.New("todo: unknown unop") } } @@ -495,13 +570,13 @@ func (c *Compiler) EmitObject(path string) error { } // Helper function for Compiler object. -func Compile(inpath, outpath, target string, printIR bool) error { - c, err := NewCompiler(inpath, target) +func Compile(pkgName, outpath, target string, printIR bool) error { + c, err := NewCompiler(pkgName, target) if err != nil { return err } - parseErr := c.Parse(inpath) + parseErr := c.Parse(pkgName) if printIR { fmt.Println(c.IR()) }