From 6a8dc7ca9a5de7bc9054506b7f7840862e956381 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Fri, 13 Apr 2018 02:11:12 +0200 Subject: [PATCH] Support functions with parameters --- hello/hello.go | 10 ++++++ tgo.go | 93 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/hello/hello.go b/hello/hello.go index c549be39..14fe2e9f 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -6,9 +6,19 @@ const SIX = 6 func main() { println("Hello world from Go!") println("The answer is:", calculateAnswer()) + println("5 ** 2 =", square(5)) + println("3 + 12 =", add(3, 12)) } func calculateAnswer() int { seven := 7 return SIX * seven } + +func square(n int) int { + return n * n +} + +func add(a, b int) int { + return a + b +} diff --git a/tgo.go b/tgo.go index 1c5cc251..fb4a4b10 100644 --- a/tgo.go +++ b/tgo.go @@ -9,6 +9,7 @@ import ( "go/token" "go/types" "os" + "sort" "strings" "golang.org/x/tools/go/loader" @@ -38,10 +39,18 @@ type Compiler struct { printintFunc llvm.Value printspaceFunc llvm.Value printnlFunc llvm.Value + funcs map[*ssa.Function]*Function +} + +type Function struct { + name string + params map[*ssa.Parameter]int } func NewCompiler(path, triple string) (*Compiler, error) { - c := &Compiler{} + c := &Compiler{ + funcs: make(map[*ssa.Function]*Function), + } target, err := llvm.GetTargetFromTriple(triple) if err != nil { @@ -89,11 +98,17 @@ func (c *Compiler) Parse(path string) error { for _, pkg := range program.AllPackages() { fmt.Println("package:", pkg.Pkg.Name()) + // Make sure we're walking through all members in a constant order every + // run. + memberNames := make([]string, 0) + for name := range pkg.Members { + memberNames = append(memberNames, name) + } + sort.Strings(memberNames) + // First, build all function declarations. - for name, member := range pkg.Members { - if name == "init" { - continue - } + for _, name := range memberNames { + member := pkg.Members[name] if member, ok := member.(*ssa.Function); ok { err := c.parseFuncDecl(pkg.Pkg.Name(), member) if err != nil { @@ -103,8 +118,9 @@ func (c *Compiler) Parse(path string) error { } // Now, add definitions to those declarations. - for name, member := range pkg.Members { - fmt.Println("member:", name, member, member.Token()) + for _, name := range memberNames { + member := pkg.Members[name] + fmt.Println("member:", member.Token(), member) if name == "init" { continue } @@ -125,6 +141,12 @@ func (c *Compiler) Parse(path string) error { } func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { + name := pkgName + "." + f.Name() + c.funcs[f] = &Function{ + name: name, + params: make(map[*ssa.Parameter]int), + } + var retType llvm.Type if f.Signature.Results() == nil { retType = llvm.VoidType() @@ -146,16 +168,38 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { } else { return errors.New("todo: return values") } - fnType := llvm.FunctionType(retType, nil, false) - llvm.AddFunction(c.mod, pkgName + "." + f.Name(), fnType) + + var paramTypes []llvm.Type + for i, param := range f.Params { + switch typ := param.Type().(type) { + case *types.Basic: + var paramType llvm.Type + switch typ.Kind() { + case types.Int: + paramType = c.intType + case types.Int32: + paramType = llvm.Int32Type() + default: + return errors.New("todo: unknown basic param type") + } + paramTypes = append(paramTypes, paramType) + c.funcs[f].params[param] = i + default: + return errors.New("todo: unknown param type") + } + } + + fnType := llvm.FunctionType(retType, paramTypes, false) + llvm.AddFunction(c.mod, name, fnType) return nil } func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { - fmt.Println("func:", f.Name(), f.Blocks, "len:", len(f.Blocks)) + fmt.Println("func:", f.Name()) - fn := c.mod.NamedFunction(pkgName + "." + f.Name()) - start := c.ctx.AddBasicBlock(fn, "start") + fn := c.funcs[f] + llvmFn := c.mod.NamedFunction(fn.name) + start := c.ctx.AddBasicBlock(llvmFn, "start") c.builder.SetInsertPointAtEnd(start) // TODO: external functions @@ -231,10 +275,10 @@ func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa } } -func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.Function) (*llvm.Value, error) { - fmt.Printf(" function: %s\n", call) +func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *ssa.Function) (*llvm.Value, error) { + fmt.Printf(" function: %s\n", fn) - name := call.Name() + name := fn.Name() if strings.IndexByte(name, '.') == -1 { // TODO: import path instead of pkgName name = pkgName + "." + name @@ -243,7 +287,17 @@ func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.Function) (*llvm. if target.IsNil() { return nil, errors.New("undefined function: " + name) } - val := c.builder.CreateCall(target, nil, "") + + var params []llvm.Value + for _, param := range call.Args { + val, err := c.parseExpr(pkgName, param) + if err != nil { + return nil, err + } + params = append(params, *val) + } + + val := c.builder.CreateCall(target, params, "") return &val, nil } @@ -254,7 +308,7 @@ func (c *Compiler) parseCall(pkgName string, instr *ssa.Call) (*llvm.Value, erro case *ssa.Builtin: return c.parseBuiltin(pkgName, instr.Common(), call) case *ssa.Function: - return c.parseFunctionCall(pkgName, call) + return c.parseFunctionCall(pkgName, instr.Common(), call) default: return nil, errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call)) } @@ -307,6 +361,11 @@ func (c *Compiler) parseExpr(pkgName string, expr ssa.Value) (*llvm.Value, error return c.parseBinOp(pkgName, expr) case *ssa.Call: return c.parseCall(pkgName, expr) + case *ssa.Parameter: + fn := c.funcs[expr.Parent()] + llvmFn := c.mod.NamedFunction(fn.name) + param := llvmFn.Param(fn.params[expr]) + return ¶m, nil default: return nil, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) }