Commit e98e5dda52

Andrew Kelley <andrew@ziglang.org>
2019-10-23 05:43:27
implement safety for resuming non-suspended function
closes #3469
1 parent 1dcf540
src/all_types.hpp
@@ -1715,6 +1715,7 @@ enum PanicMsgId {
     PanicMsgIdFrameTooSmall,
     PanicMsgIdResumedFnPendingAwait,
     PanicMsgIdBadNoAsyncCall,
+    PanicMsgIdResumeNotSuspendedFn,
 
     PanicMsgIdCount,
 };
@@ -1886,6 +1887,7 @@ struct CodeGen {
     size_t cur_resume_block_count;
     LLVMValueRef cur_err_ret_trace_val_arg;
     LLVMValueRef cur_err_ret_trace_val_stack;
+    LLVMValueRef cur_bad_not_suspended_index;
     LLVMValueRef memcpy_fn_val;
     LLVMValueRef memset_fn_val;
     LLVMValueRef trap_fn_val;
src/codegen.cpp
@@ -933,6 +933,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("resumed an async function which can only be awaited");
         case PanicMsgIdBadNoAsyncCall:
             return buf_create_from_str("async function called with noasync suspended");
+        case PanicMsgIdResumeNotSuspendedFn:
+            return buf_create_from_str("resumed a non-suspended function");
     }
     zig_unreachable();
 }
@@ -2234,6 +2236,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume
         LLVMBasicBlockRef end_bb)
 {
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
+
+    if (ir_want_runtime_safety(g, source_instr)) {
+        // Write a value to the resume index which indicates the function was resumed while not suspended.
+        LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
+    }
+
     LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadResume");
     if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume");
     LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
@@ -5764,6 +5772,9 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl
     LLVMBuildRetVoid(g->builder);
 
     LLVMPositionBuilderAtEnd(g->builder, instruction->begin->resume_bb);
+    if (ir_want_runtime_safety(g, &instruction->base)) {
+        LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
+    }
     render_async_var_decls(g, instruction->base.scope);
     return nullptr;
 }
@@ -7542,7 +7553,20 @@ static void do_code_gen(CodeGen *g) {
             IrBasicBlock *entry_block = executable->basic_block_list.at(0);
             LLVMAddCase(switch_instr, zero, entry_block->llvm_block);
             g->cur_resume_block_count += 1;
+
+            {
+                LLVMBasicBlockRef bad_not_suspended_bb = LLVMAppendBasicBlock(g->cur_fn_val, "NotSuspended");
+                size_t new_block_index = g->cur_resume_block_count;
+                g->cur_resume_block_count += 1;
+                g->cur_bad_not_suspended_index = LLVMConstInt(usize_type_ref, new_block_index, false);
+                LLVMAddCase(g->cur_async_switch_instr, g->cur_bad_not_suspended_index, bad_not_suspended_bb);
+
+                LLVMPositionBuilderAtEnd(g->builder, bad_not_suspended_bb);
+                gen_assertion_scope(g, PanicMsgIdResumeNotSuspendedFn, fn_table_entry->child_scope);
+            }
+
             LLVMPositionBuilderAtEnd(g->builder, entry_block->llvm_block);
+            LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
             if (trace_field_index_stack != UINT32_MAX) {
                 if (codegen_fn_has_err_ret_tracing_arg(g, fn_type_id->return_type)) {
                     LLVMValueRef trace_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
test/runtime_safety.zig
@@ -1,6 +1,54 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("resuming a non-suspended function which never been suspended",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\fn foo() void {
+        \\    var f = async bar(@frame());
+        \\    @import("std").os.exit(0);
+        \\}
+        \\
+        \\fn bar(frame: anyframe) void {
+        \\    suspend {
+        \\        resume frame;
+        \\    }
+        \\    @import("std").os.exit(0);
+        \\}
+        \\
+        \\pub fn main() void {
+        \\    _ = async foo();
+        \\}
+    );
+
+    cases.addRuntimeSafety("resuming a non-suspended function which has been suspended and resumed",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\fn foo() void {
+        \\    suspend {
+        \\        global_frame = @frame();
+        \\    }
+        \\    var f = async bar(@frame());
+        \\    @import("std").os.exit(0);
+        \\}
+        \\
+        \\fn bar(frame: anyframe) void {
+        \\    suspend {
+        \\        resume frame;
+        \\    }
+        \\    @import("std").os.exit(0);
+        \\}
+        \\
+        \\var global_frame: anyframe = undefined;
+        \\pub fn main() void {
+        \\    _ = async foo();
+        \\    resume global_frame;
+        \\    @import("std").os.exit(0);
+        \\}
+    );
+
     cases.addRuntimeSafety("noasync function call, callee suspends",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);