diff --git a/tgo.go b/tgo.go index 2eb2188a..c51feeab 100644 --- a/tgo.go +++ b/tgo.go @@ -39,18 +39,17 @@ type Compiler struct { printintFunc llvm.Value printspaceFunc llvm.Value printnlFunc llvm.Value - funcs map[*ssa.Function]*Function } -type Function struct { - name string - params map[*ssa.Parameter]int +type Frame struct { + pkgName string + name string // full name, including package + params map[*ssa.Parameter]int // arguments to the function + values map[ssa.Value]llvm.Value // local variables } func NewCompiler(path, triple string) (*Compiler, error) { - c := &Compiler{ - funcs: make(map[*ssa.Function]*Function), - } + c := &Compiler{} target, err := llvm.GetTargetFromTriple(triple) if err != nil { @@ -106,14 +105,17 @@ func (c *Compiler) Parse(path string) error { } sort.Strings(memberNames) + frames := make(map[*ssa.Function]*Frame) + // First, build all function declarations. for _, name := range memberNames { member := pkg.Members[name] if member, ok := member.(*ssa.Function); ok { - err := c.parseFuncDecl(pkg.Pkg.Name(), member) + frame, err := c.parseFuncDecl(pkg.Pkg.Name(), member) if err != nil { return err } + frames[member] = frame } } @@ -126,7 +128,7 @@ func (c *Compiler) Parse(path string) error { } switch member := member.(type) { case *ssa.Function: - err := c.parseFunc(pkg.Pkg.Name(), member) + err := c.parseFunc(frames[member], member) if err != nil { return err } @@ -140,11 +142,13 @@ func (c *Compiler) Parse(path string) error { return nil } -func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { +func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) { name := pkgName + "." + f.Name() - c.funcs[f] = &Function{ - name: name, - params: make(map[*ssa.Parameter]int), + frame := &Frame{ + pkgName: pkgName, + name: name, + params: make(map[*ssa.Parameter]int), + values: make(map[ssa.Value]llvm.Value), } var retType llvm.Type @@ -160,13 +164,13 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { case types.Int32: retType = llvm.Int32Type() default: - return errors.New("todo: unknown basic return type") + return nil, errors.New("todo: unknown basic return type") } default: - return errors.New("todo: unknown return type") + return nil, errors.New("todo: unknown return type") } } else { - return errors.New("todo: return values") + return nil, errors.New("todo: return values") } var paramTypes []llvm.Type @@ -180,25 +184,24 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { case types.Int32: paramType = llvm.Int32Type() default: - return errors.New("todo: unknown basic param type") + return nil, errors.New("todo: unknown basic param type") } paramTypes = append(paramTypes, paramType) - c.funcs[f].params[param] = i + frame.params[param] = i default: - return errors.New("todo: unknown param type") + return nil, errors.New("todo: unknown param type") } } fnType := llvm.FunctionType(retType, paramTypes, false) llvm.AddFunction(c.mod, name, fnType) - return nil + return frame, nil } -func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { +func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error { fmt.Println("func:", f.Name()) - fn := c.funcs[f] - llvmFn := c.mod.NamedFunction(fn.name) + llvmFn := c.mod.NamedFunction(frame.name) start := c.ctx.AddBasicBlock(llvmFn, "start") c.builder.SetInsertPointAtEnd(start) @@ -206,7 +209,7 @@ func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { for _, block := range f.Blocks { for _, instr := range block.Instrs { fmt.Printf(" instr: %v\n", instr) - err := c.parseInstr(pkgName, instr) + err := c.parseInstr(frame, instr) if err != nil { return err } @@ -215,17 +218,17 @@ func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { return nil } -func (c *Compiler) parseInstr(pkgName string, instr ssa.Instruction) error { +func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error { switch instr := instr.(type) { case *ssa.Call: - _, err := c.parseCall(pkgName, instr) + _, err := c.parseCall(frame, instr) return err case *ssa.Return: if len(instr.Results) == 0 { c.builder.CreateRetVoid() return nil } else if len(instr.Results) == 1 { - val, err := c.parseExpr(pkgName, instr.Results[0]) + val, err := c.parseExpr(frame, instr.Results[0]) if err != nil { return err } @@ -234,15 +237,15 @@ func (c *Compiler) parseInstr(pkgName string, instr ssa.Instruction) error { } else { return errors.New("todo: return value") } - case *ssa.BinOp: - _, err := c.parseBinOp(pkgName, instr) + case ssa.Value: + _, err := c.parseExpr(frame, instr) return err default: return errors.New("unknown instruction: " + fmt.Sprintf("%#v", instr)) } } -func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { +func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { fmt.Printf(" builtin: %#v\n", call) name := call.Name() @@ -253,7 +256,7 @@ func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa c.builder.CreateCall(c.printspaceFunc, nil, "") } fmt.Printf(" arg: %s\n", arg); - expr, err := c.parseExpr(pkgName, arg) + expr, err := c.parseExpr(frame, arg) if err != nil { return llvm.Value{}, err } @@ -275,13 +278,13 @@ func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa } } -func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { +func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { fmt.Printf(" function: %s\n", fn) name := fn.Name() if strings.IndexByte(name, '.') == -1 { // TODO: import path instead of pkgName - name = pkgName + "." + name + name = frame.pkgName + "." + name } target := c.mod.NamedFunction(name) if target.IsNil() { @@ -290,7 +293,7 @@ func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *s var params []llvm.Value for _, param := range call.Args { - val, err := c.parseExpr(pkgName, param) + val, err := c.parseExpr(frame, param) if err != nil { return llvm.Value{}, err } @@ -300,25 +303,25 @@ func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *s return c.builder.CreateCall(target, params, ""), nil } -func (c *Compiler) parseCall(pkgName string, instr *ssa.Call) (llvm.Value, error) { +func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error) { fmt.Printf(" call: %s\n", instr) switch call := instr.Common().Value.(type) { case *ssa.Builtin: - return c.parseBuiltin(pkgName, instr.Common(), call) + return c.parseBuiltin(frame, instr.Common(), call) case *ssa.Function: - return c.parseFunctionCall(pkgName, instr.Common(), call) + return c.parseFunctionCall(frame, instr.Common(), call) default: return llvm.Value{}, errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call)) } } -func (c *Compiler) parseBinOp(pkgName string, binop *ssa.BinOp) (llvm.Value, error) { - x, err := c.parseExpr(pkgName, binop.X) +func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) { + x, err := c.parseExpr(frame, binop.X) if err != nil { return llvm.Value{}, err } - y, err := c.parseExpr(pkgName, binop.Y) + y, err := c.parseExpr(frame, binop.Y) if err != nil { return llvm.Value{}, err } @@ -332,7 +335,7 @@ func (c *Compiler) parseBinOp(pkgName string, binop *ssa.BinOp) (llvm.Value, err } } -func (c *Compiler) parseExpr(pkgName string, expr ssa.Value) (llvm.Value, error) { +func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { fmt.Printf(" expr: %v\n", expr) switch expr := expr.(type) { case *ssa.Const: @@ -353,13 +356,12 @@ func (c *Compiler) parseExpr(pkgName string, expr ssa.Value) (llvm.Value, error) return llvm.Value{}, errors.New("todo: unknown constant") } case *ssa.BinOp: - return c.parseBinOp(pkgName, expr) + return c.parseBinOp(frame, expr) case *ssa.Call: - return c.parseCall(pkgName, expr) + return c.parseCall(frame, expr) case *ssa.Parameter: - fn := c.funcs[expr.Parent()] - llvmFn := c.mod.NamedFunction(fn.name) - return llvmFn.Param(fn.params[expr]), nil + llvmFn := c.mod.NamedFunction(frame.name) + return llvmFn.Param(frame.params[expr]), nil default: return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) }