Commit 7d303ae861

Andrew Kelley <andrew@ziglang.org>
2019-09-06 18:50:51
runtime safety for noasync function calls
See #3157
1 parent 0a3c6db
src/all_types.hpp
@@ -1680,6 +1680,7 @@ enum PanicMsgId {
     PanicMsgIdResumedAnAwaitingFn,
     PanicMsgIdFrameTooSmall,
     PanicMsgIdResumedFnPendingAwait,
+    PanicMsgIdBadNoAsyncCall,
 
     PanicMsgIdCount,
 };
src/codegen.cpp
@@ -923,6 +923,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("frame too small");
         case PanicMsgIdResumedFnPendingAwait:
             return buf_create_from_str("resumed an async function which can only be awaited");
+        case PanicMsgIdBadNoAsyncCall:
+            return buf_create_from_str("async function called with noasync suspended");
     }
     zig_unreachable();
 }
@@ -4067,6 +4069,25 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         } else if (instruction->modifier == CallModifierNoAsync && !fn_is_async(g->cur_fn)) {
             gen_resume(g, fn_val, frame_result_loc, ResumeIdCall);
 
+            if (ir_want_runtime_safety(g, &instruction->base)) {
+                LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
+                        frame_awaiter_index, "");
+                LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
+                LLVMValueRef prev_val = gen_maybe_atomic_op(g, LLVMAtomicRMWBinOpXchg, awaiter_ptr,
+                        all_ones, LLVMAtomicOrderingRelease);
+                LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, prev_val, all_ones, "");
+
+                LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "NoAsyncPanic");
+                LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "NoAsyncOk");
+                LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block);
+
+                // The async function suspended, but this noasync call asserted it wouldn't.
+                LLVMPositionBuilderAtEnd(g->builder, bad_block);
+                gen_safety_crash(g, PanicMsgIdBadNoAsyncCall);
+
+                LLVMPositionBuilderAtEnd(g->builder, ok_block);
+            }
+
             ZigType *result_type = instruction->base.value.type;
             ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true);
             return gen_await_early_return(g, &instruction->base, frame_result_loc,
test/runtime_safety.zig
@@ -1,6 +1,21 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("noasync function call, callee suspends",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    _ = noasync add(101, 100);
+        \\}
+        \\fn add(a: i32, b: i32) i32 {
+        \\    if (a > 100) {
+        \\        suspend;
+        \\    }
+        \\    return a + b;
+        \\}
+    );
+
     cases.addRuntimeSafety("awaiting twice",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);