diff --git a/transform/allocs_test.go b/transform/allocs_test.go index e6327827..f4c86f66 100644 --- a/transform/allocs_test.go +++ b/transform/allocs_test.go @@ -2,7 +2,6 @@ package transform_test import ( "go/token" - "go/types" "io/ioutil" "path/filepath" "regexp" @@ -11,9 +10,6 @@ import ( "strings" "testing" - "github.com/tinygo-org/tinygo/compileopts" - "github.com/tinygo-org/tinygo/compiler" - "github.com/tinygo-org/tinygo/loader" "github.com/tinygo-org/tinygo/transform" "tinygo.org/x/go-llvm" ) @@ -39,52 +35,7 @@ func (out allocsTestOutput) String() string { func TestAllocs2(t *testing.T) { t.Parallel() - target, err := compileopts.LoadTarget("i686--linux") - if err != nil { - t.Fatal("failed to load target:", err) - } - config := &compileopts.Config{ - Options: &compileopts.Options{}, - Target: target, - } - compilerConfig := &compiler.Config{ - Triple: config.Triple(), - GOOS: config.GOOS(), - GOARCH: config.GOARCH(), - CodeModel: config.CodeModel(), - RelocationModel: config.RelocationModel(), - Scheduler: config.Scheduler(), - FuncImplementation: config.FuncImplementation(), - AutomaticStackSize: config.AutomaticStackSize(), - Debug: true, - } - machine, err := compiler.NewTargetMachine(compilerConfig) - if err != nil { - t.Fatal("failed to create target machine:", err) - } - - // Load entire program AST into memory. - lprogram, err := loader.Load(config, []string{"./testdata/allocs2.go"}, config.ClangHeaders, types.Config{ - Sizes: compiler.Sizes(machine), - }) - if err != nil { - t.Fatal("failed to create target machine:", err) - } - err = lprogram.Parse() - if err != nil { - t.Fatal("could not parse", err) - } - - // Compile AST to IR. - program := lprogram.LoadSSA() - pkg := lprogram.MainPkg() - mod, errs := compiler.CompilePackage("allocs2.go", pkg, program.Package(pkg.Pkg), machine, compilerConfig, false) - if errs != nil { - for _, err := range errs { - t.Error(err) - } - return - } + mod := compileGoFileForTesting(t, "./testdata/allocs2.go") // Run functionattrs pass, which is necessary for escape analysis. pm := llvm.NewPassManager() diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go index c39765f6..79b3361a 100644 --- a/transform/interface-lowering.go +++ b/transform/interface-lowering.go @@ -74,12 +74,10 @@ type methodInfo struct { // typeInfo describes a single concrete Go type, which can be a basic or a named // type. If it is a named type, it may have methods. type typeInfo struct { - name string - typecode llvm.Value - methodSet llvm.Value - num uint64 // the type number after lowering - countTypeAsserts int // how often a type assert happens on this method - methods []*methodInfo + name string + typecode llvm.Value + methodSet llvm.Value + methods []*methodInfo } // getMethod looks up the method on this type with the given signature and @@ -94,27 +92,13 @@ func (t *typeInfo) getMethod(signature *signatureInfo) *methodInfo { panic("could not find method") } -// typeInfoSlice implements sort.Slice, sorting the most commonly used types -// first. -type typeInfoSlice []*typeInfo - -func (t typeInfoSlice) Len() int { return len(t) } -func (t typeInfoSlice) Less(i, j int) bool { - // Try to sort the most commonly used types first. - if t[i].countTypeAsserts != t[j].countTypeAsserts { - return t[i].countTypeAsserts < t[j].countTypeAsserts - } - return t[i].name < t[j].name -} -func (t typeInfoSlice) Swap(i, j int) { t[i], t[j] = t[j], t[i] } - // interfaceInfo keeps information about a Go interface type, including all // methods it has. type interfaceInfo struct { name string // name with $interface suffix methodSet llvm.Value // global which this interfaceInfo describes signatures []*signatureInfo // method set - types typeInfoSlice // types this interface implements + types []*typeInfo // types this interface implements assertFunc llvm.Value // runtime.interfaceImplements replacement methodFuncs map[*signatureInfo]llvm.Value // runtime.interfaceMethod replacements for each signature } @@ -163,7 +147,6 @@ func LowerInterfaces(mod llvm.Module, sizeLevel int) error { // run runs the pass itself. func (p *lowerInterfacesPass) run() error { // Collect all type codes. - var typecodeIDs []llvm.Value for global := p.mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { if strings.HasPrefix(global.Name(), "reflect/types.type:") { // Retrieve Go type information based on an opaque global variable. @@ -171,7 +154,6 @@ func (p *lowerInterfacesPass) run() error { // discarded afterwards. name := strings.TrimPrefix(global.Name(), "reflect/types.type:") if _, ok := p.types[name]; !ok { - typecodeIDs = append(typecodeIDs, global) t := &typeInfo{ name: name, typecode: global, @@ -187,18 +169,6 @@ func (p *lowerInterfacesPass) run() error { } } - // Count per type how often it is type asserted on (e.g. in a switch - // statement). - typeAssert := p.mod.NamedFunction("runtime.typeAssert") - typeAssertUses := getUses(typeAssert) - for _, use := range typeAssertUses { - typecode := use.Operand(1) - name := strings.TrimPrefix(typecode.Name(), "reflect/types.typeid:") - if t, ok := p.types[name]; ok { - t.countTypeAsserts++ - } - } - // Find all interface method calls. interfaceMethod := p.mod.NamedFunction("runtime.interfaceMethod") interfaceMethodUses := getUses(interfaceMethod) @@ -274,10 +244,11 @@ func (p *lowerInterfacesPass) run() error { } } - // Sort all types added to the interfaces, to check for more common types - // first. + // Sort all types added to the interfaces. for _, itf := range p.interfaces { - sort.Sort(itf.types) + sort.Slice(itf.types, func(i, j int) bool { + return itf.types[i].name > itf.types[j].name + }) } // Replace all interface methods with their uses, if possible. @@ -339,43 +310,10 @@ func (p *lowerInterfacesPass) run() error { use.EraseFromParentAsInstruction() } - // Make a slice of types sorted by frequency of use. - typeSlice := make(typeInfoSlice, 0, len(p.types)) - for _, t := range p.types { - typeSlice = append(typeSlice, t) - } - sort.Sort(sort.Reverse(typeSlice)) - - // Assign a type code for each type. - assignTypeCodes(p.mod, typeSlice) - - // Replace each use of a ptrtoint runtime.typecodeID with the constant type - // code. - for _, global := range typecodeIDs { - for _, use := range getUses(global) { - if use.IsAConstantExpr().IsNil() { - continue - } - t := p.types[strings.TrimPrefix(global.Name(), "reflect/types.type:")] - typecode := llvm.ConstInt(p.uintptrType, t.num, false) - switch use.Opcode() { - case llvm.PtrToInt: - // Already of the correct type. - case llvm.BitCast: - // Could happen when stored in an interface (which is of type - // i8*). - typecode = llvm.ConstIntToPtr(typecode, use.Type()) - default: - panic("unexpected constant expression") - } - use.ReplaceAllUsesWith(typecode) - } - } - // Replace each type assert with an actual type comparison or (if the type // assert is impossible) the constant false. llvmFalse := llvm.ConstInt(p.ctx.Int1Type(), 0, false) - for _, use := range typeAssertUses { + for _, use := range getUses(p.mod.NamedFunction("runtime.typeAssert")) { actualType := use.Operand(0) name := strings.TrimPrefix(use.Operand(1).Name(), "reflect/types.typeid:") if t, ok := p.types[name]; ok { @@ -395,54 +333,13 @@ func (p *lowerInterfacesPass) run() error { use.EraseFromParentAsInstruction() } - // Fill in each helper function for type asserts on interfaces - // (interface-to-interface matches). - for _, itf := range p.interfaces { - if !itf.assertFunc.IsNil() { - p.createInterfaceImplementsFunc(itf) - } - for signature := range itf.methodFuncs { - p.createInterfaceMethodFunc(itf, signature) - } - } - - // Replace all ptrtoint typecode placeholders with their final type code - // numbers. - for _, typ := range p.types { - for _, use := range getUses(typ.typecode) { - if !use.IsAConstantExpr().IsNil() && use.Opcode() == llvm.PtrToInt { - use.ReplaceAllUsesWith(llvm.ConstInt(p.uintptrType, typ.num, false)) - } - } - } - - // Remove most objects created for interface and reflect lowering. - // Unnecessary, but cleans up the IR for inspection and testing. - for _, typ := range p.types { - // Only some typecodes have an initializer. - initializer := typ.typecode.Initializer() - if !initializer.IsNil() { - references := llvm.ConstExtractValue(initializer, []uint32{0}) - typ.typecode.SetInitializer(llvm.ConstNull(initializer.Type())) - if strings.HasPrefix(typ.name, "reflect/types.type:struct:") { - // Structs have a 'references' field that is not a typecode but - // a pointer to a runtime.structField array and therefore a - // bitcast. This global should be erased separately, otherwise - // typecode objects cannot be erased. - structFields := references.Operand(0) - structFields.EraseFromParentAsGlobal() - } - } - - if !typ.methodSet.IsNil() { - typ.methodSet.EraseFromParentAsGlobal() - typ.methodSet = llvm.Value{} - } - } - for _, itf := range p.interfaces { - // Remove method sets of interfaces. - itf.methodSet.EraseFromParentAsGlobal() - itf.methodSet = llvm.Value{} + // Remove all method sets, which are now unnecessary and inhibit later + // optimizations if they are left in place. + for _, t := range p.types { + initializer := t.typecode.Initializer() + methodSet := llvm.ConstExtractValue(initializer, []uint32{2}) + initializer = llvm.ConstInsertValue(initializer, llvm.ConstNull(methodSet.Type()), []uint32{2}) + t.typecode.SetInitializer(initializer) } return nil @@ -559,6 +456,10 @@ func (p *lowerInterfacesPass) replaceInvokeWithCall(use llvm.Value, typ *typeInf // getInterfaceImplementsFunc returns a function that checks whether a given // interface type implements a given interface, by checking all possible types // that implement this interface. +// +// The type match is implemented using an if/else chain over all possible types. +// This if/else chain is easily converted to a big switch over all possible +// types by the LLVM simplifycfg pass. func (p *lowerInterfacesPass) getInterfaceImplementsFunc(itf *interfaceInfo) llvm.Value { if !itf.assertFunc.IsNil() { return itf.assertFunc @@ -568,60 +469,49 @@ func (p *lowerInterfacesPass) getInterfaceImplementsFunc(itf *interfaceInfo) llv // TODO: debug info fnName := itf.id() + "$typeassert" fnType := llvm.FunctionType(p.ctx.Int1Type(), []llvm.Type{p.uintptrType}, false) - itf.assertFunc = llvm.AddFunction(p.mod, fnName, fnType) - itf.assertFunc.Param(0).SetName("actualType") - - // Type asserts will be made for each type, so increment the counter for - // those. - for _, typ := range itf.types { - typ.countTypeAsserts++ - } - - return itf.assertFunc -} - -// createInterfaceImplementsFunc finishes the work of -// getInterfaceImplementsFunc, because it needs to run after types have a type -// code assigned. -// -// The type match is implemented using a big type switch over all possible -// types. -func (p *lowerInterfacesPass) createInterfaceImplementsFunc(itf *interfaceInfo) { - fn := itf.assertFunc + fn := llvm.AddFunction(p.mod, fnName, fnType) + itf.assertFunc = fn + fn.Param(0).SetName("actualType") fn.SetLinkage(llvm.InternalLinkage) fn.SetUnnamedAddr(true) if p.sizeLevel >= 2 { fn.AddFunctionAttr(p.ctx.CreateEnumAttribute(llvm.AttributeKindID("optsize"), 0)) } - // TODO: debug info - - // Create all used basic blocks. + // Start the if/else chain at the entry block. entry := p.ctx.AddBasicBlock(fn, "entry") thenBlock := p.ctx.AddBasicBlock(fn, "then") - elseBlock := p.ctx.AddBasicBlock(fn, "else") - - // Add all possible types as cases. p.builder.SetInsertPointAtEnd(entry) + + // Iterate over all possible types. Each iteration creates a new branch + // either to the 'then' block (success) or the .next block, for the next + // check. actualType := fn.Param(0) - sw := p.builder.CreateSwitch(actualType, elseBlock, len(itf.types)) for _, typ := range itf.types { - sw.AddCase(llvm.ConstInt(p.uintptrType, typ.num, false), thenBlock) + nextBlock := p.ctx.AddBasicBlock(fn, typ.name+".next") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + p.builder.CreateCondBr(cmp, thenBlock, nextBlock) + p.builder.SetInsertPointAtEnd(nextBlock) } + // The builder is now inserting at the last *.next block. Once we reach + // this point, all types have been checked so the type assert will have + // failed. + p.builder.CreateRet(llvm.ConstInt(p.ctx.Int1Type(), 0, false)) + // Fill 'then' block (type assert was successful). p.builder.SetInsertPointAtEnd(thenBlock) p.builder.CreateRet(llvm.ConstInt(p.ctx.Int1Type(), 1, false)) - // Fill 'else' block (type asserted failed). - p.builder.SetInsertPointAtEnd(elseBlock) - p.builder.CreateRet(llvm.ConstInt(p.ctx.Int1Type(), 0, false)) + return itf.assertFunc } // getInterfaceMethodFunc returns a thunk for calling a method on an interface. -// It only declares the function, createInterfaceMethodFunc actually defines the -// function. -func (p *lowerInterfacesPass) getInterfaceMethodFunc(itf *interfaceInfo, signature *signatureInfo, returnType llvm.Type, params []llvm.Type) llvm.Value { +// +// Matching the actual type is implemented using an if/else chain over all +// possible types. This is later converted to a switch statement by the LLVM +// simplifycfg pass. +func (p *lowerInterfacesPass) getInterfaceMethodFunc(itf *interfaceInfo, signature *signatureInfo, returnType llvm.Type, paramTypes []llvm.Type) llvm.Value { if fn, ok := itf.methodFuncs[signature]; ok { // This function has already been created. return fn @@ -634,22 +524,11 @@ func (p *lowerInterfacesPass) getInterfaceMethodFunc(itf *interfaceInfo, signatu // Construct the function name, which is of the form: // (main.Stringer).String fnName := "(" + itf.id() + ")." + signature.methodName() - fnType := llvm.FunctionType(returnType, append(params, llvm.PointerType(p.ctx.Int8Type(), 0)), false) + fnType := llvm.FunctionType(returnType, append(paramTypes, llvm.PointerType(p.ctx.Int8Type(), 0)), false) fn := llvm.AddFunction(p.mod, fnName, fnType) llvm.PrevParam(fn.LastParam()).SetName("actualType") fn.LastParam().SetName("parentHandle") itf.methodFuncs[signature] = fn - return fn -} - -// createInterfaceMethodFunc finishes the work of getInterfaceMethodFunc, -// because it needs to run after type codes have been assigned to concrete -// types. -// -// Matching the actual type is implemented using a big type switch over all -// possible types. -func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, signature *signatureInfo) { - fn := itf.methodFuncs[signature] fn.SetLinkage(llvm.InternalLinkage) fn.SetUnnamedAddr(true) if p.sizeLevel >= 2 { @@ -658,29 +537,6 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign // TODO: debug info - // Create entry block. - entry := p.ctx.AddBasicBlock(fn, "entry") - - // Create default block and call runtime.nilPanic. - // The only other possible value remaining is nil for nil interfaces. We - // could panic with a different message here such as "nil interface" but - // that would increase code size and "nil panic" is close enough. Most - // importantly, it avoids undefined behavior when accidentally calling a - // method on a nil interface. - defaultBlock := p.ctx.AddBasicBlock(fn, "default") - p.builder.SetInsertPointAtEnd(defaultBlock) - nilPanic := p.mod.NamedFunction("runtime.nilPanic") - p.builder.CreateCall(nilPanic, []llvm.Value{ - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), - }, "") - p.builder.CreateUnreachable() - - // Create type switch in entry block. - p.builder.SetInsertPointAtEnd(entry) - actualType := llvm.PrevParam(fn.LastParam()) - sw := p.builder.CreateSwitch(actualType, defaultBlock, len(itf.types)) - // Collect the params that will be passed to the functions to call. // These params exclude the receiver (which may actually consist of multiple // parts). @@ -689,10 +545,18 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign params[i] = fn.Param(i + 1) } + // Start chain in the entry block. + entry := p.ctx.AddBasicBlock(fn, "entry") + p.builder.SetInsertPointAtEnd(entry) + // Define all possible functions that can be called. + actualType := llvm.PrevParam(fn.LastParam()) for _, typ := range itf.types { + // Create type check (if/else). bb := p.ctx.AddBasicBlock(fn, typ.name) - sw.AddCase(llvm.ConstInt(p.uintptrType, typ.num, false), bb) + next := p.ctx.AddBasicBlock(fn, typ.name+".next") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + p.builder.CreateCondBr(cmp, bb, next) // The function we will redirect to when the interface has this type. function := typ.getMethod(signature).function @@ -725,5 +589,25 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign } else { p.builder.CreateRet(retval) } + + // Start next comparison in the 'next' block (which is jumped to when + // the type doesn't match). + p.builder.SetInsertPointAtEnd(next) } + + // The builder now points to the last *.then block, after all types have + // been checked. Call runtime.nilPanic here. + // The only other possible value remaining is nil for nil interfaces. We + // could panic with a different message here such as "nil interface" but + // that would increase code size and "nil panic" is close enough. Most + // importantly, it avoids undefined behavior when accidentally calling a + // method on a nil interface. + nilPanic := p.mod.NamedFunction("runtime.nilPanic") + p.builder.CreateCall(nilPanic, []llvm.Value{ + llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + }, "") + p.builder.CreateUnreachable() + + return fn } diff --git a/transform/interface-lowering_test.go b/transform/interface-lowering_test.go index e1d56e48..145b778f 100644 --- a/transform/interface-lowering_test.go +++ b/transform/interface-lowering_test.go @@ -14,5 +14,10 @@ func TestInterfaceLowering(t *testing.T) { if err != nil { t.Error(err) } + + pm := llvm.NewPassManager() + defer pm.Dispose() + pm.AddGlobalDCEPass() + pm.Run(mod) }) } diff --git a/transform/optimizer.go b/transform/optimizer.go index dca4a2bd..91e44604 100644 --- a/transform/optimizer.go +++ b/transform/optimizer.go @@ -63,7 +63,7 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i goPasses.AddFunctionAttrsPass() goPasses.Run(mod) - // Run Go-specific optimization passes. + // Run TinyGo-specific optimization passes. OptimizeMaps(mod) OptimizeStringToBytes(mod) OptimizeReflectImplements(mod) @@ -88,6 +88,7 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i goPasses.Run(mod) // Run TinyGo-specific interprocedural optimizations. + LowerReflect(mod) OptimizeAllocs(mod, config.Options.PrintAllocs, func(pos token.Position, msg string) { fmt.Fprintln(os.Stderr, pos.String()+": "+msg) }) @@ -100,6 +101,7 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i if err != nil { return []error{err} } + LowerReflect(mod) if config.FuncImplementation() == "switch" { LowerFuncValues(mod) } diff --git a/transform/reflect.go b/transform/reflect.go index fcab9a06..c4ecae77 100644 --- a/transform/reflect.go +++ b/transform/reflect.go @@ -31,6 +31,7 @@ import ( "encoding/binary" "go/ast" "math/big" + "sort" "strings" "tinygo.org/x/go-llvm" @@ -122,14 +123,45 @@ type typeCodeAssignmentState struct { needsNamedNonBasicTypesSidetable bool } -// assignTypeCodes is used to assign a type code to each type in the program +// LowerReflect is used to assign a type code to each type in the program // that is ever stored in an interface. It tries to use the smallest possible // numbers to make the code that works with interfaces as small as possible. -func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) { +func LowerReflect(mod llvm.Module) { // if reflect were not used, we could skip generating the sidetable // this does not help in practice, and is difficult to do correctly + // Obtain slice of all types in the program. + type typeInfo struct { + typecode llvm.Value + name string + numUses int + } + var types []*typeInfo + for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { + if strings.HasPrefix(global.Name(), "reflect/types.type:") { + types = append(types, &typeInfo{ + typecode: global, + name: global.Name(), + numUses: len(getUses(global)), + }) + } + } + + // Sort the slice in a way that often used types are assigned a type code + // first. + sort.Slice(types, func(i, j int) bool { + if types[i].numUses != types[j].numUses { + return types[i].numUses < types[j].numUses + } + // It would make more sense to compare the name in the other direction, + // but for some reason that increases binary size. Could be a fluke, but + // could also have some good reason (and possibly hint at a small + // optimization). + return types[i].name > types[j].name + }) + // Assign typecodes the way the reflect package expects. + uintptrType := mod.Context().IntType(llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8) state := typeCodeAssignmentState{ fallbackIndex: 1, uintptrLen: llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8, @@ -143,7 +175,7 @@ func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) { needsStructNamesSidetable: len(getUses(mod.NamedGlobal("reflect.structNamesSidetable"))) != 0, needsArrayTypesSidetable: len(getUses(mod.NamedGlobal("reflect.arrayTypesSidetable"))) != 0, } - for _, t := range typeSlice { + for _, t := range types { num := state.getTypeCodeNum(t.typecode) if num.BitLen() > state.uintptrLen || !num.IsUint64() { // TODO: support this in some way, using a side table for example. @@ -152,7 +184,25 @@ func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) { // AVR). panic("compiler: could not store type code number inside interface type code") } - t.num = num.Uint64() + + // Replace each use of the type code global with the constant type code. + for _, use := range getUses(t.typecode) { + if use.IsAConstantExpr().IsNil() { + continue + } + typecode := llvm.ConstInt(uintptrType, num.Uint64(), false) + switch use.Opcode() { + case llvm.PtrToInt: + // Already of the correct type. + case llvm.BitCast: + // Could happen when stored in an interface (which is of type + // i8*). + typecode = llvm.ConstIntToPtr(typecode, use.Type()) + default: + panic("unexpected constant expression") + } + use.ReplaceAllUsesWith(typecode) + } } // Only create this sidetable when it is necessary. @@ -180,6 +230,23 @@ func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) { global.SetUnnamedAddr(true) global.SetGlobalConstant(true) } + + // Remove most objects created for interface and reflect lowering. + // They would normally be removed anyway in later passes, but not always. + // It also cleans up the IR for testing. + for _, typ := range types { + initializer := typ.typecode.Initializer() + references := llvm.ConstExtractValue(initializer, []uint32{0}) + typ.typecode.SetInitializer(llvm.ConstNull(initializer.Type())) + if strings.HasPrefix(typ.name, "reflect/types.type:struct:") { + // Structs have a 'references' field that is not a typecode but + // a pointer to a runtime.structField array and therefore a + // bitcast. This global should be erased separately, otherwise + // typecode objects cannot be erased. + structFields := references.Operand(0) + structFields.EraseFromParentAsGlobal() + } + } } // getTypeCodeNum returns the typecode for a given type as expected by the diff --git a/transform/reflect_test.go b/transform/reflect_test.go new file mode 100644 index 00000000..06c30406 --- /dev/null +++ b/transform/reflect_test.go @@ -0,0 +1,77 @@ +package transform_test + +import ( + "testing" + + "github.com/tinygo-org/tinygo/transform" + "tinygo.org/x/go-llvm" +) + +type reflectAssert struct { + call llvm.Value + name string + expectedNumber uint64 +} + +// Test reflect lowering. This code looks at IR like this: +// +// call void @main.assertType(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:int" to i32), i8* inttoptr (i32 3 to i8*), i32 4, i8* undef, i8* undef) +// +// and verifies that the ptrtoint constant (the first parameter of +// @main.assertType) is replaced with the correct type code. The expected +// output is this: +// +// call void @main.assertType(i32 4, i8* inttoptr (i32 3 to i8*), i32 4, i8* undef, i8* undef) +// +// The first and third parameter are compared and must match, the second +// parameter is ignored. +func TestReflect(t *testing.T) { + t.Parallel() + + mod := compileGoFileForTesting(t, "./testdata/reflect.go") + + // Run the instcombine pass, to clean up the IR a bit (especially + // insertvalue/extractvalue instructions). + pm := llvm.NewPassManager() + defer pm.Dispose() + pm.AddInstructionCombiningPass() + pm.Run(mod) + + // Get a list of all the asserts in the source code. + assertType := mod.NamedFunction("main.assertType") + var asserts []reflectAssert + for user := assertType.FirstUse(); !user.IsNil(); user = user.NextUse() { + use := user.User() + if use.IsACallInst().IsNil() { + t.Fatal("expected call use of main.assertType") + } + global := use.Operand(0).Operand(0) + expectedNumber := use.Operand(2).ZExtValue() + asserts = append(asserts, reflectAssert{ + call: use, + name: global.Name(), + expectedNumber: expectedNumber, + }) + } + + // Sanity check to show that the test is actually testing anything. + if len(asserts) < 3 { + t.Errorf("expected at least 3 test cases, got %d", len(asserts)) + } + + // Now lower the type codes. + transform.LowerReflect(mod) + + // Check whether the values are as expected. + for _, assert := range asserts { + actualNumberValue := assert.call.Operand(0) + if actualNumberValue.IsAConstantInt().IsNil() { + t.Errorf("expected to see a constant for %s, got something else", assert.name) + continue + } + actualNumber := actualNumberValue.ZExtValue() + if actualNumber != assert.expectedNumber { + t.Errorf("%s: expected number 0b%b, got 0b%b", assert.name, assert.expectedNumber, actualNumber) + } + } +} diff --git a/transform/testdata/interface.ll b/transform/testdata/interface.ll index e00d1b23..c67595a2 100644 --- a/transform/testdata/interface.ll +++ b/transform/testdata/interface.ll @@ -4,10 +4,10 @@ target triple = "armv7m-none-eabi" %runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo* } %runtime.interfaceMethodInfo = type { i8*, i32 } -@"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID +@"reflect/types.type:basic:uint8" = private constant %runtime.typecodeID zeroinitializer @"reflect/types.typeid:basic:uint8" = external constant i8 @"reflect/types.typeid:basic:int16" = external constant i8 -@"reflect/types.type:basic:int" = external constant %runtime.typecodeID +@"reflect/types.type:basic:int" = private constant %runtime.typecodeID zeroinitializer @"func NeverImplementedMethod()" = external constant i8 @"Unmatched$interface" = private constant [1 x i8*] [i8* @"func NeverImplementedMethod()"] @"func Double() int" = external constant i8 diff --git a/transform/testdata/interface.out.ll b/transform/testdata/interface.out.ll index 44d3145e..69ccfb6d 100644 --- a/transform/testdata/interface.out.ll +++ b/transform/testdata/interface.out.ll @@ -4,19 +4,9 @@ target triple = "armv7m-none-eabi" %runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo* } %runtime.interfaceMethodInfo = type { i8*, i32 } -@"reflect/types.type:basic:uint8" = external constant %runtime.typecodeID -@"reflect/types.typeid:basic:uint8" = external constant i8 -@"reflect/types.typeid:basic:int16" = external constant i8 -@"reflect/types.type:basic:int" = external constant %runtime.typecodeID -@"func NeverImplementedMethod()" = external constant i8 -@"func Double() int" = external constant i8 -@"reflect/types.type:named:Number" = private constant %runtime.typecodeID zeroinitializer - -declare i1 @runtime.interfaceImplements(i32, i8**) - -declare i1 @runtime.typeAssert(i32, i8*) - -declare i32 @runtime.interfaceMethod(i32, i8**, i8*) +@"reflect/types.type:basic:uint8" = private constant %runtime.typecodeID zeroinitializer +@"reflect/types.type:basic:int" = private constant %runtime.typecodeID zeroinitializer +@"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* null } declare void @runtime.printuint8(i8) @@ -31,9 +21,9 @@ declare void @runtime.printnl() declare void @runtime.nilPanic(i8*, i8*) define void @printInterfaces() { - call void @printInterface(i32 4, i8* inttoptr (i32 5 to i8*)) - call void @printInterface(i32 16, i8* inttoptr (i8 120 to i8*)) - call void @printInterface(i32 68, i8* inttoptr (i32 3 to i8*)) + call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:int" to i32), i8* inttoptr (i32 5 to i8*)) + call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:uint8" to i32), i8* inttoptr (i8 120 to i8*)) + call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32), i8* inttoptr (i32 3 to i8*)) ret void } @@ -57,7 +47,7 @@ typeswitch.Doubler: ; preds = %typeswitch.notUnmat ret void typeswitch.notDoubler: ; preds = %typeswitch.notUnmatched - %typeassert.ok2 = icmp eq i32 16, %typecode + %typeassert.ok2 = icmp eq i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:uint8" to i32), %typecode br i1 %typeassert.ok2, label %typeswitch.byte, label %typeswitch.notByte typeswitch.byte: ; preds = %typeswitch.notDoubler @@ -92,40 +82,34 @@ define i32 @"(Number).Double$invoke"(i8* %receiverPtr, i8* %parentHandle) { define internal i32 @"(Doubler).Double"(i8* %0, i8* %1, i32 %actualType, i8* %parentHandle) unnamed_addr { entry: - switch i32 %actualType, label %default [ - i32 68, label %"named:Number" - ] - -default: ; preds = %entry - call void @runtime.nilPanic(i8* undef, i8* undef) - unreachable + %"named:Number.icmp" = icmp eq i32 %actualType, ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32) + br i1 %"named:Number.icmp", label %"named:Number", label %"named:Number.next" "named:Number": ; preds = %entry %2 = call i32 @"(Number).Double$invoke"(i8* %0, i8* %1) ret i32 %2 + +"named:Number.next": ; preds = %entry + call void @runtime.nilPanic(i8* undef, i8* undef) + unreachable } define internal i1 @"Doubler$typeassert"(i32 %actualType) unnamed_addr { entry: - switch i32 %actualType, label %else [ - i32 68, label %then - ] + %"named:Number.icmp" = icmp eq i32 %actualType, ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32) + br i1 %"named:Number.icmp", label %then, label %"named:Number.next" then: ; preds = %entry ret i1 true -else: ; preds = %entry +"named:Number.next": ; preds = %entry ret i1 false } define internal i1 @"Unmatched$typeassert"(i32 %actualType) unnamed_addr { entry: - switch i32 %actualType, label %else [ - ] + ret i1 false then: ; No predecessors! ret i1 true - -else: ; preds = %entry - ret i1 false } diff --git a/transform/testdata/reflect.go b/transform/testdata/reflect.go new file mode 100644 index 00000000..e444949c --- /dev/null +++ b/transform/testdata/reflect.go @@ -0,0 +1,56 @@ +package main + +// This file tests the type codes assigned by the reflect lowering pass. +// This test is not complete, most importantly, sidetables are not currently +// being tested. + +import ( + "reflect" + "unsafe" +) + +const ( + // See the top of src/reflect/type.go + prefixChan = 0b0001 + prefixInterface = 0b0011 + prefixPtr = 0b0101 + prefixSlice = 0b0111 + prefixArray = 0b1001 + prefixFunc = 0b1011 + prefixMap = 0b1101 + prefixStruct = 0b1111 +) + +func main() { + // Check for some basic types. + assertType(3, uintptr(reflect.Int)<<1) + assertType(uint8(3), uintptr(reflect.Uint8)<<1) + assertType(byte(3), uintptr(reflect.Uint8)<<1) + assertType(int64(3), uintptr(reflect.Int64)<<1) + assertType("", uintptr(reflect.String)<<1) + assertType(3.5, uintptr(reflect.Float64)<<1) + assertType(unsafe.Pointer(nil), uintptr(reflect.UnsafePointer)<<1) + + // Check for named types: they are given names in order. + // They are sorted in reverse, for no good reason. + const intNum = uintptr(reflect.Int) << 1 + assertType(namedInt1(0), (3<<6)|intNum) + assertType(namedInt2(0), (2<<6)|intNum) + assertType(namedInt3(0), (1<<6)|intNum) + + // Check for some "prefix-style" types. + assertType(make(chan int), (intNum<<5)|prefixChan) + assertType(new(int), (intNum<<5)|prefixPtr) + assertType([]int{}, (intNum<<5)|prefixSlice) +} + +type ( + namedInt1 int + namedInt2 int + namedInt3 int +) + +// Pseudo call that is being checked by the code in reflect_test.go. +// After reflect lowering, the type code as part of the interface should match +// the asserted type code. +func assertType(itf interface{}, assertedTypeCode uintptr) diff --git a/transform/transform_test.go b/transform/transform_test.go index 584ad435..0b4cc0d2 100644 --- a/transform/transform_test.go +++ b/transform/transform_test.go @@ -4,13 +4,19 @@ package transform_test import ( "flag" + "go/token" + "go/types" "io/ioutil" "os" + "path/filepath" "regexp" "strconv" "strings" "testing" + "github.com/tinygo-org/tinygo/compileopts" + "github.com/tinygo-org/tinygo/compiler" + "github.com/tinygo-org/tinygo/loader" "tinygo.org/x/go-llvm" ) @@ -128,3 +134,86 @@ func filterIrrelevantIRLines(lines []string) []string { } return out } + +// compileGoFileForTesting compiles the given Go file to run tests against. +// Only the given Go file is compiled (no dependencies) and no optimizations are +// run. +// If there are any errors, they are reported via the *testing.T instance. +func compileGoFileForTesting(t *testing.T, filename string) llvm.Module { + target, err := compileopts.LoadTarget("i686--linux") + if err != nil { + t.Fatal("failed to load target:", err) + } + config := &compileopts.Config{ + Options: &compileopts.Options{}, + Target: target, + } + compilerConfig := &compiler.Config{ + Triple: config.Triple(), + GOOS: config.GOOS(), + GOARCH: config.GOARCH(), + CodeModel: config.CodeModel(), + RelocationModel: config.RelocationModel(), + Scheduler: config.Scheduler(), + FuncImplementation: config.FuncImplementation(), + AutomaticStackSize: config.AutomaticStackSize(), + Debug: true, + } + machine, err := compiler.NewTargetMachine(compilerConfig) + if err != nil { + t.Fatal("failed to create target machine:", err) + } + + // Load entire program AST into memory. + lprogram, err := loader.Load(config, []string{filename}, config.ClangHeaders, types.Config{ + Sizes: compiler.Sizes(machine), + }) + if err != nil { + t.Fatal("failed to create target machine:", err) + } + err = lprogram.Parse() + if err != nil { + t.Fatal("could not parse", err) + } + + // Compile AST to IR. + program := lprogram.LoadSSA() + pkg := lprogram.MainPkg() + mod, errs := compiler.CompilePackage(filename, pkg, program.Package(pkg.Pkg), machine, compilerConfig, false) + if errs != nil { + for _, err := range errs { + t.Error(err) + } + t.FailNow() + } + return mod +} + +// getPosition returns the position information for the given value, as far as +// it is available. +func getPosition(val llvm.Value) token.Position { + if !val.IsAInstruction().IsNil() { + loc := val.InstructionDebugLoc() + if loc.IsNil() { + return token.Position{} + } + file := loc.LocationScope().ScopeFile() + return token.Position{ + Filename: filepath.Join(file.FileDirectory(), file.FileFilename()), + Line: int(loc.LocationLine()), + Column: int(loc.LocationColumn()), + } + } else if !val.IsAFunction().IsNil() { + loc := val.Subprogram() + if loc.IsNil() { + return token.Position{} + } + file := loc.ScopeFile() + return token.Position{ + Filename: filepath.Join(file.FileDirectory(), file.FileFilename()), + Line: int(loc.SubprogramLine()), + } + } else { + return token.Position{} + } +}