Commit 55f5cee86b

Andrew Kelley <andrew@ziglang.org>
2019-08-15 21:06:05
fix error return traces for async calls of blocking functions
1 parent 13b5a4b
Changed files (5)
src/analyze.cpp
@@ -3819,6 +3819,11 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
     } else if (fn->inferred_async_node->type == NodeTypeAwaitExpr) {
         add_error_note(g, msg, fn->inferred_async_node,
             buf_sprintf("await is a suspend point"));
+    } else if (fn->inferred_async_node->type == NodeTypeFnCallExpr &&
+        fn->inferred_async_node->data.fn_call_expr.is_builtin)
+    {
+        add_error_note(g, msg, fn->inferred_async_node,
+            buf_sprintf("@frame() causes function to be async"));
     } else {
         add_error_note(g, msg, fn->inferred_async_node,
             buf_sprintf("suspends here"));
src/codegen.cpp
@@ -3760,6 +3760,23 @@ static LLVMValueRef gen_frame_size(CodeGen *g, LLVMValueRef fn_val) {
     return LLVMBuildLoad(g->builder, prefix_ptr, "");
 }
 
+static void gen_init_stack_trace(CodeGen *g, LLVMValueRef trace_field_ptr, LLVMValueRef addrs_field_ptr) {
+    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
+    LLVMValueRef zero = LLVMConstNull(usize_type_ref);
+
+    LLVMValueRef index_ptr = LLVMBuildStructGEP(g->builder, trace_field_ptr, 0, "");
+    LLVMBuildStore(g->builder, zero, index_ptr);
+
+    LLVMValueRef addrs_slice_ptr = LLVMBuildStructGEP(g->builder, trace_field_ptr, 1, "");
+    LLVMValueRef addrs_ptr_ptr = LLVMBuildStructGEP(g->builder, addrs_slice_ptr, slice_ptr_index, "");
+    LLVMValueRef indices[] = { LLVMConstNull(usize_type_ref), LLVMConstNull(usize_type_ref) };
+    LLVMValueRef trace_field_addrs_as_ptr = LLVMBuildInBoundsGEP(g->builder, addrs_field_ptr, indices, 2, "");
+    LLVMBuildStore(g->builder, trace_field_addrs_as_ptr, addrs_ptr_ptr);
+
+    LLVMValueRef addrs_len_ptr = LLVMBuildStructGEP(g->builder, addrs_slice_ptr, slice_len_index, "");
+    LLVMBuildStore(g->builder, LLVMConstInt(usize_type_ref, stack_trace_ptr_count, false), addrs_len_ptr);
+}
+
 static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCallGen *instruction) {
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
 
@@ -3900,9 +3917,24 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
             if (first_arg_ret) {
                 gen_param_values.append(ret_ptr);
             }
-        }
-        if (prefix_arg_err_ret_stack) {
-            gen_param_values.append(get_cur_err_ret_trace_val(g, instruction->base.scope));
+            if (prefix_arg_err_ret_stack) {
+                // Set up the callee stack trace pointer pointing into the frame.
+                // Then we have to wire up the StackTrace pointers.
+                // Await is responsible for merging error return traces.
+                uint32_t trace_field_index_start = frame_index_trace_arg(g, src_return_type);
+                LLVMValueRef callee_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
+                        trace_field_index_start, "");
+                LLVMValueRef trace_field_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
+                        trace_field_index_start + 2, "");
+                LLVMValueRef addrs_field_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
+                        trace_field_index_start + 3, "");
+
+                LLVMBuildStore(g->builder, trace_field_ptr, callee_trace_ptr_ptr);
+
+                gen_init_stack_trace(g, trace_field_ptr, addrs_field_ptr);
+
+                gen_param_values.append(get_cur_err_ret_trace_val(g, instruction->base.scope));
+            }
         }
     } else {
         if (first_arg_ret) {
@@ -7126,20 +7158,10 @@ static void do_code_gen(CodeGen *g) {
 
                 LLVMValueRef trace_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
                         trace_field_index_stack, "");
-                LLVMValueRef trace_field_addrs = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
+                LLVMValueRef addrs_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
                         trace_field_index_stack + 1, "");
 
-                LLVMValueRef index_ptr = LLVMBuildStructGEP(g->builder, trace_field_ptr, 0, "");
-                LLVMBuildStore(g->builder, zero, index_ptr);
-
-                LLVMValueRef addrs_slice_ptr = LLVMBuildStructGEP(g->builder, trace_field_ptr, 1, "");
-                LLVMValueRef addrs_ptr_ptr = LLVMBuildStructGEP(g->builder, addrs_slice_ptr, slice_ptr_index, "");
-                LLVMValueRef indices[] = { LLVMConstNull(usize_type_ref), LLVMConstNull(usize_type_ref) };
-                LLVMValueRef trace_field_addrs_as_ptr = LLVMBuildInBoundsGEP(g->builder, trace_field_addrs, indices, 2, "");
-                LLVMBuildStore(g->builder, trace_field_addrs_as_ptr, addrs_ptr_ptr);
-
-                LLVMValueRef addrs_len_ptr = LLVMBuildStructGEP(g->builder, addrs_slice_ptr, slice_len_index, "");
-                LLVMBuildStore(g->builder, LLVMConstInt(usize_type_ref, stack_trace_ptr_count, false), addrs_len_ptr);
+                gen_init_stack_trace(g, trace_field_ptr, addrs_field_ptr);
             }
             render_async_var_decls(g, entry_block->instruction_list.at(0)->scope);
         } else {
src/ir.cpp
@@ -22078,6 +22078,10 @@ static IrInstruction *ir_analyze_instruction_frame_handle(IrAnalyze *ira, IrInst
     ZigFn *fn = exec_fn_entry(ira->new_irb.exec);
     ir_assert(fn != nullptr, &instruction->base);
 
+    if (fn->inferred_async_node == nullptr) {
+        fn->inferred_async_node = instruction->base.source_node;
+    }
+
     ZigType *frame_type = get_fn_frame_type(ira->codegen, fn);
     ZigType *ptr_frame_type = get_pointer_to_type(ira->codegen, frame_type, false);
 
test/stage1/behavior/async_fn.zig
@@ -634,17 +634,30 @@ test "returning a const error from async function" {
 test "async/await typical usage" {
     inline for ([_]bool{false, true}) |b1| {
         inline for ([_]bool{false, true}) |b2| {
-            testAsyncAwaitTypicalUsage(b1, b2).doTheTest();
+            inline for ([_]bool{false, true}) |b3| {
+                inline for ([_]bool{false, true}) |b4| {
+                    testAsyncAwaitTypicalUsage(b1, b2, b3, b4).doTheTest();
+                }
+            }
         }
     }
 }
 
-fn testAsyncAwaitTypicalUsage(comptime simulate_fail_download: bool, comptime simulate_fail_file: bool) type {
+fn testAsyncAwaitTypicalUsage(
+    comptime simulate_fail_download: bool,
+    comptime simulate_fail_file: bool,
+    comptime suspend_download: bool,
+    comptime suspend_file: bool) type
+{
     return struct {
         fn doTheTest() void {
             _ = async amainWrap();
-            resume global_file_frame;
-            resume global_download_frame;
+            if (suspend_file) {
+                resume global_file_frame;
+            }
+            if (suspend_download) {
+                resume global_download_frame;
+            }
         }
         fn amainWrap() void {
             if (amain()) |_| {
@@ -685,20 +698,26 @@ fn testAsyncAwaitTypicalUsage(comptime simulate_fail_download: bool, comptime si
 
         var global_download_frame: anyframe = undefined;
         fn fetchUrl(allocator: *std.mem.Allocator, url: []const u8) anyerror![]u8 {
-            global_download_frame = @frame();
             const result = try std.mem.dupe(allocator, u8, "expected download text");
             errdefer allocator.free(result);
-            suspend;
+            if (suspend_download) {
+                suspend {
+                    global_download_frame = @frame();
+                }
+            }
             if (simulate_fail_download) return error.NoResponse;
             return result;
         }
 
         var global_file_frame: anyframe = undefined;
         fn readFile(allocator: *std.mem.Allocator, filename: []const u8) anyerror![]u8 {
-            global_file_frame = @frame();
             const result = try std.mem.dupe(allocator, u8, "expected file text");
             errdefer allocator.free(result);
-            suspend;
+            if (suspend_file) {
+                suspend {
+                    global_file_frame = @frame();
+                }
+            }
             if (simulate_fail_file) return error.FileNotFound;
             return result;
         }
test/compile_errors.zig
@@ -2,6 +2,18 @@ const tests = @import("tests.zig");
 const builtin = @import("builtin");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.add(
+        "@frame() causes function to be async",
+        \\export fn entry() void {
+        \\    func();
+        \\}
+        \\fn func() void {
+        \\    _ = @frame();
+        \\}
+    ,
+        "tmp.zig:1:1: error: function with calling convention 'ccc' cannot be async",
+        "tmp.zig:5:9: note: @frame() causes function to be async",
+    );
     cases.add(
         "invalid suspend in exported function",
         \\export fn entry() void {