From d2b3a5486c2c7c902f1f021638f327571373c381 Mon Sep 17 00:00:00 2001 From: Ayke van Laethem Date: Thu, 11 Apr 2019 23:14:10 +0200 Subject: [PATCH] cgo: implement C unions Unions are somewhat hard to implement in Go because they are not a native type. But it is actually possible with some compiler magic. This commit inserts a special "C union" field at the start of a struct to indicate that it is a union. As such a field cannot be written directly in Go, this is a useful to distinguish structs and unions. --- compiler/compiler.go | 62 +++++++++++++++++++++++++++++++++++++++---- compiler/interface.go | 4 +++ compiler/sizes.go | 39 ++++++++++++++++++++++++--- loader/libclang.go | 50 +++++++++++++++++++++++++++------- testdata/cgo/main.c | 19 ++++++++++++- testdata/cgo/main.go | 25 ++++++++++++++++- testdata/cgo/main.h | 19 ++++++++++--- testdata/cgo/out.txt | 10 +++++++ 8 files changed, 206 insertions(+), 22 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index 46a991d6..d1fabd9a 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -546,6 +546,33 @@ func (c *Compiler) getLLVMType(goType types.Type) (llvm.Type, error) { } members[i] = member } + if len(members) > 2 && typ.Field(0).Name() == "C union" { + // Not a normal struct but a C union emitted by cgo. + // Such a field name cannot be entered in regular Go code, this must + // be manually inserted in the AST so this is safe. + maxAlign := 0 + maxSize := uint64(0) + mainType := members[0] + for _, member := range members { + align := c.targetData.ABITypeAlignment(member) + size := c.targetData.TypeAllocSize(member) + if align > maxAlign { + maxAlign = align + mainType = member + } else if align == maxAlign && size > maxSize { + maxAlign = align + maxSize = size + mainType = member + } else if size > maxSize { + maxSize = size + } + } + members = []llvm.Type{mainType} + mainTypeSize := c.targetData.TypeAllocSize(mainType) + if mainTypeSize < maxSize { + members = append(members, llvm.ArrayType(c.ctx.Int8Type(), int(maxSize-mainTypeSize))) + } + } return c.ctx.StructType(members, false), nil case *types.Tuple: members := make([]llvm.Type, typ.Len()) @@ -1592,6 +1619,19 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } + if s := expr.X.Type().Underlying().(*types.Struct); s.NumFields() > 2 && s.Field(0).Name() == "C union" { + // Extract a field from a CGo union. + // This could be done directly, but as this is a very infrequent + // operation it's much easier to bitcast it through an alloca. + resultType, err := c.getLLVMType(expr.Type()) + if err != nil { + return llvm.Value{}, err + } + alloca := c.builder.CreateAlloca(value.Type(), "") + c.builder.CreateStore(value, alloca) + bitcast := c.builder.CreateBitCast(alloca, llvm.PointerType(resultType, 0), "") + return c.builder.CreateLoad(bitcast, ""), nil + } result := c.builder.CreateExtractValue(value, expr.Field, "") return result, nil case *ssa.FieldAddr: @@ -1599,16 +1639,28 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) { if err != nil { return llvm.Value{}, err } - indices := []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - llvm.ConstInt(c.ctx.Int32Type(), uint64(expr.Field), false), - } // Check for nil pointer before calculating the address, from the spec: // > For an operand x of type T, the address operation &x generates a // > pointer of type *T to x. [...] If the evaluation of x would cause a // > run-time panic, then the evaluation of &x does too. c.emitNilCheck(frame, val, "gep") - return c.builder.CreateGEP(val, indices, ""), nil + if s := expr.X.Type().(*types.Pointer).Elem().Underlying().(*types.Struct); s.NumFields() > 2 && s.Field(0).Name() == "C union" { + // This is not a regular struct but actually an union. + // That simplifies things, as we can just bitcast the pointer to the + // right type. + ptrType, err := c.getLLVMType(expr.Type()) + if err != nil { + return llvm.Value{}, nil + } + return c.builder.CreateBitCast(val, ptrType, ""), nil + } else { + // Do a GEP on the pointer to get the field address. + indices := []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), uint64(expr.Field), false), + } + return c.builder.CreateGEP(val, indices, ""), nil + } case *ssa.Function: fn := c.ir.GetFunction(expr) if fn.IsExported() { diff --git a/compiler/interface.go b/compiler/interface.go index c6cbad5c..ad55945f 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -170,6 +170,10 @@ func getTypeCodeName(t types.Type) string { return "slice:" + name + getTypeCodeName(t.Elem()) case *types.Struct: elems := make([]string, t.NumFields()) + if t.NumFields() > 2 && t.Field(0).Name() == "C union" { + // TODO: report this as a normal error instead of panicking. + panic("cgo unions are not allowed in interfaces") + } for i := 0; i < t.NumFields(); i++ { elems[i] = getTypeCodeName(t.Field(i).Type()) } diff --git a/compiler/sizes.go b/compiler/sizes.go index 97f0c929..0d28712e 100644 --- a/compiler/sizes.go +++ b/compiler/sizes.go @@ -63,6 +63,12 @@ func (s *StdSizes) Alignof(T types.Type) int64 { func (s *StdSizes) Offsetsof(fields []*types.Var) []int64 { offsets := make([]int64, len(fields)) + if len(fields) > 1 && fields[0].Name() == "C union" { + // This struct contains the magic "C union" field which indicates that + // this is actually a union from CGo. + // All fields in the union start at 0 so return that. + return offsets // all fields are still set to 0 + } var o int64 for i, f := range fields { a := s.Alignof(f.Type()) @@ -125,11 +131,38 @@ func (s *StdSizes) Sizeof(T types.Type) int64 { return 0 } fields := make([]*types.Var, t.NumFields()) + maxAlign := int64(1) for i := range fields { - fields[i] = t.Field(i) + field := t.Field(i) + fields[i] = field + al := s.Alignof(field.Type()) + if al > maxAlign { + maxAlign = al + } + } + if fields[0].Name() == "C union" { + // Magic field that indicates this is a CGo union and not a struct. + // The size is the biggest element, aligned to the element with the + // biggest alignment. This is not necessarily the same, for example + // in the following union: + // union { int32_t l; int16_t s[3] } + maxSize := int64(0) + for _, field := range fields[1:] { + si := s.Sizeof(field.Type()) + if si > maxSize { + maxSize = si + } + } + return align(maxSize, maxAlign) + } else { + // This is a regular struct. + // Pick the size that fits this struct and add some alignment. Some + // structs have some extra padding at the end which should also be + // taken care of: + // struct { int32 n; byte b } + offsets := s.Offsetsof(fields) + return align(offsets[n-1]+s.Sizeof(fields[n-1].Type()), maxAlign) } - offsets := s.Offsetsof(fields) - return offsets[n-1] + s.Sizeof(fields[n-1].Type()) case *types.Interface: return s.PtrSize * 2 case *types.Pointer: diff --git a/loader/libclang.go b/loader/libclang.go index ea459085..a7495432 100644 --- a/loader/libclang.go +++ b/loader/libclang.go @@ -306,18 +306,50 @@ func (info *fileInfo) makeASTType(typ C.CXType) ast.Expr { return info.makeASTType(underlying) case C.CXType_Record: cursor := C.clang_getTypeDeclaration(typ) + fieldList := &ast.FieldList{ + Opening: info.importCPos, + Closing: info.importCPos, + } + ref := refMap.Put(struct { + fieldList *ast.FieldList + info *fileInfo + }{fieldList, info}) + defer refMap.Remove(ref) + C.clang_visitChildren(cursor, C.CXCursorVisitor(C.tinygo_clang_struct_visitor), C.CXClientData(uintptr(ref))) switch C.clang_getCursorKind(cursor) { case C.CXCursor_StructDecl: - fieldList := &ast.FieldList{ - Opening: info.importCPos, - Closing: info.importCPos, + return &ast.StructType{ + Struct: info.importCPos, + Fields: fieldList, + } + case C.CXCursor_UnionDecl: + if len(fieldList.List) > 1 { + // Insert a special field at the front (of zero width) as a + // marker that this is struct is actually a union. This is done + // by giving the field a name that cannot be expressed directly + // in Go. + // Other parts of the compiler look at the first element in a + // struct (of size > 2) to know whether this is a union. + // Note that we don't have to insert it for single-element + // unions as they're basically equivalent to a struct. + unionMarker := &ast.Field{ + Type: &ast.StructType{ + Struct: info.importCPos, + }, + } + unionMarker.Names = []*ast.Ident{ + &ast.Ident{ + NamePos: info.importCPos, + Name: "C union", + Obj: &ast.Object{ + Kind: ast.Var, + Name: "C union", + Decl: unionMarker, + }, + }, + } + fieldList.List = append([]*ast.Field{unionMarker}, fieldList.List...) } - ref := refMap.Put(struct { - fieldList *ast.FieldList - info *fileInfo - }{fieldList, info}) - defer refMap.Remove(ref) - C.clang_visitChildren(cursor, C.CXCursorVisitor(C.tinygo_clang_struct_visitor), C.CXClientData(uintptr(ref))) return &ast.StructType{ Struct: info.importCPos, Fields: fieldList, diff --git a/testdata/cgo/main.c b/testdata/cgo/main.c index 78a065f9..cec89e47 100644 --- a/testdata/cgo/main.c +++ b/testdata/cgo/main.c @@ -8,8 +8,11 @@ double globalDouble = 3.2; _Complex float globalComplexFloat = 4.1+3.3i; _Complex double globalComplexDouble = 4.2+3.4i; _Complex double globalComplexLongDouble = 4.3+3.5i; -collection_t globalStruct = {256, -123456, 3.14}; +collection_t globalStruct = {256, -123456, 3.14, 88}; +int globalStructSize = sizeof(globalStruct); short globalArray[3] = {5, 6, 7}; +joined_t globalUnion; +int globalUnionSize = sizeof(globalUnion); int fortytwo() { return 42; @@ -26,3 +29,17 @@ int doCallback(int a, int b, binop_t callback) { void store(int value, int *ptr) { *ptr = value; } + +void unionSetShort(short s) { + globalUnion.s = s; +} + +void unionSetFloat(float f) { + globalUnion.f = f; +} + +void unionSetData(short f0, short f1, short f2) { + globalUnion.data[0] = 5; + globalUnion.data[1] = 8; + globalUnion.data[2] = 1; +} diff --git a/testdata/cgo/main.go b/testdata/cgo/main.go index 327d38dd..3f1da8b6 100644 --- a/testdata/cgo/main.go +++ b/testdata/cgo/main.go @@ -9,6 +9,10 @@ import "C" import "unsafe" +func (s C.myint) Int() int { + return int(s) +} + func main() { println("fortytwo:", C.fortytwo()) println("add:", C.add(C.int(3), 5)) @@ -36,9 +40,28 @@ func main() { println("complex float:", C.globalComplexFloat) println("complex double:", C.globalComplexDouble) println("complex long double:", C.globalComplexLongDouble) - println("struct:", C.globalStruct.s, C.globalStruct.l, C.globalStruct.f) + + // complex types + println("struct:", C.int(unsafe.Sizeof(C.globalStruct)) == C.globalStructSize, C.globalStruct.s, C.globalStruct.l, C.globalStruct.f) var _ [3]C.short = C.globalArray println("array:", C.globalArray[0], C.globalArray[1], C.globalArray[2]) + println("union:", C.int(unsafe.Sizeof(C.globalUnion)) == C.globalUnionSize) + C.unionSetShort(22) + println("union s:", C.globalUnion.s) + C.unionSetFloat(3.14) + println("union f:", C.globalUnion.f) + C.unionSetData(5, 8, 1) + println("union global data:", C.globalUnion.data[0], C.globalUnion.data[1], C.globalUnion.data[2]) + println("union field:", printUnion(C.globalUnion).f) +} + +func printUnion(union C.joined_t) C.joined_t { + println("union local data: ", union.data[0], union.data[1], union.data[2]) + union.s = -33 + println("union s method:", union.s.Int(), union.data[0] == 5) + union.f = 6.28 + println("union f:", union.f) + return union } //export mul diff --git a/testdata/cgo/main.h b/testdata/cgo/main.h index 61b2b343..6330b223 100644 --- a/testdata/cgo/main.h +++ b/testdata/cgo/main.h @@ -6,11 +6,21 @@ typedef int * intPointer; void store(int value, int *ptr); typedef struct collection { - short s; - long l; - float f; + short s; + long l; + float f; + unsigned char c; } collection_t; +typedef union joined { + myint s; + float f; + short data[3]; +} joined_t; +void unionSetShort(short s); +void unionSetFloat(float f); +void unionSetData(short f0, short f1, short f2); + // test globals extern int global; extern _Bool globalBool; @@ -21,7 +31,10 @@ extern _Complex float globalComplexFloat; extern _Complex double globalComplexDouble; extern _Complex double globalComplexLongDouble; extern collection_t globalStruct; +extern int globalStructSize; extern short globalArray[3]; +extern joined_t globalUnion; +extern int globalUnionSize; // test duplicate definitions int add(int a, int b); diff --git a/testdata/cgo/out.txt b/testdata/cgo/out.txt index 79f8cc63..6fd5ee86 100644 --- a/testdata/cgo/out.txt +++ b/testdata/cgo/out.txt @@ -14,3 +14,13 @@ double: +3.200000e+000 complex float: (+4.100000e+000+3.300000e+000i) complex double: (+4.200000e+000+3.400000e+000i) complex long double: (+4.300000e+000+3.500000e+000i) +struct: true 256 -123456 +3.140000e+000 +array: 5 6 7 +union: true +union s: 22 +union f: +3.140000e+000 +union global data: 5 8 1 +union local data: 5 8 1 +union s method: -33 false +union f: +6.280000e+000 +union field: +6.280000e+000