Commit aa2995ee39

Andrew Kelley <superjoe30@gmail.com>
2018-03-25 04:05:29
fix invalid codegen for error return traces across suspend points
See #821 Now the code works correctly, but error return traces are missing the frames from coroutines.
1 parent a43c7af
src/all_types.hpp
@@ -1621,7 +1621,8 @@ struct CodeGen {
     FnTableEntry *panic_fn;
     LLVMValueRef cur_ret_ptr;
     LLVMValueRef cur_fn_val;
-    LLVMValueRef cur_err_ret_trace_val;
+    LLVMValueRef cur_err_ret_trace_val_arg;
+    LLVMValueRef cur_err_ret_trace_val_stack;
     bool c_want_stdint;
     bool c_want_stdbool;
     AstNode *root_export_decl;
@@ -1760,6 +1761,7 @@ enum ScopeId {
     ScopeIdLoop,
     ScopeIdFnDef,
     ScopeIdCompTime,
+    ScopeIdCoroPrelude,
 };
 
 struct Scope {
@@ -1867,6 +1869,12 @@ struct ScopeFnDef {
     FnTableEntry *fn_entry;
 };
 
+// This scope is created to indicate that the code in the scope
+// is auto-generated coroutine prelude stuff.
+struct ScopeCoroPrelude {
+    Scope base;
+};
+
 // synchronized with code in define_builtin_compile_vars
 enum AtomicOrder {
     AtomicOrderUnordered,
src/analyze.cpp
@@ -170,6 +170,12 @@ Scope *create_comptime_scope(AstNode *node, Scope *parent) {
     return &scope->base;
 }
 
+Scope *create_coro_prelude_scope(AstNode *node, Scope *parent) {
+    ScopeCoroPrelude *scope = allocate<ScopeCoroPrelude>(1);
+    init_scope(&scope->base, ScopeIdCoroPrelude, node, parent);
+    return &scope->base;
+}
+
 ImportTableEntry *get_scope_import(Scope *scope) {
     while (scope) {
         if (scope->id == ScopeIdDecls) {
@@ -3592,6 +3598,7 @@ FnTableEntry *scope_get_fn_if_root(Scope *scope) {
             case ScopeIdCImport:
             case ScopeIdLoop:
             case ScopeIdCompTime:
+            case ScopeIdCoroPrelude:
                 scope = scope->parent;
                 continue;
             case ScopeIdFnDef:
src/analyze.hpp
@@ -107,6 +107,7 @@ ScopeLoop *create_loop_scope(AstNode *node, Scope *parent);
 ScopeFnDef *create_fndef_scope(AstNode *node, Scope *parent, FnTableEntry *fn_entry);
 ScopeDecls *create_decls_scope(AstNode *node, Scope *parent, TypeTableEntry *container_type, ImportTableEntry *import);
 Scope *create_comptime_scope(AstNode *node, Scope *parent);
+Scope *create_coro_prelude_scope(AstNode *node, Scope *parent);
 
 void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
 ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
src/codegen.cpp
@@ -653,6 +653,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
         case ScopeIdDeferExpr:
         case ScopeIdLoop:
         case ScopeIdCompTime:
+        case ScopeIdCoroPrelude:
             return get_di_scope(g, scope->parent);
     }
     zig_unreachable();
@@ -1318,9 +1319,34 @@ static LLVMValueRef get_safety_crash_err_fn(CodeGen *g) {
     return fn_val;
 }
 
-static void gen_safety_crash_for_err(CodeGen *g, LLVMValueRef err_val) {
+static bool is_coro_prelude_scope(Scope *scope) {
+    while (scope != nullptr) {
+        if (scope->id == ScopeIdCoroPrelude) {
+            return true;
+        } else if (scope->id == ScopeIdFnDef) {
+            break;
+        }
+        scope = scope->parent;
+    }
+    return false;
+}
+
+static LLVMValueRef get_cur_err_ret_trace_val(CodeGen *g, Scope *scope) {
+    if (!g->have_err_ret_tracing) {
+        return nullptr;
+    }
+    if (g->cur_fn->type_entry->data.fn.fn_type_id.cc == CallingConventionAsync) {
+        return is_coro_prelude_scope(scope) ? g->cur_err_ret_trace_val_arg : g->cur_err_ret_trace_val_stack;
+    }
+    if (g->cur_err_ret_trace_val_stack != nullptr) {
+        return g->cur_err_ret_trace_val_stack;
+    }
+    return g->cur_err_ret_trace_val_arg;
+}
+
+static void gen_safety_crash_for_err(CodeGen *g, LLVMValueRef err_val, Scope *scope) {
     LLVMValueRef safety_crash_err_fn = get_safety_crash_err_fn(g);
-    LLVMValueRef err_ret_trace_val = g->cur_err_ret_trace_val;
+    LLVMValueRef err_ret_trace_val = get_cur_err_ret_trace_val(g, scope);
     if (err_ret_trace_val == nullptr) {
         TypeTableEntry *ptr_to_stack_trace_type = get_ptr_to_stack_trace_type(g);
         err_ret_trace_val = LLVMConstNull(ptr_to_stack_trace_type->type_ref);
@@ -1614,7 +1640,7 @@ static LLVMValueRef ir_render_save_err_ret_addr(CodeGen *g, IrExecutable *execut
 
     LLVMValueRef return_err_fn = get_return_err_fn(g);
     LLVMValueRef args[] = {
-        g->cur_err_ret_trace_val,
+        get_cur_err_ret_trace_val(g, save_err_ret_addr_instruction->base.scope),
     };
     LLVMValueRef call_instruction = ZigLLVMBuildCall(g->builder, return_err_fn, args, 1,
             get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
@@ -2725,7 +2751,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         gen_param_index += 1;
     }
     if (prefix_arg_err_ret_stack) {
-        gen_param_values[gen_param_index] = g->cur_err_ret_trace_val;
+        gen_param_values[gen_param_index] = get_cur_err_ret_trace_val(g, instruction->base.scope);
         gen_param_index += 1;
     }
     if (instruction->is_async) {
@@ -3292,11 +3318,12 @@ static LLVMValueRef ir_render_align_cast(CodeGen *g, IrExecutable *executable, I
 static LLVMValueRef ir_render_error_return_trace(CodeGen *g, IrExecutable *executable,
         IrInstructionErrorReturnTrace *instruction)
 {
-    if (g->cur_err_ret_trace_val == nullptr) {
+    LLVMValueRef cur_err_ret_trace_val = get_cur_err_ret_trace_val(g, instruction->base.scope);
+    if (cur_err_ret_trace_val == nullptr) {
         TypeTableEntry *ptr_to_stack_trace_type = get_ptr_to_stack_trace_type(g);
         return LLVMConstNull(ptr_to_stack_trace_type->type_ref);
     }
-    return g->cur_err_ret_trace_val;
+    return cur_err_ret_trace_val;
 }
 
 static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) {
@@ -3726,7 +3753,7 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *execu
         LLVMBuildCondBr(g->builder, cond_val, ok_block, err_block);
 
         LLVMPositionBuilderAtEnd(g->builder, err_block);
-        gen_safety_crash_for_err(g, err_val);
+        gen_safety_crash_for_err(g, err_val, instruction->base.scope);
 
         LLVMPositionBuilderAtEnd(g->builder, ok_block);
     }
@@ -3918,7 +3945,7 @@ static LLVMValueRef ir_render_container_init_list(CodeGen *g, IrExecutable *exec
 }
 
 static LLVMValueRef ir_render_panic(CodeGen *g, IrExecutable *executable, IrInstructionPanic *instruction) {
-    gen_panic(g, ir_llvm_value(g, instruction->msg), g->cur_err_ret_trace_val);
+    gen_panic(g, ir_llvm_value(g, instruction->msg), get_cur_err_ret_trace_val(g, instruction->base.scope));
     return nullptr;
 }
 
@@ -5279,9 +5306,17 @@ static void do_code_gen(CodeGen *g) {
         clear_debug_source_node(g);
 
         uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry);
-        if (err_ret_trace_arg_index != UINT32_MAX) {
-            g->cur_err_ret_trace_val = LLVMGetParam(fn, err_ret_trace_arg_index);
-        } else if (g->have_err_ret_tracing && fn_table_entry->calls_or_awaits_errorable_fn) {
+        bool have_err_ret_trace_arg = err_ret_trace_arg_index != UINT32_MAX;
+        if (have_err_ret_trace_arg) {
+            g->cur_err_ret_trace_val_arg = LLVMGetParam(fn, err_ret_trace_arg_index);
+        } else {
+            g->cur_err_ret_trace_val_arg = nullptr;
+        }
+
+        bool is_async = fn_table_entry->type_entry->data.fn.fn_type_id.cc == CallingConventionAsync;
+        bool have_err_ret_trace_stack = g->have_err_ret_tracing && fn_table_entry->calls_or_awaits_errorable_fn &&
+            (is_async || !have_err_ret_trace_arg);
+        if (have_err_ret_trace_stack) {
             // TODO call graph analysis to find out what this number needs to be for every function
             static const size_t stack_trace_ptr_count = 30;
 
@@ -5289,13 +5324,13 @@ static void do_code_gen(CodeGen *g) {
             TypeTableEntry *array_type = get_array_type(g, usize, stack_trace_ptr_count);
             LLVMValueRef err_ret_array_val = build_alloca(g, array_type, "error_return_trace_addresses",
                     get_abi_alignment(g, array_type));
-            g->cur_err_ret_trace_val = build_alloca(g, g->stack_trace_type, "error_return_trace", get_abi_alignment(g, g->stack_trace_type));
+            g->cur_err_ret_trace_val_stack = build_alloca(g, g->stack_trace_type, "error_return_trace", get_abi_alignment(g, g->stack_trace_type));
             size_t index_field_index = g->stack_trace_type->data.structure.fields[0].gen_index;
-            LLVMValueRef index_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_err_ret_trace_val, (unsigned)index_field_index, "");
+            LLVMValueRef index_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_err_ret_trace_val_stack, (unsigned)index_field_index, "");
             gen_store_untyped(g, LLVMConstNull(usize->type_ref), index_field_ptr, 0, false);
 
             size_t addresses_field_index = g->stack_trace_type->data.structure.fields[1].gen_index;
-            LLVMValueRef addresses_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_err_ret_trace_val, (unsigned)addresses_field_index, "");
+            LLVMValueRef addresses_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_err_ret_trace_val_stack, (unsigned)addresses_field_index, "");
 
             TypeTableEntry *slice_type = g->stack_trace_type->data.structure.fields[1].type_entry;
             size_t ptr_field_index = slice_type->data.structure.fields[slice_ptr_index].gen_index;
@@ -5311,7 +5346,7 @@ static void do_code_gen(CodeGen *g) {
             LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, addresses_field_ptr, (unsigned)len_field_index, "");
             gen_store(g, LLVMConstInt(usize->type_ref, stack_trace_ptr_count, false), len_field_ptr, get_pointer_to_type(g, usize, false));
         } else {
-            g->cur_err_ret_trace_val = nullptr;
+            g->cur_err_ret_trace_val_stack = nullptr;
         }
 
         // allocate temporary stack data
src/ir.cpp
@@ -6412,60 +6412,61 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
     VariableTableEntry *coro_size_var;
     if (is_async) {
         // create the coro promise
-        const_bool_false = ir_build_const_bool(irb, scope, node, false);
-        VariableTableEntry *promise_var = ir_create_var(irb, node, scope, nullptr, false, false, true, const_bool_false);
+        Scope *coro_scope = create_coro_prelude_scope(node, scope);
+        const_bool_false = ir_build_const_bool(irb, coro_scope, node, false);
+        VariableTableEntry *promise_var = ir_create_var(irb, node, coro_scope, nullptr, false, false, true, const_bool_false);
 
         return_type = fn_entry->type_entry->data.fn.fn_type_id.return_type;
-        IrInstruction *promise_init = ir_build_const_promise_init(irb, scope, node, return_type);
-        ir_build_var_decl(irb, scope, node, promise_var, nullptr, nullptr, promise_init);
-        IrInstruction *coro_promise_ptr = ir_build_var_ptr(irb, scope, node, promise_var, false, false);
+        IrInstruction *promise_init = ir_build_const_promise_init(irb, coro_scope, node, return_type);
+        ir_build_var_decl(irb, coro_scope, node, promise_var, nullptr, nullptr, promise_init);
+        IrInstruction *coro_promise_ptr = ir_build_var_ptr(irb, coro_scope, node, promise_var, false, false);
 
-        VariableTableEntry *await_handle_var = ir_create_var(irb, node, scope, nullptr, false, false, true, const_bool_false);
-        IrInstruction *null_value = ir_build_const_null(irb, scope, node);
-        IrInstruction *await_handle_type_val = ir_build_const_type(irb, scope, node,
+        VariableTableEntry *await_handle_var = ir_create_var(irb, node, coro_scope, nullptr, false, false, true, const_bool_false);
+        IrInstruction *null_value = ir_build_const_null(irb, coro_scope, node);
+        IrInstruction *await_handle_type_val = ir_build_const_type(irb, coro_scope, node,
                 get_maybe_type(irb->codegen, irb->codegen->builtin_types.entry_promise));
-        ir_build_var_decl(irb, scope, node, await_handle_var, await_handle_type_val, nullptr, null_value);
-        irb->exec->await_handle_var_ptr = ir_build_var_ptr(irb, scope, node,
+        ir_build_var_decl(irb, coro_scope, node, await_handle_var, await_handle_type_val, nullptr, null_value);
+        irb->exec->await_handle_var_ptr = ir_build_var_ptr(irb, coro_scope, node,
                 await_handle_var, false, false);
 
-        u8_ptr_type = ir_build_const_type(irb, scope, node,
+        u8_ptr_type = ir_build_const_type(irb, coro_scope, node,
                 get_pointer_to_type(irb->codegen, irb->codegen->builtin_types.entry_u8, false));
-        IrInstruction *promise_as_u8_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type, coro_promise_ptr);
-        coro_id = ir_build_coro_id(irb, scope, node, promise_as_u8_ptr);
-        coro_size_var = ir_create_var(irb, node, scope, nullptr, false, false, true, const_bool_false);
-        IrInstruction *coro_size = ir_build_coro_size(irb, scope, node);
-        ir_build_var_decl(irb, scope, node, coro_size_var, nullptr, nullptr, coro_size);
-        IrInstruction *implicit_allocator_ptr = ir_build_get_implicit_allocator(irb, scope, node,
+        IrInstruction *promise_as_u8_ptr = ir_build_ptr_cast(irb, coro_scope, node, u8_ptr_type, coro_promise_ptr);
+        coro_id = ir_build_coro_id(irb, coro_scope, node, promise_as_u8_ptr);
+        coro_size_var = ir_create_var(irb, node, coro_scope, nullptr, false, false, true, const_bool_false);
+        IrInstruction *coro_size = ir_build_coro_size(irb, coro_scope, node);
+        ir_build_var_decl(irb, coro_scope, node, coro_size_var, nullptr, nullptr, coro_size);
+        IrInstruction *implicit_allocator_ptr = ir_build_get_implicit_allocator(irb, coro_scope, node,
                 ImplicitAllocatorIdArg);
-        irb->exec->coro_allocator_var = ir_create_var(irb, node, scope, nullptr, true, true, true, const_bool_false);
-        ir_build_var_decl(irb, scope, node, irb->exec->coro_allocator_var, nullptr, nullptr, implicit_allocator_ptr);
+        irb->exec->coro_allocator_var = ir_create_var(irb, node, coro_scope, nullptr, true, true, true, const_bool_false);
+        ir_build_var_decl(irb, coro_scope, node, irb->exec->coro_allocator_var, nullptr, nullptr, implicit_allocator_ptr);
         Buf *alloc_field_name = buf_create_from_str(ASYNC_ALLOC_FIELD_NAME);
-        IrInstruction *alloc_fn_ptr = ir_build_field_ptr(irb, scope, node, implicit_allocator_ptr, alloc_field_name);
-        IrInstruction *alloc_fn = ir_build_load_ptr(irb, scope, node, alloc_fn_ptr);
-        IrInstruction *maybe_coro_mem_ptr = ir_build_coro_alloc_helper(irb, scope, node, alloc_fn, coro_size);
-        IrInstruction *alloc_result_is_ok = ir_build_test_nonnull(irb, scope, node, maybe_coro_mem_ptr);
-        IrBasicBlock *alloc_err_block = ir_create_basic_block(irb, scope, "AllocError");
-        IrBasicBlock *alloc_ok_block = ir_create_basic_block(irb, scope, "AllocOk");
-        ir_build_cond_br(irb, scope, node, alloc_result_is_ok, alloc_ok_block, alloc_err_block, const_bool_false);
+        IrInstruction *alloc_fn_ptr = ir_build_field_ptr(irb, coro_scope, node, implicit_allocator_ptr, alloc_field_name);
+        IrInstruction *alloc_fn = ir_build_load_ptr(irb, coro_scope, node, alloc_fn_ptr);
+        IrInstruction *maybe_coro_mem_ptr = ir_build_coro_alloc_helper(irb, coro_scope, node, alloc_fn, coro_size);
+        IrInstruction *alloc_result_is_ok = ir_build_test_nonnull(irb, coro_scope, node, maybe_coro_mem_ptr);
+        IrBasicBlock *alloc_err_block = ir_create_basic_block(irb, coro_scope, "AllocError");
+        IrBasicBlock *alloc_ok_block = ir_create_basic_block(irb, coro_scope, "AllocOk");
+        ir_build_cond_br(irb, coro_scope, node, alloc_result_is_ok, alloc_ok_block, alloc_err_block, const_bool_false);
 
         ir_set_cursor_at_end_and_append_block(irb, alloc_err_block);
         // we can return undefined here, because the caller passes a pointer to the error struct field
         // in the error union result, and we populate it in case of allocation failure.
-        IrInstruction *undef = ir_build_const_undefined(irb, scope, node);
-        ir_build_return(irb, scope, node, undef);
+        IrInstruction *undef = ir_build_const_undefined(irb, coro_scope, node);
+        ir_build_return(irb, coro_scope, node, undef);
 
         ir_set_cursor_at_end_and_append_block(irb, alloc_ok_block);
-        IrInstruction *coro_mem_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type, maybe_coro_mem_ptr);
-        irb->exec->coro_handle = ir_build_coro_begin(irb, scope, node, coro_id, coro_mem_ptr);
+        IrInstruction *coro_mem_ptr = ir_build_ptr_cast(irb, coro_scope, node, u8_ptr_type, maybe_coro_mem_ptr);
+        irb->exec->coro_handle = ir_build_coro_begin(irb, coro_scope, node, coro_id, coro_mem_ptr);
 
         Buf *awaiter_handle_field_name = buf_create_from_str(AWAITER_HANDLE_FIELD_NAME);
-        irb->exec->coro_awaiter_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr,
+        irb->exec->coro_awaiter_field_ptr = ir_build_field_ptr(irb, coro_scope, node, coro_promise_ptr,
                 awaiter_handle_field_name);
         Buf *result_field_name = buf_create_from_str(RESULT_FIELD_NAME);
-        irb->exec->coro_result_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_field_name);
+        irb->exec->coro_result_field_ptr = ir_build_field_ptr(irb, coro_scope, node, coro_promise_ptr, result_field_name);
         result_ptr_field_name = buf_create_from_str(RESULT_PTR_FIELD_NAME);
-        irb->exec->coro_result_ptr_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_ptr_field_name);
-        ir_build_store_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr, irb->exec->coro_result_field_ptr);
+        irb->exec->coro_result_ptr_field_ptr = ir_build_field_ptr(irb, coro_scope, node, coro_promise_ptr, result_ptr_field_name);
+        ir_build_store_ptr(irb, coro_scope, node, irb->exec->coro_result_ptr_field_ptr, irb->exec->coro_result_field_ptr);
 
 
         irb->exec->coro_early_final = ir_create_basic_block(irb, scope, "CoroEarlyFinal");
test/runtime_safety.zig
@@ -281,4 +281,34 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
         \\    f.float = 12.34;
         \\}
     );
+
+    // This case makes sure that the code compiles and runs. There is not actually a special
+    // runtime safety check having to do specifically with error return traces across suspend points.
+    cases.addRuntimeSafety("error return trace across suspend points",
+        \\const std = @import("std");
+        \\
+        \\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
+        \\    std.os.exit(126);
+        \\}
+        \\
+        \\pub fn main() void {
+        \\    const p = nonFailing();
+        \\    resume p;
+        \\    const p2 = async<std.debug.global_allocator> printTrace(p) catch unreachable;
+        \\    cancel p2;
+        \\}
+        \\
+        \\fn nonFailing() promise->error!void {
+        \\    return async<std.debug.global_allocator> failing() catch unreachable;
+        \\}
+        \\
+        \\async fn failing() error!void {
+        \\    suspend;
+        \\    return error.Fail;
+        \\}
+        \\
+        \\async fn printTrace(p: promise->error!void) void {
+        \\    (await p) catch unreachable;
+        \\}
+    );
 }