Commit 018cbff438

Andrew Kelley <superjoe30@gmail.com>
2017-11-16 04:52:47
unions have a secret field for the type
See #144
1 parent f276fd0
Changed files (4)
src/all_types.hpp
@@ -1037,6 +1037,9 @@ struct TypeTableEntryEnumTag {
     LLVMValueRef name_table;
 };
 
+uint32_t type_ptr_hash(const TypeTableEntry *ptr);
+bool type_ptr_eql(const TypeTableEntry *a, const TypeTableEntry *b);
+
 struct TypeTableEntryUnion {
     AstNode *decl_node;
     ContainerLayout layout;
@@ -1044,6 +1047,8 @@ struct TypeTableEntryUnion {
     uint32_t gen_field_count;
     TypeUnionField *fields;
     bool is_invalid; // true if any fields are invalid
+    TypeTableEntry *tag_type;
+    LLVMTypeRef union_type_ref;
 
     ScopeDecls *decls_scope;
 
@@ -1057,8 +1062,13 @@ struct TypeTableEntryUnion {
     bool zero_bits_known;
     uint32_t abi_alignment; // also figured out with zero_bits pass
 
-    uint32_t size_bytes;
+    size_t gen_union_index;
+    size_t gen_tag_index;
+
+    uint32_t union_size_bytes;
     TypeTableEntry *most_aligned_union_member;
+
+    HashMap<const TypeTableEntry *, uint32_t, type_ptr_hash, type_ptr_eql> distinct_types = {};
 };
 
 struct FnGenParamInfo {
src/analyze.cpp
@@ -992,26 +992,23 @@ TypeTableEntry *get_partial_container_type(CodeGen *g, Scope *scope, ContainerKi
     TypeTableEntryId type_id = container_to_type(kind);
     TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, scope);
 
-    unsigned dwarf_kind;
     switch (kind) {
         case ContainerKindStruct:
             entry->data.structure.decl_node = decl_node;
             entry->data.structure.layout = layout;
-            dwarf_kind = ZigLLVMTag_DW_structure_type();
             break;
         case ContainerKindEnum:
             entry->data.enumeration.decl_node = decl_node;
             entry->data.enumeration.layout = layout;
-            dwarf_kind = ZigLLVMTag_DW_structure_type();
             break;
         case ContainerKindUnion:
             entry->data.unionation.decl_node = decl_node;
             entry->data.unionation.layout = layout;
-            dwarf_kind = ZigLLVMTag_DW_union_type();
             break;
     }
 
     size_t line = decl_node ? decl_node->line : 0;
+    unsigned dwarf_kind = ZigLLVMTag_DW_structure_type();
 
     ImportTableEntry *import = get_scope_import(scope);
     entry->type_ref = LLVMStructCreateNamed(LLVMGetGlobalContext(), name);
@@ -1873,6 +1870,11 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) {
     uint64_t biggest_align_in_bits = 0;
     uint64_t biggest_size_in_bits = 0;
 
+    bool auto_layout = (union_type->data.unionation.layout == ContainerLayoutAuto);
+    ZigLLVMDIEnumerator **di_enumerators = allocate<ZigLLVMDIEnumerator*>(field_count);
+    auto distinct_types = &union_type->data.unionation.distinct_types;
+    distinct_types->init(4);
+
     Scope *scope = &union_type->data.unionation.decls_scope->base;
     ImportTableEntry *import = get_scope_import(scope);
 
@@ -1893,6 +1895,11 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) {
         if (!type_has_bits(field_type))
             continue;
 
+        size_t distinct_type_index = distinct_types->size();
+        if (distinct_types->put_unique(field_type, distinct_type_index) == nullptr) {
+            di_enumerators[i] = ZigLLVMCreateDebugEnumerator(g->dbuilder, buf_ptr(&field_type->name), distinct_type_index);
+        }
+
         uint64_t store_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref);
         uint64_t abi_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref);
 
@@ -1919,7 +1926,7 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) {
     // unset temporary flag
     union_type->data.unionation.embedded_in_current = false;
     union_type->data.unionation.complete = true;
-    union_type->data.unionation.size_bytes = biggest_size_in_bits / 8;
+    union_type->data.unionation.union_size_bytes = biggest_size_in_bits / 8;
     union_type->data.unionation.most_aligned_union_member = most_aligned_union_member;
 
     if (union_type->data.unionation.is_invalid)
@@ -1947,8 +1954,42 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) {
 
     assert(most_aligned_union_member != nullptr);
 
-    // create llvm type for union
+    bool want_safety = (distinct_types->size() > 1) && auto_layout;
     uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits;
+
+
+    if (!want_safety) {
+        if (padding_in_bits > 0) {
+            TypeTableEntry *u8_type = get_int_type(g, false, 8);
+            TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8);
+            LLVMTypeRef union_element_types[] = {
+                most_aligned_union_member->type_ref,
+                padding_array->type_ref,
+            };
+            LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false);
+        } else {
+            LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false);
+        }
+        union_type->data.unionation.union_type_ref = union_type->type_ref;
+        union_type->data.unionation.gen_tag_index = SIZE_MAX;
+        union_type->data.unionation.gen_union_index = SIZE_MAX;
+
+        assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits);
+        assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits);
+
+        // create debug type for union
+        ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder,
+            ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name),
+            import->di_file, (unsigned)(decl_node->line + 1),
+            biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types,
+            gen_field_count, 0, "");
+
+        ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type);
+        union_type->di_type = replacement_di_type;
+        return;
+    }
+
+    LLVMTypeRef union_type_ref;
     if (padding_in_bits > 0) {
         TypeTableEntry *u8_type = get_int_type(g, false, 8);
         TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8);
@@ -1956,20 +1997,87 @@ static void resolve_union_type(CodeGen *g, TypeTableEntry *union_type) {
             most_aligned_union_member->type_ref,
             padding_array->type_ref,
         };
-        LLVMStructSetBody(union_type->type_ref, union_element_types, 2, false);
+        union_type_ref = LLVMStructType(union_element_types, 2, false);
+    } else {
+        union_type_ref = most_aligned_union_member->type_ref;
+    }
+    union_type->data.unionation.union_type_ref = union_type_ref;
+
+    assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type_ref) >= biggest_align_in_bits);
+    assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type_ref) >= biggest_size_in_bits);
+
+    // create llvm type for root struct
+    TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, distinct_types->size() - 1);
+    TypeTableEntry *tag_type_entry = tag_int_type;
+    union_type->data.unionation.tag_type = tag_type_entry;
+    uint64_t align_of_tag_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, tag_int_type->type_ref);
+
+    if (align_of_tag_in_bits >= biggest_align_in_bits) {
+        union_type->data.unionation.gen_tag_index = 0;
+        union_type->data.unionation.gen_union_index = 1;
     } else {
-        LLVMStructSetBody(union_type->type_ref, &most_aligned_union_member->type_ref, 1, false);
+        union_type->data.unionation.gen_union_index = 0;
+        union_type->data.unionation.gen_tag_index = 1;
     }
 
-    assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type->type_ref) >= biggest_align_in_bits);
-    assert(8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref) >= biggest_size_in_bits);
+    LLVMTypeRef root_struct_element_types[2];
+    root_struct_element_types[union_type->data.unionation.gen_tag_index] = tag_type_entry->type_ref;
+    root_struct_element_types[union_type->data.unionation.gen_union_index] = union_type_ref;
+    LLVMStructSetBody(union_type->type_ref, root_struct_element_types, 2, false);
+
+
+    // create debug type for root struct
+
+    // create debug type for tag
+    uint64_t tag_debug_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, tag_type_entry->type_ref);
+    uint64_t tag_debug_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, tag_type_entry->type_ref);
+    ZigLLVMDIType *tag_di_type = ZigLLVMCreateDebugEnumerationType(g->dbuilder,
+            ZigLLVMTypeToScope(union_type->di_type), "AnonEnum",
+            import->di_file, (unsigned)(decl_node->line + 1),
+            tag_debug_size_in_bits, tag_debug_align_in_bits, di_enumerators, distinct_types->size(),
+            tag_type_entry->di_type, "");
 
     // create debug type for union
-    ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder,
-            ZigLLVMFileToScope(import->di_file), buf_ptr(&union_type->name),
+    ZigLLVMDIType *union_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder,
+            ZigLLVMTypeToScope(union_type->di_type), "AnonUnion",
             import->di_file, (unsigned)(decl_node->line + 1),
             biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types,
             gen_field_count, 0, "");
+
+    uint64_t union_offset_in_bits = 8*LLVMOffsetOfElement(g->target_data_ref, union_type->type_ref,
+            union_type->data.unionation.gen_union_index);
+    uint64_t tag_offset_in_bits = 8*LLVMOffsetOfElement(g->target_data_ref, union_type->type_ref,
+            union_type->data.unionation.gen_tag_index);
+
+    ZigLLVMDIType *union_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder,
+            ZigLLVMTypeToScope(union_type->di_type), "union_field",
+            import->di_file, (unsigned)(decl_node->line + 1),
+            biggest_size_in_bits,
+            biggest_align_in_bits,
+            union_offset_in_bits,
+            0, union_di_type);
+    ZigLLVMDIType *tag_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder,
+            ZigLLVMTypeToScope(union_type->di_type), "tag_field",
+            import->di_file, (unsigned)(decl_node->line + 1),
+            tag_debug_size_in_bits,
+            tag_debug_align_in_bits,
+            tag_offset_in_bits,
+            0, tag_di_type);
+
+    ZigLLVMDIType *di_root_members[2];
+    di_root_members[union_type->data.unionation.gen_tag_index] = tag_member_di_type;
+    di_root_members[union_type->data.unionation.gen_union_index] = union_member_di_type;
+
+    uint64_t debug_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, union_type->type_ref);
+    uint64_t debug_align_in_bits = 8*LLVMABISizeOfType(g->target_data_ref, union_type->type_ref);
+    ZigLLVMDIType *replacement_di_type = ZigLLVMCreateDebugStructType(g->dbuilder,
+            ZigLLVMFileToScope(import->di_file),
+            buf_ptr(&union_type->name),
+            import->di_file, (unsigned)(decl_node->line + 1),
+            debug_size_in_bits,
+            debug_align_in_bits,
+            0, nullptr, di_root_members, 2, 0, nullptr, "");
+
     ZigLLVMReplaceTemporary(g->dbuilder, union_type->di_type, replacement_di_type);
     union_type->di_type = replacement_di_type;
 }
@@ -5140,3 +5248,11 @@ TypeTableEntry *get_align_amt_type(CodeGen *g) {
     }
     return g->align_amt_type;
 }
+
+uint32_t type_ptr_hash(const TypeTableEntry *ptr) {
+    return hash_ptr((void*)ptr);
+}
+
+bool type_ptr_eql(const TypeTableEntry *a, const TypeTableEntry *b) {
+    return a == b;
+}
src/codegen.cpp
@@ -2408,9 +2408,15 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab
 
     LLVMValueRef union_ptr = ir_llvm_value(g, instruction->union_ptr);
     LLVMTypeRef field_type_ref = LLVMPointerType(field->type_entry->type_ref, 0);
-    LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, "");
-    LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, "");
 
+    if (union_type->data.unionation.gen_tag_index == SIZE_MAX) {
+        LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, 0, "");
+        LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, "");
+        return bitcasted_union_field_ptr;
+    }
+
+    LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, "");
+    LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, "");
     return bitcasted_union_field_ptr;
 }
 
@@ -3955,7 +3961,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
             }
         case TypeTableEntryIdUnion:
             {
-                LLVMTypeRef union_type_ref = type_entry->type_ref;
+                LLVMTypeRef union_type_ref = type_entry->data.unionation.union_type_ref;
                 ConstExprValue *payload_value = const_val->data.x_union.value;
                 assert(payload_value != nullptr);
 
@@ -3964,29 +3970,48 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
                 }
 
                 uint64_t field_type_bytes = LLVMStoreSizeOfType(g->target_data_ref, payload_value->type->type_ref);
-                uint64_t pad_bytes = type_entry->data.unionation.size_bytes - field_type_bytes;
-
+                uint64_t pad_bytes = type_entry->data.unionation.union_size_bytes - field_type_bytes;
                 LLVMValueRef correctly_typed_value = gen_const_val(g, payload_value);
-
                 bool make_unnamed_struct = is_llvm_value_unnamed_type(payload_value->type, correctly_typed_value) ||
                     payload_value->type != type_entry->data.unionation.most_aligned_union_member;
 
-                unsigned field_count;
-                LLVMValueRef fields[2];
-                fields[0] = correctly_typed_value;
-                if (pad_bytes == 0) {
-                    field_count = 1;
-                } else {
+                LLVMValueRef union_value_ref;
+                {
+                    unsigned field_count;
+                    LLVMValueRef fields[2];
                     fields[0] = correctly_typed_value;
-                    fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes));
-                    field_count = 2;
+                    if (pad_bytes == 0) {
+                        field_count = 1;
+                    } else {
+                        fields[0] = correctly_typed_value;
+                        fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes));
+                        field_count = 2;
+                    }
+
+                    if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) {
+                        union_value_ref = LLVMConstStruct(fields, field_count, false);
+                    } else {
+                        union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count);
+                    }
+                }
+
+                if (type_entry->data.unionation.gen_tag_index == SIZE_MAX) {
+                    return union_value_ref;
                 }
 
+                size_t distinct_type_index = type_entry->data.unionation.distinct_types.get(const_val->data.x_union.value->type);
+                LLVMValueRef tag_value = LLVMConstInt(type_entry->data.unionation.tag_type->type_ref, distinct_type_index, false);
+
+                LLVMValueRef fields[2];
+                fields[type_entry->data.unionation.gen_union_index] = union_value_ref;
+                fields[type_entry->data.unionation.gen_tag_index] = tag_value;
+
                 if (make_unnamed_struct) {
-                    return LLVMConstStruct(fields, field_count, false);
+                    return LLVMConstStruct(fields, 2, false);
                 } else {
-                    return LLVMConstNamedStruct(type_entry->type_ref, fields, field_count);
+                    return LLVMConstNamedStruct(type_entry->type_ref, fields, 2);
                 }
+
             }
         case TypeTableEntryIdEnum:
             {
test/cases/union.zig
@@ -44,3 +44,16 @@ test "basic unions" {
     foo.float = 12.34;
     assert(foo.float == 12.34);
 }
+
+
+const FooExtern = extern union {
+    float: f64,
+    int: i32,
+};
+
+test "basic extern unions" {
+    var foo = FooExtern { .int = 1 };
+    assert(foo.int == 1);
+    foo.float = 12.34;
+    assert(foo.float == 12.34);
+}