compiler: correctly generate code for local named types
It is possible to create function-local named types: func foo() any { type named int return named(0) } This patch makes sure they don't alias with named types declared at the package scope. Bug originally found by Damian Gryski while working on reflect support.
Этот коммит содержится в:
родитель
17f5fb1071
коммит
523c6c0e3b
4 изменённых файлов: 116 добавлений и 24 удалений
|
@ -71,6 +71,7 @@ type compilerContext struct {
|
||||||
difiles map[string]llvm.Metadata
|
difiles map[string]llvm.Metadata
|
||||||
ditypes map[types.Type]llvm.Metadata
|
ditypes map[types.Type]llvm.Metadata
|
||||||
llvmTypes typeutil.Map
|
llvmTypes typeutil.Map
|
||||||
|
interfaceTypes typeutil.Map
|
||||||
machine llvm.TargetMachine
|
machine llvm.TargetMachine
|
||||||
targetData llvm.TargetData
|
targetData llvm.TargetData
|
||||||
intType llvm.Type
|
intType llvm.Type
|
||||||
|
|
|
@ -32,7 +32,8 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context
|
||||||
// global reference is not real, it is only used during func lowering to assign
|
// global reference is not real, it is only used during func lowering to assign
|
||||||
// signature types to functions and will then be removed.
|
// signature types to functions and will then be removed.
|
||||||
func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value {
|
func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value {
|
||||||
sigGlobalName := "reflect/types.funcid:" + getTypeCodeName(sig)
|
s, _ := getTypeCodeName(sig)
|
||||||
|
sigGlobalName := "reflect/types.funcid:" + s
|
||||||
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
|
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
|
||||||
if sigGlobal.IsNil() {
|
if sigGlobal.IsNil() {
|
||||||
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
|
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
|
||||||
|
|
|
@ -118,8 +118,23 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
||||||
if _, ok := typ.Underlying().(*types.Interface); ok {
|
if _, ok := typ.Underlying().(*types.Interface); ok {
|
||||||
hasMethodSet = false
|
hasMethodSet = false
|
||||||
}
|
}
|
||||||
globalName := "reflect/types.type:" + getTypeCodeName(typ)
|
typeCodeName, isLocal := getTypeCodeName(typ)
|
||||||
global := c.mod.NamedGlobal(globalName)
|
globalName := "reflect/types.type:" + typeCodeName
|
||||||
|
var global llvm.Value
|
||||||
|
if isLocal {
|
||||||
|
// This type is a named type inside a function, like this:
|
||||||
|
//
|
||||||
|
// func foo() any {
|
||||||
|
// type named int
|
||||||
|
// return named(0)
|
||||||
|
// }
|
||||||
|
if obj := c.interfaceTypes.At(typ); obj != nil {
|
||||||
|
global = obj.(llvm.Value)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular type (named or otherwise).
|
||||||
|
global = c.mod.NamedGlobal(globalName)
|
||||||
|
}
|
||||||
if global.IsNil() {
|
if global.IsNil() {
|
||||||
var typeFields []llvm.Value
|
var typeFields []llvm.Value
|
||||||
// Define the type fields. These must match the structs in
|
// Define the type fields. These must match the structs in
|
||||||
|
@ -203,6 +218,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
||||||
}
|
}
|
||||||
globalType := types.NewStruct(typeFieldTypes, nil)
|
globalType := types.NewStruct(typeFieldTypes, nil)
|
||||||
global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName)
|
global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName)
|
||||||
|
if isLocal {
|
||||||
|
c.interfaceTypes.Set(typ, global)
|
||||||
|
}
|
||||||
metabyte := getTypeKind(typ)
|
metabyte := getTypeKind(typ)
|
||||||
switch typ := typ.(type) {
|
switch typ := typ.(type) {
|
||||||
case *types.Basic:
|
case *types.Basic:
|
||||||
|
@ -330,7 +348,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
||||||
alignment := c.targetData.TypeAllocSize(c.i8ptrType)
|
alignment := c.targetData.TypeAllocSize(c.i8ptrType)
|
||||||
globalValue := c.ctx.ConstStruct(typeFields, false)
|
globalValue := c.ctx.ConstStruct(typeFields, false)
|
||||||
global.SetInitializer(globalValue)
|
global.SetInitializer(globalValue)
|
||||||
global.SetLinkage(llvm.LinkOnceODRLinkage)
|
if isLocal {
|
||||||
|
global.SetLinkage(llvm.InternalLinkage)
|
||||||
|
} else {
|
||||||
|
global.SetLinkage(llvm.LinkOnceODRLinkage)
|
||||||
|
}
|
||||||
global.SetGlobalConstant(true)
|
global.SetGlobalConstant(true)
|
||||||
global.SetAlignment(int(alignment))
|
global.SetAlignment(int(alignment))
|
||||||
if c.Debug {
|
if c.Debug {
|
||||||
|
@ -411,57 +433,84 @@ var basicTypeNames = [...]string{
|
||||||
// getTypeCodeName returns a name for this type that can be used in the
|
// getTypeCodeName returns a name for this type that can be used in the
|
||||||
// interface lowering pass to assign type codes as expected by the reflect
|
// interface lowering pass to assign type codes as expected by the reflect
|
||||||
// package. See getTypeCodeNum.
|
// package. See getTypeCodeNum.
|
||||||
func getTypeCodeName(t types.Type) string {
|
func getTypeCodeName(t types.Type) (string, bool) {
|
||||||
switch t := t.(type) {
|
switch t := t.(type) {
|
||||||
case *types.Named:
|
case *types.Named:
|
||||||
return "named:" + t.String()
|
// Note: check for `t.Obj().Pkg() != nil` for Go 1.18 only.
|
||||||
|
if t.Obj().Pkg() != nil && t.Obj().Parent() != t.Obj().Pkg().Scope() {
|
||||||
|
return "named:" + t.String() + "$local", true
|
||||||
|
}
|
||||||
|
return "named:" + t.String(), false
|
||||||
case *types.Array:
|
case *types.Array:
|
||||||
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem())
|
s, isLocal := getTypeCodeName(t.Elem())
|
||||||
|
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal
|
||||||
case *types.Basic:
|
case *types.Basic:
|
||||||
return "basic:" + basicTypeNames[t.Kind()]
|
return "basic:" + basicTypeNames[t.Kind()], false
|
||||||
case *types.Chan:
|
case *types.Chan:
|
||||||
return "chan:" + getTypeCodeName(t.Elem())
|
s, isLocal := getTypeCodeName(t.Elem())
|
||||||
|
return "chan:" + s, isLocal
|
||||||
case *types.Interface:
|
case *types.Interface:
|
||||||
|
isLocal := false
|
||||||
methods := make([]string, t.NumMethods())
|
methods := make([]string, t.NumMethods())
|
||||||
for i := 0; i < t.NumMethods(); i++ {
|
for i := 0; i < t.NumMethods(); i++ {
|
||||||
name := t.Method(i).Name()
|
name := t.Method(i).Name()
|
||||||
if !token.IsExported(name) {
|
if !token.IsExported(name) {
|
||||||
name = t.Method(i).Pkg().Path() + "." + name
|
name = t.Method(i).Pkg().Path() + "." + name
|
||||||
}
|
}
|
||||||
methods[i] = name + ":" + getTypeCodeName(t.Method(i).Type())
|
s, local := getTypeCodeName(t.Method(i).Type())
|
||||||
|
if local {
|
||||||
|
isLocal = true
|
||||||
|
}
|
||||||
|
methods[i] = name + ":" + s
|
||||||
}
|
}
|
||||||
return "interface:" + "{" + strings.Join(methods, ",") + "}"
|
return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal
|
||||||
case *types.Map:
|
case *types.Map:
|
||||||
keyType := getTypeCodeName(t.Key())
|
keyType, keyLocal := getTypeCodeName(t.Key())
|
||||||
elemType := getTypeCodeName(t.Elem())
|
elemType, elemLocal := getTypeCodeName(t.Elem())
|
||||||
return "map:" + "{" + keyType + "," + elemType + "}"
|
return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal
|
||||||
case *types.Pointer:
|
case *types.Pointer:
|
||||||
return "pointer:" + getTypeCodeName(t.Elem())
|
s, isLocal := getTypeCodeName(t.Elem())
|
||||||
|
return "pointer:" + s, isLocal
|
||||||
case *types.Signature:
|
case *types.Signature:
|
||||||
|
isLocal := false
|
||||||
params := make([]string, t.Params().Len())
|
params := make([]string, t.Params().Len())
|
||||||
for i := 0; i < t.Params().Len(); i++ {
|
for i := 0; i < t.Params().Len(); i++ {
|
||||||
params[i] = getTypeCodeName(t.Params().At(i).Type())
|
s, local := getTypeCodeName(t.Params().At(i).Type())
|
||||||
|
if local {
|
||||||
|
isLocal = true
|
||||||
|
}
|
||||||
|
params[i] = s
|
||||||
}
|
}
|
||||||
results := make([]string, t.Results().Len())
|
results := make([]string, t.Results().Len())
|
||||||
for i := 0; i < t.Results().Len(); i++ {
|
for i := 0; i < t.Results().Len(); i++ {
|
||||||
results[i] = getTypeCodeName(t.Results().At(i).Type())
|
s, local := getTypeCodeName(t.Results().At(i).Type())
|
||||||
|
if local {
|
||||||
|
isLocal = true
|
||||||
|
}
|
||||||
|
results[i] = s
|
||||||
}
|
}
|
||||||
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}"
|
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal
|
||||||
case *types.Slice:
|
case *types.Slice:
|
||||||
return "slice:" + getTypeCodeName(t.Elem())
|
s, isLocal := getTypeCodeName(t.Elem())
|
||||||
|
return "slice:" + s, isLocal
|
||||||
case *types.Struct:
|
case *types.Struct:
|
||||||
elems := make([]string, t.NumFields())
|
elems := make([]string, t.NumFields())
|
||||||
|
isLocal := false
|
||||||
for i := 0; i < t.NumFields(); i++ {
|
for i := 0; i < t.NumFields(); i++ {
|
||||||
embedded := ""
|
embedded := ""
|
||||||
if t.Field(i).Embedded() {
|
if t.Field(i).Embedded() {
|
||||||
embedded = "#"
|
embedded = "#"
|
||||||
}
|
}
|
||||||
elems[i] = embedded + t.Field(i).Name() + ":" + getTypeCodeName(t.Field(i).Type())
|
s, local := getTypeCodeName(t.Field(i).Type())
|
||||||
|
if local {
|
||||||
|
isLocal = true
|
||||||
|
}
|
||||||
|
elems[i] = embedded + t.Field(i).Name() + ":" + s
|
||||||
if t.Tag(i) != "" {
|
if t.Tag(i) != "" {
|
||||||
elems[i] += "`" + t.Tag(i) + "`"
|
elems[i] += "`" + t.Tag(i) + "`"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "struct:" + "{" + strings.Join(elems, ",") + "}"
|
return "struct:" + "{" + strings.Join(elems, ",") + "}", isLocal
|
||||||
default:
|
default:
|
||||||
panic("unknown type: " + t.String())
|
panic("unknown type: " + t.String())
|
||||||
}
|
}
|
||||||
|
@ -564,7 +613,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
|
||||||
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
|
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
globalName := "reflect/types.typeid:" + getTypeCodeName(expr.AssertedType)
|
assertedTypeGlobal := b.getTypeCode(expr.AssertedType)
|
||||||
|
if !assertedTypeGlobal.IsAConstantExpr().IsNil() {
|
||||||
|
assertedTypeGlobal = assertedTypeGlobal.Operand(0) // resolve the GEP operation
|
||||||
|
}
|
||||||
|
globalName := "reflect/types.typeid:" + strings.TrimPrefix(assertedTypeGlobal.Name(), "reflect/types.type:")
|
||||||
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
|
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
|
||||||
if assertedTypeCodeGlobal.IsNil() {
|
if assertedTypeCodeGlobal.IsNil() {
|
||||||
// Create a new typecode global.
|
// Create a new typecode global.
|
||||||
|
@ -640,7 +693,8 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
|
||||||
// getInterfaceImplementsfunc returns a declared function that works as a type
|
// getInterfaceImplementsfunc returns a declared function that works as a type
|
||||||
// switch. The interface lowering pass will define this function.
|
// switch. The interface lowering pass will define this function.
|
||||||
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
|
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
|
||||||
fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert"
|
s, _ := getTypeCodeName(assertedType.Underlying())
|
||||||
|
fnName := s + ".$typeassert"
|
||||||
llvmFn := c.mod.NamedFunction(fnName)
|
llvmFn := c.mod.NamedFunction(fnName)
|
||||||
if llvmFn.IsNil() {
|
if llvmFn.IsNil() {
|
||||||
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false)
|
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false)
|
||||||
|
@ -656,7 +710,8 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll
|
||||||
// thunk is declared, not defined: it will be defined by the interface lowering
|
// thunk is declared, not defined: it will be defined by the interface lowering
|
||||||
// pass.
|
// pass.
|
||||||
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
|
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
|
||||||
fnName := getTypeCodeName(instr.Value.Type().Underlying()) + "." + instr.Method.Name() + "$invoke"
|
s, _ := getTypeCodeName(instr.Value.Type().Underlying())
|
||||||
|
fnName := s + "." + instr.Method.Name() + "$invoke"
|
||||||
llvmFn := c.mod.NamedFunction(fnName)
|
llvmFn := c.mod.NamedFunction(fnName)
|
||||||
if llvmFn.IsNil() {
|
if llvmFn.IsNil() {
|
||||||
sig := instr.Method.Type().(*types.Signature)
|
sig := instr.Method.Type().(*types.Signature)
|
||||||
|
|
35
testdata/interface.go
предоставленный
35
testdata/interface.go
предоставленный
|
@ -93,6 +93,12 @@ func main() {
|
||||||
a int
|
a int
|
||||||
b int
|
b int
|
||||||
}{3, 6}},
|
}{3, 6}},
|
||||||
|
{true, named1(), named1()},
|
||||||
|
{true, named2(), named2()},
|
||||||
|
{false, named1(), named2()},
|
||||||
|
{false, named2(), named3()},
|
||||||
|
{true, namedptr1(), namedptr1()},
|
||||||
|
{false, namedptr1(), namedptr2()},
|
||||||
}
|
}
|
||||||
for i, tc := range interfaceEqualTests {
|
for i, tc := range interfaceEqualTests {
|
||||||
if (tc.lhs == tc.rhs) != tc.equal {
|
if (tc.lhs == tc.rhs) != tc.equal {
|
||||||
|
@ -277,3 +283,32 @@ func (f FooByte) Byte() byte { return byte(f) }
|
||||||
type Byter interface {
|
type Byter interface {
|
||||||
Byte() uint8
|
Byte() uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure that named types inside functions do not alias with any other named
|
||||||
|
// functions.
|
||||||
|
|
||||||
|
type named int
|
||||||
|
|
||||||
|
func named1() any {
|
||||||
|
return named(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func named2() any {
|
||||||
|
type named int
|
||||||
|
return named(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func named3() any {
|
||||||
|
type named int
|
||||||
|
return named(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedptr1() interface{} {
|
||||||
|
type Test int
|
||||||
|
return (*Test)(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedptr2() interface{} {
|
||||||
|
type Test byte
|
||||||
|
return (*Test)(nil)
|
||||||
|
}
|
||||||
|
|
Загрузка…
Создание таблицы
Сослаться в новой задаче