Commit e444e737b7

Andrew Kelley <andrew@ziglang.org>
2019-08-03 08:11:52
add runtime safety for resuming an awaiting function
1 parent 24d7817
src/all_types.hpp
@@ -1552,6 +1552,7 @@ enum PanicMsgId {
     PanicMsgIdBadResume,
     PanicMsgIdBadAwait,
     PanicMsgIdBadReturn,
+    PanicMsgIdResumedAnAwaitingFn,
 
     PanicMsgIdCount,
 };
src/codegen.cpp
@@ -877,6 +877,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("async function awaited twice");
         case PanicMsgIdBadReturn:
             return buf_create_from_str("async function returned twice");
+        case PanicMsgIdResumedAnAwaitingFn:
+            return buf_create_from_str("awaiting function resumed");
     }
     zig_unreachable();
 }
@@ -2018,7 +2020,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns
             }
             result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, "");
         } else {
-            result_ptr_as_usize = LLVMGetUndef(usize_type_ref);
+            // For debug safety, this value has to be anything other than all 1's, which signals
+            // that it is being resumed. 0 is a bad choice since null pointers are special.
+            result_ptr_as_usize = ir_want_runtime_safety(g, &return_instruction->base) ?
+                LLVMConstInt(usize_type_ref, 1, false) : LLVMGetUndef(usize_type_ref);
         }
         LLVMValueRef zero = LLVMConstNull(usize_type_ref);
         LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
@@ -3582,8 +3587,9 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
             LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr);
         }
     }
+    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
     if (instruction->is_async) {
-        LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
+        LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
         ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
         return nullptr;
     } else if (callee_is_async) {
@@ -3591,8 +3597,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn);
         LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, "");
         LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr);
-
-        LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)};
+        LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(usize_type_ref)};
         LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, "");
         ZigLLVMSetTailCall(call_inst);
         LLVMBuildRetVoid(g->builder);
@@ -3601,6 +3606,21 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
         LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume");
         LLVMPositionBuilderAtEnd(g->builder, call_bb);
+
+        if (ir_want_runtime_safety(g, &instruction->base)) {
+            LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
+            LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
+            LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
+            LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
+            LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
+            LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
+
+            LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
+            gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
+
+            LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
+        }
+
         render_async_var_decls(g, instruction->base.scope);
 
         if (type_has_bits(src_return_type)) {
@@ -5139,6 +5159,21 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
     LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "AwaitResume");
     LLVMPositionBuilderAtEnd(g->builder, call_bb);
+
+    if (ir_want_runtime_safety(g, &instruction->base)) {
+        LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "BadResume");
+        LLVMBasicBlockRef ok_resume_block = LLVMAppendBasicBlock(split_llvm_fn, "OkResume");
+        LLVMValueRef arg_val = LLVMGetParam(split_llvm_fn, 1);
+        LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
+        LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntNE, arg_val, all_ones, "");
+        LLVMBuildCondBr(g->builder, ok_bit, ok_resume_block, bad_resume_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, bad_resume_block);
+        gen_safety_crash(g, PanicMsgIdResumedAnAwaitingFn);
+
+        LLVMPositionBuilderAtEnd(g->builder, ok_resume_block);
+    }
+
     render_async_var_decls(g, instruction->base.scope);
 
     if (type_has_bits(result_type)) {
@@ -5178,7 +5213,9 @@ static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
     LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
     LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
     LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, anyframe_fn_type(g), "");
-    LLVMValueRef args[] = {frame, LLVMGetUndef(usize_type_ref)};
+    LLVMValueRef arg_val = ir_want_runtime_safety(g, &instruction->base) ?
+        LLVMConstAllOnes(usize_type_ref) : LLVMGetUndef(usize_type_ref);
+    LLVMValueRef args[] = {frame, arg_val};
     ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
     return nullptr;
 }
test/stage1/behavior/coroutines.zig
@@ -89,7 +89,7 @@ test "calling an inferred async function" {
         var other_frame: *@Frame(other) = undefined;
 
         fn doTheTest() void {
-            const p = async first();
+            _ = async first();
             expect(x == 1);
             resume other_frame.*;
             expect(x == 2);
test/runtime_safety.zig
@@ -1,6 +1,38 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("resuming a function which is awaiting a frame",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var frame = async first();
+        \\    resume frame;
+        \\}
+        \\fn first() void {
+        \\    var frame = async other();
+        \\    await frame;
+        \\}
+        \\fn other() void {
+        \\    suspend;
+        \\}
+    );
+    cases.addRuntimeSafety("resuming a function which is awaiting a call",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var frame = async first();
+        \\    resume frame;
+        \\}
+        \\fn first() void {
+        \\    other();
+        \\}
+        \\fn other() void {
+        \\    suspend;
+        \\}
+    );
+
     cases.addRuntimeSafety("invalid resume of async function",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);
BRANCH_TODO
@@ -6,7 +6,6 @@
  * @asyncCall with an async function pointer
  * cancel
  * defer and errdefer
- * safety for resuming when it is awaiting
  * safety for double await
  * implicit cast of normal function to async function should be allowed when it is inferred to be async
  * go over the commented out tests
@@ -19,3 +18,6 @@
  * make sure there are safety tests for all the new safety features (search the new PanicFnId enum values)
  * error return tracing
  * compile error for casting a function to a non-async function pointer, but then later it gets inferred to be an async function
+ * compile error for copying a frame
+ * compile error for resuming a const frame pointer
+ * runtime safety enabling/disabling scope has to be coordinated across resume/await/calls/return