compiler: correctly generate code for local named types
It is possible to create function-local named types: func foo() any { type named int return named(0) } This patch makes sure they don't alias with named types declared at the package scope. Bug originally found by Damian Gryski while working on reflect support.
Этот коммит содержится в:
родитель
17f5fb1071
коммит
523c6c0e3b
4 изменённых файлов: 116 добавлений и 24 удалений
|
@ -71,6 +71,7 @@ type compilerContext struct {
|
|||
difiles map[string]llvm.Metadata
|
||||
ditypes map[types.Type]llvm.Metadata
|
||||
llvmTypes typeutil.Map
|
||||
interfaceTypes typeutil.Map
|
||||
machine llvm.TargetMachine
|
||||
targetData llvm.TargetData
|
||||
intType llvm.Type
|
||||
|
|
|
@ -32,7 +32,8 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context
|
|||
// global reference is not real, it is only used during func lowering to assign
|
||||
// signature types to functions and will then be removed.
|
||||
func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value {
|
||||
sigGlobalName := "reflect/types.funcid:" + getTypeCodeName(sig)
|
||||
s, _ := getTypeCodeName(sig)
|
||||
sigGlobalName := "reflect/types.funcid:" + s
|
||||
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
|
||||
if sigGlobal.IsNil() {
|
||||
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
|
||||
|
|
|
@ -118,8 +118,23 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
|||
if _, ok := typ.Underlying().(*types.Interface); ok {
|
||||
hasMethodSet = false
|
||||
}
|
||||
globalName := "reflect/types.type:" + getTypeCodeName(typ)
|
||||
global := c.mod.NamedGlobal(globalName)
|
||||
typeCodeName, isLocal := getTypeCodeName(typ)
|
||||
globalName := "reflect/types.type:" + typeCodeName
|
||||
var global llvm.Value
|
||||
if isLocal {
|
||||
// This type is a named type inside a function, like this:
|
||||
//
|
||||
// func foo() any {
|
||||
// type named int
|
||||
// return named(0)
|
||||
// }
|
||||
if obj := c.interfaceTypes.At(typ); obj != nil {
|
||||
global = obj.(llvm.Value)
|
||||
}
|
||||
} else {
|
||||
// Regular type (named or otherwise).
|
||||
global = c.mod.NamedGlobal(globalName)
|
||||
}
|
||||
if global.IsNil() {
|
||||
var typeFields []llvm.Value
|
||||
// Define the type fields. These must match the structs in
|
||||
|
@ -203,6 +218,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
|||
}
|
||||
globalType := types.NewStruct(typeFieldTypes, nil)
|
||||
global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName)
|
||||
if isLocal {
|
||||
c.interfaceTypes.Set(typ, global)
|
||||
}
|
||||
metabyte := getTypeKind(typ)
|
||||
switch typ := typ.(type) {
|
||||
case *types.Basic:
|
||||
|
@ -330,7 +348,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
|
|||
alignment := c.targetData.TypeAllocSize(c.i8ptrType)
|
||||
globalValue := c.ctx.ConstStruct(typeFields, false)
|
||||
global.SetInitializer(globalValue)
|
||||
global.SetLinkage(llvm.LinkOnceODRLinkage)
|
||||
if isLocal {
|
||||
global.SetLinkage(llvm.InternalLinkage)
|
||||
} else {
|
||||
global.SetLinkage(llvm.LinkOnceODRLinkage)
|
||||
}
|
||||
global.SetGlobalConstant(true)
|
||||
global.SetAlignment(int(alignment))
|
||||
if c.Debug {
|
||||
|
@ -411,57 +433,84 @@ var basicTypeNames = [...]string{
|
|||
// getTypeCodeName returns a name for this type that can be used in the
|
||||
// interface lowering pass to assign type codes as expected by the reflect
|
||||
// package. See getTypeCodeNum.
|
||||
func getTypeCodeName(t types.Type) string {
|
||||
func getTypeCodeName(t types.Type) (string, bool) {
|
||||
switch t := t.(type) {
|
||||
case *types.Named:
|
||||
return "named:" + t.String()
|
||||
// Note: check for `t.Obj().Pkg() != nil` for Go 1.18 only.
|
||||
if t.Obj().Pkg() != nil && t.Obj().Parent() != t.Obj().Pkg().Scope() {
|
||||
return "named:" + t.String() + "$local", true
|
||||
}
|
||||
return "named:" + t.String(), false
|
||||
case *types.Array:
|
||||
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem())
|
||||
s, isLocal := getTypeCodeName(t.Elem())
|
||||
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal
|
||||
case *types.Basic:
|
||||
return "basic:" + basicTypeNames[t.Kind()]
|
||||
return "basic:" + basicTypeNames[t.Kind()], false
|
||||
case *types.Chan:
|
||||
return "chan:" + getTypeCodeName(t.Elem())
|
||||
s, isLocal := getTypeCodeName(t.Elem())
|
||||
return "chan:" + s, isLocal
|
||||
case *types.Interface:
|
||||
isLocal := false
|
||||
methods := make([]string, t.NumMethods())
|
||||
for i := 0; i < t.NumMethods(); i++ {
|
||||
name := t.Method(i).Name()
|
||||
if !token.IsExported(name) {
|
||||
name = t.Method(i).Pkg().Path() + "." + name
|
||||
}
|
||||
methods[i] = name + ":" + getTypeCodeName(t.Method(i).Type())
|
||||
s, local := getTypeCodeName(t.Method(i).Type())
|
||||
if local {
|
||||
isLocal = true
|
||||
}
|
||||
methods[i] = name + ":" + s
|
||||
}
|
||||
return "interface:" + "{" + strings.Join(methods, ",") + "}"
|
||||
return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal
|
||||
case *types.Map:
|
||||
keyType := getTypeCodeName(t.Key())
|
||||
elemType := getTypeCodeName(t.Elem())
|
||||
return "map:" + "{" + keyType + "," + elemType + "}"
|
||||
keyType, keyLocal := getTypeCodeName(t.Key())
|
||||
elemType, elemLocal := getTypeCodeName(t.Elem())
|
||||
return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal
|
||||
case *types.Pointer:
|
||||
return "pointer:" + getTypeCodeName(t.Elem())
|
||||
s, isLocal := getTypeCodeName(t.Elem())
|
||||
return "pointer:" + s, isLocal
|
||||
case *types.Signature:
|
||||
isLocal := false
|
||||
params := make([]string, t.Params().Len())
|
||||
for i := 0; i < t.Params().Len(); i++ {
|
||||
params[i] = getTypeCodeName(t.Params().At(i).Type())
|
||||
s, local := getTypeCodeName(t.Params().At(i).Type())
|
||||
if local {
|
||||
isLocal = true
|
||||
}
|
||||
params[i] = s
|
||||
}
|
||||
results := make([]string, t.Results().Len())
|
||||
for i := 0; i < t.Results().Len(); i++ {
|
||||
results[i] = getTypeCodeName(t.Results().At(i).Type())
|
||||
s, local := getTypeCodeName(t.Results().At(i).Type())
|
||||
if local {
|
||||
isLocal = true
|
||||
}
|
||||
results[i] = s
|
||||
}
|
||||
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}"
|
||||
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal
|
||||
case *types.Slice:
|
||||
return "slice:" + getTypeCodeName(t.Elem())
|
||||
s, isLocal := getTypeCodeName(t.Elem())
|
||||
return "slice:" + s, isLocal
|
||||
case *types.Struct:
|
||||
elems := make([]string, t.NumFields())
|
||||
isLocal := false
|
||||
for i := 0; i < t.NumFields(); i++ {
|
||||
embedded := ""
|
||||
if t.Field(i).Embedded() {
|
||||
embedded = "#"
|
||||
}
|
||||
elems[i] = embedded + t.Field(i).Name() + ":" + getTypeCodeName(t.Field(i).Type())
|
||||
s, local := getTypeCodeName(t.Field(i).Type())
|
||||
if local {
|
||||
isLocal = true
|
||||
}
|
||||
elems[i] = embedded + t.Field(i).Name() + ":" + s
|
||||
if t.Tag(i) != "" {
|
||||
elems[i] += "`" + t.Tag(i) + "`"
|
||||
}
|
||||
}
|
||||
return "struct:" + "{" + strings.Join(elems, ",") + "}"
|
||||
return "struct:" + "{" + strings.Join(elems, ",") + "}", isLocal
|
||||
default:
|
||||
panic("unknown type: " + t.String())
|
||||
}
|
||||
|
@ -564,7 +613,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
|
|||
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
|
||||
|
||||
} else {
|
||||
globalName := "reflect/types.typeid:" + getTypeCodeName(expr.AssertedType)
|
||||
assertedTypeGlobal := b.getTypeCode(expr.AssertedType)
|
||||
if !assertedTypeGlobal.IsAConstantExpr().IsNil() {
|
||||
assertedTypeGlobal = assertedTypeGlobal.Operand(0) // resolve the GEP operation
|
||||
}
|
||||
globalName := "reflect/types.typeid:" + strings.TrimPrefix(assertedTypeGlobal.Name(), "reflect/types.type:")
|
||||
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
|
||||
if assertedTypeCodeGlobal.IsNil() {
|
||||
// Create a new typecode global.
|
||||
|
@ -640,7 +693,8 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
|
|||
// getInterfaceImplementsfunc returns a declared function that works as a type
|
||||
// switch. The interface lowering pass will define this function.
|
||||
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
|
||||
fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert"
|
||||
s, _ := getTypeCodeName(assertedType.Underlying())
|
||||
fnName := s + ".$typeassert"
|
||||
llvmFn := c.mod.NamedFunction(fnName)
|
||||
if llvmFn.IsNil() {
|
||||
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false)
|
||||
|
@ -656,7 +710,8 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll
|
|||
// thunk is declared, not defined: it will be defined by the interface lowering
|
||||
// pass.
|
||||
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
|
||||
fnName := getTypeCodeName(instr.Value.Type().Underlying()) + "." + instr.Method.Name() + "$invoke"
|
||||
s, _ := getTypeCodeName(instr.Value.Type().Underlying())
|
||||
fnName := s + "." + instr.Method.Name() + "$invoke"
|
||||
llvmFn := c.mod.NamedFunction(fnName)
|
||||
if llvmFn.IsNil() {
|
||||
sig := instr.Method.Type().(*types.Signature)
|
||||
|
|
35
testdata/interface.go
предоставленный
35
testdata/interface.go
предоставленный
|
@ -93,6 +93,12 @@ func main() {
|
|||
a int
|
||||
b int
|
||||
}{3, 6}},
|
||||
{true, named1(), named1()},
|
||||
{true, named2(), named2()},
|
||||
{false, named1(), named2()},
|
||||
{false, named2(), named3()},
|
||||
{true, namedptr1(), namedptr1()},
|
||||
{false, namedptr1(), namedptr2()},
|
||||
}
|
||||
for i, tc := range interfaceEqualTests {
|
||||
if (tc.lhs == tc.rhs) != tc.equal {
|
||||
|
@ -277,3 +283,32 @@ func (f FooByte) Byte() byte { return byte(f) }
|
|||
type Byter interface {
|
||||
Byte() uint8
|
||||
}
|
||||
|
||||
// Make sure that named types inside functions do not alias with any other named
|
||||
// functions.
|
||||
|
||||
type named int
|
||||
|
||||
func named1() any {
|
||||
return named(0)
|
||||
}
|
||||
|
||||
func named2() any {
|
||||
type named int
|
||||
return named(0)
|
||||
}
|
||||
|
||||
func named3() any {
|
||||
type named int
|
||||
return named(0)
|
||||
}
|
||||
|
||||
func namedptr1() interface{} {
|
||||
type Test int
|
||||
return (*Test)(nil)
|
||||
}
|
||||
|
||||
func namedptr2() interface{} {
|
||||
type Test byte
|
||||
return (*Test)(nil)
|
||||
}
|
||||
|
|
Загрузка…
Создание таблицы
Сослаться в новой задаче