Commit fa6c20a02d

Andrew Kelley <andrew@ziglang.org>
2019-08-25 17:34:07
hook up unions with lazy values
this case works now: ```zig const Expr = union(enum) { Literal: u8, Question: *Expr, }; ```
1 parent d277a11
src/all_types.hpp
@@ -489,10 +489,12 @@ struct TypeEnumField {
 
 struct TypeUnionField {
     Buf *name;
+    ZigType *type_entry; // available after ResolveStatusSizeKnown
+    ConstExprValue *type_val; // available after ResolveStatusZeroBitsKnown
     TypeEnumField *enum_field;
-    ZigType *type_entry;
     AstNode *decl_node;
     uint32_t gen_index;
+    uint32_t align;
 };
 
 enum NodeType {
@@ -1247,7 +1249,7 @@ struct ZigTypeUnion {
     HashMap<Buf *, TypeUnionField *, buf_hash, buf_eql_buf> fields_by_name;
     ZigType *tag_type; // always an enum or null
     LLVMTypeRef union_llvm_type;
-    ZigType *most_aligned_union_member;
+    TypeUnionField *most_aligned_union_member;
     size_t gen_union_index;
     size_t gen_tag_index;
     size_t union_abi_size;
@@ -1262,7 +1264,8 @@ struct ZigTypeUnion {
     // whether any of the fields require comptime
     // the value is not valid until zero_bits_known == true
     bool requires_comptime;
-    bool resolve_loop_flag;
+    bool resolve_loop_flag_zero_bits;
+    bool resolve_loop_flag_other;
 };
 
 struct FnGenParamInfo {
src/analyze.cpp
@@ -961,10 +961,12 @@ static Error type_val_resolve_zero_bits(CodeGen *g, ConstExprValue *type_val, Zi
     Error err;
     if (type_val->special != ConstValSpecialLazy) {
         assert(type_val->special == ConstValSpecialStatic);
-        if (type_val->data.x_type->id == ZigTypeIdStruct &&
-            type_val->data.x_type->data.structure.resolve_loop_flag_zero_bits)
+        if ((type_val->data.x_type->id == ZigTypeIdStruct &&
+            type_val->data.x_type->data.structure.resolve_loop_flag_zero_bits) ||
+            (type_val->data.x_type->id == ZigTypeIdUnion &&
+             type_val->data.x_type->data.unionation.resolve_loop_flag_zero_bits))
         {
-            // Does a struct which contains a pointer field to itself have bits? Yes.
+            // Does a struct/union which contains a pointer field to itself have bits? Yes.
             *is_zero_bits = false;
             return ErrorNone;
         }
@@ -1079,7 +1081,7 @@ static ReqCompTime type_val_resolve_requires_comptime(CodeGen *g, ConstExprValue
     zig_unreachable();
 }
 
-static Error type_val_resolve_abi_align(CodeGen *g, ConstExprValue *type_val, size_t *abi_align) {
+static Error type_val_resolve_abi_align(CodeGen *g, ConstExprValue *type_val, uint32_t *abi_align) {
     Error err;
     if (type_val->special != ConstValSpecialLazy) {
         assert(type_val->special == ConstValSpecialStatic);
@@ -1917,7 +1919,7 @@ static Error resolve_union_alignment(CodeGen *g, ZigType *union_type) {
 
     AstNode *decl_node = union_type->data.structure.decl_node;
 
-    if (union_type->data.unionation.resolve_loop_flag) {
+    if (union_type->data.unionation.resolve_loop_flag_other) {
         if (union_type->data.unionation.resolve_status != ResolveStatusInvalid) {
             union_type->data.unionation.resolve_status = ResolveStatusInvalid;
             g->trace_err = add_node_error(g, decl_node,
@@ -1927,9 +1929,9 @@ static Error resolve_union_alignment(CodeGen *g, ZigType *union_type) {
     }
 
     // set temporary flag
-    union_type->data.unionation.resolve_loop_flag = true;
+    union_type->data.unionation.resolve_loop_flag_other = true;
 
-    ZigType *most_aligned_union_member = nullptr;
+    TypeUnionField *most_aligned_union_member = nullptr;
     uint32_t field_count = union_type->data.unionation.src_field_count;
     bool packed = union_type->data.unionation.layout == ContainerLayoutPacked;
 
@@ -1938,33 +1940,32 @@ static Error resolve_union_alignment(CodeGen *g, ZigType *union_type) {
         if (field->gen_index == UINT32_MAX)
             continue;
 
-        src_assert(field->type_entry != nullptr, decl_node);
-
-        size_t this_field_align;
-        if (packed) {
-            // TODO: https://github.com/ziglang/zig/issues/1512
-            this_field_align = 1;
+        AstNode *align_expr = field->decl_node->data.struct_field.align_expr;
+        if (align_expr != nullptr) {
+            if (!analyze_const_align(g, &union_type->data.unionation.decls_scope->base, align_expr,
+                        &field->align))
+            {
+                union_type->data.unionation.resolve_status = ResolveStatusInvalid;
+                return err;
+            }
+        } else if (packed) {
+            field->align = 1;
         } else {
-            if ((err = type_resolve(g, field->type_entry, ResolveStatusAlignmentKnown))) {
+            if ((err = type_val_resolve_abi_align(g, field->type_val, &field->align))) {
                 union_type->data.unionation.resolve_status = ResolveStatusInvalid;
-                return ErrorSemanticAnalyzeFail;
+                return err;
             }
-
             if (union_type->data.unionation.resolve_status == ResolveStatusInvalid)
                 return ErrorSemanticAnalyzeFail;
-
-            this_field_align = field->type_entry->abi_align;
         }
 
-        if (most_aligned_union_member == nullptr ||
-            this_field_align > most_aligned_union_member->abi_align)
-        {
-            most_aligned_union_member = field->type_entry;
+        if (most_aligned_union_member == nullptr || field->align > most_aligned_union_member->align) {
+            most_aligned_union_member = field;
         }
     }
 
     // unset temporary flag
-    union_type->data.unionation.resolve_loop_flag = false;
+    union_type->data.unionation.resolve_loop_flag_other = false;
     union_type->data.unionation.resolve_status = ResolveStatusAlignmentKnown;
     union_type->data.unionation.most_aligned_union_member = most_aligned_union_member;
 
@@ -1978,18 +1979,18 @@ static Error resolve_union_alignment(CodeGen *g, ZigType *union_type) {
             union_type->abi_align = tag_type->abi_align;
             union_type->data.unionation.gen_tag_index = SIZE_MAX;
             union_type->data.unionation.gen_union_index = SIZE_MAX;
-        } else if (tag_type->abi_align > most_aligned_union_member->abi_align) {
+        } else if (tag_type->abi_align > most_aligned_union_member->align) {
             union_type->abi_align = tag_type->abi_align;
             union_type->data.unionation.gen_tag_index = 0;
             union_type->data.unionation.gen_union_index = 1;
         } else {
-            union_type->abi_align = most_aligned_union_member->abi_align;
+            union_type->abi_align = most_aligned_union_member->align;
             union_type->data.unionation.gen_union_index = 0;
             union_type->data.unionation.gen_tag_index = 1;
         }
     } else {
         assert(most_aligned_union_member != nullptr);
-        union_type->abi_align = most_aligned_union_member->abi_align;
+        union_type->abi_align = most_aligned_union_member->align;
         union_type->data.unionation.gen_union_index = SIZE_MAX;
         union_type->data.unionation.gen_tag_index = SIZE_MAX;
     }
@@ -2016,14 +2017,14 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) {
     assert(decl_node->type == NodeTypeContainerDecl);
 
     uint32_t field_count = union_type->data.unionation.src_field_count;
-    ZigType *most_aligned_union_member = union_type->data.unionation.most_aligned_union_member;
+    TypeUnionField *most_aligned_union_member = union_type->data.unionation.most_aligned_union_member;
 
     assert(union_type->data.unionation.fields);
 
     size_t union_abi_size = 0;
     size_t union_size_in_bits = 0;
 
-    if (union_type->data.unionation.resolve_loop_flag) {
+    if (union_type->data.unionation.resolve_loop_flag_other) {
         if (union_type->data.unionation.resolve_status != ResolveStatusInvalid) {
             union_type->data.unionation.resolve_status = ResolveStatusInvalid;
             g->trace_err = add_node_error(g, decl_node,
@@ -2033,11 +2034,18 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) {
     }
 
     // set temporary flag
-    union_type->data.unionation.resolve_loop_flag = true;
+    union_type->data.unionation.resolve_loop_flag_other = true;
 
     for (uint32_t i = 0; i < field_count; i += 1) {
+        AstNode *field_source_node = decl_node->data.container_decl.fields.at(i);
         TypeUnionField *union_field = &union_type->data.unionation.fields[i];
-        ZigType *field_type = union_field->type_entry;
+
+        if ((err = ir_resolve_lazy(g, field_source_node, union_field->type_val))) {
+            union_type->data.unionation.resolve_status = ResolveStatusInvalid;
+            return err;
+        }
+        ZigType *field_type = union_field->type_val->data.x_type;
+        union_field->type_entry = field_type;
 
         if ((err = type_resolve(g, field_type, ResolveStatusSizeKnown))) {
             union_type->data.unionation.resolve_status = ResolveStatusInvalid;
@@ -2057,11 +2065,11 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) {
     // The union itself for now has to be treated as being independently aligned.
     // See https://github.com/ziglang/zig/issues/2166.
     if (most_aligned_union_member != nullptr) {
-        union_abi_size = align_forward(union_abi_size, most_aligned_union_member->abi_align);
+        union_abi_size = align_forward(union_abi_size, most_aligned_union_member->align);
     }
 
     // unset temporary flag
-    union_type->data.unionation.resolve_loop_flag = false;
+    union_type->data.unionation.resolve_loop_flag_other = false;
     union_type->data.unionation.resolve_status = ResolveStatusSizeKnown;
     union_type->data.unionation.union_abi_size = union_abi_size;
 
@@ -2080,7 +2088,7 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) {
             field_sizes[union_type->data.unionation.gen_tag_index] = tag_type->abi_size;
             field_aligns[union_type->data.unionation.gen_tag_index] = tag_type->abi_align;
             field_sizes[union_type->data.unionation.gen_union_index] = union_abi_size;
-            field_aligns[union_type->data.unionation.gen_union_index] = most_aligned_union_member->abi_align;
+            field_aligns[union_type->data.unionation.gen_union_index] = most_aligned_union_member->align;
             size_t field2_offset = next_field_offset(0, union_type->abi_align, field_sizes[0], field_aligns[1]);
             union_type->abi_size = next_field_offset(field2_offset, union_type->abi_align, field_sizes[1], union_type->abi_align);
             union_type->size_in_bits = union_type->abi_size * 8;
@@ -2449,12 +2457,12 @@ static Error resolve_struct_alignment(CodeGen *g, ZigType *struct_type) {
         } else if (packed) {
             field->align = 1;
         } else {
-            size_t result_abi_align;
-            if ((err = type_val_resolve_abi_align(g, field->type_val, &result_abi_align))) {
+            if ((err = type_val_resolve_abi_align(g, field->type_val, &field->align))) {
                 struct_type->data.structure.resolve_status = ResolveStatusInvalid;
                 return err;
             }
-            field->align = result_abi_align;
+            if (struct_type->data.structure.resolve_status == ResolveStatusInvalid)
+                return ErrorSemanticAnalyzeFail;
         }
 
         if (field->align > struct_type->abi_align) {
@@ -2486,7 +2494,7 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
     AstNode *decl_node = union_type->data.unionation.decl_node;
     assert(decl_node->type == NodeTypeContainerDecl);
 
-    if (union_type->data.unionation.resolve_loop_flag) {
+    if (union_type->data.unionation.resolve_loop_flag_zero_bits) {
         if (union_type->data.unionation.resolve_status != ResolveStatusInvalid) {
             union_type->data.unionation.resolve_status = ResolveStatusInvalid;
             g->trace_err = add_node_error(g, decl_node,
@@ -2496,7 +2504,7 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
         return ErrorSemanticAnalyzeFail;
     }
 
-    union_type->data.unionation.resolve_loop_flag = true;
+    union_type->data.unionation.resolve_loop_flag_zero_bits = true;
 
     assert(union_type->data.unionation.fields == nullptr);
     uint32_t field_count = (uint32_t)decl_node->data.container_decl.fields.length;
@@ -2605,49 +2613,68 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
             return ErrorSemanticAnalyzeFail;
         }
 
-        ZigType *field_type;
+        bool field_is_zero_bits;
         if (field_node->data.struct_field.type == nullptr) {
-            if (decl_node->data.container_decl.auto_enum || decl_node->data.container_decl.init_arg_expr != nullptr) {
-                field_type = g->builtin_types.entry_void;
+            if (decl_node->data.container_decl.auto_enum ||
+                decl_node->data.container_decl.init_arg_expr != nullptr)
+            {
+                union_field->type_entry = g->builtin_types.entry_void;
+                field_is_zero_bits = false;
             } else {
                 add_node_error(g, field_node, buf_sprintf("union field missing type"));
                 union_type->data.unionation.resolve_status = ResolveStatusInvalid;
                 return ErrorSemanticAnalyzeFail;
             }
         } else {
-            field_type = analyze_type_expr(g, scope, field_node->data.struct_field.type);
-            if ((err = type_resolve(g, field_type, ResolveStatusAlignmentKnown))) {
+            ConstExprValue *field_type_val = analyze_const_value_allow_lazy(g, scope,
+                    field_node->data.struct_field.type, g->builtin_types.entry_type, nullptr, true);
+            if (type_is_invalid(field_type_val->type)) {
                 union_type->data.unionation.resolve_status = ResolveStatusInvalid;
                 return ErrorSemanticAnalyzeFail;
             }
+            assert(field_type_val->special != ConstValSpecialRuntime);
+            union_field->type_val = field_type_val;
             if (union_type->data.unionation.resolve_status == ResolveStatusInvalid)
                 return ErrorSemanticAnalyzeFail;
-        }
-        union_field->type_entry = field_type;
 
-        if (field_type->id == ZigTypeIdOpaque) {
-            add_node_error(g, field_node->data.struct_field.type,
-                buf_sprintf("opaque types have unknown size and therefore cannot be directly embedded in unions"));
-            union_type->data.unionation.resolve_status = ResolveStatusInvalid;
-            return ErrorSemanticAnalyzeFail;
-        }
+            bool field_is_opaque_type;
+            if ((err = type_val_resolve_is_opaque_type(g, field_type_val, &field_is_opaque_type))) {
+                union_type->data.unionation.resolve_status = ResolveStatusInvalid;
+                return ErrorSemanticAnalyzeFail;
+            }
+            if (field_is_opaque_type) {
+                add_node_error(g, field_node->data.struct_field.type,
+                    buf_create_from_str(
+                        "opaque types have unknown size and therefore cannot be directly embedded in unions"));
+                union_type->data.unionation.resolve_status = ResolveStatusInvalid;
+                return ErrorSemanticAnalyzeFail;
+            }
 
-        switch (type_requires_comptime(g, field_type)) {
-            case ReqCompTimeInvalid:
+            switch (type_val_resolve_requires_comptime(g, field_type_val)) {
+                case ReqCompTimeInvalid:
+                    if (g->trace_err != nullptr) {
+                        g->trace_err = add_error_note(g, g->trace_err, field_node,
+                            buf_create_from_str("while checking this field"));
+                    }
+                    union_type->data.unionation.resolve_status = ResolveStatusInvalid;
+                    return ErrorSemanticAnalyzeFail;
+                case ReqCompTimeYes:
+                    union_type->data.unionation.requires_comptime = true;
+                    break;
+                case ReqCompTimeNo:
+                    break;
+            }
+
+            if ((err = type_val_resolve_zero_bits(g, field_type_val, union_type, nullptr, &field_is_zero_bits))) {
                 union_type->data.unionation.resolve_status = ResolveStatusInvalid;
                 return ErrorSemanticAnalyzeFail;
-            case ReqCompTimeYes:
-                union_type->data.unionation.requires_comptime = true;
-                break;
-            case ReqCompTimeNo:
-                break;
+            }
         }
 
         if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) {
             ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value,
-                    buf_sprintf("non-enum union field assignment"));
-            add_error_note(g, msg, decl_node,
-                    buf_sprintf("consider 'union(enum)' here"));
+                    buf_create_from_str("untagged union field assignment"));
+            add_error_note(g, msg, decl_node, buf_create_from_str("consider 'union(enum)' here"));
         }
 
         if (create_enum_type) {
@@ -2706,7 +2733,7 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
         }
         assert(union_field->enum_field != nullptr);
 
-        if (!type_has_bits(field_type))
+        if (field_is_zero_bits)
             continue;
 
         union_field->gen_index = gen_field_index;
@@ -2783,7 +2810,7 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
         return ErrorSemanticAnalyzeFail;
     }
 
-    union_type->data.unionation.resolve_loop_flag = false;
+    union_type->data.unionation.resolve_loop_flag_zero_bits = false;
 
     union_type->data.unionation.gen_field_count = gen_field_index;
     bool zero_bits = gen_field_index == 0 && (field_count < 2 || !src_have_tag);
@@ -5002,6 +5029,10 @@ ReqCompTime type_requires_comptime(CodeGen *g, ZigType *ty) {
                 return ReqCompTimeInvalid;
             return ty->data.structure.requires_comptime ? ReqCompTimeYes : ReqCompTimeNo;
         case ZigTypeIdUnion:
+            if (ty->data.unionation.resolve_loop_flag_zero_bits) {
+                // Does a union which contains a pointer field to itself require comptime? No.
+                return ReqCompTimeNo;
+            }
             if ((err = type_resolve(g, ty, ResolveStatusZeroBitsKnown)))
                 return ReqCompTimeInvalid;
             return ty->data.unionation.requires_comptime ? ReqCompTimeYes : ReqCompTimeNo;
@@ -7308,7 +7339,7 @@ static void resolve_llvm_types_enum(CodeGen *g, ZigType *enum_type, ResolveStatu
 static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveStatus wanted_resolve_status) {
     if (union_type->data.unionation.resolve_status >= wanted_resolve_status) return;
 
-    ZigType *most_aligned_union_member = union_type->data.unionation.most_aligned_union_member;
+    TypeUnionField *most_aligned_union_member = union_type->data.unionation.most_aligned_union_member;
     ZigType *tag_type = union_type->data.unionation.tag_type;
     if (most_aligned_union_member == nullptr) {
         union_type->llvm_type = get_llvm_type(g, tag_type);
@@ -7361,17 +7392,17 @@ static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveSta
     if (tag_type == nullptr || !type_has_bits(tag_type)) {
         assert(most_aligned_union_member != nullptr);
 
-        size_t padding_bytes = union_type->data.unionation.union_abi_size - most_aligned_union_member->abi_size;
+        size_t padding_bytes = union_type->data.unionation.union_abi_size - most_aligned_union_member->type_entry->abi_size;
         if (padding_bytes > 0) {
             ZigType *u8_type = get_int_type(g, false, 8);
             ZigType *padding_array = get_array_type(g, u8_type, padding_bytes);
             LLVMTypeRef union_element_types[] = {
-                most_aligned_union_member->llvm_type,
+                most_aligned_union_member->type_entry->llvm_type,
                 get_llvm_type(g, padding_array),
             };
             LLVMStructSetBody(union_type->llvm_type, union_element_types, 2, false);
         } else {
-            LLVMStructSetBody(union_type->llvm_type, &most_aligned_union_member->llvm_type, 1, false);
+            LLVMStructSetBody(union_type->llvm_type, &most_aligned_union_member->type_entry->llvm_type, 1, false);
         }
         union_type->data.unionation.union_llvm_type = union_type->llvm_type;
         union_type->data.unionation.gen_tag_index = SIZE_MAX;
@@ -7382,7 +7413,7 @@ static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveSta
             ZigLLVMFileToScope(import->data.structure.root_struct->di_file), buf_ptr(&union_type->name),
             import->data.structure.root_struct->di_file, (unsigned)(decl_node->line + 1),
             union_type->data.unionation.union_abi_size * 8,
-            most_aligned_union_member->abi_align * 8,
+            most_aligned_union_member->align * 8,
             ZigLLVM_DIFlags_Zero, union_inner_di_types,
             gen_field_count, 0, "");
 
@@ -7393,14 +7424,14 @@ static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveSta
     }
 
     LLVMTypeRef union_type_ref;
-    size_t padding_bytes = union_type->data.unionation.union_abi_size - most_aligned_union_member->abi_size;
+    size_t padding_bytes = union_type->data.unionation.union_abi_size - most_aligned_union_member->type_entry->abi_size;
     if (padding_bytes == 0) {
-        union_type_ref = get_llvm_type(g, most_aligned_union_member);
+        union_type_ref = get_llvm_type(g, most_aligned_union_member->type_entry);
     } else {
         ZigType *u8_type = get_int_type(g, false, 8);
         ZigType *padding_array = get_array_type(g, u8_type, padding_bytes);
         LLVMTypeRef union_element_types[] = {
-            get_llvm_type(g, most_aligned_union_member),
+            get_llvm_type(g, most_aligned_union_member->type_entry),
             get_llvm_type(g, padding_array),
         };
         union_type_ref = LLVMStructType(union_element_types, 2, false);
@@ -7416,7 +7447,7 @@ static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveSta
     ZigLLVMDIType *union_di_type = ZigLLVMCreateDebugUnionType(g->dbuilder,
             ZigLLVMTypeToScope(union_type->llvm_di_type), "AnonUnion",
             import->data.structure.root_struct->di_file, (unsigned)(decl_node->line + 1),
-            most_aligned_union_member->size_in_bits, 8*most_aligned_union_member->abi_align,
+            most_aligned_union_member->type_entry->size_in_bits, 8*most_aligned_union_member->align,
             ZigLLVM_DIFlags_Zero, union_inner_di_types, gen_field_count, 0, "");
 
     uint64_t union_offset_in_bits = 8*LLVMOffsetOfElement(g->target_data_ref, union_type->llvm_type,
@@ -7427,8 +7458,8 @@ static void resolve_llvm_types_union(CodeGen *g, ZigType *union_type, ResolveSta
     ZigLLVMDIType *union_member_di_type = ZigLLVMCreateDebugMemberType(g->dbuilder,
             ZigLLVMTypeToScope(union_type->llvm_di_type), "payload",
             import->data.structure.root_struct->di_file, (unsigned)(decl_node->line + 1),
-            most_aligned_union_member->size_in_bits,
-            8*most_aligned_union_member->abi_align,
+            most_aligned_union_member->type_entry->size_in_bits,
+            8*most_aligned_union_member->align,
             union_offset_in_bits,
             ZigLLVM_DIFlags_Zero, union_di_type);
 
src/codegen.cpp
@@ -6568,7 +6568,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val, const c
                     uint64_t pad_bytes = type_entry->data.unionation.union_abi_size - field_type_bytes;
                     LLVMValueRef correctly_typed_value = gen_const_val(g, payload_value, "");
                     make_unnamed_struct = is_llvm_value_unnamed_type(g, payload_value->type, correctly_typed_value) ||
-                        payload_value->type != type_entry->data.unionation.most_aligned_union_member;
+                        payload_value->type != type_entry->data.unionation.most_aligned_union_member->type_entry;
 
                     {
                         if (pad_bytes == 0) {