Commit c558de6655

Veikka Tuominen <git@vexu.eu>
2022-08-30 15:01:28
stage2 llvm: use tag value instead of field index in airUnionInit
Closes #12656
1 parent d3b4b2e
Changed files (3)
src
test
behavior
src/codegen/llvm.zig
@@ -8527,12 +8527,26 @@ pub const FuncGen = struct {
         const union_llvm_ty = try self.dg.lowerType(union_ty);
         const target = self.dg.module.getTarget();
         const layout = union_ty.unionGetLayout(target);
+        const union_obj = union_ty.cast(Type.Payload.Union).?.data;
+        const tag_int = blk: {
+            const tag_ty = union_ty.unionTagTypeHypothetical();
+            const union_field_name = union_obj.fields.keys()[extra.field_index];
+            const enum_field_index = tag_ty.enumFieldIndex(union_field_name).?;
+            var tag_val_payload: Value.Payload.U32 = .{
+                .base = .{ .tag = .enum_field_index },
+                .data = @intCast(u32, enum_field_index),
+            };
+            const tag_val = Value.initPayload(&tag_val_payload.base);
+            var int_payload: Value.Payload.U64 = undefined;
+            const tag_int_val = tag_val.enumToInt(tag_ty, &int_payload);
+            break :blk tag_int_val.toUnsignedInt(target);
+        };
         if (layout.payload_size == 0) {
             if (layout.tag_size == 0) {
                 return null;
             }
             assert(!isByRef(union_ty));
-            return union_llvm_ty.constInt(extra.field_index, .False);
+            return union_llvm_ty.constInt(tag_int, .False);
         }
         assert(isByRef(union_ty));
         // The llvm type of the alloca will the the named LLVM union type, which will not
@@ -8541,7 +8555,6 @@ pub const FuncGen = struct {
         // then set the fields appropriately.
         const result_ptr = self.buildAlloca(union_llvm_ty);
         const llvm_payload = try self.resolveInst(extra.init);
-        const union_obj = union_ty.cast(Type.Payload.Union).?.data;
         assert(union_obj.haveFieldTypes());
         const field = union_obj.fields.values()[extra.field_index];
         const field_llvm_ty = try self.dg.lowerType(field.ty);
@@ -8625,7 +8638,7 @@ pub const FuncGen = struct {
             };
             const field_ptr = self.builder.buildInBoundsGEP(casted_ptr, &indices, indices.len, "");
             const tag_llvm_ty = try self.dg.lowerType(union_obj.tag_ty);
-            const llvm_tag = tag_llvm_ty.constInt(extra.field_index, .False);
+            const llvm_tag = tag_llvm_ty.constInt(tag_int, .False);
             const store_inst = self.builder.buildStore(llvm_tag, field_ptr);
             store_inst.setAlignment(union_obj.tag_ty.abiAlignment(target));
         }
src/Sema.zig
@@ -21933,6 +21933,7 @@ fn unionFieldPtr(
         .mutable = union_ptr_ty.ptrIsMutable(),
         .@"addrspace" = union_ptr_ty.ptrAddressSpace(),
     });
+    const enum_field_index = @intCast(u32, union_obj.tag_ty.enumFieldIndex(field_name).?);
 
     if (initializing and field.ty.zigTypeTag() == .NoReturn) {
         const msg = msg: {
@@ -21954,11 +21955,10 @@ fn unionFieldPtr(
                 if (union_val.isUndef()) {
                     return sema.failWithUseOfUndef(block, src);
                 }
-                const enum_field_index = union_obj.tag_ty.enumFieldIndex(field_name).?;
                 const tag_and_val = union_val.castTag(.@"union").?.data;
                 var field_tag_buf: Value.Payload.U32 = .{
                     .base = .{ .tag = .enum_field_index },
-                    .data = @intCast(u32, enum_field_index),
+                    .data = enum_field_index,
                 };
                 const field_tag = Value.initPayload(&field_tag_buf.base);
                 const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, sema.mod);
@@ -21990,7 +21990,7 @@ fn unionFieldPtr(
     if (!initializing and union_obj.layout == .Auto and block.wantSafety() and
         union_ty.unionTagTypeSafety() != null and union_obj.fields.count() > 1)
     {
-        const wanted_tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
+        const wanted_tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index);
         const wanted_tag = try sema.addConstant(union_obj.tag_ty, wanted_tag_val);
         // TODO would it be better if get_union_tag supported pointers to unions?
         const union_val = try block.addTyOp(.load, union_ty, union_ptr);
@@ -22020,15 +22020,15 @@ fn unionFieldVal(
     const union_obj = union_ty.cast(Type.Payload.Union).?.data;
     const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_name_src);
     const field = union_obj.fields.values()[field_index];
+    const enum_field_index = @intCast(u32, union_obj.tag_ty.enumFieldIndex(field_name).?);
 
     if (try sema.resolveMaybeUndefVal(block, src, union_byval)) |union_val| {
         if (union_val.isUndef()) return sema.addConstUndef(field.ty);
 
         const tag_and_val = union_val.castTag(.@"union").?.data;
-        const enum_field_index = union_obj.tag_ty.enumFieldIndex(field_name).?;
         var field_tag_buf: Value.Payload.U32 = .{
             .base = .{ .tag = .enum_field_index },
-            .data = @intCast(u32, enum_field_index),
+            .data = enum_field_index,
         };
         const field_tag = Value.initPayload(&field_tag_buf.base);
         const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, sema.mod);
@@ -22064,7 +22064,7 @@ fn unionFieldVal(
     if (union_obj.layout == .Auto and block.wantSafety() and
         union_ty.unionTagTypeSafety() != null and union_obj.fields.count() > 1)
     {
-        const wanted_tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
+        const wanted_tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index);
         const wanted_tag = try sema.addConstant(union_obj.tag_ty, wanted_tag_val);
         const active_tag = try block.addTyOp(.get_union_tag, union_obj.tag_ty, union_byval);
         const ok = try block.addBinOp(.cmp_eq, active_tag, wanted_tag);
test/behavior/union.zig
@@ -1324,3 +1324,31 @@ test "union and enum field order doesn't match" {
     x = .b;
     try expect(x == .b);
 }
+
+test "@unionInit uses tag value instead of field index" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+
+    const E = enum(u8) {
+        b = 255,
+        a = 3,
+    };
+    const U = union(E) {
+        a: usize,
+        b: isize,
+    };
+    var i: isize = -1;
+    var u = @unionInit(U, "b", i);
+    {
+        var a = u.b;
+        try expect(a == i);
+    }
+    {
+        var a = &u.b;
+        try expect(a.* == i);
+    }
+    try expect(@enumToInt(u) == 255);
+}