compiler: refactor top-level createInstruction function

Этот коммит содержится в:
Ayke van Laethem 2019-12-09 18:16:45 +01:00 коммит произвёл Ron Evans
родитель c1521fe12e
коммит ad992e2456
2 изменённых файлов: 64 добавлений и 62 удалений

Просмотреть файл

@ -1000,7 +1000,7 @@ func (c *Compiler) parseFunc(frame *Frame) {
fmt.Printf("\t%s\n", instr.String()) fmt.Printf("\t%s\n", instr.String())
} }
} }
c.parseInstr(frame, instr) frame.createInstruction(instr)
} }
if frame.fn.Name() == "init" && len(block.Instrs) == 0 { if frame.fn.Name() == "init" && len(block.Instrs) == 0 {
c.builder.CreateRetVoid() c.builder.CreateRetVoid()
@ -1018,69 +1018,71 @@ func (c *Compiler) parseFunc(frame *Frame) {
} }
} }
func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) { // createInstruction builds the LLVM IR equivalent instructions for the
if c.Debug() { // particular Go SSA instruction.
pos := c.ir.Program.Fset.Position(instr.Pos()) func (b *builder) createInstruction(instr ssa.Instruction) {
c.builder.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), frame.difunc, llvm.Metadata{}) if b.Debug() {
pos := b.ir.Program.Fset.Position(instr.Pos())
b.SetCurrentDebugLocation(uint(pos.Line), uint(pos.Column), b.difunc, llvm.Metadata{})
} }
switch instr := instr.(type) { switch instr := instr.(type) {
case ssa.Value: case ssa.Value:
if value, err := frame.createExpr(instr); err != nil { if value, err := b.createExpr(instr); err != nil {
// This expression could not be parsed. Add the error to the list // This expression could not be parsed. Add the error to the list
// of diagnostics and continue with an undef value. // of diagnostics and continue with an undef value.
// The resulting IR will be incorrect (but valid). However, // The resulting IR will be incorrect (but valid). However,
// compilation can proceed which is useful because there may be // compilation can proceed which is useful because there may be
// more compilation errors which can then all be shown together to // more compilation errors which can then all be shown together to
// the user. // the user.
c.diagnostics = append(c.diagnostics, err) b.diagnostics = append(b.diagnostics, err)
frame.locals[instr] = llvm.Undef(c.getLLVMType(instr.Type())) b.locals[instr] = llvm.Undef(b.getLLVMType(instr.Type()))
} else { } else {
frame.locals[instr] = value b.locals[instr] = value
if len(*instr.Referrers()) != 0 && c.NeedsStackObjects() { if len(*instr.Referrers()) != 0 && b.NeedsStackObjects() {
c.trackExpr(frame, instr, value) b.trackExpr(instr, value)
} }
} }
case *ssa.DebugRef: case *ssa.DebugRef:
// ignore // ignore
case *ssa.Defer: case *ssa.Defer:
frame.createDefer(instr) b.createDefer(instr)
case *ssa.Go: case *ssa.Go:
// Get all function parameters to pass to the goroutine. // Get all function parameters to pass to the goroutine.
var params []llvm.Value var params []llvm.Value
for _, param := range instr.Call.Args { for _, param := range instr.Call.Args {
params = append(params, frame.getValue(param)) params = append(params, b.getValue(param))
} }
// Start a new goroutine. // Start a new goroutine.
if callee := instr.Call.StaticCallee(); callee != nil { if callee := instr.Call.StaticCallee(); callee != nil {
// Static callee is known. This makes it easier to start a new // Static callee is known. This makes it easier to start a new
// goroutine. // goroutine.
calleeFn := c.ir.GetFunction(callee) calleeFn := b.ir.GetFunction(callee)
var context llvm.Value var context llvm.Value
switch value := instr.Call.Value.(type) { switch value := instr.Call.Value.(type) {
case *ssa.Function: case *ssa.Function:
// Goroutine call is regular function call. No context is necessary. // Goroutine call is regular function call. No context is necessary.
context = llvm.Undef(c.i8ptrType) context = llvm.Undef(b.i8ptrType)
case *ssa.MakeClosure: case *ssa.MakeClosure:
// A goroutine call on a func value, but the callee is trivial to find. For // A goroutine call on a func value, but the callee is trivial to find. For
// example: immediately applied functions. // example: immediately applied functions.
funcValue := frame.getValue(value) funcValue := b.getValue(value)
context = frame.extractFuncContext(funcValue) context = b.extractFuncContext(funcValue)
default: default:
panic("StaticCallee returned an unexpected value") panic("StaticCallee returned an unexpected value")
} }
params = append(params, context) // context parameter params = append(params, context) // context parameter
frame.createGoInstruction(calleeFn.LLVMFn, params) b.createGoInstruction(calleeFn.LLVMFn, params)
} else if !instr.Call.IsInvoke() { } else if !instr.Call.IsInvoke() {
// This is a function pointer. // This is a function pointer.
// At the moment, two extra params are passed to the newly started // At the moment, two extra params are passed to the newly started
// goroutine: // goroutine:
// * The function context, for closures. // * The function context, for closures.
// * The function pointer (for tasks). // * The function pointer (for tasks).
funcPtr, context := frame.decodeFuncValue(frame.getValue(instr.Call.Value), instr.Call.Value.Type().(*types.Signature)) funcPtr, context := b.decodeFuncValue(b.getValue(instr.Call.Value), instr.Call.Value.Type().(*types.Signature))
params = append(params, context) // context parameter params = append(params, context) // context parameter
switch c.Scheduler() { switch b.Scheduler() {
case "none", "coroutines": case "none", "coroutines":
// There are no additional parameters needed for the goroutine start operation. // There are no additional parameters needed for the goroutine start operation.
case "tasks": case "tasks":
@ -1089,58 +1091,58 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) {
default: default:
panic("unknown scheduler type") panic("unknown scheduler type")
} }
frame.createGoInstruction(funcPtr, params) b.createGoInstruction(funcPtr, params)
} else { } else {
c.addError(instr.Pos(), "todo: go on interface call") b.addError(instr.Pos(), "todo: go on interface call")
} }
case *ssa.If: case *ssa.If:
cond := frame.getValue(instr.Cond) cond := b.getValue(instr.Cond)
block := instr.Block() block := instr.Block()
blockThen := frame.blockEntries[block.Succs[0]] blockThen := b.blockEntries[block.Succs[0]]
blockElse := frame.blockEntries[block.Succs[1]] blockElse := b.blockEntries[block.Succs[1]]
c.builder.CreateCondBr(cond, blockThen, blockElse) b.CreateCondBr(cond, blockThen, blockElse)
case *ssa.Jump: case *ssa.Jump:
blockJump := frame.blockEntries[instr.Block().Succs[0]] blockJump := b.blockEntries[instr.Block().Succs[0]]
c.builder.CreateBr(blockJump) b.CreateBr(blockJump)
case *ssa.MapUpdate: case *ssa.MapUpdate:
m := frame.getValue(instr.Map) m := b.getValue(instr.Map)
key := frame.getValue(instr.Key) key := b.getValue(instr.Key)
value := frame.getValue(instr.Value) value := b.getValue(instr.Value)
mapType := instr.Map.Type().Underlying().(*types.Map) mapType := instr.Map.Type().Underlying().(*types.Map)
frame.createMapUpdate(mapType.Key(), m, key, value, instr.Pos()) b.createMapUpdate(mapType.Key(), m, key, value, instr.Pos())
case *ssa.Panic: case *ssa.Panic:
value := frame.getValue(instr.X) value := b.getValue(instr.X)
c.createRuntimeCall("_panic", []llvm.Value{value}, "") b.createRuntimeCall("_panic", []llvm.Value{value}, "")
c.builder.CreateUnreachable() b.CreateUnreachable()
case *ssa.Return: case *ssa.Return:
if len(instr.Results) == 0 { if len(instr.Results) == 0 {
c.builder.CreateRetVoid() b.CreateRetVoid()
} else if len(instr.Results) == 1 { } else if len(instr.Results) == 1 {
c.builder.CreateRet(frame.getValue(instr.Results[0])) b.CreateRet(b.getValue(instr.Results[0]))
} else { } else {
// Multiple return values. Put them all in a struct. // Multiple return values. Put them all in a struct.
retVal := llvm.ConstNull(frame.fn.LLVMFn.Type().ElementType().ReturnType()) retVal := llvm.ConstNull(b.fn.LLVMFn.Type().ElementType().ReturnType())
for i, result := range instr.Results { for i, result := range instr.Results {
val := frame.getValue(result) val := b.getValue(result)
retVal = c.builder.CreateInsertValue(retVal, val, i, "") retVal = b.CreateInsertValue(retVal, val, i, "")
} }
c.builder.CreateRet(retVal) b.CreateRet(retVal)
} }
case *ssa.RunDefers: case *ssa.RunDefers:
frame.createRunDefers() b.createRunDefers()
case *ssa.Send: case *ssa.Send:
frame.createChanSend(instr) b.createChanSend(instr)
case *ssa.Store: case *ssa.Store:
llvmAddr := frame.getValue(instr.Addr) llvmAddr := b.getValue(instr.Addr)
llvmVal := frame.getValue(instr.Val) llvmVal := b.getValue(instr.Val)
frame.createNilCheck(llvmAddr, "store") b.createNilCheck(llvmAddr, "store")
if c.targetData.TypeAllocSize(llvmVal.Type()) == 0 { if b.targetData.TypeAllocSize(llvmVal.Type()) == 0 {
// nothing to store // nothing to store
return return
} }
c.builder.CreateStore(llvmVal, llvmAddr) b.CreateStore(llvmVal, llvmAddr)
default: default:
c.addError(instr.Pos(), "unknown instruction: "+instr.String()) b.addError(instr.Pos(), "unknown instruction: "+instr.String())
} }
} }

Просмотреть файл

@ -12,53 +12,53 @@ import (
// trackExpr inserts pointer tracking intrinsics for the GC if the expression is // trackExpr inserts pointer tracking intrinsics for the GC if the expression is
// one of the expressions that need this. // one of the expressions that need this.
func (c *Compiler) trackExpr(frame *Frame, expr ssa.Value, value llvm.Value) { func (b *builder) trackExpr(expr ssa.Value, value llvm.Value) {
// There are uses of this expression, Make sure the pointers // There are uses of this expression, Make sure the pointers
// are tracked during GC. // are tracked during GC.
switch expr := expr.(type) { switch expr := expr.(type) {
case *ssa.Alloc, *ssa.MakeChan, *ssa.MakeMap: case *ssa.Alloc, *ssa.MakeChan, *ssa.MakeMap:
// These values are always of pointer type in IR. // These values are always of pointer type in IR.
c.trackPointer(value) b.trackPointer(value)
case *ssa.Call, *ssa.Convert, *ssa.MakeClosure, *ssa.MakeInterface, *ssa.MakeSlice, *ssa.Next: case *ssa.Call, *ssa.Convert, *ssa.MakeClosure, *ssa.MakeInterface, *ssa.MakeSlice, *ssa.Next:
if !value.IsNil() { if !value.IsNil() {
c.trackValue(value) b.trackValue(value)
} }
case *ssa.Select: case *ssa.Select:
if alloca, ok := frame.selectRecvBuf[expr]; ok { if alloca, ok := b.selectRecvBuf[expr]; ok {
if alloca.IsAUndefValue().IsNil() { if alloca.IsAUndefValue().IsNil() {
c.trackPointer(alloca) b.trackPointer(alloca)
} }
} }
case *ssa.UnOp: case *ssa.UnOp:
switch expr.Op { switch expr.Op {
case token.MUL: case token.MUL:
// Pointer dereference. // Pointer dereference.
c.trackValue(value) b.trackValue(value)
case token.ARROW: case token.ARROW:
// Channel receive operator. // Channel receive operator.
// It's not necessary to look at commaOk here, because in that // It's not necessary to look at commaOk here, because in that
// case it's just an aggregate and trackValue will extract the // case it's just an aggregate and trackValue will extract the
// pointer in there (if there is one). // pointer in there (if there is one).
c.trackValue(value) b.trackValue(value)
} }
} }
} }
// trackValue locates pointers in a value (possibly an aggregate) and tracks the // trackValue locates pointers in a value (possibly an aggregate) and tracks the
// individual pointers // individual pointers
func (c *Compiler) trackValue(value llvm.Value) { func (b *builder) trackValue(value llvm.Value) {
typ := value.Type() typ := value.Type()
switch typ.TypeKind() { switch typ.TypeKind() {
case llvm.PointerTypeKind: case llvm.PointerTypeKind:
c.trackPointer(value) b.trackPointer(value)
case llvm.StructTypeKind: case llvm.StructTypeKind:
if !typeHasPointers(typ) { if !typeHasPointers(typ) {
return return
} }
numElements := typ.StructElementTypesCount() numElements := typ.StructElementTypesCount()
for i := 0; i < numElements; i++ { for i := 0; i < numElements; i++ {
subValue := c.builder.CreateExtractValue(value, i, "") subValue := b.CreateExtractValue(value, i, "")
c.trackValue(subValue) b.trackValue(subValue)
} }
case llvm.ArrayTypeKind: case llvm.ArrayTypeKind:
if !typeHasPointers(typ) { if !typeHasPointers(typ) {
@ -66,8 +66,8 @@ func (c *Compiler) trackValue(value llvm.Value) {
} }
numElements := typ.ArrayLength() numElements := typ.ArrayLength()
for i := 0; i < numElements; i++ { for i := 0; i < numElements; i++ {
subValue := c.builder.CreateExtractValue(value, i, "") subValue := b.CreateExtractValue(value, i, "")
c.trackValue(subValue) b.trackValue(subValue)
} }
} }
} }