diff --git a/tgo.go b/tgo.go index 16f277e5..11326867 100644 --- a/tgo.go +++ b/tgo.go @@ -7,10 +7,10 @@ import ( "fmt" "go/ast" "go/constant" - "go/parser" "go/token" "os" + "golang.org/x/tools/go/loader" "llvm.org/llvm/bindings/go/llvm" ) @@ -46,23 +46,41 @@ func NewCompiler(path, triplet string) (*Compiler, error) { putsType := llvm.FunctionType(llvm.Int32Type(), []llvm.Type{llvm.PointerType(llvm.Int8Type(), 0)}, false) c.putsFunc = llvm.AddFunction(c.mod, "puts", putsType) - mainType := llvm.FunctionType(llvm.Int32Type(), nil, false) - mainFunc := llvm.AddFunction(c.mod, "main", mainType) - start := c.ctx.AddBasicBlock(mainFunc, "start") - c.builder.SetInsertPointAtEnd(start) - return c, nil } func (c *Compiler) Parse(path string) error { - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, path, nil, 0) + config := loader.Config { + // TODO: TypeChecker.Sizes + // TODO: Build (build.Context) - GOOS, GOARCH, GOPATH, etc + } + config.CreateFromFilenames("main", path) + program, err := config.Load() if err != nil { return err } - fmt.Println("package:", file.Name.Name) + pkgInfo := program.Created[0] // main? + if len(pkgInfo.Errors) != 0 { + return pkgInfo.Errors[0] // TODO: better error checking + } + fmt.Println("package:", pkgInfo.Pkg.Name()) + for _, file := range pkgInfo.Files { + err := c.parseFile(file) + if err != nil { + return err + } + } + return nil +} + +func (c *Compiler) IR() string { + return c.mod.String() +} + + +func (c *Compiler) parseFile(file *ast.File) error { for _, decl := range file.Decls { switch v := decl.(type) { case *ast.FuncDecl: @@ -75,17 +93,17 @@ func (c *Compiler) Parse(path string) error { } } - c.builder.CreateRet(llvm.ConstInt(llvm.Int32Type(), 0, false)) - return nil } -func (c *Compiler) IR() string { - return c.mod.String() -} - func (c *Compiler) parseFunc(f *ast.FuncDecl) error { fmt.Println("func:", f.Name) + + fnType := llvm.FunctionType(llvm.Int32Type(), nil, false) + fn := llvm.AddFunction(c.mod, f.Name.Name, fnType) + start := c.ctx.AddBasicBlock(fn, "start") + c.builder.SetInsertPointAtEnd(start) + // TODO: external functions for _, stmt := range f.Body.List { err := c.parseStmt(stmt) @@ -93,6 +111,8 @@ func (c *Compiler) parseFunc(f *ast.FuncDecl) error { return err } } + + c.builder.CreateRet(llvm.ConstInt(llvm.Int32Type(), 0, false)) return nil }