diff --git a/compiler/interface.go b/compiler/interface.go index beda21a6..e43a15ec 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -53,21 +53,28 @@ func (c *Compiler) getTypeCode(typ types.Type) llvm.Value { // Some type classes contain more information for underlying types or // element types. Store it directly in the typecode global to make // reflect lowering simpler. - var elementType types.Type + var references llvm.Value switch typ := typ.(type) { case *types.Named: - elementType = typ.Underlying() + references = c.getTypeCode(typ.Underlying()) case *types.Chan: - elementType = typ.Elem() + references = c.getTypeCode(typ.Elem()) case *types.Pointer: - elementType = typ.Elem() + references = c.getTypeCode(typ.Elem()) case *types.Slice: - elementType = typ.Elem() + references = c.getTypeCode(typ.Elem()) + case *types.Struct: + // Take a pointer to the typecodeID of the first field (if it exists). + structGlobal := c.makeStructTypeFields(typ) + references = llvm.ConstGEP(structGlobal, []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 0, false), + }) } - if elementType != nil { + if !references.IsNil() { // Set the 'references' field of the runtime.typecodeID struct. globalValue := c.getZeroValue(global.Type().ElementType()) - globalValue = llvm.ConstInsertValue(globalValue, c.getTypeCode(elementType), []uint32{0}) + globalValue = llvm.ConstInsertValue(globalValue, references, []uint32{0}) global.SetInitializer(globalValue) global.SetLinkage(llvm.PrivateLinkage) } @@ -76,6 +83,48 @@ func (c *Compiler) getTypeCode(typ types.Type) llvm.Value { return global } +// makeStructTypeFields creates a new global that stores all type information +// related to this struct type, and returns the resulting global. This global is +// actually an array of all the fields in the structs. +func (c *Compiler) makeStructTypeFields(typ *types.Struct) llvm.Value { + // The global is an array of runtime.structField structs. + runtimeStructField := c.getLLVMRuntimeType("structField") + structGlobalType := llvm.ArrayType(runtimeStructField, typ.NumFields()) + structGlobal := llvm.AddGlobal(c.mod, structGlobalType, "reflect/types.structFields") + structGlobalValue := c.getZeroValue(structGlobalType) + for i := 0; i < typ.NumFields(); i++ { + fieldGlobalValue := c.getZeroValue(runtimeStructField) + fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, c.getTypeCode(typ.Field(i).Type()), []uint32{0}) + fieldName := c.makeGlobalBytes([]byte(typ.Field(i).Name()), "reflect/types.structFieldName") + fieldName = llvm.ConstGEP(fieldName, []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 0, false), + }) + fieldName.SetLinkage(llvm.PrivateLinkage) + fieldName.SetUnnamedAddr(true) + fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, fieldName, []uint32{1}) + if typ.Tag(i) != "" { + fieldTag := c.makeGlobalBytes([]byte(typ.Tag(i)), "reflect/types.structFieldTag") + fieldTag = llvm.ConstGEP(fieldTag, []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), 0, false), + }) + fieldTag.SetLinkage(llvm.PrivateLinkage) + fieldTag.SetUnnamedAddr(true) + fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, fieldTag, []uint32{2}) + } + if typ.Field(i).Embedded() { + fieldEmbedded := llvm.ConstInt(c.ctx.Int1Type(), 1, false) + fieldGlobalValue = llvm.ConstInsertValue(fieldGlobalValue, fieldEmbedded, []uint32{3}) + } + structGlobalValue = llvm.ConstInsertValue(structGlobalValue, fieldGlobalValue, []uint32{uint32(i)}) + } + structGlobal.SetInitializer(structGlobalValue) + structGlobal.SetUnnamedAddr(true) + structGlobal.SetLinkage(llvm.PrivateLinkage) + return structGlobal +} + // getTypeCodeName returns a name for this type that can be used in the // interface lowering pass to assign type codes as expected by the reflect // package. See getTypeCodeNum. diff --git a/compiler/llvm.go b/compiler/llvm.go index fb52b006..79dbbfd1 100644 --- a/compiler/llvm.go +++ b/compiler/llvm.go @@ -152,3 +152,46 @@ func (c *Compiler) splitBasicBlock(afterInst llvm.Value, insertAfter llvm.BasicB return newBlock } + +// makeGlobalBytes creates a new LLVM global with the given name and bytes as +// contents, and returns the global. +// Note that it is left with the default linkage etc., you should set +// linkage/constant/etc properties yourself. +func (c *Compiler) makeGlobalBytes(buf []byte, name string) llvm.Value { + globalType := llvm.ArrayType(c.ctx.Int8Type(), len(buf)) + global := llvm.AddGlobal(c.mod, globalType, name) + value := llvm.Undef(globalType) + for i, ch := range buf { + value = llvm.ConstInsertValue(value, llvm.ConstInt(c.ctx.Int8Type(), uint64(ch), false), []uint32{uint32(i)}) + } + global.SetInitializer(value) + return global +} + +// getGlobalBytes returns the byte slice contained in the i8 array of the +// provided global. It can recover the bytes originally created using +// makeGlobalBytes. +func getGlobalBytes(global llvm.Value) []byte { + value := global.Initializer() + buf := make([]byte, value.Type().ArrayLength()) + for i := range buf { + buf[i] = byte(llvm.ConstExtractValue(value, []uint32{uint32(i)}).ZExtValue()) + } + return buf +} + +// replaceGlobalByteWithArray replaces a global i8 in the module with a byte +// array, using a GEP to make the types match. It is a convenience function used +// for creating reflection sidetables, for example. +func (c *Compiler) replaceGlobalByteWithArray(name string, buf []byte) llvm.Value { + global := c.makeGlobalBytes(buf, name+".tmp") + oldGlobal := c.mod.NamedGlobal(name) + gep := llvm.ConstGEP(global, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }) + oldGlobal.ReplaceAllUsesWith(gep) + oldGlobal.EraseFromParentAsGlobal() + global.SetName(name) + return global +} diff --git a/compiler/reflect.go b/compiler/reflect.go index 3af0bfb4..135e3b60 100644 --- a/compiler/reflect.go +++ b/compiler/reflect.go @@ -28,6 +28,8 @@ package compiler // non-basic types have their underlying type stored in a sidetable. import ( + "encoding/binary" + "go/ast" "math/big" "strings" @@ -65,11 +67,25 @@ type typeCodeAssignmentState struct { // package (or are simply unused in the compiled program). fallbackIndex int + // This is the length of an uintptr. Only used occasionally to know whether + // a given number can be encoded as a varint. + uintptrLen int + // Map of named types to their type code. It is important that named types // get unique IDs for each type. namedBasicTypes map[string]int namedNonBasicTypes map[string]int + // Map of struct types to their type code. + structTypes map[string]int + structTypesSidetable []byte + needsStructNamesSidetable bool + + // Map of struct names and tags to their name string. + structNames map[string]int + structNamesSidetable []byte + needsStructTypesSidetable bool + // This byte array is stored in reflect.namedNonBasicTypesSidetable and is // used at runtime to get details about a named non-basic type. // Entries are varints (see makeVarint below and readVarint in @@ -82,10 +98,6 @@ type typeCodeAssignmentState struct { // needsNamedTypesSidetable. namedNonBasicTypesSidetable []byte - // This is the length of an uintptr. Only used occasionally to know whether - // a given number can be encoded as a varint. - uintptrLen int - // This indicates whether namedNonBasicTypesSidetable needs to be created at // all. If it is false, namedNonBasicTypesSidetable will contain simple // monotonically increasing numbers. @@ -109,13 +121,17 @@ func (c *Compiler) assignTypeCodes(typeSlice typeInfoSlice) { // Assign typecodes the way the reflect package expects. state := typeCodeAssignmentState{ fallbackIndex: 1, + uintptrLen: c.uintptrType.IntTypeWidth(), namedBasicTypes: make(map[string]int), namedNonBasicTypes: make(map[string]int), - uintptrLen: c.uintptrType.IntTypeWidth(), + structTypes: make(map[string]int), + structNames: make(map[string]int), needsNamedNonBasicTypesSidetable: len(getUses(c.mod.NamedGlobal("reflect.namedNonBasicTypesSidetable"))) != 0, + needsStructTypesSidetable: len(getUses(c.mod.NamedGlobal("reflect.structTypesSidetable"))) != 0, + needsStructNamesSidetable: len(getUses(c.mod.NamedGlobal("reflect.structNamesSidetable"))) != 0, } for _, t := range typeSlice { - num := c.getTypeCodeNum(t.typecode, &state) + num := state.getTypeCodeNum(t.typecode) if num.BitLen() > c.uintptrType.IntTypeWidth() || !num.IsUint64() { // TODO: support this in some way, using a side table for example. // That's less efficient but better than not working at all. @@ -128,29 +144,26 @@ func (c *Compiler) assignTypeCodes(typeSlice typeInfoSlice) { // Only create this sidetable when it is necessary. if state.needsNamedNonBasicTypesSidetable { - // Create the sidetable and replace the old dummy global with this value. - globalType := llvm.ArrayType(c.ctx.Int8Type(), len(state.namedNonBasicTypesSidetable)) - global := llvm.AddGlobal(c.mod, globalType, "reflect.namedNonBasicTypesSidetable.tmp") - value := llvm.Undef(globalType) - for i, ch := range state.namedNonBasicTypesSidetable { - value = llvm.ConstInsertValue(value, llvm.ConstInt(c.ctx.Int8Type(), uint64(ch), false), []uint32{uint32(i)}) - } - global.SetInitializer(value) - oldGlobal := c.mod.NamedGlobal("reflect.namedNonBasicTypesSidetable") - gep := llvm.ConstGEP(global, []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - }) - oldGlobal.ReplaceAllUsesWith(gep) - oldGlobal.EraseFromParentAsGlobal() - global.SetName("reflect.namedNonBasicTypesSidetable") + global := c.replaceGlobalByteWithArray("reflect.namedNonBasicTypesSidetable", state.namedNonBasicTypesSidetable) + global.SetLinkage(llvm.InternalLinkage) + global.SetUnnamedAddr(true) + } + if state.needsStructTypesSidetable { + global := c.replaceGlobalByteWithArray("reflect.structTypesSidetable", state.structTypesSidetable) + global.SetLinkage(llvm.InternalLinkage) + global.SetUnnamedAddr(true) + } + if state.needsStructNamesSidetable { + global := c.replaceGlobalByteWithArray("reflect.structNamesSidetable", state.structNamesSidetable) + global.SetLinkage(llvm.InternalLinkage) + global.SetUnnamedAddr(true) } } // getTypeCodeNum returns the typecode for a given type as expected by the // reflect package. Also see getTypeCodeName, which serializes types to a string // based on a types.Type value for this function. -func (c *Compiler) getTypeCodeNum(typecode llvm.Value, state *typeCodeAssignmentState) *big.Int { +func (state *typeCodeAssignmentState) getTypeCodeNum(typecode llvm.Value) *big.Int { // Note: see src/reflect/type.go for bit allocations. class, value := getClassAndValueFromTypeCode(typecode) name := "" @@ -186,7 +199,7 @@ func (c *Compiler) getTypeCodeNum(typecode llvm.Value, state *typeCodeAssignment switch class { case "chan": sub := llvm.ConstExtractValue(typecode.Initializer(), []uint32{0}) - num = c.getTypeCodeNum(sub, state) + num = state.getTypeCodeNum(sub) classNumber = 0 case "interface": num = big.NewInt(int64(state.fallbackIndex)) @@ -194,11 +207,11 @@ func (c *Compiler) getTypeCodeNum(typecode llvm.Value, state *typeCodeAssignment classNumber = 1 case "pointer": sub := llvm.ConstExtractValue(typecode.Initializer(), []uint32{0}) - num = c.getTypeCodeNum(sub, state) + num = state.getTypeCodeNum(sub) classNumber = 2 case "slice": sub := llvm.ConstExtractValue(typecode.Initializer(), []uint32{0}) - num = c.getTypeCodeNum(sub, state) + num = state.getTypeCodeNum(sub) classNumber = 3 case "array": num = big.NewInt(int64(state.fallbackIndex)) @@ -213,8 +226,7 @@ func (c *Compiler) getTypeCodeNum(typecode llvm.Value, state *typeCodeAssignment state.fallbackIndex++ classNumber = 6 case "struct": - num = big.NewInt(int64(state.fallbackIndex)) - state.fallbackIndex++ + num = big.NewInt(int64(state.getStructTypeNum(typecode))) classNumber = 7 default: panic("unknown type kind: " + class) @@ -283,36 +295,125 @@ func (state *typeCodeAssignmentState) getNonBasicNamedTypeNum(name string, value return num } -// makeVarint encodes a varint in a way that should be easy to decode. -// It may need to be decoded very quickly at runtime at low-powered processors -// so should be efficient to decode. -// The current algorithm is probably not even close to efficient, but it is easy -// to change as the format is only used inside the same program. -func makeVarint(n uint64) []byte { - // This is the reverse of what src/runtime/sidetables.go does. - buf := make([]byte, 0, 8) - for { - c := byte(n & 0x7f << 1) - n >>= 7 - if n != 0 { - c |= 1 +// getStructTypeNum returns the struct type number, which is an index into +// reflect.structTypesSidetable or an unique number for every struct if this +// sidetable is not needed in the to-be-compiled program. +func (state *typeCodeAssignmentState) getStructTypeNum(typecode llvm.Value) int { + name := typecode.Name() + if num, ok := state.structTypes[name]; ok { + // This struct already has an assigned type code. + return num + } + + if !state.needsStructTypesSidetable { + // We don't need struct sidetables, so we can just assign monotonically + // increasing numbers to each struct type. + num := len(state.structTypes) + state.structTypes[name] = num + return num + } + + // Get the fields this struct type contains. + // The struct number will be the start index of + structTypeGlobal := llvm.ConstExtractValue(typecode.Initializer(), []uint32{0}).Operand(0).Initializer() + numFields := structTypeGlobal.Type().ArrayLength() + + // The first data that is stored in the struct sidetable is the number of + // fields this struct contains. This is usually just a single byte because + // most structs don't contain that many fields, but make it a varint just + // to be sure. + buf := makeVarint(uint64(numFields)) + + // Iterate over every field in the struct. + // Every field is stored sequentially in the struct sidetable. Fields can + // be retrieved from this list of fields at runtime by iterating over all + // of them until the right field has been found. + // Perhaps adding some index would speed things up, but it would also make + // the sidetable bigger. + for i := 0; i < numFields; i++ { + // Collect some information about this field. + field := llvm.ConstExtractValue(structTypeGlobal, []uint32{uint32(i)}) + + nameGlobal := llvm.ConstExtractValue(field, []uint32{1}) + if nameGlobal == llvm.ConstPointerNull(nameGlobal.Type()) { + panic("compiler: no name for this struct field") } - buf = append(buf, c) - if n == 0 { - break + fieldNameBytes := getGlobalBytes(nameGlobal.Operand(0)) + fieldNameNumber := state.getStructNameNumber(fieldNameBytes) + + // See whether this struct field has an associated tag, and if so, + // store that tag in the tags sidetable. + tagGlobal := llvm.ConstExtractValue(field, []uint32{2}) + hasTag := false + tagNumber := 0 + if tagGlobal != llvm.ConstPointerNull(tagGlobal.Type()) { + hasTag = true + tagBytes := getGlobalBytes(tagGlobal.Operand(0)) + tagNumber = state.getStructNameNumber(tagBytes) + } + + // The 'embedded' or 'anonymous' flag for this field. + embedded := llvm.ConstExtractValue(field, []uint32{3}).ZExtValue() != 0 + + // The first byte in the struct types sidetable is a flags byte with + // two bits in it. + flagsByte := byte(0) + if embedded { + flagsByte |= 1 + } + if hasTag { + flagsByte |= 2 + } + if ast.IsExported(string(fieldNameBytes)) { + flagsByte |= 4 + } + buf = append(buf, flagsByte) + + // Get the type number and add it to the buffer. + // All fields have a type, so include it directly here. + typeNum := state.getTypeCodeNum(llvm.ConstExtractValue(field, []uint32{0})) + if typeNum.BitLen() > state.uintptrLen || !typeNum.IsUint64() { + // TODO: make this a regular error + panic("struct field has a type code that is too big") + } + buf = append(buf, makeVarint(typeNum.Uint64())...) + + // Add the name. + buf = append(buf, makeVarint(uint64(fieldNameNumber))...) + + // Add the tag, if there is one. + if hasTag { + buf = append(buf, makeVarint(uint64(tagNumber))...) } } - reverseBytes(buf) - return buf + + num := len(state.structTypesSidetable) + state.structTypes[name] = num + state.structTypesSidetable = append(state.structTypesSidetable, buf...) + return num } -func reverseBytes(s []byte) { - // Actually copied from https://blog.golang.org/why-generics - first := 0 - last := len(s) - 1 - for first < last { - s[first], s[last] = s[last], s[first] - first++ - last-- +// getStructNameNumber stores this string (name or tag) onto the struct names +// sidetable. The format is a varint of the length of the struct, followed by +// the raw bytes of the name. Multiple identical strings are stored under the +// same name for space efficiency. +func (state *typeCodeAssignmentState) getStructNameNumber(nameBytes []byte) int { + name := string(nameBytes) + if n, ok := state.structNames[name]; ok { + // This name was used before, re-use it now (for space efficiency). + return n } + // This name is not yet in the names sidetable. Add it now. + n := len(state.structNamesSidetable) + state.structNames[name] = n + state.structNamesSidetable = append(state.structNamesSidetable, makeVarint(uint64(len(nameBytes)))...) + state.structNamesSidetable = append(state.structNamesSidetable, nameBytes...) + return n +} + +// makeVarint is a small helper function that returns the bytes of the number in +// varint encoding. +func makeVarint(n uint64) []byte { + buf := make([]byte, binary.MaxVarintLen64) + return buf[:binary.PutUvarint(buf, n)] } diff --git a/src/reflect/sidetables.go b/src/reflect/sidetables.go index 75fb80b1..c4012f0a 100644 --- a/src/reflect/sidetables.go +++ b/src/reflect/sidetables.go @@ -10,22 +10,48 @@ import ( //go:extern reflect.namedNonBasicTypesSidetable var namedNonBasicTypesSidetable byte -func readVarint(buf unsafe.Pointer) Type { - var t Type +//go:extern reflect.structTypesSidetable +var structTypesSidetable byte + +//go:extern reflect.structNamesSidetable +var structNamesSidetable byte + +// readStringSidetable reads a string from the given table (like +// structNamesSidetable) and returns this string. No heap allocation is +// necessary because it makes the string point directly to the raw bytes of the +// table. +func readStringSidetable(table unsafe.Pointer, index uintptr) string { + nameLen, namePtr := readVarint(unsafe.Pointer(uintptr(table) + index)) + return *(*string)(unsafe.Pointer(&StringHeader{ + Data: uintptr(namePtr), + Len: nameLen, + })) +} + +// readVarint decodes a varint as used in the encoding/binary package. +// It has an input pointer and returns the read varint and the pointer +// incremented to the next field in the data structure, just after the varint. +// +// Details: +// https://github.com/golang/go/blob/e37a1b1c/src/encoding/binary/varint.go#L7-L25 +func readVarint(buf unsafe.Pointer) (uintptr, unsafe.Pointer) { + var n uintptr + shift := uintptr(0) for { - // Read the next byte. + // Read the next byte in the buffer. c := *(*byte)(buf) - // Add this byte to the type code. The upper 7 bits are the value. - t = t<<7 | Type(c>>1) - - // Check whether this is the last byte of this varint. The lower bit - // indicates whether any bytes follow. - if c%1 == 0 { - return t - } + // Decode the bits from this byte and add them to the output number. + n |= uintptr(c&0x7f) << shift + shift += 7 // Increment the buf pointer (pointer arithmetic!). buf = unsafe.Pointer(uintptr(buf) + 1) + + // Check whether this is the last byte of this varint. The upper bit + // (msb) indicates whether any bytes follow. + if c>>7 == 0 { + return n, buf + } } } diff --git a/src/reflect/type.go b/src/reflect/type.go index 56ce0857..91f845e2 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -145,23 +145,102 @@ func (t Type) Kind() Kind { func (t Type) Elem() Type { switch t.Kind() { case Chan, Ptr, Slice: - // Look at the 'n' bit in the type code (see the top of this file) to - // see whether this is a named type. - if (t>>4)%2 != 0 { - // This is a named type. The element type is stored in a sidetable. - namedTypeNum := t >> 5 - return readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&namedNonBasicTypesSidetable)) + uintptr(namedTypeNum))) - } - // Not a named type, so the element type is stored directly in the type - // code. - return t >> 5 + return t.stripPrefix() default: // not implemented: Array, Map panic("unimplemented: (reflect.Type).Elem()") } } +// stripPrefix removes the "prefix" (the first 5 bytes of the type code) from +// the type code. If this is a named type, it will resolve the underlying type +// (which is the data for this named type). If it is not, the lower bits are +// simply shifted off. +// +// The behavior is only defined for non-basic types. +func (t Type) stripPrefix() Type { + // Look at the 'n' bit in the type code (see the top of this file) to see + // whether this is a named type. + if (t>>4)%2 != 0 { + // This is a named type. The data is stored in a sidetable. + namedTypeNum := t >> 5 + n, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&namedNonBasicTypesSidetable)) + uintptr(namedTypeNum))) + return Type(n) + } + // Not a named type, so the value is stored directly in the type code. + return t >> 5 +} + +// Field returns the type of the i'th field of this struct type. It panics if t +// is not a struct type. func (t Type) Field(i int) StructField { - panic("unimplemented: (reflect.Type).Field()") + if t.Kind() != Struct { + panic(&TypeError{"Field"}) + } + structIdentifier := t.stripPrefix() + + numField, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier))) + if uint(i) >= uint(numField) { + panic("reflect: field index out of range") + } + + // Iterate over every field in the struct and update the StructField each + // time, until the target field has been reached. This is very much not + // efficient, but it is easy to implement. + // Adding a jump table at the start to jump to the field directly would + // make this much faster, but that would also impact code size. + field := StructField{} + offset := uintptr(0) + for fieldNum := 0; fieldNum <= i; fieldNum++ { + // Read some flags of this field, like whether the field is an + // embedded field. + flagsByte := *(*uint8)(p) + p = unsafe.Pointer(uintptr(p) + 1) + + // Read the type of this struct field. + var fieldType uintptr + fieldType, p = readVarint(p) + field.Type = Type(fieldType) + + // Move Offset forward to align it to this field's alignment. + // Assume alignment is a power of two. + offset = align(offset, uintptr(field.Type.Align())) + field.Offset = offset + offset += field.Type.Size() // starting (unaligned) offset for next field + + // Read the field name. + var nameNum uintptr + nameNum, p = readVarint(p) + field.Name = readStringSidetable(unsafe.Pointer(&structNamesSidetable), nameNum) + + // The first bit in the flagsByte indicates whether this is an embedded + // field. + field.Anonymous = flagsByte&1 != 0 + + // The second bit indicates whether there is a tag. + if flagsByte&2 != 0 { + // There is a tag. + var tagNum uintptr + tagNum, p = readVarint(p) + field.Tag = readStringSidetable(unsafe.Pointer(&structNamesSidetable), tagNum) + } else { + // There is no tag. + field.Tag = "" + } + + // The third bit indicates whether this field is exported. + if flagsByte&4 != 0 { + // This field is exported. + field.PkgPath = "" + } else { + // This field is unexported. + // TODO: list the real package path here. Storing it should not + // significantly impact binary size as there is only a limited + // number of packages in any program. + field.PkgPath = "" + } + } + + return field } // Bits returns the number of bits that this type uses. It is only valid for @@ -179,10 +258,19 @@ func (t Type) Len() int { panic("unimplemented: (reflect.Type).Len()") } +// NumField returns the number of fields of a struct type. It panics for other +// type kinds. func (t Type) NumField() int { - panic("unimplemented: (reflect.Type).NumField()") + if t.Kind() != Struct { + panic(&TypeError{"NumField"}) + } + structIdentifier := t.stripPrefix() + n, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier))) + return int(n) } +// Size returns the size in bytes of a given type. It is similar to +// unsafe.Sizeof. func (t Type) Size() uintptr { switch t.Kind() { case Bool, Int8, Uint8: @@ -211,6 +299,15 @@ func (t Type) Size() uintptr { return unsafe.Sizeof(uintptr(0)) case Slice: return unsafe.Sizeof(SliceHeader{}) + case Interface: + return unsafe.Sizeof(interfaceHeader{}) + case Struct: + numField := t.NumField() + if numField == 0 { + return 0 + } + lastField := t.Field(numField - 1) + return lastField.Offset + lastField.Type.Size() default: panic("unimplemented: size of type") } @@ -246,6 +343,18 @@ func (t Type) Align() int { return int(unsafe.Alignof(uintptr(0))) case Slice: return int(unsafe.Alignof(SliceHeader{})) + case Interface: + return int(unsafe.Alignof(interfaceHeader{})) + case Struct: + numField := t.NumField() + alignment := 1 + for i := 0; i < numField; i++ { + fieldAlignment := t.Field(i).Type.Align() + if fieldAlignment > alignment { + alignment = fieldAlignment + } + } + return alignment default: panic("unimplemented: alignment of type") } @@ -269,9 +378,19 @@ func (t Type) AssignableTo(u Type) bool { return false } +// A StructField describes a single field in a struct. type StructField struct { + // Name indicates the field name. Name string - Type Type + + // PkgPath is the package path where the struct containing this field is + // declared for unexported fields, or the empty string for exported fields. + PkgPath string + + Type Type + Tag string + Anonymous bool + Offset uintptr } // TypeError is the error that is used in a panic when invoking a method on a diff --git a/src/reflect/value.go b/src/reflect/value.go index 950e358b..3f36b9ff 100644 --- a/src/reflect/value.go +++ b/src/reflect/value.go @@ -4,10 +4,28 @@ import ( "unsafe" ) +type valueFlags uint8 + +// Flags list some useful flags that contain some extra information not +// contained in an interface{} directly, like whether this value was exported at +// all (it is possible to read unexported fields using reflection, but it is not +// possible to modify them). +const ( + valueFlagIndirect valueFlags = 1 << iota + valueFlagExported +) + type Value struct { typecode Type value unsafe.Pointer - indirect bool + flags valueFlags +} + +// isIndirect returns whether the value pointer in this Value is always a +// pointer to the value. If it is false, it is only a pointer to the value if +// the value is bigger than a pointer. +func (v Value) isIndirect() bool { + return v.flags&valueFlagIndirect != 0 } func Indirect(v Value) Value { @@ -22,6 +40,7 @@ func ValueOf(i interface{}) Value { return Value{ typecode: v.typecode, value: v.value, + flags: valueFlagExported, } } @@ -30,7 +49,7 @@ func (v Value) Interface() interface{} { typecode: v.typecode, value: v.value, } - if v.indirect && v.Type().Size() <= unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() && v.Type().Size() <= unsafe.Sizeof(uintptr(0)) { // Value was indirect but must be put back directly in the interface // value. var value uintptr @@ -109,13 +128,13 @@ func (v Value) Addr() Value { } func (v Value) CanSet() bool { - return v.indirect + return v.flags&(valueFlagExported|valueFlagIndirect) == valueFlagExported|valueFlagIndirect } func (v Value) Bool() bool { switch v.Kind() { case Bool: - if v.indirect { + if v.isIndirect() { return *((*bool)(v.value)) } else { return uintptr(v.value) != 0 @@ -128,31 +147,31 @@ func (v Value) Bool() bool { func (v Value) Int() int64 { switch v.Kind() { case Int: - if v.indirect || unsafe.Sizeof(int(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(int(0)) > unsafe.Sizeof(uintptr(0)) { return int64(*(*int)(v.value)) } else { return int64(int(uintptr(v.value))) } case Int8: - if v.indirect { + if v.isIndirect() { return int64(*(*int8)(v.value)) } else { return int64(int8(uintptr(v.value))) } case Int16: - if v.indirect { + if v.isIndirect() { return int64(*(*int16)(v.value)) } else { return int64(int16(uintptr(v.value))) } case Int32: - if v.indirect || unsafe.Sizeof(int32(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(int32(0)) > unsafe.Sizeof(uintptr(0)) { return int64(*(*int32)(v.value)) } else { return int64(int32(uintptr(v.value))) } case Int64: - if v.indirect || unsafe.Sizeof(int64(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(int64(0)) > unsafe.Sizeof(uintptr(0)) { return int64(*(*int64)(v.value)) } else { return int64(int64(uintptr(v.value))) @@ -165,37 +184,37 @@ func (v Value) Int() int64 { func (v Value) Uint() uint64 { switch v.Kind() { case Uintptr: - if v.indirect { + if v.isIndirect() { return uint64(*(*uintptr)(v.value)) } else { return uint64(uintptr(v.value)) } case Uint8: - if v.indirect { + if v.isIndirect() { return uint64(*(*uint8)(v.value)) } else { return uint64(uintptr(v.value)) } case Uint16: - if v.indirect { + if v.isIndirect() { return uint64(*(*uint16)(v.value)) } else { return uint64(uintptr(v.value)) } case Uint: - if v.indirect || unsafe.Sizeof(uint(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(uint(0)) > unsafe.Sizeof(uintptr(0)) { return uint64(*(*uint)(v.value)) } else { return uint64(uintptr(v.value)) } case Uint32: - if v.indirect || unsafe.Sizeof(uint32(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(uint32(0)) > unsafe.Sizeof(uintptr(0)) { return uint64(*(*uint32)(v.value)) } else { return uint64(uintptr(v.value)) } case Uint64: - if v.indirect || unsafe.Sizeof(uint64(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(uint64(0)) > unsafe.Sizeof(uintptr(0)) { return uint64(*(*uint64)(v.value)) } else { return uint64(uintptr(v.value)) @@ -208,7 +227,7 @@ func (v Value) Uint() uint64 { func (v Value) Float() float64 { switch v.Kind() { case Float32: - if v.indirect || unsafe.Sizeof(float32(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(float32(0)) > unsafe.Sizeof(uintptr(0)) { // The float is stored as an external value on systems with 16-bit // pointers. return float64(*(*float32)(v.value)) @@ -218,7 +237,7 @@ func (v Value) Float() float64 { return float64(*(*float32)(unsafe.Pointer(&v.value))) } case Float64: - if v.indirect || unsafe.Sizeof(float64(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(float64(0)) > unsafe.Sizeof(uintptr(0)) { // For systems with 16-bit and 32-bit pointers. return *(*float64)(v.value) } else { @@ -234,7 +253,7 @@ func (v Value) Float() float64 { func (v Value) Complex() complex128 { switch v.Kind() { case Complex64: - if v.indirect || unsafe.Sizeof(complex64(0)) > unsafe.Sizeof(uintptr(0)) { + if v.isIndirect() || unsafe.Sizeof(complex64(0)) > unsafe.Sizeof(uintptr(0)) { // The complex number is stored as an external value on systems with // 16-bit and 32-bit pointers. return complex128(*(*complex64)(v.value)) @@ -295,15 +314,17 @@ func (v Value) Cap() int { } } +// NumField returns the number of fields of this struct. It panics for other +// value types. func (v Value) NumField() int { - panic("unimplemented: (reflect.Value).NumField()") + return v.Type().NumField() } func (v Value) Elem() Value { switch v.Kind() { case Ptr: ptr := v.value - if v.indirect { + if v.isIndirect() { ptr = *(*unsafe.Pointer)(ptr) } if ptr == nil { @@ -312,15 +333,77 @@ func (v Value) Elem() Value { return Value{ typecode: v.Type().Elem(), value: ptr, - indirect: true, + flags: v.flags | valueFlagIndirect, } default: // not implemented: Interface panic(&ValueError{"Elem"}) } } +// Field returns the value of the i'th field of this struct. func (v Value) Field(i int) Value { - panic("unimplemented: (reflect.Value).Field()") + structField := v.Type().Field(i) + flags := v.flags + if structField.PkgPath != "" { + // The fact that PkgPath is present means that this field is not + // exported. + flags &^= valueFlagExported + } + + size := v.Type().Size() + fieldSize := structField.Type.Size() + if v.isIndirect() || fieldSize > unsafe.Sizeof(uintptr(0)) { + // v.value was already a pointer to the value and it should stay that + // way. + return Value{ + flags: flags, + typecode: structField.Type, + value: unsafe.Pointer(uintptr(v.value) + structField.Offset), + } + } + + // The fieldSize is smaller than uintptr, which means that the value will + // have to be stored directly in the interface value. + + if fieldSize == 0 { + // The struct field is zero sized. + // This is a rare situation, but because it's undefined behavior + // to shift the size of the value (zeroing the value), handle this + // situation explicitly. + return Value{ + flags: flags, + typecode: structField.Type, + value: unsafe.Pointer(uintptr(0)), + } + } + + if size > unsafe.Sizeof(uintptr(0)) { + // The value was not stored in the interface before but will be + // afterwards, so load the value (from the correct offset) and return + // it. + ptr := unsafe.Pointer(uintptr(v.value) + structField.Offset) + loadedValue := uintptr(0) + shift := uintptr(0) + for i := uintptr(0); i < fieldSize; i++ { + loadedValue |= uintptr(*(*byte)(ptr)) << shift + shift += 8 + ptr = unsafe.Pointer(uintptr(ptr) + 1) + } + return Value{ + flags: 0, + typecode: structField.Type, + value: unsafe.Pointer(loadedValue), + } + } + + // The value was already stored directly in the interface and it still + // is. Cut out the part of the value that we need. + mask := ^uintptr(0) >> ((unsafe.Sizeof(uintptr(0)) - fieldSize) * 8) + return Value{ + flags: flags, + typecode: structField.Type, + value: unsafe.Pointer((uintptr(v.value) >> (structField.Offset * 8)) & mask), + } } func (v Value) Index(i int) Value { @@ -333,7 +416,7 @@ func (v Value) Index(i int) Value { } elem := Value{ typecode: v.Type().Elem(), - indirect: true, + flags: v.flags | valueFlagIndirect, } addr := uintptr(slice.Data) + elem.Type().Size()*uintptr(i) // pointer to new value elem.value = unsafe.Pointer(addr) @@ -385,15 +468,13 @@ func (it *MapIter) Next() bool { } func (v Value) Set(x Value) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() if !v.Type().AssignableTo(x.Type()) { panic("reflect: cannot set") } size := v.Type().Size() xptr := x.value - if size <= unsafe.Sizeof(uintptr(0)) && !x.indirect { + if size <= unsafe.Sizeof(uintptr(0)) && !x.isIndirect() { value := x.value xptr = unsafe.Pointer(&value) } @@ -401,9 +482,7 @@ func (v Value) Set(x Value) { } func (v Value) SetBool(x bool) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case Bool: *(*bool)(v.value) = x @@ -413,9 +492,7 @@ func (v Value) SetBool(x bool) { } func (v Value) SetInt(x int64) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case Int: *(*int)(v.value) = int(x) @@ -433,9 +510,7 @@ func (v Value) SetInt(x int64) { } func (v Value) SetUint(x uint64) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case Uint: *(*uint)(v.value) = uint(x) @@ -455,9 +530,7 @@ func (v Value) SetUint(x uint64) { } func (v Value) SetFloat(x float64) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case Float32: *(*float32)(v.value) = float32(x) @@ -469,9 +542,7 @@ func (v Value) SetFloat(x float64) { } func (v Value) SetComplex(x complex128) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case Complex64: *(*complex64)(v.value) = complex64(x) @@ -483,9 +554,7 @@ func (v Value) SetComplex(x complex128) { } func (v Value) SetString(x string) { - if !v.indirect { - panic("reflect: value is not addressable") - } + v.checkAddressable() switch v.Kind() { case String: *(*string)(v.value) = x @@ -494,6 +563,12 @@ func (v Value) SetString(x string) { } } +func (v Value) checkAddressable() { + if !v.isIndirect() { + panic("reflect: value is not addressable") + } +} + func MakeSlice(typ Type, len, cap int) Value { panic("unimplemented: reflect.MakeSlice()") } diff --git a/src/runtime/interface.go b/src/runtime/interface.go index daa9d781..c4032db5 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -50,10 +50,20 @@ type typecodeID struct { // * named type: the underlying type // * interface: null // * chan/pointer/slice: the element type - // * array/func/map/struct: TODO + // * struct: GEP of structField array (to typecode field) + // * array/func/map: TODO references *typecodeID } +// structField is used by the compiler to pass information to the interface +// lowering pass. It is not used in the final binary. +type structField struct { + typecode *typecodeID // type of this struct field + name *uint8 // pointer to char array + tag *uint8 // pointer to char array, or nil + embedded bool +} + // Pseudo type used before interface lowering. By using a struct instead of a // function call, this is simpler to reason about during init interpretation // than a function call. Also, by keeping the method set around it is easier to diff --git a/testdata/reflect.go b/testdata/reflect.go index 7c8a10d1..fc28b189 100644 --- a/testdata/reflect.go +++ b/testdata/reflect.go @@ -11,6 +11,17 @@ type ( myslice2 []myint mychan chan int myptr *int + point struct { + X int16 + Y int16 + } + mystruct struct { + n int `foo:"bar"` + some point + zero struct{} + buf []byte + Buf []byte + } ) func main() { @@ -86,6 +97,12 @@ func main() { // structs struct{}{}, struct{ error }{}, + struct { + a uint8 + b int16 + c int8 + }{42, 321, 123}, + mystruct{5, point{-5, 3}, struct{}{}, []byte{'G', 'o'}, []byte{'X'}}, } { showValue(reflect.ValueOf(v), "") } @@ -291,7 +308,14 @@ func showValue(rv reflect.Value, indent string) { showValue(rv.Index(i), indent+" ") } case reflect.Struct: - println(indent + " struct") + println(indent+" struct:", rt.NumField()) + for i := 0; i < rv.NumField(); i++ { + field := rt.Field(i) + println(indent+" field:", i, field.Name) + println(indent+" tag:", field.Tag) + println(indent+" embedded:", field.Anonymous) + showValue(rv.Field(i), indent+" ") + } default: println(indent + " unknown type kind!") } diff --git a/testdata/reflect.txt b/testdata/reflect.txt index f43a6224..9a3e3977 100644 --- a/testdata/reflect.txt +++ b/testdata/reflect.txt @@ -217,9 +217,82 @@ reflect type: map map nil: false reflect type: struct - struct + struct: 0 reflect type: struct - struct + struct: 1 + field: 0 error + tag: + embedded: true + reflect type: interface + interface + nil: true +reflect type: struct + struct: 3 + field: 0 a + tag: + embedded: false + reflect type: uint8 + uint: 42 + field: 1 b + tag: + embedded: false + reflect type: int16 + int: 321 + field: 2 c + tag: + embedded: false + reflect type: int8 + int: 123 +reflect type: struct + struct: 5 + field: 0 n + tag: foo:"bar" + embedded: false + reflect type: int + int: 5 + field: 1 some + tag: + embedded: false + reflect type: struct + struct: 2 + field: 0 X + tag: + embedded: false + reflect type: int16 + int: -5 + field: 1 Y + tag: + embedded: false + reflect type: int16 + int: 3 + field: 2 zero + tag: + embedded: false + reflect type: struct + struct: 0 + field: 3 buf + tag: + embedded: false + reflect type: slice + slice: uint8 2 2 + pointer: true + nil: false + indexing: 0 + reflect type: uint8 + uint: 71 + indexing: 1 + reflect type: uint8 + uint: 111 + field: 4 Buf + tag: + embedded: false + reflect type: slice + slice: uint8 1 1 + pointer: true + nil: false + indexing: 0 + reflect type: uint8 settable=true + uint: 88 sizes: int8 1 8