Commit 2f326f24dd

Veikka Tuominen <git@vexu.eu>
2022-03-25 12:48:11
Sema: implement zirSwitchCapture multi for unions
1 parent 17d214a
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -6938,12 +6938,19 @@ fn zirSwitchCapture(
     const switch_info = zir_datas[capture_info.switch_inst].pl_node;
     const switch_extra = sema.code.extraData(Zir.Inst.SwitchBlock, switch_info.payload_index);
     const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_info.src_node };
+    const switch_src = switch_info.src();
     const operand_is_ref = switch_extra.data.bits.is_ref;
     const cond_inst = Zir.refToIndex(switch_extra.data.operand).?;
     const cond_info = sema.code.instructions.items(.data)[cond_inst].un_node;
     const operand_ptr = sema.resolveInst(cond_info.operand);
     const operand_ptr_ty = sema.typeOf(operand_ptr);
     const operand_ty = if (operand_is_ref) operand_ptr_ty.childType() else operand_ptr_ty;
+    const target = sema.mod.getTarget();
+
+    const operand = if (operand_is_ref)
+        try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
+    else
+        operand_ptr;
 
     if (capture_info.prong_index == std.math.maxInt(@TypeOf(capture_info.prong_index))) {
         // It is the else/`_` prong.
@@ -6952,64 +6959,57 @@ fn zirSwitchCapture(
             return operand_ptr;
         }
 
-        const operand = if (operand_is_ref)
-            try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
-        else
-            operand_ptr;
-
         switch (operand_ty.zigTypeTag()) {
             .ErrorSet => return sema.bitCast(block, block.switch_else_err_ty.?, operand, operand_src),
             else => return operand,
         }
     }
 
-    if (is_multi) {
-        const items = switch_extra.data.getMultiProng(sema.code, switch_extra.end, capture_info.prong_index).items;
-
-        var names: Module.ErrorSet.NameMap = .{};
-        try names.ensureUnusedCapacity(sema.arena, items.len);
-        for (items) |item| {
-            const item_ref = sema.resolveInst(item);
-            // Previous switch validation ensured this will succeed
-            const item_val = sema.resolveConstValue(block, .unneeded, item_ref) catch unreachable;
-            names.putAssumeCapacityNoClobber(
-                item_val.getError().?,
-                {},
-            );
-        }
-
-        // names must be sorted
-        Module.ErrorSet.sortNames(&names);
-        const else_error_ty = try Type.Tag.error_set_merged.create(sema.arena, names);
-
-        const operand = if (operand_is_ref)
-            try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
-        else
-            operand_ptr;
-        return sema.bitCast(block, else_error_ty, operand, operand_src);
-    }
-    const scalar_prong = switch_extra.data.getScalarProng(sema.code, switch_extra.end, capture_info.prong_index);
-    const item = sema.resolveInst(scalar_prong.item);
-    // Previous switch validation ensured this will succeed
-    const item_val = sema.resolveConstValue(block, .unneeded, item) catch unreachable;
-    const target = sema.mod.getTarget();
+    const items = if (is_multi)
+        switch_extra.data.getMultiProng(sema.code, switch_extra.end, capture_info.prong_index).items
+    else
+        &[_]Zir.Inst.Ref{
+            switch_extra.data.getScalarProng(sema.code, switch_extra.end, capture_info.prong_index).item,
+        };
 
     switch (operand_ty.zigTypeTag()) {
         .Union => {
             const union_obj = operand_ty.cast(Type.Payload.Union).?.data;
             const enum_ty = union_obj.tag_ty;
 
-            const field_index_usize = enum_ty.enumTagFieldIndex(item_val, target).?;
-            const field_index = @intCast(u32, field_index_usize);
-            const field = union_obj.fields.values()[field_index];
+            const first_item = sema.resolveInst(items[0]);
+            // Previous switch validation ensured this will succeed
+            const first_item_val = sema.resolveConstValue(block, .unneeded, first_item) catch unreachable;
 
-            // TODO handle multiple union tags which have compatible types
+            const first_field_index = @intCast(u32, enum_ty.enumTagFieldIndex(first_item_val, target).?);
+            const first_field = union_obj.fields.values()[first_field_index];
+
+            for (items[1..]) |item| {
+                const item_ref = sema.resolveInst(item);
+                // Previous switch validation ensured this will succeed
+                const item_val = sema.resolveConstValue(block, .unneeded, item_ref) catch unreachable;
+
+                const field_index = enum_ty.enumTagFieldIndex(item_val, target).?;
+                const field = union_obj.fields.values()[field_index];
+                if (!field.ty.eql(first_field.ty, target)) {
+                    const first_item_src = switch_src; // TODO better source location
+                    const item_src = switch_src;
+                    const msg = msg: {
+                        const msg = try sema.errMsg(block, switch_src, "capture group with incompatible types", .{});
+                        errdefer msg.destroy(sema.gpa);
+                        try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(target)});
+                        try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(target)});
+                        break :msg msg;
+                    };
+                    return sema.failWithOwnedErrorMsg(block, msg);
+                }
+            }
 
             if (is_ref) {
                 assert(operand_is_ref);
 
                 const field_ty_ptr = try Type.ptr(sema.arena, target, .{
-                    .pointee_type = field.ty,
+                    .pointee_type = first_field.ty,
                     .@"addrspace" = .generic,
                     .mutable = operand_ptr_ty.ptrIsMutable(),
                 });
@@ -7020,35 +7020,48 @@ fn zirSwitchCapture(
                         try Value.Tag.field_ptr.create(sema.arena, .{
                             .container_ptr = op_ptr_val,
                             .container_ty = operand_ty,
-                            .field_index = field_index,
+                            .field_index = first_field_index,
                         }),
                     );
                 }
                 try sema.requireRuntimeBlock(block, operand_src);
-                return block.addStructFieldPtr(operand_ptr, field_index, field_ty_ptr);
+                return block.addStructFieldPtr(operand_ptr, first_field_index, field_ty_ptr);
             }
 
-            const operand = if (operand_is_ref)
-                try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
-            else
-                operand_ptr;
-
             if (try sema.resolveDefinedValue(block, operand_src, operand)) |operand_val| {
                 return sema.addConstant(
-                    field.ty,
+                    first_field.ty,
                     operand_val.castTag(.@"union").?.data.val,
                 );
             }
             try sema.requireRuntimeBlock(block, operand_src);
-            return block.addStructFieldVal(operand, field_index, field.ty);
+            return block.addStructFieldVal(operand, first_field_index, first_field.ty);
         },
         .ErrorSet => {
-            const item_ty = try Type.Tag.error_set_single.create(sema.arena, item_val.getError().?);
-            const operand = if (operand_is_ref)
-                try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
-            else
-                operand_ptr;
-            return sema.bitCast(block, item_ty, operand, operand_src);
+            if (is_multi) {
+                var names: Module.ErrorSet.NameMap = .{};
+                try names.ensureUnusedCapacity(sema.arena, items.len);
+                for (items) |item| {
+                    const item_ref = sema.resolveInst(item);
+                    // Previous switch validation ensured this will succeed
+                    const item_val = sema.resolveConstValue(block, .unneeded, item_ref) catch unreachable;
+                    names.putAssumeCapacityNoClobber(
+                        item_val.getError().?,
+                        {},
+                    );
+                }
+                // names must be sorted
+                Module.ErrorSet.sortNames(&names);
+                const else_error_ty = try Type.Tag.error_set_merged.create(sema.arena, names);
+
+                return sema.bitCast(block, else_error_ty, operand, operand_src);
+            } else {
+                // Previous switch validation ensured this will succeed
+                const item_val = sema.resolveConstValue(block, .unneeded, items[0]) catch unreachable;
+
+                const item_ty = try Type.Tag.error_set_single.create(sema.arena, item_val.getError().?);
+                return sema.bitCast(block, item_ty, operand, operand_src);
+            }
         },
         else => {
             return sema.fail(block, operand_src, "switch on type '{}' provides no capture value", .{
test/behavior/switch.zig
@@ -541,7 +541,8 @@ test "switch with null and T peer types and inferred result location type" {
 }
 
 test "switch prongs with cases with identical payload types" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
 
     const Union = union(enum) {
         A: usize,