diff --git a/tgo.go b/tgo.go index d0821db5..eab5af36 100644 --- a/tgo.go +++ b/tgo.go @@ -5,12 +5,13 @@ import ( "errors" "flag" "fmt" - "go/ast" "go/constant" "go/token" "os" "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" "llvm.org/llvm/bindings/go/llvm" ) @@ -70,138 +71,164 @@ func (c *Compiler) Parse(path string) error { // TODO: Build (build.Context) - GOOS, GOARCH, GOPATH, etc } config.CreateFromFilenames("main", path) - program, err := config.Load() + lprogram, err := config.Load() if err != nil { return err } - pkgInfo := program.Created[0] // main? - if len(pkgInfo.Errors) != 0 { - return pkgInfo.Errors[0] // TODO: better error checking - } - fmt.Println("package:", pkgInfo.Pkg.Name()) - for _, file := range pkgInfo.Files { - err := c.parseFile(pkgInfo.Pkg.Name(), file) - if err != nil { - return err - } - } - - return nil -} - -func (c *Compiler) parseFile(pkgName string, file *ast.File) error { - for _, decl := range file.Decls { - switch v := decl.(type) { - case *ast.FuncDecl: - err := c.parseFunc(pkgName, v) - if err != nil { - return err + program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions) + program.Build() + for _, pkg := range program.AllPackages() { + fmt.Println("package:", pkg.Pkg.Name()) + for name, member := range pkg.Members { + fmt.Println("member:", name, member, member.Token()) + if member.Name() == "init" { + continue + } + switch member := member.(type) { + case *ssa.Function: + err := c.parseFunc(pkg.Pkg.Name(), member) + if err != nil { + return err + } + + default: + fmt.Println(" TODO") } - default: - return errors.New("unknown declaration") } } return nil } -func (c *Compiler) parseFunc(pkgName string, f *ast.FuncDecl) error { - fmt.Println("func:", f.Name) +func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { + fmt.Println("func:", f.Name(), f.Blocks, "len:", len(f.Blocks)) var fnType llvm.Type - if f.Type.Results == nil { + if f.Signature.Results() == nil { fnType = llvm.FunctionType(llvm.VoidType(), nil, false) } else { return errors.New("todo: return values") } - fn := llvm.AddFunction(c.mod, pkgName + "." + f.Name.Name, fnType) + fn := llvm.AddFunction(c.mod, pkgName + "." + f.Name(), fnType) start := c.ctx.AddBasicBlock(fn, "start") c.builder.SetInsertPointAtEnd(start) // TODO: external functions - for _, stmt := range f.Body.List { - err := c.parseStmt(stmt) - if err != nil { - return err + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + fmt.Printf(" instr: %v\n", instr) + err := c.parseInstr(instr) + if err != nil { + return err + } } } - - if f.Type.Results == nil { - c.builder.CreateRetVoid() - //} else if len(f.Type.Results.List) == 1 { - // c.builder.CreateRet(llvm.ConstInt(llvm.Int32Type(), 0, false)) - } return nil } -func (c *Compiler) parseStmt(stmt ast.Node) error { - switch v := stmt.(type) { - case *ast.ExprStmt: - err := c.parseExpr(v.X) - if err != nil { - return err +func (c *Compiler) parseInstr(instr ssa.Instruction) error { + switch instr := instr.(type) { + case *ssa.Call: + switch call := instr.Common().Value.(type) { + case *ssa.Builtin: + return c.parseBuiltin(instr.Common(), call) + default: + return errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call)) } - default: - return errors.New("unknown stmt") - } - return nil -} - -func (c *Compiler) parseExpr(expr ast.Expr) error { - switch v := expr.(type) { - case *ast.CallExpr: - name := v.Fun.(*ast.Ident).Name - fmt.Printf(" call: %s\n", name) - - printnl := false - if name == "println" { - printnl = true - } else if name == "print" { + case *ssa.Return: + if len(instr.Results) == 0 { + c.builder.CreateRetVoid() + return nil } else { - return errors.New("todo: call anything other than println()") + return errors.New("todo: return value") } + case *ssa.BinOp: + return c.parseBinOp(instr) + default: + return errors.New("unknown instruction: " + fmt.Sprintf("%#v", instr)) + } +} - for i, arg := range v.Args { +func (c *Compiler) parseBuiltin(instr *ssa.CallCommon, call *ssa.Builtin) error { + fmt.Printf(" builtin: %#v\n", call) + name := call.Name() + + switch name { + case "print", "println": + for i, arg := range instr.Args { if i >= 1 { c.builder.CreateCall(c.printspaceFunc, nil, "") } - switch arg := arg.(type) { - case *ast.BasicLit: - fmt.Printf(" arg: %s\n", arg.Value) - val := constant.MakeFromLiteral(arg.Value, arg.Kind, 0) - switch arg.Kind { - case token.STRING: - str := constant.StringVal(val) - strVal := c.ctx.ConstString(str, false) - strLen := llvm.ConstInt(llvm.Int32Type(), uint64(len(str)), false) - strObj := llvm.ConstStruct([]llvm.Value{strLen, strVal}, false) - ptr := llvm.AddGlobal(c.mod, strObj.Type(), ".str") - ptr.SetInitializer(strObj) - ptr.SetLinkage(llvm.InternalLinkage) - ptrCast := llvm.ConstPointerCast(ptr, c.stringPtrType) - c.builder.CreateCall(c.printstringFunc, []llvm.Value{ptrCast}, "") - case token.INT: - n, _ := constant.Int64Val(val) // TODO: do something with the 'exact' return value? - val := llvm.ConstInt(llvm.Int32Type(), uint64(n), true) - c.builder.CreateCall(c.printintFunc, []llvm.Value{val}, "") - default: - return errors.New("todo: print anything other than strings") - } + fmt.Printf(" arg: %s\n", arg); + expr, err := c.parseExpr(arg) + if err != nil { + return err + } + switch expr.Type() { + case c.stringPtrType: + c.builder.CreateCall(c.printstringFunc, []llvm.Value{*expr}, "") + case llvm.Int32Type(): + c.builder.CreateCall(c.printintFunc, []llvm.Value{*expr}, "") default: return errors.New("unknown arg type") } } - if printnl { + if name == "println" { c.builder.CreateCall(c.printnlFunc, nil, "") } - default: - return errors.New("unknown expr") } + return nil } +func (c *Compiler) parseBinOp(binop *ssa.BinOp) error { + x, err := c.parseExpr(binop.X) + if err != nil { + return err + } + y, err := c.parseExpr(binop.Y) + if err != nil { + return err + } + switch binop.Op { + case token.ADD: + c.builder.CreateBinOp(llvm.Add, *x, *y, "") + return nil + case token.MUL: + c.builder.CreateBinOp(llvm.Mul, *x, *y, "") + return nil + } + return errors.New("todo: unknown binop") +} + +func (c *Compiler) parseExpr(expr ssa.Value) (*llvm.Value, error) { + fmt.Printf(" expr: %v\n", expr) + switch expr := expr.(type) { + case *ssa.Const: + switch expr.Value.Kind() { + case constant.String: + str := constant.StringVal(expr.Value) + strVal := c.ctx.ConstString(str, false) + strLen := llvm.ConstInt(llvm.Int32Type(), uint64(len(str)), false) + strObj := llvm.ConstStruct([]llvm.Value{strLen, strVal}, false) + ptr := llvm.AddGlobal(c.mod, strObj.Type(), ".str") + ptr.SetInitializer(strObj) + ptr.SetLinkage(llvm.InternalLinkage) + ptrCast := llvm.ConstPointerCast(ptr, c.stringPtrType) + return &ptrCast, nil + case constant.Int: + n, _ := constant.Int64Val(expr.Value) // TODO: do something with the 'exact' return value? + val := llvm.ConstInt(llvm.Int32Type(), uint64(n), true) + return &val, nil + default: + return nil, errors.New("todo: unknown constant") + } + } + return nil, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) +} + // IR returns the whole IR as a human-readable string. func (c *Compiler) IR() string { return c.mod.String()