Commit 14f0b70570

Veikka Tuominen <git@vexu.eu>
2022-08-01 18:33:18
Sema: add safety for sentinel slice
1 parent 292906f
lib/std/builtin.zig
@@ -846,6 +846,17 @@ pub fn default_panic(msg: []const u8, error_return_trace: ?*StackTrace) noreturn
     }
 }
 
+pub fn checkNonScalarSentinel(expected: anytype, actual: @TypeOf(expected)) void {
+    if (!std.meta.eql(expected, actual)) {
+        panicSentinelMismatch(expected, actual);
+    }
+}
+
+pub fn panicSentinelMismatch(expected: anytype, actual: @TypeOf(expected)) noreturn {
+    @setCold(true);
+    std.debug.panic("sentinel mismatch: expected {any}, found {any}", .{ expected, actual });
+}
+
 pub fn panicUnwrapError(st: ?*StackTrace, err: anyerror) noreturn {
     @setCold(true);
     std.debug.panicExtra(st, "attempt to unwrap error: {s}", .{@errorName(err)});
src/Sema.zig
@@ -20148,6 +20148,77 @@ fn panicIndexOutOfBounds(
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
 }
 
+fn panicSentinelMismatch(
+    sema: *Sema,
+    parent_block: *Block,
+    src: LazySrcLoc,
+    maybe_sentinel: ?Value,
+    sentinel_ty: Type,
+    ptr: Air.Inst.Ref,
+    sentinel_index: Air.Inst.Ref,
+) !void {
+    const expected_sentinel_val = maybe_sentinel orelse return;
+    const expected_sentinel = try sema.addConstant(sentinel_ty, expected_sentinel_val);
+
+    const ptr_ty = sema.typeOf(ptr);
+    const actual_sentinel = if (ptr_ty.isSlice())
+        try parent_block.addBinOp(.slice_elem_val, ptr, sentinel_index)
+    else blk: {
+        const elem_ptr_ty = try sema.elemPtrType(ptr_ty, null);
+        const sentinel_ptr = try parent_block.addPtrElemPtr(ptr, sentinel_index, elem_ptr_ty);
+        break :blk try parent_block.addTyOp(.load, sentinel_ty, sentinel_ptr);
+    };
+
+    const ok = if (sentinel_ty.zigTypeTag() == .Vector) ok: {
+        const eql =
+            try parent_block.addCmpVector(expected_sentinel, actual_sentinel, .eq, try sema.addType(sentinel_ty));
+        break :ok try parent_block.addInst(.{
+            .tag = .reduce,
+            .data = .{ .reduce = .{
+                .operand = eql,
+                .operation = .And,
+            } },
+        });
+    } else if (sentinel_ty.isSelfComparable(true))
+        try parent_block.addBinOp(.cmp_eq, expected_sentinel, actual_sentinel)
+    else {
+        const panic_fn = try sema.getBuiltin(parent_block, src, "checkNonScalarSentinel");
+        const args: [2]Air.Inst.Ref = .{ expected_sentinel, actual_sentinel };
+        _ = try sema.analyzeCall(parent_block, panic_fn, src, src, .auto, false, &args, null);
+        return;
+    };
+    const gpa = sema.gpa;
+
+    var fail_block: Block = .{
+        .parent = parent_block,
+        .sema = sema,
+        .src_decl = parent_block.src_decl,
+        .namespace = parent_block.namespace,
+        .wip_capture_scope = parent_block.wip_capture_scope,
+        .instructions = .{},
+        .inlining = parent_block.inlining,
+        .is_comptime = parent_block.is_comptime,
+    };
+
+    defer fail_block.instructions.deinit(gpa);
+
+    {
+        const this_feature_is_implemented_in_the_backend =
+            sema.mod.comp.bin_file.options.use_llvm;
+
+        if (!this_feature_is_implemented_in_the_backend) {
+            // TODO implement this feature in all the backends and then delete this branch
+            _ = try fail_block.addNoOp(.breakpoint);
+            _ = try fail_block.addNoOp(.unreach);
+        } else {
+            const panic_fn = try sema.getBuiltin(&fail_block, src, "panicSentinelMismatch");
+            const args: [2]Air.Inst.Ref = .{ expected_sentinel, actual_sentinel };
+            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, &args, null);
+        }
+    }
+    try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
+}
+
 fn safetyPanic(
     sema: *Sema,
     block: *Block,
@@ -25368,6 +25439,7 @@ fn analyzeSlice(
         }
         break :s null;
     };
+    const slice_sentinel = if (sentinel_opt != .none) sentinel else null;
 
     // requirement: start <= end
     if (try sema.resolveDefinedValue(block, end_src, end)) |end_val| {
@@ -25447,7 +25519,12 @@ fn analyzeSlice(
 
         const opt_new_ptr_val = try sema.resolveMaybeUndefVal(block, ptr_src, new_ptr);
         const new_ptr_val = opt_new_ptr_val orelse {
-            return block.addBitCast(return_ty, new_ptr);
+            const result = try block.addBitCast(return_ty, new_ptr);
+            if (block.wantSafety()) {
+                // requirement: result[new_len] == slice_sentinel
+                try sema.panicSentinelMismatch(block, src, slice_sentinel, elem_ty, result, new_len);
+            }
+            return result;
         };
 
         if (!new_ptr_val.isUndef()) {
@@ -25511,7 +25588,7 @@ fn analyzeSlice(
         // requirement: start <= end
         try sema.panicIndexOutOfBounds(block, src, start, end, .cmp_lte);
     }
-    return block.addInst(.{
+    const result = try block.addInst(.{
         .tag = .slice,
         .data = .{ .ty_pl = .{
             .ty = try sema.addType(return_ty),
@@ -25521,6 +25598,11 @@ fn analyzeSlice(
             }),
         } },
     });
+    if (block.wantSafety()) {
+        // requirement: result[new_len] == slice_sentinel
+        try sema.panicSentinelMismatch(block, src, slice_sentinel, elem_ty, result, new_len);
+    }
+    return result;
 }
 
 /// Asserts that lhs and rhs types are both numeric.
test/cases/compile_errors/reify_type_for_tagged_union_with_extra_union_field.zig
@@ -31,5 +31,5 @@ export fn entry() void {
 // backend=stage2
 // target=native
 //
-// :13:16: error: no field named 'arst' in enum 'tmp.Tag__enum_264'
+// :13:16: error: no field named 'arst' in enum 'tmp.Tag__enum_266'
 // :1:13: note: enum declared here
test/cases/safety/array slice sentinel mismatch non-scalar.zig
@@ -0,0 +1,21 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "sentinel mismatch: expected tmp.main.S{ .a = 1 }, found tmp.main.S{ .a = 2 }")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+
+pub fn main() !void {
+    const S = struct { a: u32 };
+    var arr = [_]S{ .{ .a = 1 }, .{ .a = 2 } };
+    var s = arr[0..1 :.{ .a = 1 }];
+    _ = s;
+    return error.TestFailed;
+}
+
+// run
+// backend=llvm
+// target=native
test/cases/safety/array slice sentinel mismatch vector.zig
@@ -0,0 +1,19 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "sentinel mismatch: expected { 0, 0 }, found { 4, 4 }")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+
+pub fn main() !void {
+    var buf: [4]@Vector(2, u32) = .{ .{ 1, 1 }, .{ 2, 2 }, .{ 3, 3 }, .{ 4, 4 } };
+    const slice = buf[0..3 :.{ 0, 0 }];
+    _ = slice;
+    return error.TestFailed;
+}
+// run
+// backend=llvm
+// target=native
test/cases/safety/array slice sentinel mismatch.zig
@@ -2,17 +2,18 @@ const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
     _ = stack_trace;
-    if (std.mem.eql(u8, message, "sentinel mismatch")) {
+    if (std.mem.eql(u8, message, "sentinel mismatch: expected 0, found 4")) {
         std.process.exit(0);
     }
     std.process.exit(1);
 }
+
 pub fn main() !void {
-    var buf: [4]u8 = undefined;
+    var buf: [4]u8 = .{ 1, 2, 3, 4 };
     const slice = buf[0..3 :0];
     _ = slice;
     return error.TestFailed;
 }
 // run
-// backend=stage1
+// backend=llvm
 // target=native