compiler: correctly generate code for local named types

It is possible to create function-local named types:

    func foo() any {
        type named int
        return named(0)
    }

This patch makes sure they don't alias with named types declared at the
package scope.

Bug originally found by Damian Gryski while working on reflect support.
Этот коммит содержится в:
Ayke van Laethem 2023-03-16 15:06:01 +01:00 коммит произвёл Ron Evans
родитель 17f5fb1071
коммит 523c6c0e3b
4 изменённых файлов: 116 добавлений и 24 удалений

Просмотреть файл

@ -71,6 +71,7 @@ type compilerContext struct {
difiles map[string]llvm.Metadata
ditypes map[types.Type]llvm.Metadata
llvmTypes typeutil.Map
interfaceTypes typeutil.Map
machine llvm.TargetMachine
targetData llvm.TargetData
intType llvm.Type

Просмотреть файл

@ -32,7 +32,8 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context
// global reference is not real, it is only used during func lowering to assign
// signature types to functions and will then be removed.
func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value {
sigGlobalName := "reflect/types.funcid:" + getTypeCodeName(sig)
s, _ := getTypeCodeName(sig)
sigGlobalName := "reflect/types.funcid:" + s
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
if sigGlobal.IsNil() {
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)

Просмотреть файл

@ -118,8 +118,23 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
if _, ok := typ.Underlying().(*types.Interface); ok {
hasMethodSet = false
}
globalName := "reflect/types.type:" + getTypeCodeName(typ)
global := c.mod.NamedGlobal(globalName)
typeCodeName, isLocal := getTypeCodeName(typ)
globalName := "reflect/types.type:" + typeCodeName
var global llvm.Value
if isLocal {
// This type is a named type inside a function, like this:
//
// func foo() any {
// type named int
// return named(0)
// }
if obj := c.interfaceTypes.At(typ); obj != nil {
global = obj.(llvm.Value)
}
} else {
// Regular type (named or otherwise).
global = c.mod.NamedGlobal(globalName)
}
if global.IsNil() {
var typeFields []llvm.Value
// Define the type fields. These must match the structs in
@ -203,6 +218,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
}
globalType := types.NewStruct(typeFieldTypes, nil)
global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName)
if isLocal {
c.interfaceTypes.Set(typ, global)
}
metabyte := getTypeKind(typ)
switch typ := typ.(type) {
case *types.Basic:
@ -330,7 +348,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
alignment := c.targetData.TypeAllocSize(c.i8ptrType)
globalValue := c.ctx.ConstStruct(typeFields, false)
global.SetInitializer(globalValue)
global.SetLinkage(llvm.LinkOnceODRLinkage)
if isLocal {
global.SetLinkage(llvm.InternalLinkage)
} else {
global.SetLinkage(llvm.LinkOnceODRLinkage)
}
global.SetGlobalConstant(true)
global.SetAlignment(int(alignment))
if c.Debug {
@ -411,57 +433,84 @@ var basicTypeNames = [...]string{
// getTypeCodeName returns a name for this type that can be used in the
// interface lowering pass to assign type codes as expected by the reflect
// package. See getTypeCodeNum.
func getTypeCodeName(t types.Type) string {
func getTypeCodeName(t types.Type) (string, bool) {
switch t := t.(type) {
case *types.Named:
return "named:" + t.String()
// Note: check for `t.Obj().Pkg() != nil` for Go 1.18 only.
if t.Obj().Pkg() != nil && t.Obj().Parent() != t.Obj().Pkg().Scope() {
return "named:" + t.String() + "$local", true
}
return "named:" + t.String(), false
case *types.Array:
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem())
s, isLocal := getTypeCodeName(t.Elem())
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal
case *types.Basic:
return "basic:" + basicTypeNames[t.Kind()]
return "basic:" + basicTypeNames[t.Kind()], false
case *types.Chan:
return "chan:" + getTypeCodeName(t.Elem())
s, isLocal := getTypeCodeName(t.Elem())
return "chan:" + s, isLocal
case *types.Interface:
isLocal := false
methods := make([]string, t.NumMethods())
for i := 0; i < t.NumMethods(); i++ {
name := t.Method(i).Name()
if !token.IsExported(name) {
name = t.Method(i).Pkg().Path() + "." + name
}
methods[i] = name + ":" + getTypeCodeName(t.Method(i).Type())
s, local := getTypeCodeName(t.Method(i).Type())
if local {
isLocal = true
}
methods[i] = name + ":" + s
}
return "interface:" + "{" + strings.Join(methods, ",") + "}"
return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal
case *types.Map:
keyType := getTypeCodeName(t.Key())
elemType := getTypeCodeName(t.Elem())
return "map:" + "{" + keyType + "," + elemType + "}"
keyType, keyLocal := getTypeCodeName(t.Key())
elemType, elemLocal := getTypeCodeName(t.Elem())
return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal
case *types.Pointer:
return "pointer:" + getTypeCodeName(t.Elem())
s, isLocal := getTypeCodeName(t.Elem())
return "pointer:" + s, isLocal
case *types.Signature:
isLocal := false
params := make([]string, t.Params().Len())
for i := 0; i < t.Params().Len(); i++ {
params[i] = getTypeCodeName(t.Params().At(i).Type())
s, local := getTypeCodeName(t.Params().At(i).Type())
if local {
isLocal = true
}
params[i] = s
}
results := make([]string, t.Results().Len())
for i := 0; i < t.Results().Len(); i++ {
results[i] = getTypeCodeName(t.Results().At(i).Type())
s, local := getTypeCodeName(t.Results().At(i).Type())
if local {
isLocal = true
}
results[i] = s
}
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}"
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal
case *types.Slice:
return "slice:" + getTypeCodeName(t.Elem())
s, isLocal := getTypeCodeName(t.Elem())
return "slice:" + s, isLocal
case *types.Struct:
elems := make([]string, t.NumFields())
isLocal := false
for i := 0; i < t.NumFields(); i++ {
embedded := ""
if t.Field(i).Embedded() {
embedded = "#"
}
elems[i] = embedded + t.Field(i).Name() + ":" + getTypeCodeName(t.Field(i).Type())
s, local := getTypeCodeName(t.Field(i).Type())
if local {
isLocal = true
}
elems[i] = embedded + t.Field(i).Name() + ":" + s
if t.Tag(i) != "" {
elems[i] += "`" + t.Tag(i) + "`"
}
}
return "struct:" + "{" + strings.Join(elems, ",") + "}"
return "struct:" + "{" + strings.Join(elems, ",") + "}", isLocal
default:
panic("unknown type: " + t.String())
}
@ -564,7 +613,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
} else {
globalName := "reflect/types.typeid:" + getTypeCodeName(expr.AssertedType)
assertedTypeGlobal := b.getTypeCode(expr.AssertedType)
if !assertedTypeGlobal.IsAConstantExpr().IsNil() {
assertedTypeGlobal = assertedTypeGlobal.Operand(0) // resolve the GEP operation
}
globalName := "reflect/types.typeid:" + strings.TrimPrefix(assertedTypeGlobal.Name(), "reflect/types.type:")
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
if assertedTypeCodeGlobal.IsNil() {
// Create a new typecode global.
@ -640,7 +693,8 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
// getInterfaceImplementsfunc returns a declared function that works as a type
// switch. The interface lowering pass will define this function.
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert"
s, _ := getTypeCodeName(assertedType.Underlying())
fnName := s + ".$typeassert"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false)
@ -656,7 +710,8 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll
// thunk is declared, not defined: it will be defined by the interface lowering
// pass.
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
fnName := getTypeCodeName(instr.Value.Type().Underlying()) + "." + instr.Method.Name() + "$invoke"
s, _ := getTypeCodeName(instr.Value.Type().Underlying())
fnName := s + "." + instr.Method.Name() + "$invoke"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
sig := instr.Method.Type().(*types.Signature)

35
testdata/interface.go предоставленный
Просмотреть файл

@ -93,6 +93,12 @@ func main() {
a int
b int
}{3, 6}},
{true, named1(), named1()},
{true, named2(), named2()},
{false, named1(), named2()},
{false, named2(), named3()},
{true, namedptr1(), namedptr1()},
{false, namedptr1(), namedptr2()},
}
for i, tc := range interfaceEqualTests {
if (tc.lhs == tc.rhs) != tc.equal {
@ -277,3 +283,32 @@ func (f FooByte) Byte() byte { return byte(f) }
type Byter interface {
Byte() uint8
}
// Make sure that named types inside functions do not alias with any other named
// functions.
type named int
func named1() any {
return named(0)
}
func named2() any {
type named int
return named(0)
}
func named3() any {
type named int
return named(0)
}
func namedptr1() interface{} {
type Test int
return (*Test)(nil)
}
func namedptr2() interface{} {
type Test byte
return (*Test)(nil)
}