Commit 5d2ba056c8

Andrew Kelley <superjoe30@gmail.com>
2017-11-17 04:06:08
fix codegen for union init with runtime value
see #144
1 parent e26ccd5
Changed files (2)
src
test
cases
src/codegen.cpp
@@ -3414,17 +3414,34 @@ static LLVMValueRef ir_render_struct_init(CodeGen *g, IrExecutable *executable,
 static LLVMValueRef ir_render_union_init(CodeGen *g, IrExecutable *executable, IrInstructionUnionInit *instruction) {
     TypeUnionField *type_union_field = instruction->field;
 
-    assert(type_has_bits(type_union_field->type_entry));
-
-    LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, "");
-    LLVMValueRef value = ir_llvm_value(g, instruction->init_value);
+    if (!type_has_bits(type_union_field->type_entry))
+        return nullptr;
 
     uint32_t field_align_bytes = get_abi_alignment(g, type_union_field->type_entry);
-
     TypeTableEntry *ptr_type = get_pointer_to_type_extra(g, type_union_field->type_entry,
             false, false, field_align_bytes,
             0, 0);
 
+    LLVMValueRef uncasted_union_ptr;
+    // Even if safety is off in this block, if the union type has the safety field, we have to populate it
+    // correctly. Otherwise safety code somewhere other than here could fail.
+    TypeTableEntry *union_type = instruction->union_type;
+    if (union_type->data.unionation.gen_tag_index != SIZE_MAX) {
+        LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr,
+                union_type->data.unionation.gen_tag_index, "");
+        LLVMValueRef tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref,
+                type_union_field->value, false);
+        gen_store_untyped(g, tag_value, tag_field_ptr, 0, false);
+
+        uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr,
+                (unsigned)union_type->data.unionation.gen_union_index, "");
+    } else {
+        uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, "");
+    }
+
+    LLVMValueRef field_ptr = LLVMBuildBitCast(g->builder, uncasted_union_ptr, ptr_type->type_ref, "");
+    LLVMValueRef value = ir_llvm_value(g, instruction->init_value);
+
     gen_assign_raw(g, field_ptr, ptr_type, value);
 
     return instruction->tmp_ptr;
test/cases/union.zig
@@ -45,6 +45,23 @@ test "basic unions" {
     assert(foo.float == 12.34);
 }
 
+test "init union with runtime value" {
+    var foo: Foo = undefined;
+
+    setFloat(&foo, 12.34);
+    assert(foo.float == 12.34);
+
+    setInt(&foo, 42);
+    assert(foo.int == 42);
+}
+
+fn setFloat(foo: &Foo, x: f64) {
+    *foo = Foo { .float = x };
+}
+
+fn setInt(foo: &Foo, x: i32) {
+    *foo = Foo { .int = x };
+}
 
 const FooExtern = extern union {
     float: f64,
@@ -57,3 +74,4 @@ test "basic extern unions" {
     foo.float = 12.34;
     assert(foo.float == 12.34);
 }
+