Commit d48af541c7

Veikka Tuominen <git@vexu.eu>
2022-08-21 11:25:19
Sema: handle union and enum field order being different
Closes #12543
1 parent e8102d8
Changed files (6)
src/codegen/c.zig
@@ -835,7 +835,6 @@ pub const DeclGen = struct {
             },
             .Union => {
                 const union_obj = val.castTag(.@"union").?.data;
-                const union_ty = ty.cast(Type.Payload.Union).?.data;
                 const layout = ty.unionGetLayout(target);
 
                 try writer.writeAll("(");
@@ -851,7 +850,7 @@ pub const DeclGen = struct {
                     try writer.writeAll(".payload = {");
                 }
 
-                const index = union_ty.tag_ty.enumTagFieldIndex(union_obj.tag, dg.module).?;
+                const index = ty.unionTagFieldIndex(union_obj.tag, dg.module).?;
                 const field_ty = ty.unionFields().values()[index].ty;
                 const field_name = ty.unionFields().keys()[index];
                 if (field_ty.hasRuntimeBits()) {
src/codegen/llvm.zig
@@ -3502,7 +3502,7 @@ pub const DeclGen = struct {
                     });
                 }
                 const union_obj = tv.ty.cast(Type.Payload.Union).?.data;
-                const field_index = union_obj.tag_ty.enumTagFieldIndex(tag_and_val.tag, dg.module).?;
+                const field_index = tv.ty.unionTagFieldIndex(tag_and_val.tag, dg.module).?;
                 assert(union_obj.haveFieldTypes());
 
                 // Sometimes we must make an unnamed struct because LLVM does
src/codegen.zig
@@ -607,7 +607,7 @@ pub fn generateSymbol(
 
             const union_ty = typed_value.ty.cast(Type.Payload.Union).?.data;
             const mod = bin_file.options.module.?;
-            const field_index = union_ty.tag_ty.enumTagFieldIndex(union_obj.tag, mod).?;
+            const field_index = typed_value.ty.unionTagFieldIndex(union_obj.tag, mod).?;
             assert(union_ty.haveFieldTypes());
             const field_ty = union_ty.fields.values()[field_index].ty;
             if (!field_ty.hasRuntimeBits()) {
src/Sema.zig
@@ -3615,8 +3615,6 @@ fn validateUnionInit(
     union_ptr: Air.Inst.Ref,
     is_comptime: bool,
 ) CompileError!void {
-    const union_obj = union_ty.cast(Type.Payload.Union).?.data;
-
     if (instrs.len != 1) {
         const msg = msg: {
             const msg = try sema.errMsg(
@@ -3650,7 +3648,8 @@ fn validateUnionInit(
     const field_src: LazySrcLoc = .{ .node_offset_initializer = field_ptr_data.src_node };
     const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data;
     const field_name = sema.code.nullTerminatedString(field_ptr_extra.field_name_start);
-    const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_src);
+    // Validate the field access but ignore the index since we want the tag enum field index.
+    _ = try sema.unionFieldIndex(block, union_ty, field_name, field_src);
     const air_tags = sema.air_instructions.items(.tag);
     const air_datas = sema.air_instructions.items(.data);
     const field_ptr_air_ref = sema.inst_map.get(field_ptr).?;
@@ -3709,7 +3708,9 @@ fn validateUnionInit(
         break;
     }
 
-    const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
+    const tag_ty = union_ty.unionTagTypeHypothetical();
+    const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?);
+    const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index);
 
     if (init_val) |val| {
         // Our task is to delete all the `field_ptr` and `store` instructions, and insert
@@ -3726,7 +3727,7 @@ fn validateUnionInit(
     }
 
     try sema.requireFunctionBlock(block, init_src);
-    const new_tag = try sema.addConstant(union_obj.tag_ty, tag_val);
+    const new_tag = try sema.addConstant(tag_ty, tag_val);
     _ = try block.addBinOp(.set_union_tag, union_ptr, new_tag);
 }
 
@@ -8838,13 +8839,11 @@ fn zirSwitchCapture(
     switch (operand_ty.zigTypeTag()) {
         .Union => {
             const union_obj = operand_ty.cast(Type.Payload.Union).?.data;
-            const enum_ty = union_obj.tag_ty;
-
             const first_item = try sema.resolveInst(items[0]);
             // Previous switch validation ensured this will succeed
             const first_item_val = sema.resolveConstValue(block, .unneeded, first_item, undefined) catch unreachable;
 
-            const first_field_index = @intCast(u32, enum_ty.enumTagFieldIndex(first_item_val, sema.mod).?);
+            const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, sema.mod).?);
             const first_field = union_obj.fields.values()[first_field_index];
 
             for (items[1..]) |item, i| {
@@ -8852,7 +8851,7 @@ fn zirSwitchCapture(
                 // Previous switch validation ensured this will succeed
                 const item_val = sema.resolveConstValue(block, .unneeded, item_ref, undefined) catch unreachable;
 
-                const field_index = enum_ty.enumTagFieldIndex(item_val, sema.mod).?;
+                const field_index = operand_ty.unionTagFieldIndex(item_val, sema.mod).?;
                 const field = union_obj.fields.values()[field_index];
                 if (!field.ty.eql(first_field.ty, sema.mod)) {
                     const msg = msg: {
@@ -15585,7 +15584,9 @@ fn unionInit(
     const init = try sema.coerce(block, field.ty, uncasted_init, init_src);
 
     if (try sema.resolveMaybeUndefVal(block, init_src, init)) |init_val| {
-        const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
+        const tag_ty = union_ty.unionTagTypeHypothetical();
+        const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?);
+        const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index);
         return sema.addConstant(union_ty, try Value.Tag.@"union".create(sema.arena, .{
             .tag = tag_val,
             .val = init_val,
@@ -15683,7 +15684,9 @@ fn zirStructInit(
         const field_type_extra = sema.code.extraData(Zir.Inst.FieldType, field_type_data.payload_index).data;
         const field_name = sema.code.nullTerminatedString(field_type_extra.name_start);
         const field_index = try sema.unionFieldIndex(block, resolved_ty, field_name, field_src);
-        const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
+        const tag_ty = resolved_ty.unionTagTypeHypothetical();
+        const enum_field_index = @intCast(u32, tag_ty.enumFieldIndex(field_name).?);
+        const tag_val = try Value.Tag.enum_field_index.create(sema.arena, enum_field_index);
 
         const init_inst = try sema.resolveInst(item.data.init);
         if (try sema.resolveMaybeUndefVal(block, field_src, init_inst)) |val| {
@@ -16448,9 +16451,8 @@ fn zirReify(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData, in
     const type_info = try sema.coerce(block, type_info_ty, uncasted_operand, operand_src);
     const val = try sema.resolveConstValue(block, operand_src, type_info, "operand to @Type must be comptime known");
     const union_val = val.cast(Value.Payload.Union).?.data;
-    const tag_ty = type_info_ty.unionTagType().?;
     const target = mod.getTarget();
-    const tag_index = tag_ty.enumTagFieldIndex(union_val.tag, mod).?;
+    const tag_index = type_info_ty.unionTagFieldIndex(union_val.tag, mod).?;
     if (union_val.val.anyUndef()) return sema.failWithUseOfUndef(block, src);
     switch (@intToEnum(std.builtin.TypeId, tag_index)) {
         .Type => return Air.Inst.Ref.type_type,
@@ -25155,8 +25157,7 @@ fn coerceEnumToUnion(
 
     const enum_tag = try sema.coerce(block, tag_ty, inst, inst_src);
     if (try sema.resolveDefinedValue(block, inst_src, enum_tag)) |val| {
-        const union_obj = union_ty.cast(Type.Payload.Union).?.data;
-        const field_index = union_obj.tag_ty.enumTagFieldIndex(val, sema.mod) orelse {
+        const field_index = union_ty.unionTagFieldIndex(val, sema.mod) orelse {
             const msg = msg: {
                 const msg = try sema.errMsg(block, inst_src, "union '{}' has no tag with value '{}'", .{
                     union_ty.fmt(sema.mod), val.fmtValue(tag_ty, sema.mod),
@@ -25167,6 +25168,8 @@ fn coerceEnumToUnion(
             };
             return sema.failWithOwnedErrorMsg(msg);
         };
+
+        const union_obj = union_ty.cast(Type.Payload.Union).?.data;
         const field = union_obj.fields.values()[field_index];
         const field_ty = try sema.resolveTypeFields(block, inst_src, field.ty);
         if (field_ty.zigTypeTag() == .NoReturn) {
src/type.zig
@@ -4285,11 +4285,18 @@ pub const Type = extern union {
 
     pub fn unionFieldType(ty: Type, enum_tag: Value, mod: *Module) Type {
         const union_obj = ty.cast(Payload.Union).?.data;
-        const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag, mod).?;
+        const index = ty.unionTagFieldIndex(enum_tag, mod).?;
         assert(union_obj.haveFieldTypes());
         return union_obj.fields.values()[index].ty;
     }
 
+    pub fn unionTagFieldIndex(ty: Type, enum_tag: Value, mod: *Module) ?usize {
+        const union_obj = ty.cast(Payload.Union).?.data;
+        const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag, mod) orelse return null;
+        const name = union_obj.tag_ty.enumFieldName(index);
+        return union_obj.fields.getIndex(name);
+    }
+
     pub fn unionHasAllZeroBitFieldTypes(ty: Type) bool {
         return ty.cast(Payload.Union).?.data.hasAllZeroBitFieldTypes();
     }
test/behavior/union.zig
@@ -1301,3 +1301,27 @@ test "noreturn field in union" {
     }
     try expect(count == 5);
 }
+
+test "union and enum field order doesn't match" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    const MyTag = enum(u32) {
+        b = 1337,
+        a = 1666,
+    };
+    const MyUnion = union(MyTag) {
+        a: f32,
+        b: void,
+    };
+    var x: MyUnion = .{ .a = 666 };
+    switch (x) {
+        .a => |my_f32| {
+            try expect(@TypeOf(my_f32) == f32);
+        },
+        .b => unreachable,
+    }
+    x = .b;
+    try expect(x == .b);
+}