Commit 62120e3d0e

Veikka Tuominen <git@vexu.eu>
2022-07-12 12:52:39
Sema: fix non-exhaustive union switch checks
1 parent 4602114
src/Sema.zig
@@ -8247,7 +8247,7 @@ fn zirSwitchCond(
 ) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const src = inst_data.src();
-    const operand_src = src; // TODO make this point at the switch operand
+    const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = inst_data.src_node };
     const operand_ptr = try sema.resolveInst(inst_data.operand);
     const operand = if (is_ref)
         try sema.analyzeLoad(block, src, operand_ptr, operand_src)
@@ -8345,12 +8345,19 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         },
     };
 
+    const union_originally = blk: {
+        const zir_data = sema.code.instructions.items(.data);
+        const cond_index = Zir.refToIndex(extra.data.operand).?;
+        const raw_operand = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable;
+        break :blk sema.typeOf(raw_operand).zigTypeTag() == .Union;
+    };
+
     const operand_ty = sema.typeOf(operand);
 
     var else_error_ty: ?Type = null;
 
     // Validate usage of '_' prongs.
-    if (special_prong == .under and !operand_ty.isNonexhaustiveEnum()) {
+    if (special_prong == .under and (!operand_ty.isNonexhaustiveEnum() or union_originally)) {
         const msg = msg: {
             const msg = try sema.errMsg(
                 block,
@@ -8375,6 +8382,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
     // Validate for duplicate items, missing else prong, and invalid range.
     switch (operand_ty.zigTypeTag()) {
+        .Union => unreachable, // handled in zirSwitchCond
         .Enum => {
             var seen_fields = try gpa.alloc(?Module.SwitchProngSrc, operand_ty.enumFieldCount());
             defer gpa.free(seen_fields);
@@ -8432,60 +8440,54 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             }
             const all_tags_handled = for (seen_fields) |seen_src| {
                 if (seen_src == null) break false;
-            } else !operand_ty.isNonexhaustiveEnum();
-
-            switch (special_prong) {
-                .none => {
-                    if (!all_tags_handled) {
-                        const msg = msg: {
-                            const msg = try sema.errMsg(
-                                block,
-                                src,
-                                "switch must handle all possibilities",
-                                .{},
-                            );
-                            errdefer msg.destroy(sema.gpa);
-                            for (seen_fields) |seen_src, i| {
-                                if (seen_src != null) continue;
+            } else true;
 
-                                const field_name = operand_ty.enumFieldName(i);
-
-                                // TODO have this point to the tag decl instead of here
-                                try sema.errNote(
-                                    block,
-                                    src,
-                                    msg,
-                                    "unhandled enumeration value: '{s}'",
-                                    .{field_name},
-                                );
-                            }
-                            try sema.mod.errNoteNonLazy(
-                                operand_ty.declSrcLoc(sema.mod),
-                                msg,
-                                "enum '{}' declared here",
-                                .{operand_ty.fmt(sema.mod)},
-                            );
-                            break :msg msg;
-                        };
-                        return sema.failWithOwnedErrorMsg(block, msg);
-                    }
-                },
-                .under => {
-                    if (all_tags_handled) return sema.fail(
+            if (special_prong == .@"else") {
+                if (all_tags_handled and !operand_ty.isNonexhaustiveEnum()) return sema.fail(
+                    block,
+                    special_prong_src,
+                    "unreachable else prong; all cases already handled",
+                    .{},
+                );
+            } else if (!all_tags_handled) {
+                const msg = msg: {
+                    const msg = try sema.errMsg(
                         block,
-                        special_prong_src,
-                        "unreachable '_' prong; all cases already handled",
+                        src,
+                        "switch must handle all possibilities",
                         .{},
                     );
-                },
-                .@"else" => {
-                    if (all_tags_handled) return sema.fail(
-                        block,
-                        special_prong_src,
-                        "unreachable else prong; all cases already handled",
-                        .{},
+                    errdefer msg.destroy(sema.gpa);
+                    for (seen_fields) |seen_src, i| {
+                        if (seen_src != null) continue;
+
+                        const field_name = operand_ty.enumFieldName(i);
+
+                        const field_src = src; // TODO better source location
+                        try sema.errNote(
+                            block,
+                            field_src,
+                            msg,
+                            "unhandled enumeration value: '{s}'",
+                            .{field_name},
+                        );
+                    }
+                    try sema.mod.errNoteNonLazy(
+                        operand_ty.declSrcLoc(sema.mod),
+                        msg,
+                        "enum '{}' declared here",
+                        .{operand_ty.fmt(sema.mod)},
                     );
-                },
+                    break :msg msg;
+                };
+                return sema.failWithOwnedErrorMsg(block, msg);
+            } else if (special_prong == .none and operand_ty.isNonexhaustiveEnum() and !union_originally) {
+                return sema.fail(
+                    block,
+                    src,
+                    "switch on non-exhaustive enum must include 'else' or '_' prong",
+                    .{},
+                );
             }
         },
         .ErrorSet => {
@@ -8625,7 +8627,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 else_error_ty = try Type.Tag.error_set_merged.create(sema.arena, names);
             }
         },
-        .Union => return sema.fail(block, src, "TODO validate switch .Union", .{}),
         .Int, .ComptimeInt => {
             var range_set = RangeSet.init(gpa, sema.mod);
             defer range_set.deinit();
test/behavior/union.zig
@@ -1016,7 +1016,6 @@ test "switching on non exhaustive union" {
             switch (a) {
                 .a => |val| try expect(val == 2),
                 .b => return error.Fail,
-                _ => return error.Fail,
             }
         }
     };
test/cases/compile_errors/stage1/test/helpful_return_type_error_message.zig → test/cases/compile_errors/helpful_return_type_error_message.zig
File renamed without changes
test/cases/compile_errors/stage1/test/switching_with_non-exhaustive_enums.zig → test/cases/compile_errors/switching_with_non-exhaustive_enums.zig
@@ -7,16 +7,21 @@ const U = union(E) {
     a: i32,
     b: u32,
 };
-pub export fn entry() void {
+pub export fn entry1() void {
     var e: E = .b;
     switch (e) { // error: switch not handling the tag `b`
         .a => {},
         _ => {},
     }
+}
+pub export fn entry2() void {
+    var e: E = .b;
     switch (e) { // error: switch on non-exhaustive enum must include `else` or `_` prong
         .a => {},
         .b => {},
     }
+}
+pub export fn entry3() void {
     var u = U{.a = 2};
     switch (u) { // error: `_` prong not allowed when switching on tagged union
         .a => {},
@@ -26,10 +31,12 @@ pub export fn entry() void {
 }
 
 // error
-// backend=stage1
+// backend=stage2
 // target=native
-// is_test=1
 //
-// tmp.zig:12:5: error: enumeration value 'E.b' not handled in switch
-// tmp.zig:16:5: error: switch on non-exhaustive enum must include `else` or `_` prong
-// tmp.zig:21:5: error: `_` prong not allowed when switching on tagged union
+// :12:5: error: switch must handle all possibilities
+// :12:5: note: unhandled enumeration value: 'b'
+// :1:11: note: enum 'tmp.E' declared here
+// :19:5: error: switch on non-exhaustive enum must include 'else' or '_' prong
+// :26:5: error: '_' prong only allowed when switching on non-exhaustive enums
+// :29:11: note: '_' prong here