Commit 8918cb06fc

Andrew Kelley <andrew@ziglang.org>
2019-12-20 23:48:45
sentinel slicing improvements
* add runtime safety for slicing pointers, arrays, and slices. * slicing without a sentinel value results in non-sentineled slice * improved `std.debug.panic` handling of panic-during-panic
1 parent 26f3c2d
lib/std/debug.zig
@@ -219,7 +219,7 @@ pub fn panic(comptime format: []const u8, args: var) noreturn {
 }
 
 /// TODO multithreaded awareness
-var panicking: u8 = 0; // TODO make this a bool
+var panicking: u8 = 0;
 
 pub fn panicExtra(trace: ?*const builtin.StackTrace, first_trace_addr: ?usize, comptime format: []const u8, args: var) noreturn {
     @setCold(true);
@@ -230,21 +230,25 @@ pub fn panicExtra(trace: ?*const builtin.StackTrace, first_trace_addr: ?usize, c
         resetSegfaultHandler();
     }
 
-    if (@atomicRmw(u8, &panicking, builtin.AtomicRmwOp.Xchg, 1, builtin.AtomicOrder.SeqCst) == 1) {
-        // Panicked during a panic.
-
-        // TODO detect if a different thread caused the panic, because in that case
-        // we would want to return here instead of calling abort, so that the thread
-        // which first called panic can finish printing a stack trace.
-        os.abort();
-    }
-    const stderr = getStderrStream();
-    stderr.print(format ++ "\n", args) catch os.abort();
-    if (trace) |t| {
-        dumpStackTrace(t.*);
+    switch (@atomicRmw(u8, &panicking, .Add, 1, .SeqCst)) {
+        0 => {
+            const stderr = getStderrStream();
+            stderr.print(format ++ "\n", args) catch os.abort();
+            if (trace) |t| {
+                dumpStackTrace(t.*);
+            }
+            dumpCurrentStackTrace(first_trace_addr);
+        },
+        1 => {
+            // TODO detect if a different thread caused the panic, because in that case
+            // we would want to return here instead of calling abort, so that the thread
+            // which first called panic can finish printing a stack trace.
+            warn("Panicked during a panic. Aborting.\n", .{});
+        },
+        else => {
+            // Panicked while printing "Panicked during a panic."
+        },
     }
-    dumpCurrentStackTrace(first_trace_addr);
-
     os.abort();
 }
 
lib/std/mem.zig
@@ -364,11 +364,11 @@ pub fn len(comptime T: type, ptr: [*:0]const T) usize {
 }
 
 pub fn toSliceConst(comptime T: type, ptr: [*:0]const T) [:0]const T {
-    return ptr[0..len(T, ptr)];
+    return ptr[0..len(T, ptr) :0];
 }
 
 pub fn toSlice(comptime T: type, ptr: [*:0]T) [:0]T {
-    return ptr[0..len(T, ptr)];
+    return ptr[0..len(T, ptr) :0];
 }
 
 /// Returns true if all elements in a slice are equal to the scalar value provided
src/all_types.hpp
@@ -1779,6 +1779,7 @@ enum PanicMsgId {
     PanicMsgIdResumedFnPendingAwait,
     PanicMsgIdBadNoAsyncCall,
     PanicMsgIdResumeNotSuspendedFn,
+    PanicMsgIdBadSentinel,
 
     PanicMsgIdCount,
 };
src/codegen.cpp
@@ -941,6 +941,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("async function called with noasync suspended");
         case PanicMsgIdResumeNotSuspendedFn:
             return buf_create_from_str("resumed a non-suspended function");
+        case PanicMsgIdBadSentinel:
+            return buf_create_from_str("sentinel mismatch");
     }
     zig_unreachable();
 }
@@ -1419,6 +1421,22 @@ static void add_bounds_check(CodeGen *g, LLVMValueRef target_val,
     LLVMPositionBuilderAtEnd(g->builder, ok_block);
 }
 
+static void add_sentinel_check(CodeGen *g, LLVMValueRef sentinel_elem_ptr, ZigValue *sentinel) {
+    LLVMValueRef expected_sentinel = gen_const_val(g, sentinel, "");
+
+    LLVMValueRef actual_sentinel = gen_load_untyped(g, sentinel_elem_ptr, 0, false, "");
+    LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, actual_sentinel, expected_sentinel, "");
+
+    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "SentinelFail");
+    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "SentinelOk");
+    LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, fail_block);
+    gen_safety_crash(g, PanicMsgIdBadSentinel);
+
+    LLVMPositionBuilderAtEnd(g->builder, ok_block);
+}
+
 static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) {
     LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, int_type));
     LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
@@ -5244,6 +5262,9 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst
 
     bool want_runtime_safety = instruction->safety_check_on && ir_want_runtime_safety(g, &instruction->base);
 
+    ZigType *res_slice_ptr_type = instruction->base.value->type->data.structure.fields[slice_ptr_index]->type_entry;
+    ZigValue *sentinel = res_slice_ptr_type->data.pointer.sentinel;
+
     if (array_type->id == ZigTypeIdArray ||
         (array_type->id == ZigTypeIdPointer && array_type->data.pointer.ptr_len == PtrLenSingle))
     {
@@ -5265,6 +5286,15 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst
                 LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_usize->llvm_type,
                         array_type->data.array.len, false);
                 add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end);
+
+                if (sentinel != nullptr) {
+                    LLVMValueRef indices[] = {
+                        LLVMConstNull(g->builtin_types.entry_usize->llvm_type),
+                        end_val,
+                    };
+                    LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
+                    add_sentinel_check(g, sentinel_elem_ptr, sentinel);
+                }
             }
         }
         if (!type_has_bits(array_type)) {
@@ -5297,6 +5327,10 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst
 
         if (want_runtime_safety) {
             add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
+            if (sentinel != nullptr) {
+                LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &end_val, 1, "");
+                add_sentinel_check(g, sentinel_elem_ptr, sentinel);
+            }
         }
 
         if (type_has_bits(array_type)) {
@@ -5337,18 +5371,24 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutable *executable, IrInst
             end_val = prev_end;
         }
 
+        LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)ptr_index, "");
+        LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, "");
+
         if (want_runtime_safety) {
             assert(prev_end);
             add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
             if (instruction->end) {
                 add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end);
+
+                if (sentinel != nullptr) {
+                    LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &end_val, 1, "");
+                    add_sentinel_check(g, sentinel_elem_ptr, sentinel);
+                }
             }
         }
 
-        LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)ptr_index, "");
-        LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, "");
         LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)ptr_index, "");
-        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, (unsigned)len_index, "");
+        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, 1, "");
         gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false);
 
         LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)len_index, "");
src/ir.cpp
@@ -25122,14 +25122,16 @@ static IrInstruction *ir_analyze_instruction_slice(IrAnalyze *ira, IrInstruction
             if (array_type->data.pointer.ptr_len == PtrLenC) {
                 array_type = adjust_ptr_len(ira->codegen, array_type, PtrLenUnknown);
             }
-            non_sentinel_slice_ptr_type = array_type;
+            ZigType *maybe_sentineled_slice_ptr_type = array_type;
+            non_sentinel_slice_ptr_type = adjust_ptr_sentinel(ira->codegen, maybe_sentineled_slice_ptr_type, nullptr);
             if (!end) {
                 ir_add_error(ira, &instruction->base, buf_sprintf("slice of pointer must include end value"));
                 return ira->codegen->invalid_instruction;
             }
         }
     } else if (is_slice(array_type)) {
-        non_sentinel_slice_ptr_type = array_type->data.structure.fields[slice_ptr_index]->type_entry;
+        ZigType *maybe_sentineled_slice_ptr_type = array_type->data.structure.fields[slice_ptr_index]->type_entry;
+        non_sentinel_slice_ptr_type = adjust_ptr_sentinel(ira->codegen, maybe_sentineled_slice_ptr_type, nullptr);
         elem_type = non_sentinel_slice_ptr_type->data.pointer.child_type;
     } else {
         ir_add_error(ira, &instruction->base,
test/runtime_safety.zig
@@ -1,12 +1,57 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("pointer slice sentinel mismatch",
+        \\const std = @import("std");
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    if (std.mem.eql(u8, message, "sentinel mismatch")) {
+        \\        std.process.exit(126); // good
+        \\    }
+        \\    std.process.exit(0); // test failed
+        \\}
+        \\pub fn main() void {
+        \\    var buf: [4]u8 = undefined;
+        \\    const ptr = buf[0..].ptr;
+        \\    const slice = ptr[0..3 :0];
+        \\}
+    );
+
+    cases.addRuntimeSafety("slice slice sentinel mismatch",
+        \\const std = @import("std");
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    if (std.mem.eql(u8, message, "sentinel mismatch")) {
+        \\        std.process.exit(126); // good
+        \\    }
+        \\    std.process.exit(0); // test failed
+        \\}
+        \\pub fn main() void {
+        \\    var buf: [4]u8 = undefined;
+        \\    const slice = buf[0..];
+        \\    const slice2 = slice[0..3 :0];
+        \\}
+    );
+
+    cases.addRuntimeSafety("array slice sentinel mismatch",
+        \\const std = @import("std");
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    if (std.mem.eql(u8, message, "sentinel mismatch")) {
+        \\        std.process.exit(126); // good
+        \\    }
+        \\    std.process.exit(0); // test failed
+        \\}
+        \\pub fn main() void {
+        \\    var buf: [4]u8 = undefined;
+        \\    const slice = buf[0..3 :0];
+        \\}
+    );
+
     cases.addRuntimeSafety("intToPtr with misaligned address",
+        \\const std = @import("std");
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
-        \\    if (@import("std").mem.eql(u8, message, "incorrect alignment")) {
-        \\        @import("std").os.exit(126); // good
+        \\    if (std.mem.eql(u8, message, "incorrect alignment")) {
+        \\        std.os.exit(126); // good
         \\    }
-        \\    @import("std").os.exit(0); // test failed
+        \\    std.os.exit(0); // test failed
         \\}
         \\pub fn main() void {
         \\    var x: usize = 5;