Make expression statement more generic; introduce *Frame for func state

Этот коммит содержится в:
Ayke van Laethem 2018-04-13 20:19:54 +02:00
родитель 9d8d0b9e03
коммит e0e04b88cb

94
tgo.go
Просмотреть файл

@ -39,18 +39,17 @@ type Compiler struct {
printintFunc llvm.Value printintFunc llvm.Value
printspaceFunc llvm.Value printspaceFunc llvm.Value
printnlFunc llvm.Value printnlFunc llvm.Value
funcs map[*ssa.Function]*Function
} }
type Function struct { type Frame struct {
name string pkgName string
params map[*ssa.Parameter]int name string // full name, including package
params map[*ssa.Parameter]int // arguments to the function
values map[ssa.Value]llvm.Value // local variables
} }
func NewCompiler(path, triple string) (*Compiler, error) { func NewCompiler(path, triple string) (*Compiler, error) {
c := &Compiler{ c := &Compiler{}
funcs: make(map[*ssa.Function]*Function),
}
target, err := llvm.GetTargetFromTriple(triple) target, err := llvm.GetTargetFromTriple(triple)
if err != nil { if err != nil {
@ -106,14 +105,17 @@ func (c *Compiler) Parse(path string) error {
} }
sort.Strings(memberNames) sort.Strings(memberNames)
frames := make(map[*ssa.Function]*Frame)
// First, build all function declarations. // First, build all function declarations.
for _, name := range memberNames { for _, name := range memberNames {
member := pkg.Members[name] member := pkg.Members[name]
if member, ok := member.(*ssa.Function); ok { if member, ok := member.(*ssa.Function); ok {
err := c.parseFuncDecl(pkg.Pkg.Name(), member) frame, err := c.parseFuncDecl(pkg.Pkg.Name(), member)
if err != nil { if err != nil {
return err return err
} }
frames[member] = frame
} }
} }
@ -126,7 +128,7 @@ func (c *Compiler) Parse(path string) error {
} }
switch member := member.(type) { switch member := member.(type) {
case *ssa.Function: case *ssa.Function:
err := c.parseFunc(pkg.Pkg.Name(), member) err := c.parseFunc(frames[member], member)
if err != nil { if err != nil {
return err return err
} }
@ -140,11 +142,13 @@ func (c *Compiler) Parse(path string) error {
return nil return nil
} }
func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error { func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) {
name := pkgName + "." + f.Name() name := pkgName + "." + f.Name()
c.funcs[f] = &Function{ frame := &Frame{
name: name, pkgName: pkgName,
params: make(map[*ssa.Parameter]int), name: name,
params: make(map[*ssa.Parameter]int),
values: make(map[ssa.Value]llvm.Value),
} }
var retType llvm.Type var retType llvm.Type
@ -160,13 +164,13 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error {
case types.Int32: case types.Int32:
retType = llvm.Int32Type() retType = llvm.Int32Type()
default: default:
return errors.New("todo: unknown basic return type") return nil, errors.New("todo: unknown basic return type")
} }
default: default:
return errors.New("todo: unknown return type") return nil, errors.New("todo: unknown return type")
} }
} else { } else {
return errors.New("todo: return values") return nil, errors.New("todo: return values")
} }
var paramTypes []llvm.Type var paramTypes []llvm.Type
@ -180,25 +184,24 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) error {
case types.Int32: case types.Int32:
paramType = llvm.Int32Type() paramType = llvm.Int32Type()
default: default:
return errors.New("todo: unknown basic param type") return nil, errors.New("todo: unknown basic param type")
} }
paramTypes = append(paramTypes, paramType) paramTypes = append(paramTypes, paramType)
c.funcs[f].params[param] = i frame.params[param] = i
default: default:
return errors.New("todo: unknown param type") return nil, errors.New("todo: unknown param type")
} }
} }
fnType := llvm.FunctionType(retType, paramTypes, false) fnType := llvm.FunctionType(retType, paramTypes, false)
llvm.AddFunction(c.mod, name, fnType) llvm.AddFunction(c.mod, name, fnType)
return nil return frame, nil
} }
func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error { func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
fmt.Println("func:", f.Name()) fmt.Println("func:", f.Name())
fn := c.funcs[f] llvmFn := c.mod.NamedFunction(frame.name)
llvmFn := c.mod.NamedFunction(fn.name)
start := c.ctx.AddBasicBlock(llvmFn, "start") start := c.ctx.AddBasicBlock(llvmFn, "start")
c.builder.SetInsertPointAtEnd(start) c.builder.SetInsertPointAtEnd(start)
@ -206,7 +209,7 @@ func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error {
for _, block := range f.Blocks { for _, block := range f.Blocks {
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
fmt.Printf(" instr: %v\n", instr) fmt.Printf(" instr: %v\n", instr)
err := c.parseInstr(pkgName, instr) err := c.parseInstr(frame, instr)
if err != nil { if err != nil {
return err return err
} }
@ -215,17 +218,17 @@ func (c *Compiler) parseFunc(pkgName string, f *ssa.Function) error {
return nil return nil
} }
func (c *Compiler) parseInstr(pkgName string, instr ssa.Instruction) error { func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
switch instr := instr.(type) { switch instr := instr.(type) {
case *ssa.Call: case *ssa.Call:
_, err := c.parseCall(pkgName, instr) _, err := c.parseCall(frame, instr)
return err return err
case *ssa.Return: case *ssa.Return:
if len(instr.Results) == 0 { if len(instr.Results) == 0 {
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
return nil return nil
} else if len(instr.Results) == 1 { } else if len(instr.Results) == 1 {
val, err := c.parseExpr(pkgName, instr.Results[0]) val, err := c.parseExpr(frame, instr.Results[0])
if err != nil { if err != nil {
return err return err
} }
@ -234,15 +237,15 @@ func (c *Compiler) parseInstr(pkgName string, instr ssa.Instruction) error {
} else { } else {
return errors.New("todo: return value") return errors.New("todo: return value")
} }
case *ssa.BinOp: case ssa.Value:
_, err := c.parseBinOp(pkgName, instr) _, err := c.parseExpr(frame, instr)
return err return err
default: default:
return errors.New("unknown instruction: " + fmt.Sprintf("%#v", instr)) return errors.New("unknown instruction: " + fmt.Sprintf("%#v", instr))
} }
} }
func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) { func (c *Compiler) parseBuiltin(frame *Frame, instr *ssa.CallCommon, call *ssa.Builtin) (llvm.Value, error) {
fmt.Printf(" builtin: %#v\n", call) fmt.Printf(" builtin: %#v\n", call)
name := call.Name() name := call.Name()
@ -253,7 +256,7 @@ func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa
c.builder.CreateCall(c.printspaceFunc, nil, "") c.builder.CreateCall(c.printspaceFunc, nil, "")
} }
fmt.Printf(" arg: %s\n", arg); fmt.Printf(" arg: %s\n", arg);
expr, err := c.parseExpr(pkgName, arg) expr, err := c.parseExpr(frame, arg)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -275,13 +278,13 @@ func (c *Compiler) parseBuiltin(pkgName string, instr *ssa.CallCommon, call *ssa
} }
} }
func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) { func (c *Compiler) parseFunctionCall(frame *Frame, call *ssa.CallCommon, fn *ssa.Function) (llvm.Value, error) {
fmt.Printf(" function: %s\n", fn) fmt.Printf(" function: %s\n", fn)
name := fn.Name() name := fn.Name()
if strings.IndexByte(name, '.') == -1 { if strings.IndexByte(name, '.') == -1 {
// TODO: import path instead of pkgName // TODO: import path instead of pkgName
name = pkgName + "." + name name = frame.pkgName + "." + name
} }
target := c.mod.NamedFunction(name) target := c.mod.NamedFunction(name)
if target.IsNil() { if target.IsNil() {
@ -290,7 +293,7 @@ func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *s
var params []llvm.Value var params []llvm.Value
for _, param := range call.Args { for _, param := range call.Args {
val, err := c.parseExpr(pkgName, param) val, err := c.parseExpr(frame, param)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -300,25 +303,25 @@ func (c *Compiler) parseFunctionCall(pkgName string, call *ssa.CallCommon, fn *s
return c.builder.CreateCall(target, params, ""), nil return c.builder.CreateCall(target, params, ""), nil
} }
func (c *Compiler) parseCall(pkgName string, instr *ssa.Call) (llvm.Value, error) { func (c *Compiler) parseCall(frame *Frame, instr *ssa.Call) (llvm.Value, error) {
fmt.Printf(" call: %s\n", instr) fmt.Printf(" call: %s\n", instr)
switch call := instr.Common().Value.(type) { switch call := instr.Common().Value.(type) {
case *ssa.Builtin: case *ssa.Builtin:
return c.parseBuiltin(pkgName, instr.Common(), call) return c.parseBuiltin(frame, instr.Common(), call)
case *ssa.Function: case *ssa.Function:
return c.parseFunctionCall(pkgName, instr.Common(), call) return c.parseFunctionCall(frame, instr.Common(), call)
default: default:
return llvm.Value{}, errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call)) return llvm.Value{}, errors.New("todo: unknown call type: " + fmt.Sprintf("%#v", call))
} }
} }
func (c *Compiler) parseBinOp(pkgName string, binop *ssa.BinOp) (llvm.Value, error) { func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error) {
x, err := c.parseExpr(pkgName, binop.X) x, err := c.parseExpr(frame, binop.X)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
y, err := c.parseExpr(pkgName, binop.Y) y, err := c.parseExpr(frame, binop.Y)
if err != nil { if err != nil {
return llvm.Value{}, err return llvm.Value{}, err
} }
@ -332,7 +335,7 @@ func (c *Compiler) parseBinOp(pkgName string, binop *ssa.BinOp) (llvm.Value, err
} }
} }
func (c *Compiler) parseExpr(pkgName string, expr ssa.Value) (llvm.Value, error) { func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
fmt.Printf(" expr: %v\n", expr) fmt.Printf(" expr: %v\n", expr)
switch expr := expr.(type) { switch expr := expr.(type) {
case *ssa.Const: case *ssa.Const:
@ -353,13 +356,12 @@ func (c *Compiler) parseExpr(pkgName string, expr ssa.Value) (llvm.Value, error)
return llvm.Value{}, errors.New("todo: unknown constant") return llvm.Value{}, errors.New("todo: unknown constant")
} }
case *ssa.BinOp: case *ssa.BinOp:
return c.parseBinOp(pkgName, expr) return c.parseBinOp(frame, expr)
case *ssa.Call: case *ssa.Call:
return c.parseCall(pkgName, expr) return c.parseCall(frame, expr)
case *ssa.Parameter: case *ssa.Parameter:
fn := c.funcs[expr.Parent()] llvmFn := c.mod.NamedFunction(frame.name)
llvmFn := c.mod.NamedFunction(fn.name) return llvmFn.Param(frame.params[expr]), nil
return llvmFn.Param(fn.params[expr]), nil
default: default:
return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr))
} }