Commit c1c9bc0c41

mlugg <mlugg@mlugg.co.uk>
2023-10-28 02:22:30
Sema: do not assume switch item indices align with union field indices
Resolves: #17754
1 parent 5257643
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -10789,23 +10789,24 @@ const SwitchProngAnalysis = struct {
                 const first_field_index: u32 = mod.unionTagFieldIndex(union_obj, first_item_val).?;
                 const first_field_ty = union_obj.field_types.get(ip)[first_field_index].toType();
 
-                const field_tys = try sema.arena.alloc(Type, case_vals.len);
-                for (case_vals, field_tys) |item, *field_ty| {
+                const field_indices = try sema.arena.alloc(u32, case_vals.len);
+                for (case_vals, field_indices) |item, *field_idx| {
                     const item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable;
-                    const field_idx = mod.unionTagFieldIndex(union_obj, item_val).?;
-                    field_ty.* = union_obj.field_types.get(ip)[field_idx].toType();
+                    field_idx.* = mod.unionTagFieldIndex(union_obj, item_val).?;
                 }
 
                 // Fast path: if all the operands are the same type already, we don't need to hit
                 // PTR! This will also allow us to emit simpler code.
-                const same_types = for (field_tys[1..]) |field_ty| {
-                    if (!field_ty.eql(field_tys[0], sema.mod)) break false;
+                const same_types = for (field_indices[1..]) |field_idx| {
+                    const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
+                    if (!field_ty.eql(first_field_ty, sema.mod)) break false;
                 } else true;
 
-                const capture_ty = if (same_types) field_tys[0] else capture_ty: {
+                const capture_ty = if (same_types) first_field_ty else capture_ty: {
                     // We need values to run PTR on, so make a bunch of undef constants.
                     const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
-                    for (dummy_captures, field_tys) |*dummy, field_ty| {
+                    for (dummy_captures, field_indices) |*dummy, field_idx| {
+                        const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
                         dummy.* = try mod.undefRef(field_ty);
                     }
 
@@ -10852,7 +10853,8 @@ const SwitchProngAnalysis = struct {
                     // By-ref captures of hetereogeneous types are only allowed if each field
                     // pointer type is in-memory coercible to the capture pointer type.
                     if (!same_types) {
-                        for (field_tys, 0..) |field_ty, i| {
+                        for (field_indices, 0..) |field_idx, i| {
+                            const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
                             const field_ptr_ty = try sema.ptrType(.{
                                 .child = field_ty.toIntern(),
                                 .flags = .{
@@ -10915,7 +10917,8 @@ const SwitchProngAnalysis = struct {
                 // We may have to emit a switch block which coerces the operand to the capture type.
                 // If we can, try to avoid that using in-memory coercions.
                 const first_non_imc = in_mem: {
-                    for (field_tys, 0..) |field_ty, i| {
+                    for (field_indices, 0..) |field_idx, i| {
+                        const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
                         if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
                             break :in_mem i;
                         }
@@ -10933,11 +10936,12 @@ const SwitchProngAnalysis = struct {
                 // be several, and we can squash all of these cases into the same switch prong using
                 // a simple bitcast. We'll make this the 'else' prong.
 
-                var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
+                var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_indices.len);
                 in_mem_coercible.unset(first_non_imc);
                 {
                     const next = first_non_imc + 1;
-                    for (field_tys[next..], next..) |field_ty, i| {
+                    for (field_indices[next..], next..) |field_idx, i| {
+                        const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
                         if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
                             in_mem_coercible.unset(i);
                         }
@@ -10954,7 +10958,7 @@ const SwitchProngAnalysis = struct {
                     },
                 });
 
-                const prong_count = field_tys.len - in_mem_coercible.count();
+                const prong_count = field_indices.len - in_mem_coercible.count();
 
                 const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
                 var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
@@ -10967,7 +10971,9 @@ const SwitchProngAnalysis = struct {
                         var coerce_block = block.makeSubBlock();
                         defer coerce_block.instructions.deinit(sema.gpa);
 
-                        const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(idx), field_tys[idx]);
+                        const field_idx = field_indices[idx];
+                        const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
+                        const uncoerced = try coerce_block.addStructFieldVal(spa.operand, field_idx, field_ty);
                         const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
                             error.NeededSourceLocation => {
                                 const multi_idx = raw_capture_src.multi_capture;
@@ -10993,8 +10999,10 @@ const SwitchProngAnalysis = struct {
                     var coerce_block = block.makeSubBlock();
                     defer coerce_block.instructions.deinit(sema.gpa);
 
-                    const first_imc = in_mem_coercible.findFirstSet().?;
-                    const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(first_imc), field_tys[first_imc]);
+                    const first_imc_item_idx = in_mem_coercible.findFirstSet().?;
+                    const first_imc_field_idx = field_indices[first_imc_item_idx];
+                    const first_imc_field_ty = union_obj.field_types.get(ip)[first_imc_field_idx].toType();
+                    const uncoerced = try coerce_block.addStructFieldVal(spa.operand, first_imc_field_idx, first_imc_field_ty);
                     const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
                     _ = try coerce_block.addBr(capture_block_inst, coerced);
 
test/behavior/switch.zig
@@ -800,3 +800,26 @@ test "nested break ignores switch conditions and breaks instead" {
     // Originally reported at https://github.com/ziglang/zig/issues/10196
     try expect(0x01 == try S.register_to_address("a0"));
 }
+
+test "peer type resolution on switch captures ignores unused payload bits" {
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+    const Foo = union(enum) {
+        a: u32,
+        b: u64,
+    };
+
+    var val: Foo = undefined;
+    @memset(std.mem.asBytes(&val), 0xFF);
+
+    // This is runtime-known so the following store isn't comptime-known.
+    var rt: u32 = 123;
+    val = .{ .a = rt }; // will not necessarily zero remaning payload memory
+
+    // Fields intentionally backwards here
+    const x = switch (val) {
+        .b, .a => |x| x,
+    };
+
+    try expect(x == 123);
+}