Commit 966c9ea63c

Andrew Kelley <andrew@ziglang.org>
2019-08-07 00:47:09
error return trace across suspend points
1 parent 17199b0
src/all_types.hpp
@@ -1737,6 +1737,7 @@ struct CodeGen {
     LLVMValueRef stacksave_fn_val;
     LLVMValueRef stackrestore_fn_val;
     LLVMValueRef write_register_fn_val;
+    LLVMValueRef merge_err_ret_traces_fn_val;
     LLVMValueRef sp_md_node;
     LLVMValueRef err_name_table;
     LLVMValueRef safety_crash_err_fn;
src/codegen.cpp
@@ -2026,18 +2026,159 @@ void walk_function_params(CodeGen *g, ZigType *fn_type, FnWalk *fn_walk) {
     }
 }
 
+static LLVMValueRef get_merge_err_ret_traces_fn_val(CodeGen *g) {
+    if (g->merge_err_ret_traces_fn_val)
+        return g->merge_err_ret_traces_fn_val;
+
+    assert(g->stack_trace_type != nullptr);
+
+    LLVMTypeRef param_types[] = {
+        get_llvm_type(g, get_ptr_to_stack_trace_type(g)),
+        get_llvm_type(g, get_ptr_to_stack_trace_type(g)),
+    };
+    LLVMTypeRef fn_type_ref = LLVMFunctionType(LLVMVoidType(), param_types, 2, false);
+
+    Buf *fn_name = get_mangled_name(g, buf_create_from_str("__zig_merge_error_return_traces"), false);
+    LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(fn_name), fn_type_ref);
+    LLVMSetLinkage(fn_val, LLVMInternalLinkage);
+    LLVMSetFunctionCallConv(fn_val, get_llvm_cc(g, CallingConventionUnspecified));
+    addLLVMFnAttr(fn_val, "nounwind");
+    add_uwtable_attr(g, fn_val);
+    // Error return trace memory is in the stack, which is impossible to be at address 0
+    // on any architecture.
+    addLLVMArgAttr(fn_val, (unsigned)0, "nonnull");
+    addLLVMArgAttr(fn_val, (unsigned)0, "noalias");
+    addLLVMArgAttr(fn_val, (unsigned)0, "writeonly");
+    // Error return trace memory is in the stack, which is impossible to be at address 0
+    // on any architecture.
+    addLLVMArgAttr(fn_val, (unsigned)1, "nonnull");
+    addLLVMArgAttr(fn_val, (unsigned)1, "noalias");
+    addLLVMArgAttr(fn_val, (unsigned)1, "readonly");
+    if (g->build_mode == BuildModeDebug) {
+        ZigLLVMAddFunctionAttr(fn_val, "no-frame-pointer-elim", "true");
+        ZigLLVMAddFunctionAttr(fn_val, "no-frame-pointer-elim-non-leaf", nullptr);
+    }
+
+    // this is above the ZigLLVMClearCurrentDebugLocation
+    LLVMValueRef add_error_return_trace_addr_fn_val = get_add_error_return_trace_addr_fn(g);
+
+    LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn_val, "Entry");
+    LLVMBasicBlockRef prev_block = LLVMGetInsertBlock(g->builder);
+    LLVMValueRef prev_debug_location = LLVMGetCurrentDebugLocation(g->builder);
+    LLVMPositionBuilderAtEnd(g->builder, entry_block);
+    ZigLLVMClearCurrentDebugLocation(g->builder);
+
+    // var frame_index: usize = undefined;
+    // var frames_left: usize = undefined;
+    // if (src_stack_trace.index < src_stack_trace.instruction_addresses.len) {
+    //     frame_index = 0;
+    //     frames_left = src_stack_trace.index;
+    //     if (frames_left == 0) return;
+    // } else {
+    //     frame_index = (src_stack_trace.index + 1) % src_stack_trace.instruction_addresses.len;
+    //     frames_left = src_stack_trace.instruction_addresses.len;
+    // }
+    // while (true) {
+    //     __zig_add_err_ret_trace_addr(dest_stack_trace, src_stack_trace.instruction_addresses[frame_index]);
+    //     frames_left -= 1;
+    //     if (frames_left == 0) return;
+    //     frame_index = (frame_index + 1) % src_stack_trace.instruction_addresses.len;
+    // }
+    LLVMBasicBlockRef return_block = LLVMAppendBasicBlock(fn_val, "Return");
+
+    LLVMValueRef frame_index_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frame_index");
+    LLVMValueRef frames_left_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frames_left");
+
+    LLVMValueRef dest_stack_trace_ptr = LLVMGetParam(fn_val, 0);
+    LLVMValueRef src_stack_trace_ptr = LLVMGetParam(fn_val, 1);
+
+    size_t src_index_field_index = g->stack_trace_type->data.structure.fields[0].gen_index;
+    size_t src_addresses_field_index = g->stack_trace_type->data.structure.fields[1].gen_index;
+    LLVMValueRef src_index_field_ptr = LLVMBuildStructGEP(g->builder, src_stack_trace_ptr,
+            (unsigned)src_index_field_index, "");
+    LLVMValueRef src_addresses_field_ptr = LLVMBuildStructGEP(g->builder, src_stack_trace_ptr,
+            (unsigned)src_addresses_field_index, "");
+    ZigType *slice_type = g->stack_trace_type->data.structure.fields[1].type_entry;
+    size_t ptr_field_index = slice_type->data.structure.fields[slice_ptr_index].gen_index;
+    LLVMValueRef src_ptr_field_ptr = LLVMBuildStructGEP(g->builder, src_addresses_field_ptr, (unsigned)ptr_field_index, "");
+    size_t len_field_index = slice_type->data.structure.fields[slice_len_index].gen_index;
+    LLVMValueRef src_len_field_ptr = LLVMBuildStructGEP(g->builder, src_addresses_field_ptr, (unsigned)len_field_index, "");
+    LLVMValueRef src_index_val = LLVMBuildLoad(g->builder, src_index_field_ptr, "");
+    LLVMValueRef src_ptr_val = LLVMBuildLoad(g->builder, src_ptr_field_ptr, "");
+    LLVMValueRef src_len_val = LLVMBuildLoad(g->builder, src_len_field_ptr, "");
+    LLVMValueRef no_wrap_bit = LLVMBuildICmp(g->builder, LLVMIntULT, src_index_val, src_len_val, "");
+    LLVMBasicBlockRef no_wrap_block = LLVMAppendBasicBlock(fn_val, "NoWrap");
+    LLVMBasicBlockRef yes_wrap_block = LLVMAppendBasicBlock(fn_val, "YesWrap");
+    LLVMBasicBlockRef loop_block = LLVMAppendBasicBlock(fn_val, "Loop");
+    LLVMBuildCondBr(g->builder, no_wrap_bit, no_wrap_block, yes_wrap_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, no_wrap_block);
+    LLVMValueRef usize_zero = LLVMConstNull(g->builtin_types.entry_usize->llvm_type);
+    LLVMBuildStore(g->builder, usize_zero, frame_index_ptr);
+    LLVMBuildStore(g->builder, src_index_val, frames_left_ptr);
+    LLVMValueRef frames_left_eq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, src_index_val, usize_zero, "");
+    LLVMBuildCondBr(g->builder, frames_left_eq_zero_bit, return_block, loop_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, yes_wrap_block);
+    LLVMValueRef usize_one = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, 1, false);
+    LLVMValueRef plus_one = LLVMBuildNUWAdd(g->builder, src_index_val, usize_one, "");
+    LLVMValueRef mod_len = LLVMBuildURem(g->builder, plus_one, src_len_val, "");
+    LLVMBuildStore(g->builder, mod_len, frame_index_ptr);
+    LLVMBuildStore(g->builder, src_len_val, frames_left_ptr);
+    LLVMBuildBr(g->builder, loop_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, loop_block);
+    LLVMValueRef ptr_index = LLVMBuildLoad(g->builder, frame_index_ptr, "");
+    LLVMValueRef addr_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr_val, &ptr_index, 1, "");
+    LLVMValueRef this_addr_val = LLVMBuildLoad(g->builder, addr_ptr, "");
+    LLVMValueRef args[] = {dest_stack_trace_ptr, this_addr_val};
+    ZigLLVMBuildCall(g->builder, add_error_return_trace_addr_fn_val, args, 2, get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAlways, "");
+    LLVMValueRef prev_frames_left = LLVMBuildLoad(g->builder, frames_left_ptr, "");
+    LLVMValueRef new_frames_left = LLVMBuildNUWSub(g->builder, prev_frames_left, usize_one, "");
+    LLVMValueRef done_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, new_frames_left, usize_zero, "");
+    LLVMBasicBlockRef continue_block = LLVMAppendBasicBlock(fn_val, "Continue");
+    LLVMBuildCondBr(g->builder, done_bit, return_block, continue_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, return_block);
+    LLVMBuildRetVoid(g->builder);
+
+    LLVMPositionBuilderAtEnd(g->builder, continue_block);
+    LLVMBuildStore(g->builder, new_frames_left, frames_left_ptr);
+    LLVMValueRef prev_index = LLVMBuildLoad(g->builder, frame_index_ptr, "");
+    LLVMValueRef index_plus_one = LLVMBuildNUWAdd(g->builder, prev_index, usize_one, "");
+    LLVMValueRef index_mod_len = LLVMBuildURem(g->builder, index_plus_one, src_len_val, "");
+    LLVMBuildStore(g->builder, index_mod_len, frame_index_ptr);
+    LLVMBuildBr(g->builder, loop_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, prev_block);
+    if (!g->strip_debug_symbols) {
+        LLVMSetCurrentDebugLocation(g->builder, prev_debug_location);
+    }
+
+    g->merge_err_ret_traces_fn_val = fn_val;
+    return fn_val;
+
+}
 static LLVMValueRef ir_render_save_err_ret_addr(CodeGen *g, IrExecutable *executable,
         IrInstructionSaveErrRetAddr *save_err_ret_addr_instruction)
 {
     assert(g->have_err_ret_tracing);
 
     LLVMValueRef return_err_fn = get_return_err_fn(g);
-    LLVMValueRef args[] = {
-        get_cur_err_ret_trace_val(g, save_err_ret_addr_instruction->base.scope),
-    };
-    LLVMValueRef call_instruction = ZigLLVMBuildCall(g->builder, return_err_fn, args, 1,
+    LLVMValueRef my_err_trace_val = get_cur_err_ret_trace_val(g, save_err_ret_addr_instruction->base.scope);
+    ZigLLVMBuildCall(g->builder, return_err_fn, &my_err_trace_val, 1,
             get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
-    return call_instruction;
+
+    if (fn_is_async(g->cur_fn) && g->cur_fn->calls_or_awaits_errorable_fn &&
+        codegen_fn_has_err_ret_tracing_arg(g, g->cur_fn->type_entry->data.fn.fn_type_id.return_type))
+    {
+        LLVMValueRef dest_trace_ptr = LLVMBuildLoad(g->builder, g->cur_err_ret_trace_val_arg, "");
+        LLVMValueRef args[] = { dest_trace_ptr, my_err_trace_val };
+        ZigLLVMBuildCall(g->builder, get_merge_err_ret_traces_fn_val(g), args, 2,
+                get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
+    }
+
+    return nullptr;
 }
 
 static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, ResumeId resume_id, PanicMsgId msg_id,
test/runtime_safety.zig
@@ -544,23 +544,29 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
         \\    std.os.exit(126);
         \\}
         \\
+        \\var failing_frame: @Frame(failing) = undefined;
+        \\
         \\pub fn main() void {
         \\    const p = nonFailing();
         \\    resume p;
-        \\    const p2 = async<std.debug.global_allocator> printTrace(p) catch unreachable;
-        \\    cancel p2;
+        \\    const p2 = async printTrace(p);
         \\}
         \\
-        \\fn nonFailing() promise->anyerror!void {
-        \\    return async<std.debug.global_allocator> failing() catch unreachable;
+        \\fn nonFailing() anyframe->anyerror!void {
+        \\    failing_frame = async failing();
+        \\    return &failing_frame;
         \\}
         \\
         \\async fn failing() anyerror!void {
         \\    suspend;
+        \\    return second();
+        \\}
+        \\
+        \\async fn second() anyerror!void {
         \\    return error.Fail;
         \\}
         \\
-        \\async fn printTrace(p: promise->anyerror!void) void {
+        \\async fn printTrace(p: anyframe->anyerror!void) void {
         \\    (await p) catch unreachable;
         \\}
     );