Add Jump and Phi support, to enable things like for loops

Этот коммит содержится в:
Ayke van Laethem 2018-04-14 19:07:29 +02:00
родитель ad98a29a6f
коммит 63a545540d
2 изменённых файлов: 65 добавлений и 18 удалений

Просмотреть файл

@ -9,6 +9,7 @@ func main() {
println("5 ** 2 =", square(5)) println("5 ** 2 =", square(5))
println("3 + 12 =", add(3, 12)) println("3 + 12 =", add(3, 12))
println("fib(11) =", fib(11)) println("fib(11) =", fib(11))
println("sumrange(100) =", sumrange(100))
} }
func calculateAnswer() int { func calculateAnswer() int {
@ -28,6 +29,13 @@ func fib(n int) int {
if n <= 2 { if n <= 2 {
return 1 return 1
} }
ret := fib(n - 1) + fib(n - 2) return fib(n - 1) + fib(n - 2)
return ret }
func sumrange(n int) int {
sum := 0
for i := 1; i <= n; i++ {
sum += i
}
return sum
} }

71
tgo.go
Просмотреть файл

@ -47,6 +47,12 @@ type Frame struct {
params map[*ssa.Parameter]int // arguments to the function params map[*ssa.Parameter]int // arguments to the function
locals map[ssa.Value]llvm.Value // local variables locals map[ssa.Value]llvm.Value // local variables
blocks map[*ssa.BasicBlock]llvm.BasicBlock blocks map[*ssa.BasicBlock]llvm.BasicBlock
phis []Phi
}
type Phi struct {
ssa *ssa.Phi
llvm llvm.Value
} }
func NewCompiler(path, triple string) (*Compiler, error) { func NewCompiler(path, triple string) (*Compiler, error) {
@ -143,6 +149,22 @@ func (c *Compiler) Parse(path string) error {
return nil return nil
} }
func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) {
switch typ := goType.(type) {
case *types.Basic:
switch typ.Kind() {
case types.Int:
return c.intType, nil
case types.Int32:
return llvm.Int32Type(), nil
default:
return llvm.Type{}, errors.New("todo: unknown basic type")
}
default:
return llvm.Type{}, errors.New("todo: unknown type")
}
}
func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) { func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error) {
name := pkgName + "." + f.Name() name := pkgName + "." + f.Name()
frame := &Frame{ frame := &Frame{
@ -157,19 +179,10 @@ func (c *Compiler) parseFuncDecl(pkgName string, f *ssa.Function) (*Frame, error
if f.Signature.Results() == nil { if f.Signature.Results() == nil {
retType = llvm.VoidType() retType = llvm.VoidType()
} else if f.Signature.Results().Len() == 1 { } else if f.Signature.Results().Len() == 1 {
result := f.Signature.Results().At(0) var err error
switch typ := result.Type().(type) { retType, err = c.getLLVMType(f.Signature.Results().At(0).Type())
case *types.Basic: if err != nil {
switch typ.Kind() { return nil, err
case types.Int:
retType = c.intType
case types.Int32:
retType = llvm.Int32Type()
default:
return nil, errors.New("todo: unknown basic return type")
}
default:
return nil, errors.New("todo: unknown return type")
} }
} else { } else {
return nil, errors.New("todo: return values") return nil, errors.New("todo: return values")
@ -205,13 +218,13 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
// Pre-create all basic blocks in the function. // Pre-create all basic blocks in the function.
llvmFn := c.mod.NamedFunction(frame.name) llvmFn := c.mod.NamedFunction(frame.name)
for _, block := range f.Blocks { for _, block := range f.DomPreorder() {
llvmBlock := c.ctx.AddBasicBlock(llvmFn, "block") llvmBlock := c.ctx.AddBasicBlock(llvmFn, block.Comment)
frame.blocks[block] = llvmBlock frame.blocks[block] = llvmBlock
} }
// Fill those blocks with instructions. // Fill those blocks with instructions.
for _, block := range f.Blocks { for _, block := range f.DomPreorder() {
c.builder.SetInsertPointAtEnd(frame.blocks[block]) c.builder.SetInsertPointAtEnd(frame.blocks[block])
for _, instr := range block.Instrs { for _, instr := range block.Instrs {
fmt.Printf(" instr: %v\n", instr) fmt.Printf(" instr: %v\n", instr)
@ -221,6 +234,20 @@ func (c *Compiler) parseFunc(frame *Frame, f *ssa.Function) error {
} }
} }
} }
// Resolve phi nodes
for _, phi := range frame.phis {
block := phi.ssa.Block()
for i, edge := range phi.ssa.Edges {
llvmVal, err := c.parseExpr(frame, edge)
if err != nil {
return err
}
llvmBlock := frame.blocks[block.Preds[i]]
phi.llvm.AddIncoming([]llvm.Value{llvmVal}, []llvm.BasicBlock{llvmBlock})
}
}
return nil return nil
} }
@ -240,6 +267,10 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
blockElse := frame.blocks[block.Succs[1]] blockElse := frame.blocks[block.Succs[1]]
c.builder.CreateCondBr(cond, blockThen, blockElse) c.builder.CreateCondBr(cond, blockThen, blockElse)
return nil return nil
case *ssa.Jump:
blockJump := frame.blocks[instr.Block().Succs[0]]
c.builder.CreateBr(blockJump)
return nil
case *ssa.Return: case *ssa.Return:
if len(instr.Results) == 0 { if len(instr.Results) == 0 {
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
@ -415,6 +446,14 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
case *ssa.Parameter: case *ssa.Parameter:
llvmFn := c.mod.NamedFunction(frame.name) llvmFn := c.mod.NamedFunction(frame.name)
return llvmFn.Param(frame.params[expr]), nil return llvmFn.Param(frame.params[expr]), nil
case *ssa.Phi:
t, err := c.getLLVMType(expr.Type())
if err != nil {
return llvm.Value{}, err
}
phi := c.builder.CreatePHI(t, "")
frame.phis = append(frame.phis, Phi{expr, phi})
return phi, nil
default: default:
return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr)) return llvm.Value{}, errors.New("todo: unknown expression: " + fmt.Sprintf("%#v", expr))
} }