Commit 2315e1b410

Veikka Tuominen <git@vexu.eu>
2022-10-06 16:25:00
safety: add safety check for hitting else branch on a corrupt enum value
Closes #7053
1 parent 29ae651
src/Sema.zig
@@ -9998,6 +9998,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
     }
 
+    const backend_supports_is_named_enum = sema.mod.comp.bin_file.options.use_llvm;
+
     if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) {
         if (empty_enum) {
             return Air.Inst.Ref.void_value;
@@ -10008,6 +10010,12 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand)) {
             return Air.Inst.Ref.unreachable_value;
         }
+        if (backend_supports_is_named_enum and block.wantSafety() and operand_ty.zigTypeTag() == .Enum and
+            (!operand_ty.isNonexhaustiveEnum() or union_originally))
+        {
+            const ok = try block.addUnOp(.is_named_enum_value, operand);
+            try sema.addSafetyCheck(block, ok, .corrupt_switch);
+        }
         return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
     }
 
@@ -10465,6 +10473,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         case_block.wip_capture_scope = wip_captures.scope;
         case_block.inline_case_capture = .none;
 
+        if (backend_supports_is_named_enum and special.body.len != 0 and block.wantSafety() and
+            operand_ty.zigTypeTag() == .Enum and (!operand_ty.isNonexhaustiveEnum() or union_originally))
+        {
+            const ok = try case_block.addUnOp(.is_named_enum_value, operand);
+            try sema.addSafetyCheck(&case_block, ok, .corrupt_switch);
+        }
+
         const analyze_body = if (union_originally and !special.is_inline)
             for (seen_enum_fields) |seen_field, index| {
                 if (seen_field != null) continue;
test/cases/safety/switch else on corrupt enum value - one prong.zig
@@ -0,0 +1,24 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "switch on corrupt value")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+const E = enum(u32) {
+    one = 1,
+    two = 2,
+};
+pub fn main() !void {
+    var a: E = undefined;
+    @ptrCast(*u32, &a).* = 255;
+    switch (a) {
+        .one => @panic("one"),
+        else => @panic("else"),
+    }
+}
+// run
+// backend=llvm
+// target=native
test/cases/safety/switch else on corrupt enum value - union.zig
@@ -0,0 +1,29 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "switch on corrupt value")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+const E = enum(u16) {
+    one = 1,
+    two = 2,
+    _,
+};
+const U = union(E) {
+    one: u16,
+    two: u16,
+};
+pub fn main() !void {
+    var a: U = undefined;
+    @ptrCast(*align(@alignOf(U)) u32, &a).* = 0xFFFF_FFFF;
+    switch (a) {
+        .one => @panic("one"),
+        else => @panic("else"),
+    }
+}
+// run
+// backend=llvm
+// target=native
test/cases/safety/switch else on corrupt enum value.zig
@@ -0,0 +1,23 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "switch on corrupt value")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+const E = enum(u32) {
+    one = 1,
+    two = 2,
+};
+pub fn main() !void {
+    var a: E = undefined;
+    @ptrCast(*u32, &a).* = 255;
+    switch (a) {
+        else => @panic("else"),
+    }
+}
+// run
+// backend=llvm
+// target=native