Commit 5bc877017e

scurest <scurest@users.noreply.github.com>
2017-06-17 18:30:29
use most_aligned_member+padding to represent enum unions
1 parent e726925
Changed files (4)
src/all_types.hpp
@@ -977,7 +977,7 @@ struct TypeTableEntryEnum {
     TypeEnumField *fields;
     bool is_invalid; // true if any fields are invalid
     TypeTableEntry *tag_type;
-    TypeTableEntry *union_type;
+    LLVMTypeRef union_type_ref;
 
     ScopeDecls *decls_scope;
 
@@ -1633,7 +1633,7 @@ struct ScopeDecls {
 struct ScopeBlock {
     Scope base;
 
-    HashMap<Buf *, LabelTableEntry *, buf_hash, buf_eql_buf> label_table; 
+    HashMap<Buf *, LabelTableEntry *, buf_hash, buf_eql_buf> label_table;
     bool safety_off;
     AstNode *safety_set_node;
     bool fast_math_off;
src/analyze.cpp
@@ -1245,9 +1245,10 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
     uint32_t gen_field_count = enum_type->data.enumeration.gen_field_count;
     ZigLLVMDIType **union_inner_di_types = allocate<ZigLLVMDIType*>(gen_field_count);
 
-    TypeTableEntry *biggest_union_member = nullptr;
+    TypeTableEntry *most_aligned_union_member = nullptr;
+    uint64_t size_of_most_aligned_member_in_bits = 0;
     uint64_t biggest_align_in_bits = 0;
-    uint64_t biggest_union_member_size_in_bits = 0;
+    uint64_t biggest_size_in_bits = 0;
 
     Scope *scope = &enum_type->data.enumeration.decls_scope->base;
     ImportTableEntry *import = get_scope_import(scope);
@@ -1272,7 +1273,7 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
             continue;
 
         uint64_t debug_size_in_bits = 8*LLVMStoreSizeOfType(g->target_data_ref, field_type->type_ref);
-        uint64_t debug_align_in_bits = 8*LLVMABISizeOfType(g->target_data_ref, field_type->type_ref);
+        uint64_t debug_align_in_bits = 8*LLVMABIAlignmentOfType(g->target_data_ref, field_type->type_ref);
 
         assert(debug_size_in_bits > 0);
         assert(debug_align_in_bits > 0);
@@ -1285,13 +1286,14 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
                 0,
                 0, field_type->di_type);
 
-        biggest_align_in_bits = max(biggest_align_in_bits, debug_align_in_bits);
+        biggest_size_in_bits = max(biggest_size_in_bits, debug_size_in_bits);
 
-        if (!biggest_union_member ||
-            debug_size_in_bits > biggest_union_member_size_in_bits)
+        if (!most_aligned_union_member ||
+            debug_align_in_bits > biggest_align_in_bits)
         {
-            biggest_union_member = field_type;
-            biggest_union_member_size_in_bits = debug_size_in_bits;
+            most_aligned_union_member = field_type;
+            biggest_align_in_bits = debug_align_in_bits;
+            size_of_most_aligned_member_in_bits = debug_size_in_bits;
         }
     }
 
@@ -1300,16 +1302,34 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
     enum_type->data.enumeration.complete = true;
 
     if (!enum_type->data.enumeration.is_invalid) {
-        enum_type->data.enumeration.union_type = biggest_union_member;
-
         TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, field_count);
         TypeTableEntry *tag_type_entry = create_enum_tag_type(g, enum_type, tag_int_type);
         enum_type->data.enumeration.tag_type = tag_type_entry;
 
-        if (biggest_union_member) {
+        if (most_aligned_union_member) {
             // create llvm type for union
-            LLVMTypeRef union_element_type = biggest_union_member->type_ref;
-            LLVMTypeRef union_type_ref = LLVMStructType(&union_element_type, 1, false);
+            uint64_t padding_in_bits = biggest_size_in_bits - size_of_most_aligned_member_in_bits;
+            LLVMTypeRef union_type_ref;
+            if (padding_in_bits > 0) {
+                TypeTableEntry *u8_type = get_int_type(g, false, 8);
+                TypeTableEntry *padding_array = get_array_type(g, u8_type, padding_in_bits / 8);
+                LLVMTypeRef union_element_types[] = {
+                    most_aligned_union_member->type_ref,
+                    padding_array->type_ref,
+                };
+                union_type_ref = LLVMStructType(union_element_types, 2, false);
+            } else {
+                LLVMTypeRef union_element_types[] = {
+                    most_aligned_union_member->type_ref,
+                };
+                union_type_ref = LLVMStructType(union_element_types, 1, false);
+            }
+            enum_type->data.enumeration.union_type_ref = union_type_ref;
+
+            assert(8*LLVMABIAlignmentOfType(g->target_data_ref, union_type_ref) >=
+                    biggest_align_in_bits);
+            assert(8*LLVMABISizeOfType(g->target_data_ref, union_type_ref) >=
+                    biggest_size_in_bits);
 
             // create llvm type for root struct
             LLVMTypeRef root_struct_element_types[] = {
@@ -1331,7 +1351,7 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
             ZigLLVMDIType *union_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder,
                     ZigLLVMTypeToScope(enum_type->di_type), "AnonUnion",
                     import->di_file, (unsigned)(decl_node->line + 1),
-                    biggest_union_member_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types,
+                    biggest_size_in_bits, biggest_align_in_bits, 0, union_inner_di_types,
                     gen_field_count, 0, "");
 
             // create debug types for members of root struct
@@ -1348,7 +1368,7 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
             ZigLLVMDIType *union_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder,
                     ZigLLVMTypeToScope(enum_type->di_type), "union_field",
                     import->di_file, (unsigned)(decl_node->line + 1),
-                    biggest_union_member_size_in_bits,
+                    biggest_size_in_bits,
                     biggest_align_in_bits,
                     union_offset_in_bits,
                     0, union_di_type);
@@ -2541,7 +2561,7 @@ bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTableEntry *
         if (expected_type->data.fn.is_generic != actual_type->data.fn.is_generic) {
             return false;
         }
-        if (!expected_type->data.fn.is_generic && 
+        if (!expected_type->data.fn.is_generic &&
             actual_type->data.fn.fn_type_id.return_type->id != TypeTableEntryIdUnreachable &&
             !types_match_const_cast_only(
                 expected_type->data.fn.fn_type_id.return_type,
src/codegen.cpp
@@ -3663,13 +3663,13 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
                 if (type_entry->data.enumeration.gen_field_count == 0) {
                     return tag_value;
                 } else {
-                    TypeTableEntry *union_type = type_entry->data.enumeration.union_type;
+                    LLVMTypeRef union_type_ref = type_entry->data.enumeration.union_type_ref;
                     TypeEnumField *enum_field = &type_entry->data.enumeration.fields[const_val->data.x_enum.tag];
                     assert(enum_field->value == const_val->data.x_enum.tag);
                     LLVMValueRef union_value;
                     if (type_has_bits(enum_field->type_entry)) {
                         uint64_t union_type_bytes = LLVMStoreSizeOfType(g->target_data_ref,
-                                union_type->type_ref);
+                                union_type_ref);
                         uint64_t field_type_bytes = LLVMStoreSizeOfType(g->target_data_ref,
                                 enum_field->type_entry->type_ref);
                         uint64_t pad_bytes = union_type_bytes - field_type_bytes;
@@ -3685,7 +3685,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
                             union_value = LLVMConstStruct(fields, 2, false);
                         }
                     } else {
-                        union_value = LLVMGetUndef(union_type->type_ref);
+                        union_value = LLVMGetUndef(union_type_ref);
                     }
                     LLVMValueRef fields[] = {
                         tag_value,
test/cases/enum.zig
@@ -120,3 +120,14 @@ const BareNumber = enum {
     Two,
     Three,
 };
+
+
+test "enum alignment" {
+    comptime assert(@alignOf(AlignTestEnum) >= @alignOf([9]u8));
+    comptime assert(@alignOf(AlignTestEnum) >= @alignOf(u64));
+}
+
+const AlignTestEnum = enum {
+    A: [9]u8,
+    B: u64,
+};