diff --git a/arm.ld b/arm.ld index c3c1e266..36a20982 100644 --- a/arm.ld +++ b/arm.ld @@ -75,3 +75,8 @@ SECTIONS __etext = _etext; __data_start__ = _sdata; __bss_start__ = _sbss; + +/* For the memory allocator. */ +_heap_start = _ebss; +_heap_end = ORIGIN(RAM) + LENGTH(RAM); +runtime.heapptr = _heap_start; /* necessary? */ diff --git a/src/runtime/gc.go b/src/runtime/gc.go new file mode 100644 index 00000000..7596af82 --- /dev/null +++ b/src/runtime/gc.go @@ -0,0 +1,31 @@ + +// +build !linux + +package runtime + +import ( + "unsafe" +) + +var ( + _extern__heap_start unsafe.Pointer // defined by the linker + heapptr = uintptr(unsafe.Pointer(&_extern__heap_start)) +) + +func alloc(size uintptr) unsafe.Pointer { + // TODO: this can be optimized by not casting between pointers and ints so + // much. And by using platform-native data types (e.g. *uint8 for 8-bit + // systems). + size = align(size) + addr := heapptr + heapptr += size + for i := uintptr(0); i < uintptr(size); i += 4 { + ptr := (*uint32)(unsafe.Pointer(addr + i)) + *ptr = 0 + } + return unsafe.Pointer(addr) +} + +func free(ptr unsafe.Pointer) { + // TODO: use a GC +} diff --git a/src/runtime/runtime_unix.go b/src/runtime/runtime_unix.go index 9f9fe9db..5ba18fef 100644 --- a/src/runtime/runtime_unix.go +++ b/src/runtime/runtime_unix.go @@ -3,6 +3,10 @@ package runtime +import ( + "unsafe" +) + // #include // #include // #include @@ -21,3 +25,15 @@ func Sleep(d Duration) { func abort() { C.abort() } + +func alloc(size uintptr) unsafe.Pointer { + buf := C.calloc(1, C.size_t(size)) + if buf == nil { + panic("cannot allocate memory") + } + return buf +} + +func free(ptr unsafe.Pointer) { + C.free(ptr) +} diff --git a/tgo.go b/tgo.go index 26b2a785..8f035b1f 100644 --- a/tgo.go +++ b/tgo.go @@ -35,11 +35,14 @@ type Compiler struct { machine llvm.TargetMachine targetData llvm.TargetData intType llvm.Type + i8ptrType llvm.Type // for convenience uintptrType llvm.Type stringLenType llvm.Type stringType llvm.Type interfaceType llvm.Type typeassertType llvm.Type + allocFunc llvm.Value + freeFunc llvm.Value itfTypeNumbers map[types.Type]uint64 itfTypes []types.Type initFuncs []llvm.Value @@ -86,15 +89,22 @@ func NewCompiler(pkgName, triple string) (*Compiler, error) { c.intType = llvm.Int32Type() c.stringLenType = llvm.Int32Type() c.uintptrType = c.targetData.IntPtrType() + c.i8ptrType = llvm.PointerType(llvm.Int8Type(), 0) // Go string: tuple of (len, ptr) - c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, llvm.PointerType(llvm.Int8Type(), 0)}, false) + c.stringType = llvm.StructType([]llvm.Type{c.stringLenType, c.i8ptrType}, false) // Go interface: tuple of (type, ptr) - c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), llvm.PointerType(llvm.Int8Type(), 0)}, false) + c.interfaceType = llvm.StructType([]llvm.Type{llvm.Int32Type(), c.i8ptrType}, false) // Go typeassert result: tuple of (ptr, bool) - c.typeassertType = llvm.StructType([]llvm.Type{llvm.PointerType(llvm.Int8Type(), 0), llvm.Int1Type()}, false) + c.typeassertType = llvm.StructType([]llvm.Type{c.i8ptrType, llvm.Int1Type()}, false) + + allocType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.uintptrType}, false) + c.allocFunc = llvm.AddFunction(c.mod, "runtime.alloc", allocType) + + freeType := llvm.FunctionType(llvm.VoidType(), []llvm.Type{c.i8ptrType}, false) + c.freeFunc = llvm.AddFunction(c.mod, "runtime.free", freeType) return c, nil } @@ -224,7 +234,7 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { case types.Uintptr: return c.uintptrType, nil case types.UnsafePointer: - return llvm.PointerType(llvm.Int8Type(), 0), nil + return c.i8ptrType, nil default: return llvm.Type{}, errors.New("todo: unknown basic type: " + fmt.Sprintf("%#v", typ)) } @@ -295,6 +305,16 @@ func (c *Compiler) getInterfaceType(typ types.Type) llvm.Value { return llvm.ConstInt(llvm.Int32Type(), c.itfTypeNumbers[typ], false) } +func (c *Compiler) isPointer(typ types.Type) bool { + if _, ok := typ.(*types.Pointer); ok { + return true + } else if typ, ok := typ.(*types.Basic); ok && typ.Kind() == types.UnsafePointer { + return true + } else { + return false + } +} + func (c *Compiler) getFunctionName(fn *ssa.Function) string { if fn.Signature.Recv() != nil { // Method on a defined type. @@ -311,6 +331,14 @@ func (c *Compiler) getFunctionName(fn *ssa.Function) string { } } +func (c *Compiler) getGlobalName(global *ssa.Global) string { + if strings.HasPrefix(global.Name(), "_extern_") { + return global.Name()[len("_extern_"):] + } else { + return pkgPrefix(global.Pkg) + "." + global.Name() + } +} + func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { fmt.Println("\npackage:", pkg.Pkg.Path()) @@ -359,13 +387,15 @@ func (c *Compiler) parsePackage(program *ssa.Program, pkg *ssa.Package) error { if err != nil { return err } - global := llvm.AddGlobal(c.mod, llvmType, pkgPrefix(member.Pkg) + "." + member.Name()) - global.SetLinkage(llvm.PrivateLinkage) - initializer, err := c.getZeroValue(llvmType) - if err != nil { - return err + global := llvm.AddGlobal(c.mod, llvmType, c.getGlobalName(member)) + if !strings.HasPrefix(member.Name(), "_extern_") { + global.SetLinkage(llvm.PrivateLinkage) + initializer, err := c.getZeroValue(llvmType) + if err != nil { + return err + } + global.SetInitializer(initializer) } - global.SetInitializer(initializer) case *ssa.Type: ms := program.MethodSets.MethodSet(member.Type()) for i := 0; i < ms.Len(); i++ { @@ -489,8 +519,7 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { if err != nil { return err } - fullName := pkgPrefix(addr.Pkg) + "." + addr.Name() - llvmAddr := c.mod.NamedGlobal(fullName) + llvmAddr := c.mod.NamedGlobal(c.getGlobalName(addr)) llvmAddr.SetInitializer(val) case *ssa.FieldAddr: // Initialize field of a global struct. @@ -503,8 +532,14 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { return err } global := addr.X.(*ssa.Global) - llvmAddr := c.mod.NamedGlobal(pkgPrefix(global.Pkg) + "." + global.Name()) + llvmAddr := c.mod.NamedGlobal(c.getGlobalName(global)) llvmValue := llvmAddr.Initializer() + if llvmValue.IsNil() { + llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) + if err != nil { + return err + } + } llvmValue = c.builder.CreateInsertValue(llvmValue, val, addr.Field, "") llvmAddr.SetInitializer(llvmValue) case *ssa.IndexAddr: @@ -519,8 +554,14 @@ func (c *Compiler) parseInitFunc(frame *Frame, f *ssa.Function) error { } fieldAddr := addr.X.(*ssa.FieldAddr) global := fieldAddr.X.(*ssa.Global) - llvmAddr := c.mod.NamedGlobal(pkgPrefix(global.Pkg) + "." + global.Name()) - llvmValue := c.mod.NamedGlobal(pkgPrefix(global.Pkg) + "." + global.Name()).Initializer() + llvmAddr := c.mod.NamedGlobal(c.getGlobalName(global)) + llvmValue := llvmAddr.Initializer() + if llvmValue.IsNil() { + llvmValue, err = c.getZeroValue(llvmAddr.Type().ElementType()) + if err != nil { + return err + } + } llvmFieldValue := c.builder.CreateExtractValue(llvmValue, fieldAddr.Field, "") llvmFieldValue = c.builder.CreateInsertValue(llvmFieldValue, val, int(index), "") llvmValue = c.builder.CreateInsertValue(llvmValue, llvmFieldValue, fieldAddr.Field, "") @@ -774,18 +815,17 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { var buf llvm.Value if expr.Heap { // TODO: escape analysis - buf = c.builder.CreateMalloc(typ, expr.Comment) + size := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(typ), false) + buf = c.builder.CreateCall(c.allocFunc, []llvm.Value{size}, expr.Comment) + buf = c.builder.CreateBitCast(buf, llvm.PointerType(typ, 0), "") } else { buf = c.builder.CreateAlloca(typ, expr.Comment) + zero, err := c.getZeroValue(typ) + if err != nil { + return llvm.Value{}, err + } + c.builder.CreateStore(zero, buf) // zero-initialize var } - if err != nil { - return llvm.Value{}, err - } - zero, err := c.getZeroValue(typ) - if err != nil { - return llvm.Value{}, err - } - c.builder.CreateStore(zero, buf) // zero-initialize var return buf, nil case *ssa.BinOp: return c.parseBinOp(frame, expr) @@ -815,7 +855,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { } return c.builder.CreateGEP(val, indices, ""), nil case *ssa.Global: - fullName := pkgPrefix(expr.Pkg) + "." + expr.Name() + fullName := c.getGlobalName(expr) value := c.mod.NamedGlobal(fullName) if value.IsNil() { return llvm.Value{}, errors.New("global not found: " + fullName) @@ -905,14 +945,15 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { var itfValue llvm.Value switch typ := expr.X.Type().(type) { case *types.Basic: - itfValueType := llvm.PointerType(llvm.Int8Type(), 0) if typ.Info() & types.IsInteger != 0 { // TODO: 64-bit int on 32-bit platform - itfValue = c.builder.CreateIntToPtr(val, itfValueType, "") + itfValue = c.builder.CreateIntToPtr(val, c.i8ptrType, "") } else if typ.Kind() == types.String { // TODO: escape analysis - itfValue = c.builder.CreateMalloc(c.stringType, "") - c.builder.CreateStore(val, itfValue) - itfValue = c.builder.CreateBitCast(itfValue, itfValueType, "") + size := c.targetData.TypeAllocSize(c.stringType) + sizeValue := llvm.ConstInt(c.uintptrType, size, false) + itfValue = c.builder.CreateCall(c.allocFunc, []llvm.Value{sizeValue}, "") + itfValueCast := c.builder.CreateBitCast(itfValue, llvm.PointerType(c.stringType, 0), "") + c.builder.CreateStore(val, itfValueCast) } else { return llvm.Value{}, errors.New("todo: make interface: unknown basic type") } @@ -920,7 +961,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { return llvm.Value{}, errors.New("todo: make interface: unknown type") } itfType := c.getInterfaceType(expr.X.Type()) - itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(llvm.PointerType(llvm.Int8Type(), 0))}, false) + itf := c.ctx.ConstStruct([]llvm.Value{itfType, llvm.Undef(c.i8ptrType)}, false) itf = c.builder.CreateInsertValue(itf, itfValue, 1, "") return itf, nil case *ssa.Phi: @@ -1057,34 +1098,34 @@ func (c *Compiler) parseBinOp(frame *Frame, binop *ssa.BinOp) (llvm.Value, error } func (c *Compiler) parseConst(expr *ssa.Const) (llvm.Value, error) { - switch expr.Value.Kind() { - case constant.Bool, constant.Int: - return c.parseConstInt(expr, expr.Type()) - case constant.String: - str := constant.StringVal(expr.Value) - strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) - strPtr := c.builder.CreateGlobalStringPtr(str, ".str") // TODO: remove \0 at end - strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) - return strObj, nil - default: - return llvm.Value{}, errors.New("todo: unknown constant: " + fmt.Sprintf("%#v", expr.Value.Kind())) + typ := expr.Type() + if named, ok := typ.(*types.Named); ok { + typ = named.Underlying() } -} - -func (c *Compiler) parseConstInt(expr *ssa.Const, typ types.Type) (llvm.Value, error) { switch typ := typ.(type) { case *types.Basic: llvmType, err := c.getLLVMType(typ) if err != nil { return llvm.Value{}, err } - if typ.Info() & types.IsBoolean != 0 { + if typ.Kind() == types.Bool { b := constant.BoolVal(expr.Value) n := uint64(0) if b { n = 1 } return llvm.ConstInt(llvmType, n, false), nil + } else if typ.Kind() == types.String { + str := constant.StringVal(expr.Value) + strLen := llvm.ConstInt(c.stringLenType, uint64(len(str)), false) + strPtr := c.builder.CreateGlobalStringPtr(str, ".str") // TODO: remove \0 at end + strObj := llvm.ConstStruct([]llvm.Value{strLen, strPtr}, false) + return strObj, nil + } else if typ.Kind() == types.UnsafePointer { + if !expr.IsNil() { + return llvm.Value{}, errors.New("todo: non-null constant pointer") + } + return llvm.ConstNull(c.i8ptrType), nil } else if typ.Info() & types.IsUnsigned != 0 { n, _ := constant.Uint64Val(expr.Value) return llvm.ConstInt(llvmType, n, false), nil @@ -1092,34 +1133,40 @@ func (c *Compiler) parseConstInt(expr *ssa.Const, typ types.Type) (llvm.Value, e n, _ := constant.Int64Val(expr.Value) return llvm.ConstInt(llvmType, uint64(n), true), nil } else { - return llvm.Value{}, errors.New("unknown integer constant") + return llvm.Value{}, errors.New("todo: unknown constant: " + fmt.Sprintf("%v", typ)) } - case *types.Named: - return c.parseConstInt(expr, typ.Underlying()) default: return llvm.Value{}, errors.New("todo: unknown constant: " + fmt.Sprintf("%#v", typ)) } } func (c *Compiler) parseConvert(frame *Frame, typeTo types.Type, x ssa.Value) (llvm.Value, error) { + value, err := c.parseExpr(frame, x) + if err != nil { + return value, nil + } + + llvmTypeFrom, err := c.getLLVMType(x.Type()) + if err != nil { + return llvm.Value{}, err + } + llvmTypeTo, err := c.getLLVMType(typeTo) + if err != nil { + return llvm.Value{}, err + } + switch typeTo := typeTo.(type) { case *types.Basic: - value, err := c.parseExpr(frame, x) - if err != nil { - return value, nil + isPtrFrom := c.isPointer(x.Type()) + isPtrTo := c.isPointer(typeTo) + if isPtrFrom && !isPtrTo { + return c.builder.CreatePtrToInt(value, llvmTypeTo, ""), nil + } else if !isPtrFrom && isPtrTo { + return c.builder.CreateIntToPtr(value, llvmTypeTo, ""), nil } - llvmTypeFrom, err := c.getLLVMType(x.Type()) - if err != nil { - return llvm.Value{}, err - } sizeFrom := c.targetData.TypeAllocSize(llvmTypeFrom) - llvmTypeTo, err := c.getLLVMType(typeTo) - if err != nil { - return llvm.Value{}, err - } sizeTo := c.targetData.TypeAllocSize(llvmTypeTo) - if sizeFrom == sizeTo { return c.builder.CreateBitCast(value, llvmTypeTo, ""), nil } @@ -1137,6 +1184,8 @@ func (c *Compiler) parseConvert(frame *Frame, typeTo types.Type, x ssa.Value) (l } case *types.Named: return c.parseConvert(frame, typeTo.Underlying(), x) + case *types.Pointer: + return c.builder.CreateBitCast(value, llvmTypeTo, ""), nil default: return llvm.Value{}, errors.New("todo: convert: extend non-basic type: " + fmt.Sprintf("%#v", typeTo)) } @@ -1158,7 +1207,7 @@ func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) { // Magic type name: treat the value as a register pointer. register := unop.X.(*ssa.FieldAddr) global := register.X.(*ssa.Global) - llvmGlobal := c.mod.NamedGlobal(pkgPrefix(global.Pkg) + "." + global.Name()) + llvmGlobal := c.mod.NamedGlobal(c.getGlobalName(global)) llvmAddr := c.builder.CreateExtractValue(llvmGlobal.Initializer(), register.Field, "") ptr := llvm.ConstIntToPtr(llvmAddr, x.Type()) load := c.builder.CreateLoad(ptr, "")