Commit eec2978fac

Veikka Tuominen <git@vexu.eu>
2022-08-05 15:49:06
Sema: better safety check on switch on corrupt value
1 parent 18440cb
src/Sema.zig
@@ -9692,7 +9692,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     }
 
     var final_else_body: []const Air.Inst.Index = &.{};
-    if (special.body.len != 0 or !is_first) {
+    if (special.body.len != 0 or !is_first or case_block.wantSafety()) {
         var wip_captures = try WipCaptureScope.init(gpa, sema.perm_arena, child_block.wip_capture_scope);
         defer wip_captures.deinit();
 
@@ -9715,9 +9715,11 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         } else {
             // We still need a terminator in this block, but we have proven
             // that it is unreachable.
-            // TODO this should be a special safety panic other than unreachable, something
-            // like "panic: switch operand had corrupt value not allowed by the type"
-            try case_block.addUnreachable(src, true);
+            if (case_block.wantSafety()) {
+                _ = try sema.safetyPanic(&case_block, src, .corrupt_switch);
+            } else {
+                _ = try case_block.addNoOp(.unreach);
+            }
         }
 
         try wip_captures.finalize();
@@ -19970,6 +19972,7 @@ pub const PanicId = enum {
     /// TODO make this call `std.builtin.panicInactiveUnionField`.
     inactive_union_field,
     integer_part_out_of_bounds,
+    corrupt_switch,
 };
 
 fn addSafetyCheck(
@@ -20265,6 +20268,7 @@ fn safetyPanic(
         .exact_division_remainder => "exact division produced remainder",
         .inactive_union_field => "access of inactive union field",
         .integer_part_out_of_bounds => "integer part of floating point value out of bounds",
+        .corrupt_switch => "switch on corrupt value",
     };
 
     const msg_inst = msg_inst: {
test/behavior/switch.zig
@@ -531,6 +531,7 @@ test "switch with null and T peer types and inferred result location type" {
 test "switch prongs with cases with identical payload types" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
 
     const Union = union(enum) {
         A: usize,
test/cases/safety/switch on corrupted enum value.zig
@@ -2,7 +2,7 @@ const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
     _ = stack_trace;
-    if (std.mem.eql(u8, message, "reached unreachable code")) {
+    if (std.mem.eql(u8, message, "switch on corrupt value")) {
         std.process.exit(0);
     }
     std.process.exit(1);
@@ -10,17 +10,18 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
 
 const E = enum(u32) {
     X = 1,
+    Y = 2,
 };
 
 pub fn main() !void {
     var e: E = undefined;
     @memset(@ptrCast([*]u8, &e), 0x55, @sizeOf(E));
     switch (e) {
-        .X => @breakpoint(),
+        .X, .Y => @breakpoint(),
     }
     return error.TestFailed;
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native
test/cases/safety/switch on corrupted union value.zig
@@ -2,7 +2,7 @@ const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
     _ = stack_trace;
-    if (std.mem.eql(u8, message, "reached unreachable code")) {
+    if (std.mem.eql(u8, message, "switch on corrupt value")) {
         std.process.exit(0);
     }
     std.process.exit(1);
@@ -10,17 +10,18 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
 
 const U = union(enum(u32)) {
     X: u8,
+    Y: i8,
 };
 
 pub fn main() !void {
     var u: U = undefined;
     @memset(@ptrCast([*]u8, &u), 0x55, @sizeOf(U));
     switch (u) {
-        .X => @breakpoint(),
+        .X, .Y => @breakpoint(),
     }
     return error.TestFailed;
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native